I'm trying to create a 3D autoencoder (3D Unet without BatchNorm and skip connections) with keras
.
When I train it with tf.float32
it learns, but when I'm using mixed precision policy the training seems infinite.
I managed to reproduce error in colab. Here is my code (I'm using tf 2.12.0 following the installation of tensorflow).
Is there something that I've missed?
import tensorflow as tf
from tqdm import tqdm
from tensorflow.keras.layers import Conv3D, Conv3DTranspose
from tensorflow.keras import Model
from tensorflow.keras import mixed_precision
## Run if mixed precision
# policy = mixed_precision.Policy('mixed_float16')
# tf.keras.mixed_precision.set_global_policy(policy)
def autoencoder():
inputs = tf.keras.Input(shape = (160, 192, 160, 1))
#Encoder
conv_1 = Conv3D(32, (3, 3, 3), padding='same', activation = 'relu')(inputs)
conv_2 = Conv3D(64, (3, 3, 3), padding='same', activation = 'relu')(conv_1)
dw_1 = Conv3D(64, (2, 2, 2), strides = (2, 2, 2), padding = 'valid')(conv_2)
conv_3 = Conv3D(64, (3, 3, 3), padding='same', activation = 'relu')(dw_1)
conv_4 = Conv3D(128, (3, 3, 3), padding='same', activation = 'relu')(conv_3)
dw_2 = Conv3D(128, (2, 2, 2), strides = (2, 2, 2), padding = 'valid')(conv_4)
conv_5 = Conv3D(128, (3, 3, 3), padding='same', activation = 'relu')(dw_2)
conv_6 = Conv3D(128, (3, 3, 3), padding='same', activation = 'relu')(conv_5)
dw_3 = Conv3D(256, (2, 2, 2), strides = (2, 2, 2), padding = 'valid')(conv_6)
conv_7 = Conv3D(256, (3, 3, 3), padding='same', activation = 'relu')(dw_3)
conv_8 = Conv3D(256, (3, 3, 3), padding='same', activation = 'relu')(conv_7)
#Decoder
up_1 = Conv3DTranspose(256, (2, 2, 2), strides = (2, 2, 2), padding = 'valid')(conv_8)
conv_9 = Conv3D(256, (3, 3, 3), padding='same', activation = 'relu')(up_1)
conv_10 = Conv3D(256, (3, 3, 3), padding='same', activation = 'relu')(conv_9)
up_2 = Conv3DTranspose(128, (2, 2, 2), strides = (2, 2, 2), padding = 'valid')(conv_10)
conv_11 = Conv3D(128, (3, 3, 3), padding='same', activation = 'relu')(up_2)
conv_12 = Conv3D(128, (3, 3, 3), padding='same', activation = 'relu')(conv_11)
up_3 = Conv3DTranspose(64, (2, 2, 2), strides = (2, 2, 2), padding = 'valid')(conv_12)
conv_13 = Conv3D(64, (3, 3, 3), padding='same', activation = 'relu')(up_3)
conv_14 = Conv3D(32, (3, 3, 3), padding='same', activation = 'relu')(conv_13)
out = Conv3D(1, (1, 1, 1), padding='same', activation = 'sigmoid', dtype = tf.float32)(conv_14)
return Model(inputs=inputs, outputs=[out])
AE = autoencoder()
## without mixed precision policy
irm = tf.ones((1,160,192,160,1), dtype = tf.float32)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
@tf.function
def train_step(irm):
with tf.GradientTape() as tape:
irm_reconstruction = AE(irm)
loss_rec = tf.keras.losses.MeanAbsoluteError()(irm, irm_reconstruction)
grads = tape.gradient(loss_rec, AE.trainable_variables)
optimizer.apply_gradients(zip(grads, AE.trainable_weights))
return loss_rec
G_L = []
for epoch in tqdm(range(1, 5)):
L = train_step(irm)
G_L.append(L)
print(f'Epoch {epoch} - train_loss = {G_L[-1]}')
# with mixed precision policy
irm = tf.ones((1,160,192,160,1), dtype = tf.float32)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
optimizer = mixed_precision.LossScaleOptimizer(optimizer)
@tf.function
def train_step(irm):
with tf.GradientTape() as tape:
irm_reconstruction = AE(irm)
loss_rec = tf.keras.losses.MeanAbsoluteError()(irm, irm_reconstruction)
scaled_loss = optimizer.get_scaled_loss(loss_rec)
scaled_gradients = tape.gradient(scaled_loss, AE.trainable_variables)
grads = optimizer.get_unscaled_gradients(scaled_gradients)
optimizer.apply_gradients(zip(grads, AE.trainable_weights))
return loss_rec
G_L = []
for epoch in tqdm(range(1, 5)):
L = train_step(irm)
G_L.append(L)
print(f'Epoch {epoch} - train_loss = {G_L[-1]}')