import logging from pathlib import Path from typing import Iterator, Tuple, Union import numpy as np import torch from torch.utils.data.dataset import Dataset RetT = Tuple[torch.Tensor, torch.Tensor] PathLike = Union[Path, str] msg_logger = logging.getLogger(__name__) class SingleFileDataset(Dataset): image_shape = None def __init__(self, inputs: torch.Tensor, outputs: torch.Tensor): self.inputs, self.outputs = inputs, outputs @classmethod def from_file( cls, input_file: PathLike, labels_file: PathLike, count: int = -1, offset: int = 0, ): # NOTE: assuming (N, *) ordering of inputs (such as NCHW, NHWC) channel_size = np.prod(np.array(cls.image_shape)) inputs_count_byte = -1 if count == -1 else count * channel_size inputs = read_tensor_from_file( input_file, -1, *cls.image_shape, count=inputs_count_byte, offset=offset * channel_size, ) labels = read_tensor_from_file( labels_file, -1, read_ty=np.int32, cast_ty=np.int64, count=count, offset=offset, ) if inputs.shape[0] != labels.shape[0]: raise ValueError("Input and output have different number of data points") msg_logger.info(f"%d entries loaded from dataset.", inputs.shape[0]) return cls(inputs, labels) @property def sample_input(self): inputs, outputs = next(iter(self)) return inputs def __len__(self) -> int: return len(self.inputs) def __getitem__(self, idx) -> RetT: return self.inputs[idx], self.outputs[idx] def __iter__(self) -> Iterator[RetT]: for i in range(len(self)): yield self[i] class MNIST(SingleFileDataset): image_shape = 1, 28, 28 class CIFAR(SingleFileDataset): image_shape = 3, 32, 32 class ImageNet(SingleFileDataset): image_shape = 3, 224, 224 def read_tensor_from_file( filename: Union[str, Path], *shape: int, read_ty=np.float32, cast_ty=np.float32, count: int = -1, offset: int = 0, ) -> torch.Tensor: offset = offset * read_ty().itemsize mmap = np.memmap(filename, dtype=read_ty, mode="r", offset=offset) n_entries = min(mmap.shape[0], count) if count != -1 else mmap.shape[0] np_array = mmap[:n_entries].reshape(shape).astype(cast_ty) return torch.from_numpy(np_array).clone()