Keras: Value error while running deepLabV3+ with ResNet50 backbone

55 Views Asked by At

I am trying to run a DeepLabv3+ model with a ResNet50 backbone (weights = None) as my input shape has 8 channels (224,224,8). I keep running into a value error with the data type but I've followed the guidelines of making sure my data is float32 and the class values which are categorical (1,2,3) are int. I also tried to change it to float as a hail Mary but it made no difference to the error.

This is my code:

def convolution_block(
    block_input,
    num_filters=256,
    kernel_size=3,
    dilation_rate=1,
    padding="same",
    use_bias=False,
):
    x = layers.Conv2D(
        num_filters,
        kernel_size=kernel_size,
        dilation_rate=dilation_rate,
        padding="same",
        use_bias=use_bias,
        kernel_initializer=keras.initializers.HeNormal(),
    )(block_input)
    x = layers.BatchNormalization()(x)
    return tf.nn.relu(x)


def DilatedSpatialPyramidPooling(dspp_input):
    dims = dspp_input.shape
    x = layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
    x = convolution_block(x, kernel_size=1, use_bias=True)
    out_pool = layers.UpSampling2D(
        size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear",
    )(x)

    out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
    out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
    out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
    out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)

    x = layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
    output = convolution_block(x, kernel_size=1)
    return output
    def DeeplabV3Plus(image_size, num_classes):
    input_layer = Input(shape=(224, 224, 8))
    # Create the custom neural network model for the initial transformation
    custom_model = keras.Sequential([
        keras.Input(shape=(224, 224, 8)),
        layers.Conv2D(64, (3, 3), activation='relu', padding='same',kernel_initializer='glorot_uniform'),
        layers.Conv2D(64, (3, 3), activation='relu', padding='same',kernel_initializer='glorot_uniform'),
        layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same',kernel_initializer='glorot_uniform')
    ])
    print(custom_model)
    #model_input = custom_model(keras.Input(shape=(224, 224, 8)))  #keras.Input(shape=(image_size, image_size, 3))
    resnet50 = keras.applications.ResNet50(
        weights=None, include_top=False, input_tensor=input_layer
    )
    x = resnet50.get_layer("conv4_block6_2_relu").output
    x = DilatedSpatialPyramidPooling(x)
    print(x)
    input_a = layers.UpSampling2D(
        size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
        interpolation="bilinear",
    )(x)
    input_b = resnet50.get_layer("conv2_block3_2_relu").output
    input_b = convolution_block(input_b, num_filters=48, kernel_size=1)

    x = layers.Concatenate(axis=-1)([input_a, input_b])
    x = convolution_block(x)
    x = convolution_block(x)
    x = layers.UpSampling2D(
        size=(image_size // x.shape[1], image_size // x.shape[2]),
        interpolation="bilinear",
    )(x)
    model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
    return keras.Model(inputs=input_layer, outputs=model_output)


    model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES)
    model.summary()
    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    model.compile(
       optimizer=keras.optimizers.Adam(learning_rate=0.001),
       loss=loss,
       metrics=["accuracy"],
    )

   def preprocess_input_data(x, y):
       # Reshape x to (224, 224, 8) for each example in the batch
       x = tf.map_fn(lambda example: tf.reshape(example, (224, 224, 8)), x, dtype=tf.float32)
       y= tf.cast(y, tf.int8)
       x= tf.cast(x, tf.float32)
       return x, y
  
   train_dataset = train_dataset.map(preprocess_input_data)
   val_dataset = val_dataset.map(preprocess_input_data)

   print(train_dataset)
   print(val_dataset)

   history = model.fit(train_dataset, validation_data=val_dataset, epochs=1)
    

This is the print out of my dataset shape and types

<MapDataset element_spec=(TensorSpec(shape=(5, 224, 224, 8), dtype=tf.float32, name=None), TensorSpec(shape=(5, 224, 224, 1), dtype=tf.int8, name=None))>
<MapDataset element_spec=(TensorSpec(shape=(5, 224, 224, 8), dtype=tf.float32, name=None), TensorSpec(shape=(5, 224, 224, 1), dtype=tf.int8, name=None))>

This is the error I get with model.fit

2023-10-30 12:29:57.446049: W tensorflow/core/framework/op_kernel.cc:1818] INVALID_ARGUMENT: ValueError: Tensor conversion requested dtype float32 for Tensor with dtype uint8: <tf.Tensor: shape=(224, 224, 4), dtype=uint8, numpy=
array([[[  0,   0,   0, 255],
        [  0,   0,   0, 255],
        [  0,   0,   0, 255],
        ...,
        [  0,   0,   0, 255],
        [  0,   0,   0, 255],
        [  0,   0,   0, 255]],

       [[  0,   0,   0, 255],
        [  0,   0,   0, 255],
        [  0,   0,   0, 255],
        ...,
        [ 11,  11,  11, 255],
        [ 11,  11,  11, 255],
        [ 11,  11,  11, 255]],

       [[  6,   6,   6, 255],
        [  6,   6,   6, 255],
        [  6,   6,   6, 255],
        ...,
        [  6,   6,   6, 255],
        [  6,   6,   6, 255],
        [  6,   6,   6, 255]],

       ...,

       [[  0,   0,   0, 255],
        [  0,   0,   0, 255],
        [  0,   0,   0, 255],
        ...,
        [  0,   0,   0, 255],
        [  0,   0,   0, 255],
        [  0,   0,   0, 255]],

       [[  7,   7,   7, 255],
        [  7,   7,   7, 255],
        [  7,   7,   7, 255],
        ...,
        [  6,   6,   6, 255],
        [  6,   6,   6, 255],
        [  6,   6,   6, 255]],

       [[ 18,  18,  18, 255],
        [ 18,  18,  18, 255],
        [ 18,  18,  18, 255],
        ...,
        [ 19,  19,  19, 255],
        [ 19,  19,  19, 255],
        [ 19,  19,  19, 255]]], dtype=uint8)>

I've run out of things to try and there aren't any more related stack overflow questions that have a problem that is like mine. Any help would be appreciated!

0

There are 0 best solutions below