class Hdf5DataSet(Dataset): def __init__(self, data_name, y_name, dataset_data_path, dataset_y_path): super(Hdf5DataSet, self).__init__() with h5py.File(dataset_data_path, 'r') as f: self.length = len(f[data_name]) # to get the length, do not load self.dataset_data_path = dataset_data_path self.dataset_y_path = dataset_y_path self.data_name = data_name self.y_name = y_name
def __len__(self): return self.length
def open_data_hdf5(self): self.data_hdf5 = h5py.File(self.dataset_data_path, 'r') self.dataset_data = self.data_hdf5[self.data_name][:] # if you want dataset. def open_y_hdf5(self): self.y_hdf5 = h5py.File(self.dataset_y_path, 'r') self.dataset_y = self.y_hdf5[self.y_name][:] # if you want dataset.
def __getitem__(self, index): if not hasattr(self, 'data_hdf5'): self.open_data_hdf5() if not hasattr(self, 'y_hdf5'): self.open_y_hdf5() out_data = self.dataset_data[index] # Do loading here out_y = self.dataset_y[index] return out_data, out_y