Port pytorch code to tf2.0: equivalent of x_batch.requires_grad = True in tf2.0?

244 Views Asked by At

I am trying to port pytorch code to tf2 code for this repo. The overall logic of the code is as follows:

  • sample data from Gaussian distribution with zero mean and unit variance
  • extract mean and variance from batch-normalization layers of trained model
  • input sampled data into model and get outputs from each conv-layers(before BN layers)
  • calculate L2-loss between outputs and extracted means and variances
  • update the sampled data via Optimizer

My issue is that in order to update the data, it has be made trainable and therefore I converted to x_batch to x_batch = tf.Variable(x_batch, trainable=True) but Variables in tf are not iterable and hence there is an a problem when updating the weights via optimizer.apply_gradients(zip(gradients, x_batch)).

In pytorch, it's relatively simple and can done simple by

for x_batch in dataloader:    
    x_batch.requires_grad = True
    .
    .
    .
    # update the distilled data
    loss.backward()
    optimizer.step()

My attempt in tf2.0 is below.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense, Conv2D, BatchNormalization, Activation, Input, MaxPooling2D, Dropout, Flatten
from tensorflow.keras.models import Model
from tensorflow import keras
from tensorflow.keras import regularizers
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.layers import BatchNormalization
import numpy as np

def vgg_block(x, filters, layers, name, weight_decay):
    for i in range(layers):
        x = Conv2D(filters, (3, 3), padding='same', kernel_initializer='he_normal',
                     kernel_regularizer=regularizers.l2(weight_decay), name=f'{name}_conv_{i}')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
    return x


def vgg8(x, weight_decay=1e-4, is_classifier=False):
    x = vgg_block(x, 16, 2, 'block_1', weight_decay)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = vgg_block(x, 32, 2, 'block_2', weight_decay)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = vgg_block(x, 64, 2, 'block_3', weight_decay)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Flatten()(x)
    x = Dense(512, kernel_initializer='he_normal', activation='relu', name = 'dense_1')(x)
    x = Dense(10)(x)
    return x
    
def l2_loss(A, B):
    """
    L-2 loss between A and B normalized by length.
    Shape of A should be (features_num, ), shape of B should be (batch_size, features_num)
    """
    # pytorch : (A - B).norm()**2 / B.size(0)
    diff = A-B
    l2loss = tf.nn.l2_loss(diff)  # sum(t ** 2) / 2
    return l2loss / B.shape[0]
    

inputs = Input((32, 32, 3))
model = Model(inputs, vgg8(inputs))

eps = 1.0e-6
bn_stats = []

for layer in model.layers:
    if isinstance(layer, BatchNormalization):
        bn_gamma, bn_beta, bn_mean, bn_var = layer.get_weights()
        #print(bn_mean.shape, bn_var.shape)  # tf.reshape(w, [-1]) 
        bn_stats.append((bn_mean, tf.math.sqrt(bn_var+eps)))

    

extractor = tf.keras.models.Model(inputs=model.inputs,
                        outputs=[layer.output for layer in model.layers if isinstance(layer, Conv2D)])


class UniformDataset(keras.utils.Sequence):
    """
    get random uniform samples with mean 0 and variance 1
    """
    def __init__(self, length, size):
        self.length = length
        self.size = size
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        sample = tf.random.normal(self.size)
        return sample

data = UniformDataset(length=100, size=(32, 32,3))

# convert to tf.data iterator
train_iter = iter(data)
train_data = []
for x in train_iter:
    train_data.append(x)
train_data = tf.stack(train_data, axis=0)
train_dataset = tf.data.Dataset.from_tensor_slices(train_data)

refined_gaussian = []
iterations = 500

for x_batch in train_dataset.batch(32):
    x_batch = tf.Variable(x_batch, trainable=True)  # make x_batch trainable
    
    outputs = extractor(x_batch)
    
    input_mean = tf.zeros([1,3]) 
    input_std = tf.ones([1,3])   
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.5)

    with tf.GradientTape(persistent=True) as tape:
        for it in range(iterations):
            mean_loss = 0
            std_loss = 0

            for cnt, (bn_stat, output) in enumerate(zip(bn_stats, outputs)):  
                output = tf.reshape(output, [output.shape[0], output.shape[-1], -1])
                tmp_mean = tf.math.reduce_mean(output, axis=2)
                tmp_std = tf.math.reduce_std(output, axis=2) + eps
                bn_mean, bn_std = bn_stat[0], bn_stat[1]
                mean_loss += l2_loss(bn_mean, tmp_mean)
                std_loss += l2_loss(bn_std, tmp_std)
            
            #print('mean_loss', mean_loss, 'std_loss', std_loss)
            x_reshape = tf.reshape(x_train, [x_train.shape[0], 3, -1])
            tmp_mean = tf.math.reduce_mean(x_reshape, axis=2)
            tmp_std = tf.math.reduce_std(x_reshape, axis=2) + eps

            mean_loss += l2_loss(input_mean, tmp_mean)
            std_loss += l2_loss(input_std, tmp_std)
            loss = mean_loss + std_loss
            gradients = tape.gradient(loss, x_batch)

            optimizer.apply_gradients(zip(gradients, x_batch))
            
        refined_gaussian.append(x_batch)
        
0

There are 0 best solutions below