I created a model with the image-segmentation-keras library by initializing it as such:
import keras_segmentation
from keras_segmentation.models.unet import vgg_unet
from tensorflow.keras.layers import Input
model = vgg_unet(n_classes=21 , input_height=256, input_width=448)
I then train it as such:
model.train(
train_images = "/content/drive/MyDrive/imgs_train/",
train_annotations = "/content/drive/MyDrive/masks_train/",
val_images = "/content/drive/MyDrive/mgs_validation/",
val_annotations = "/content/drive/MyDrive/masks_validation/",
checkpoints_path = "/content/drive/MyDrive/tmp/vgg_unet_1" ,
epochs=28,validate=True,callbacks = [myCallback])
model.load_weights('checkpoint_filepath')
And save it like so:
model.save('/content/drive/MyDrive/vgg_unet_segmentation.h5')
Then load it like so:
model = load_model('/content/drive/MyDrive/vgg_unet_segmentation.h5'')
However, when I try to make a prediction by doing out = model.predict_segmentation(inp=inp, out_fname="/tmp/out.png")
, I get the following error:
AttributeError: 'Functional' object has no attribute 'predict_segmentation'
So to solve this issue I did the following:
from types import MethodType
model = load_model('/content/drive/MyDrive/vgg_unet_segmentation.h5'')
model.predict_segmentation = MethodType(keras_segmentation.predict.predict, model)
However, this lead to another issue which I haven't been able to resolve:
[<ipython-input-7-a4b7d02cd9a2>](https://localhost:8080/#) in <module>()
4 out = model.predict_segmentation(
5 inp=inp,
----> 6 out_fname="/tmp/out.png")
[/content/image-segmentation-keras/keras_segmentation/predict.py](https://localhost:8080/#) in predict(model, inp, out_fname, checkpoints_path, overlay_img, class_names, show_legends, colors, prediction_width, prediction_height, read_image_type)
148 assert (len(inp.shape) == 3 or len(inp.shape) == 1 or len(inp.shape) == 4), "Image should be h,w,3 "
149
--> 150 output_width = model.output_width
151 output_height = model.output_height
152 input_width = model.input_width
AttributeError: 'Functional' object has no attribute 'output_width'
Any idea why this might be happening, and if so, how it can be resolved?
Any help is appreciated!
Thanks!
Try model.predict() for the output like the following code :