Viewing a single comment thread. View all comments

mildresponse t1_j4xhvkg wrote

Are there any easy and straightforward methods for moving ML models across different frameworks? Does it come down to just manually translating the parameters?

For instance, I am looking at a transformer model in PyTorch, whose parameters are stored within a series of nested objects of various types in an OrderedDict. I would like to extract all of these parameter tensors for use in a similar architecture constructed in Tensorflow or JAX. The naive method of manually collecting the parameters themselves into a new dict seems tedious. And if the target is something like Haiku in JAX, the corresponding model will initialize its parameters into a new nested dict with some default naming structure, which will then have to be connected to the interim dict created from PyTorch. Are there any better ways of moving the parameters or models around?

1