I'm reorganizing my code to be easy to read, but when compiling it says:
ValueError: No gradients provided for any variable: ['enc_conv_4/kernel:0'...
. I know my loss function is differentiable cuz the code worked before touch it, but now is missing the gradients of my model.
@tf.function
def train_disc(self,real_imgs,gen_imgs):
with tf.GradientTape() as disc_tape:
d_loss = self.wasserstein_loss(real_imgs,gen_imgs)
gradients_d = disc_tape.gradient(d_loss, self.discriminator.trainable_variables)
self.d_optimizer.apply_gradients(zip(gradients_d, self.discriminator.trainable_variables))
return d_loss
@tf.function
def train_gen(self,real_img,gen_imgs,mask,img_feat,rot_feat_mean):
with tf.GradientTape() as gen_tape:
g_loss_param = self.generator_loss(mask,img_feat,rot_feat_mean)
g_loss = g_loss_param(real_img, gen_imgs)
gradients_g = gen_tape.gradient(g_loss, self.generator.trainable_variables)
print(gradients_g)
self.g_optimizer.apply_gradients(zip(gradients_g, self.generator.trainable_variables))
As you can see, when I do the same for discriminator and generator, the generator gives me an empty list of gradients.
gen_imgs = self.generator([real_img, mask], training=True)
d_loss = self.train_disc(real_img,gen_imgs[:,:,:,:-1])
if step%self.n_critic == 0:
masked_images = real_img * mask
idx = 3 # index of desired layer
layer_input = Input(shape=(self.img_shape)) #
x = layer_input
for layer in self.generator.layers[idx:idx+12]:
x = layer(x)
model_feat = Model(inputs=layer_input,outputs=x)
model_feat.trainable = False
img_feat = model_feat(masked_images,training=False)
rot_feat_mean = []
for i in range(self.batch_size):
rot = []
for an in [180, 155, 130, 105, 80, 55, 20, 10]:
r = tf.keras.preprocessing.image.random_rotation(masked_images[i], an, row_axis=0, col_axis=1,
channel_axis=2)
rot.append(r)
rot = np.array(rot)
rot_feat_mean.append(np.mean(model_feat(rot,training=False),axis=0))
rot_feat_mean = np.array(rot_feat_mean)
g_loss = self.train_gen(real_img,gen_imgs[:,:,:,:-1],mask,img_feat,rot_feat_mean)
The last line in the last code gives me an error. I don't know if this erros is due to any semantic mistake.