-
Yifan Zhao authoredYifan Zhao authored
dataset.py 2.45 KiB
import logging
from pathlib import Path
from typing import Iterator, Tuple, Union
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
RetT = Tuple[torch.Tensor, torch.Tensor]
PathLike = Union[Path, str]
msg_logger = logging.getLogger(__name__)
class SingleFileDataset(Dataset):
image_shape = None
def __init__(self, inputs: torch.Tensor, outputs: torch.Tensor):
self.inputs, self.outputs = inputs, outputs
@classmethod
def from_file(
cls,
input_file: PathLike,
labels_file: PathLike,
count: int = -1,
offset: int = 0,
):
# NOTE: assuming (N, *) ordering of inputs (such as NCHW, NHWC)
channel_size = np.prod(np.array(cls.image_shape))
inputs_count_byte = -1 if count == -1 else count * channel_size
inputs = read_tensor_from_file(
input_file,
-1,
*cls.image_shape,
count=inputs_count_byte,
offset=offset * channel_size,
)
labels = read_tensor_from_file(
labels_file,
-1,
read_ty=np.int32,
cast_ty=np.int64,
count=count,
offset=offset,
)
if inputs.shape[0] != labels.shape[0]:
raise ValueError("Input and output have different number of data points")
msg_logger.info(f"%d entries loaded from dataset.", inputs.shape[0])
return cls(inputs, labels)
@property
def sample_input(self):
inputs, outputs = next(iter(self))
return inputs
def __len__(self) -> int:
return len(self.inputs)
def __getitem__(self, idx) -> RetT:
return self.inputs[idx], self.outputs[idx]
def __iter__(self) -> Iterator[RetT]:
for i in range(len(self)):
yield self[i]
class MNIST(SingleFileDataset):
image_shape = 1, 28, 28
class CIFAR(SingleFileDataset):
image_shape = 3, 32, 32
class ImageNet(SingleFileDataset):
image_shape = 3, 224, 224
def read_tensor_from_file(
filename: Union[str, Path],
*shape: int,
read_ty=np.float32,
cast_ty=np.float32,
count: int = -1,
offset: int = 0,
) -> torch.Tensor:
offset = offset * read_ty().itemsize
mmap = np.memmap(filename, dtype=read_ty, mode="r", offset=offset)
n_entries = min(mmap.shape[0], count) if count != -1 else mmap.shape[0]
np_array = mmap[:n_entries].reshape(shape).astype(cast_ty)
return torch.from_numpy(np_array).clone()