I'm trying to import a Keras model trained in Python with DL4J and I'm getting the following error:
11:39:22.588 [main] Exception in thread "main" java.lang.IllegalStateException: Invalid input type: Batch norm layer expected input of type CNN, CNN Flat or FF, got InputTypeRecurrent(10,timeSeriesLength=24,format=NWC) for layer index -1, layer name = batch_normalization_1
at org.deeplearning4j.nn.conf.layers.BatchNormalization.getOutputType(BatchNormalization.java:130)
at org.deeplearning4j.nn.modelimport.keras.layers.normalization.KerasBatchNormalization.getOutputType(KerasBatchNormalization.java:165)
at org.deeplearning4j.nn.modelimport.keras.KerasModel.inferOutputTypes(KerasModel.java:473)
at org.deeplearning4j.nn.modelimport.keras.KerasModel.<init>(KerasModel.java:186)
at org.deeplearning4j.nn.modelimport.keras.KerasModel.<init>(KerasModel.java:99)
at org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder.buildModel(KerasModelBuilder.java:311)
at org.deeplearning4j.nn.modelimport.keras.KerasModelImport.importKerasModelAndWeights(KerasModelImport.java:167)
at edu.mit.ll.seamnet.SpeechEnhancement.runBatchNormErrModel(SpeechEnhancement.java:161)
at edu.mit.ll.seamnet.SpeechEnhancement.main(SpeechEnhancement.java:172)
This error seems to be documented and, according to the issue report, a fix was implemented in later versions of DL4J. That said, I'm still seeing this error. I'm I missing something?
I'm using DL4J version 1.0.0-M1
Python code to save a simple model that generates this error:
in_layer = Input((25, 25,))
x = Conv1D(filters=10, kernel_size=2)(in_layer)
out_layer = BatchNormalization()(x)
model = Model(in_layer, out_layer)
model.save("batchNormError.h5")
Here's the DL4J code I'm using to import the model:
String modelPath ="batchNormError.h5";
ComputationGraph model = KerasModelImport.importKerasModelAndWeights(modelPath);
Any help will be greatly appreciated.