I am trying to port pytorch code to tf2 code for this repo. The overall logic of the code is as follows:
- sample data from Gaussian distribution with zero mean and unit variance
- extract mean and variance from batch-normalization layers of trained model
- input sampled data into model and get outputs from each conv-layers(before BN layers)
- calculate L2-loss between outputs and extracted means and variances
- update the sampled data via Optimizer
My issue is that in order to update the data, it has be made trainable and therefore I converted to x_batch
to x_batch = tf.Variable(x_batch, trainable=True)
but Variables in tf
are not iterable and hence there is an a problem when updating the weights via optimizer.apply_gradients(zip(gradients, x_batch))
.
In pytorch, it's relatively simple and can done simple by
for x_batch in dataloader:
x_batch.requires_grad = True
.
.
.
# update the distilled data
loss.backward()
optimizer.step()
My attempt in tf2.0 is below.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense, Conv2D, BatchNormalization, Activation, Input, MaxPooling2D, Dropout, Flatten
from tensorflow.keras.models import Model
from tensorflow import keras
from tensorflow.keras import regularizers
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.layers import BatchNormalization
import numpy as np
def vgg_block(x, filters, layers, name, weight_decay):
for i in range(layers):
x = Conv2D(filters, (3, 3), padding='same', kernel_initializer='he_normal',
kernel_regularizer=regularizers.l2(weight_decay), name=f'{name}_conv_{i}')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
return x
def vgg8(x, weight_decay=1e-4, is_classifier=False):
x = vgg_block(x, 16, 2, 'block_1', weight_decay)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = vgg_block(x, 32, 2, 'block_2', weight_decay)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = vgg_block(x, 64, 2, 'block_3', weight_decay)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Flatten()(x)
x = Dense(512, kernel_initializer='he_normal', activation='relu', name = 'dense_1')(x)
x = Dense(10)(x)
return x
def l2_loss(A, B):
"""
L-2 loss between A and B normalized by length.
Shape of A should be (features_num, ), shape of B should be (batch_size, features_num)
"""
# pytorch : (A - B).norm()**2 / B.size(0)
diff = A-B
l2loss = tf.nn.l2_loss(diff) # sum(t ** 2) / 2
return l2loss / B.shape[0]
inputs = Input((32, 32, 3))
model = Model(inputs, vgg8(inputs))
eps = 1.0e-6
bn_stats = []
for layer in model.layers:
if isinstance(layer, BatchNormalization):
bn_gamma, bn_beta, bn_mean, bn_var = layer.get_weights()
#print(bn_mean.shape, bn_var.shape) # tf.reshape(w, [-1])
bn_stats.append((bn_mean, tf.math.sqrt(bn_var+eps)))
extractor = tf.keras.models.Model(inputs=model.inputs,
outputs=[layer.output for layer in model.layers if isinstance(layer, Conv2D)])
class UniformDataset(keras.utils.Sequence):
"""
get random uniform samples with mean 0 and variance 1
"""
def __init__(self, length, size):
self.length = length
self.size = size
def __len__(self):
return self.length
def __getitem__(self, idx):
sample = tf.random.normal(self.size)
return sample
data = UniformDataset(length=100, size=(32, 32,3))
# convert to tf.data iterator
train_iter = iter(data)
train_data = []
for x in train_iter:
train_data.append(x)
train_data = tf.stack(train_data, axis=0)
train_dataset = tf.data.Dataset.from_tensor_slices(train_data)
refined_gaussian = []
iterations = 500
for x_batch in train_dataset.batch(32):
x_batch = tf.Variable(x_batch, trainable=True) # make x_batch trainable
outputs = extractor(x_batch)
input_mean = tf.zeros([1,3])
input_std = tf.ones([1,3])
optimizer = tf.keras.optimizers.Adam(learning_rate=0.5)
with tf.GradientTape(persistent=True) as tape:
for it in range(iterations):
mean_loss = 0
std_loss = 0
for cnt, (bn_stat, output) in enumerate(zip(bn_stats, outputs)):
output = tf.reshape(output, [output.shape[0], output.shape[-1], -1])
tmp_mean = tf.math.reduce_mean(output, axis=2)
tmp_std = tf.math.reduce_std(output, axis=2) + eps
bn_mean, bn_std = bn_stat[0], bn_stat[1]
mean_loss += l2_loss(bn_mean, tmp_mean)
std_loss += l2_loss(bn_std, tmp_std)
#print('mean_loss', mean_loss, 'std_loss', std_loss)
x_reshape = tf.reshape(x_train, [x_train.shape[0], 3, -1])
tmp_mean = tf.math.reduce_mean(x_reshape, axis=2)
tmp_std = tf.math.reduce_std(x_reshape, axis=2) + eps
mean_loss += l2_loss(input_mean, tmp_mean)
std_loss += l2_loss(input_std, tmp_std)
loss = mean_loss + std_loss
gradients = tape.gradient(loss, x_batch)
optimizer.apply_gradients(zip(gradients, x_batch))
refined_gaussian.append(x_batch)