Tensorflow model can't use mixed precision

121 Views Asked by At

I'm trying to create a 3D autoencoder (3D Unet without BatchNorm and skip connections) with keras. When I train it with tf.float32 it learns, but when I'm using mixed precision policy the training seems infinite.

I managed to reproduce error in colab. Here is my code (I'm using tf 2.12.0 following the installation of tensorflow).

Is there something that I've missed?

import tensorflow as tf
from tqdm import tqdm

from tensorflow.keras.layers import Conv3D, Conv3DTranspose
from tensorflow.keras import Model
from tensorflow.keras import mixed_precision

## Run if mixed precision

# policy = mixed_precision.Policy('mixed_float16')
# tf.keras.mixed_precision.set_global_policy(policy)
def autoencoder():
    inputs = tf.keras.Input(shape = (160, 192, 160, 1))
    #Encoder
    conv_1 = Conv3D(32, (3, 3, 3), padding='same', activation = 'relu')(inputs) 
    conv_2 = Conv3D(64, (3, 3, 3), padding='same', activation = 'relu')(conv_1) 
    dw_1 = Conv3D(64, (2, 2, 2), strides = (2, 2, 2), padding = 'valid')(conv_2)
    conv_3 = Conv3D(64, (3, 3, 3), padding='same', activation = 'relu')(dw_1)
    conv_4 = Conv3D(128, (3, 3, 3), padding='same', activation = 'relu')(conv_3)
    dw_2 = Conv3D(128, (2, 2, 2), strides = (2, 2, 2), padding = 'valid')(conv_4)
    conv_5 = Conv3D(128, (3, 3, 3), padding='same', activation = 'relu')(dw_2)
    conv_6 = Conv3D(128, (3, 3, 3), padding='same', activation = 'relu')(conv_5)
    dw_3 = Conv3D(256, (2, 2, 2), strides = (2, 2, 2), padding = 'valid')(conv_6)
    conv_7 = Conv3D(256, (3, 3, 3), padding='same', activation = 'relu')(dw_3)
    conv_8 = Conv3D(256, (3, 3, 3), padding='same', activation = 'relu')(conv_7)
    #Decoder
    up_1 = Conv3DTranspose(256, (2, 2, 2), strides = (2, 2, 2), padding = 'valid')(conv_8)
    conv_9 = Conv3D(256, (3, 3, 3), padding='same', activation = 'relu')(up_1)
    conv_10 = Conv3D(256, (3, 3, 3), padding='same', activation = 'relu')(conv_9)
    up_2 = Conv3DTranspose(128, (2, 2, 2), strides = (2, 2, 2), padding = 'valid')(conv_10)
    conv_11 = Conv3D(128, (3, 3, 3), padding='same', activation = 'relu')(up_2)
    conv_12 = Conv3D(128, (3, 3, 3), padding='same', activation = 'relu')(conv_11)
    up_3 = Conv3DTranspose(64, (2, 2, 2), strides = (2, 2, 2), padding = 'valid')(conv_12)
    conv_13 = Conv3D(64, (3, 3, 3), padding='same', activation = 'relu')(up_3)
    conv_14 = Conv3D(32, (3, 3, 3), padding='same', activation = 'relu')(conv_13)
    
    out = Conv3D(1, (1, 1, 1), padding='same', activation = 'sigmoid', dtype = tf.float32)(conv_14)
    
    
    
    return Model(inputs=inputs, outputs=[out])

AE = autoencoder()
## without mixed precision policy

irm = tf.ones((1,160,192,160,1), dtype = tf.float32)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

@tf.function
def train_step(irm):
    with tf.GradientTape() as tape:
        irm_reconstruction = AE(irm)
        loss_rec = tf.keras.losses.MeanAbsoluteError()(irm, irm_reconstruction)
    grads = tape.gradient(loss_rec, AE.trainable_variables)
    optimizer.apply_gradients(zip(grads, AE.trainable_weights))
    return loss_rec

G_L = []
for epoch in tqdm(range(1, 5)):
    L = train_step(irm)
    G_L.append(L)
    print(f'Epoch {epoch} - train_loss = {G_L[-1]}')
# with mixed precision policy

irm = tf.ones((1,160,192,160,1), dtype = tf.float32)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
optimizer = mixed_precision.LossScaleOptimizer(optimizer)

@tf.function
def train_step(irm):
    with tf.GradientTape() as tape:
        irm_reconstruction = AE(irm)
        loss_rec = tf.keras.losses.MeanAbsoluteError()(irm, irm_reconstruction)
        scaled_loss = optimizer.get_scaled_loss(loss_rec)
    scaled_gradients = tape.gradient(scaled_loss, AE.trainable_variables)
    grads = optimizer.get_unscaled_gradients(scaled_gradients)
    optimizer.apply_gradients(zip(grads, AE.trainable_weights))
    return loss_rec

G_L = []
for epoch in tqdm(range(1, 5)):
    L = train_step(irm)
    G_L.append(L)
    print(f'Epoch {epoch} - train_loss = {G_L[-1]}')
0

There are 0 best solutions below