IterableDataset on multiple files

99 Views Asked by At

I have 1000 files, each file is processed with sliding window to generate many training sample. I write this iterable-style dataset. Everything is working, the only thing is each batch consists of data from 1 file. Is there a way to implement a custom collate function or something similar to merge data?

class TextFileDataset(IterableDataset):
                def __init__(self, file_pathes: List,) -> None:
                    self.file_pathes = file_pathes
                   
            
                    self.file_id_map= {file_path: idx for idx,
                                        file_path in enumerate(iterable=file_pathes)}
                   def process_file(self, file_path) -> None:
                         doc_indices = self.load_data(file_path=file_path)
                         idx= self.file_id_map[file_path]
                         for i in range(len(doc_indices)):
                             # run a sliding window
                             ..........
                             yield {"idx": torch.tensor(idx), "text1": torch.tensor(text1), "text2": torch.tensor(text2)}
                def __iter__(self) -> None:
                    worker_info = torch.utils.data.get_worker_info()
                    if worker_info is None:
                        for file_path in self.file_pathes:
                            yield from self.process_file(file_path=file_path)
    
                    per_worker = int(np.ceil(len(self.file_pathes) /
                                     float(worker_info.num_workers)))
                    worker_id = worker_info.id
                    self.iter_file_paths = self.file_pathes[worker_id *
                                                            per_worker:(worker_id + 1) * per_worker]
                    for file_path in self.iter_file_paths:
                        yield from self.process_file(file_path=file_path)
0

There are 0 best solutions below