I am using ml5.js, a wrapper around tensorflowjs. I want to train a neural network in the browser, download the weights, process them as tensors in pyTorch, and load them back into the browser's tensorflowjs model. How do I convert between these formats tfjs <-> pytorch
?
The browser model has a save()
function which generates three files. A metadata file specific to ml5.js (json), a topology file describing model architecture (json), and a binary weights file (bin).
// Browser
model.save()
// HTTP/Download
model_meta.json (needed by ml5.js)
model.json (needed by tfjs)
model.weights.bin (needed by tfjs)
# python backend
import json
with open('model.weights.bin', 'rb') as weights_file:
with open('model.json', 'rb') as model_file:
weights = weights_file.read()
model = json.loads(model_file.read())
####
pytorch_tensor = convert2tensor(weights, model) # whats in this function?
####
# Do some processing in pytorch
####
new_weights_bin = convert2bin(pytorch_tensor, model) # and in this?
####
Here is sample javascript code to generate and load the 3 files in the browser. To load, select all 3 files at once in the dialog box. If they are correct, a popup will show a sample prediction.
I was able to find a way to convert from tfjs
model.weights.bin
to numpy'sndarrays
. It is trivial to convert from numpy arrays to pytorchstate_dict
which is a dictionary of tensors and their names.Theory
First, the tfjs representation of the model should be understood.
model.json
describes the model. In python, it can be read as a dictionary. It has the following keys:The model architecture is described as another json/dictionary under the key
modelTopology
.It also has a json/dictionary under the key
weightsManifest
which describes the type/shape/location of each weight wrapped up in the correspondingmodel.weights.bin
file. As an aside, the weights manifest allows for multiple.bin
files to store weights.Tensorflow.js has a companion python package
tensorflowjs
, which comes with utility functions to read and write weights between the tf.js binary and numpy array format.Each weight file is read as a "group". A group is a list of dictionaries with keys
name
anddata
which refer to the weight name and the numpy array containing weights. There are optionally other keys too.Application
Install tensorflowjs. Unfortunately, it will also install tensorflow.
Use these functions. Note that I changed the signatures for convenience.