I'm trying to import a Keras model trained in Python with DL4J and I'm getting the following error:

11:39:22.588 [main] Exception in thread "main" java.lang.IllegalStateException: Invalid input type: Batch norm layer expected input of type CNN, CNN Flat or FF, got InputTypeRecurrent(10,timeSeriesLength=24,format=NWC) for layer index -1, layer name = batch_normalization_1
    at org.deeplearning4j.nn.conf.layers.BatchNormalization.getOutputType(BatchNormalization.java:130)
    at org.deeplearning4j.nn.modelimport.keras.layers.normalization.KerasBatchNormalization.getOutputType(KerasBatchNormalization.java:165)
    at org.deeplearning4j.nn.modelimport.keras.KerasModel.inferOutputTypes(KerasModel.java:473)
    at org.deeplearning4j.nn.modelimport.keras.KerasModel.<init>(KerasModel.java:186)
    at org.deeplearning4j.nn.modelimport.keras.KerasModel.<init>(KerasModel.java:99)
    at org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder.buildModel(KerasModelBuilder.java:311)
    at org.deeplearning4j.nn.modelimport.keras.KerasModelImport.importKerasModelAndWeights(KerasModelImport.java:167)
    at edu.mit.ll.seamnet.SpeechEnhancement.runBatchNormErrModel(SpeechEnhancement.java:161)
    at edu.mit.ll.seamnet.SpeechEnhancement.main(SpeechEnhancement.java:172)

This error seems to be documented and, according to the issue report, a fix was implemented in later versions of DL4J. That said, I'm still seeing this error. I'm I missing something?

I'm using DL4J version 1.0.0-M1

Python code to save a simple model that generates this error:

in_layer = Input((25, 25,))
x = Conv1D(filters=10, kernel_size=2)(in_layer)
out_layer = BatchNormalization()(x)
model = Model(in_layer, out_layer)
model.save("batchNormError.h5")

Here's the DL4J code I'm using to import the model:

String modelPath ="batchNormError.h5";
ComputationGraph model = KerasModelImport.importKerasModelAndWeights(modelPath);

Any help will be greatly appreciated.

0

There are 0 best solutions below