python imageai error `Function call stack: train_function`

339 Views Asked by At

Hi sorry for the cringe "tell me what this means" question but I can't seem to figure it out... Here's the code

from imageai.Classification.Custom import ClassificationModelTrainer

model_trainer = ClassificationModelTrainer()
model_trainer.setModelTypeAsResNet50()
model_trainer.setDataDirectory("images")
model_trainer.trainModel(num_objects=10, num_experiments=200, enhance_data=True, batch_size=8, show_network_summary=True)

and it yields this scary error:

Traceback (most recent call last):
  File "marco.py", line 15, in <module>
    model_trainer.trainModel(num_objects=10, num_experiments=200, enhance_data=True, batch_size=8, show_network_summary=True)
  File "/home/lollo/.local/lib/python3.8/site-packages/imageai/Classification/Custom/__init__.py", line 393, in trainModel
    model.fit_generator(train_generator, steps_per_epoch=int(num_train / batch_size), epochs=self.__num_epochs,
  File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1847, in fit_generator
    return self.fit(
  File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1100, in fit
    tmp_logs = self.train_function(iterator)
  File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
    result = self._call(*args, **kwds)
  File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 888, in _call
    return self._stateless_fn(*args, **kwds)
  File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2942, in __call__
    return graph_function._call_flat(
  File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 1918, in _call_flat
    return self._build_call_outputs(self._inference_function.call(
  File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 555, in call
    outputs = execute.execute(
  File "/home/lollo/.local/lib/python3.8/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError:  logits and labels must be broadcastable: logits_size=[8,10] labels_size=[8,2]
     [[node categorical_crossentropy/softmax_cross_entropy_with_logits (defined at /home/lollo/.local/lib/python3.8/site-packages/imageai/Classification/Custom/__init__.py:393) ]] [Op:__inference_train_function_11908]

Function call stack:
train_function

Thank you for any suggestion <3

0

There are 0 best solutions below