I'm new to tensorflow and trying to write a custom dataset class derived from tf.data.dataset like this example code:
import tensorflow as tf
class CustomDataset(tf.data.Dataset):
def __init__(self, num_of_images: int):
self.num_of_images = num_of_images
def generator(self):
return tf.ones(shape=(5, 5, 3), dtype=tf.int32)* self.num_of_images
def __len__(self) -> int:
return self.num_of_images
def __call__(self):
for _ in range(self.__len__()):
yield self.generator()
def _inputs(self):
return ()
def element_spec(self):
return tf.TensorSpec(shape=(5, 5, 3), dtype=tf.int32)
if __name__ == "__main__":
custom_dataset1 = CustomDataset(3)
custom_dataset2 = CustomDataset(4)
all_ds=[]
all_ds.append(custom_dataset1)
all_ds.append(custom_dataset2)
sampled_ds = tf.data.Dataset.sample_from_datasets(all_ds, seed=1)
#rest of the code
- I'm not quite sure what methods exactly to override from tf.data.Dataset
- When trying to sample it gives me this error:
result = a.most_specific_common_supertype([b])
AttributeError: 'function' object has no attribute 'most_specific_common_supertype'
any help?
I've tried this, and it worked
ds1 = tf.data.Dataset.from_generator(generator= CustomDataset(1),
output_signature=tf.TensorSpec(shape=(5, 5, 3), dtype=tf.int32))
ds2 = tf.data.Dataset.from_generator(generator= CustomDataset(2),
output_signature=tf.TensorSpec(shape=(5, 5, 3), dtype=tf.int32))
all_ds=[]
all_ds.append(ds1)
all_ds.append(ds2)
sampled_ds = tf.data.Dataset.sample_from_datasets(all_ds, seed=1)
but is there any other way rather than using from_generator?