Hi sorry for the cringe "tell me what this means" question but I can't seem to figure it out... Here's the code
from imageai.Classification.Custom import ClassificationModelTrainer
model_trainer = ClassificationModelTrainer()
model_trainer.setModelTypeAsResNet50()
model_trainer.setDataDirectory("images")
model_trainer.trainModel(num_objects=10, num_experiments=200, enhance_data=True, batch_size=8, show_network_summary=True)
and it yields this scary error:
Traceback (most recent call last):
File "marco.py", line 15, in <module>
model_trainer.trainModel(num_objects=10, num_experiments=200, enhance_data=True, batch_size=8, show_network_summary=True)
File "/home/lollo/.local/lib/python3.8/site-packages/imageai/Classification/Custom/__init__.py", line 393, in trainModel
model.fit_generator(train_generator, steps_per_epoch=int(num_train / batch_size), epochs=self.__num_epochs,
File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1847, in fit_generator
return self.fit(
File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1100, in fit
tmp_logs = self.train_function(iterator)
File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
result = self._call(*args, **kwds)
File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 888, in _call
return self._stateless_fn(*args, **kwds)
File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2942, in __call__
return graph_function._call_flat(
File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1918, in _call_flat
return self._build_call_outputs(self._inference_function.call(
File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 555, in call
outputs = execute.execute(
File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: logits and labels must be broadcastable: logits_size=[8,10] labels_size=[8,2]
[[node categorical_crossentropy/softmax_cross_entropy_with_logits (defined at /home/lollo/.local/lib/python3.8/site-packages/imageai/Classification/Custom/__init__.py:393) ]] [Op:__inference_train_function_11908]
Function call stack:
train_function
Thank you for any suggestion <3