I am using tf.keras.preprocessing.image_dataset_from_directory
to get a BatchDataset
, where the dataset has 10 classes.
I am trying to integrate this BatchDataset
with a Keras VGG16
(docs) network. From the docs:
Note: each Keras Application expects a specific kind of input preprocessing. For VGG16, call
tf.keras.applications.vgg16.preprocess_input
on your inputs before passing them to the model.
However, I am struggling to get this preprocess_input
working with a BatchDataset
. Can you please help me figure out how to connect these two dots?
Please see the below code:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(train_data_dir, image_size=(224, 224))
train_ds = tf.keras.applications.vgg16.preprocess_input(train_ds)
This will throw TypeError: 'BatchDataset' object is not subscriptable
:
Traceback (most recent call last):
...
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/vgg16.py", line 232, in preprocess_input
return imagenet_utils.preprocess_input(
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 117, in preprocess_input
return _preprocess_symbolic_input(
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 278, in _preprocess_symbolic_input
x = x[..., ::-1]
TypeError: 'BatchDataset' object is not subscriptable
From TypeError: 'DatasetV1Adapter' object is not subscriptable (from BatchDataset not subscriptable when trying to format Python dictionary as table) the suggestion was to use:
train_ds = tf.keras.applications.vgg16.preprocess_input(
list(train_ds.as_numpy_iterator())
)
However, this also fails:
Traceback (most recent call last):
...
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/vgg16.py", line 232, in preprocess_input
return imagenet_utils.preprocess_input(
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 117, in preprocess_input
return _preprocess_symbolic_input(
File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 278, in _preprocess_symbolic_input
x = x[..., ::-1]
TypeError: list indices must be integers or slices, not tuple
This is all using Python==3.10.3
with tensorflow==2.8.0
.
How can I get this working? Thank you in advance.
Okay I figured it out. I needed to pass a
tf.Tensor
, not atf.data.Dataset
. One can get aTensor
out by iterating over theDataset
.This can be done in a few ways:
If you convert option 2 into a generator, it can be directly passed into the downstream
model.fit
. Cheers!