I have 1000 files, each file is processed with sliding window to generate many training sample. I write this iterable-style dataset. Everything is working, the only thing is each batch consists of data from 1 file. Is there a way to implement a custom collate function or something similar to merge data?
class TextFileDataset(IterableDataset):
def __init__(self, file_pathes: List,) -> None:
self.file_pathes = file_pathes
self.file_id_map= {file_path: idx for idx,
file_path in enumerate(iterable=file_pathes)}
def process_file(self, file_path) -> None:
doc_indices = self.load_data(file_path=file_path)
idx= self.file_id_map[file_path]
for i in range(len(doc_indices)):
# run a sliding window
..........
yield {"idx": torch.tensor(idx), "text1": torch.tensor(text1), "text2": torch.tensor(text2)}
def __iter__(self) -> None:
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
for file_path in self.file_pathes:
yield from self.process_file(file_path=file_path)
per_worker = int(np.ceil(len(self.file_pathes) /
float(worker_info.num_workers)))
worker_id = worker_info.id
self.iter_file_paths = self.file_pathes[worker_id *
per_worker:(worker_id + 1) * per_worker]
for file_path in self.iter_file_paths:
yield from self.process_file(file_path=file_path)