Copy Frozen Values From A Frozen Graph to Another Frozen Graph

132 Views Asked by At

I have 2 frozen_graphs which are trained and stored as different pb files. They all share some same nodes. How can I transfer the node value from 1 graph to the other one? For example, how can I copy the FakeQuantWithMinMaxVars nodes to replace the below nodes?

  1. model 1

  2. model 2

1

There are 1 best solutions below

0
On BEST ANSWER

I have figured it out by mapping nodes by nodes that is similar in the graph. Then connect it by tf.import_graph_def and remove unused_nodes by graph_transform. For quantization capabilities, avoid using merge duplicate or fold batch norm, which will create errors in quantization by missing min-max quant

import tensorflow as tf
import numpy as np

# load graphs using pb file path
def load_graph(pb_file):
    graph = tf.Graph()
    with graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(pb_file, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
    return graph

resnet_pretrained = 'frozen_124.pb'
trained = 'frozen.pb'

# new file name to save combined model
final_graph = 'final_graph.pb'

# loads both graphs
graph1 = load_graph(resnet_pretrained)
graph2 = load_graph(trained)
replace_dict = {}
# get tensor names from first graph
with graph1.as_default():

    # getting tensors to add crop and resize step
    ops = graph1.get_operations()
    ops1_name = []
    for op in ops:
        # print(op.name)
        ops1_name.append(op.name)
    ops = graph2.get_operations()
    ops2_name = []
    replace_name = []
    for op in ops:
        # print(op.name)
        ops2_name.append(op.name)
        if op.name in ops1_name:
            replace_name = op.name
            replace_dict[str(replace_name)+':0'] = replace_name+':0'
            continue
        if 'resnet' in op.name:
            replace_name = op.name.replace("resnet","model")
            if replace_name in ops1_name:
                replace_dict[str(op.name)+':0'] = replace_name+':0'

with tf.Graph().as_default() as final:
    y = tf.import_graph_def(graph1.as_graph_def(), return_elements=replace_dict.values())
    new=dict()
    for i,j in zip(replace_dict.keys(),y):
        new[i] = j
    z = tf.import_graph_def(graph2.as_graph_def(), input_map=new, return_elements=["concatenate_1/concat:0"])

    # tf.train.write_graph(graph2.as_graph_def(), "./", final_graph, as_text=False)

# for op in final.get_operations():
#     print(op.name)
from tensorflow.tools.graph_transforms import TransformGraph
transforms = ['remove_nodes(op=Identity)',
 'strip_unused_nodes']
output_graph_def = TransformGraph(
        final.as_graph_def(),
        ["import/input_image","import_1/input_box"], ## input
        ["import_1/concatenate_1/concat"], ## outputs
        transforms)
tf.train.write_graph(output_graph_def, '.' , as_text=False, name='optimized_model.pb')
print('Graph optimized!')