#!/usr/bin/env python
# -*- coding: utf-8 -*-
from abc import abstractmethod
from typing import Callable, Generic, Iterable, Iterator, Optional, TypeVar, Union
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.dataset import Subset as TorchSubset
from torchoutil.pyoutil.collections import is_sorted
from torchoutil.pyoutil.typing.classes import (
SupportsLenAndGetItem,
SupportsLenAndGetItemAndIter,
)
from torchoutil.types.tensor_subclasses import LongTensor1D
T = TypeVar("T", covariant=True)
U = TypeVar("U", covariant=True)
SizedDatasetLike = SupportsLenAndGetItem
SizedIterableDatasetLike = SupportsLenAndGetItemAndIter
T_Dataset = TypeVar("T_Dataset", bound=Dataset)
T_SizedDatasetLike = TypeVar("T_SizedDatasetLike", bound=SupportsLenAndGetItem)
T_SizedIterableDataset = TypeVar(
"T_SizedIterableDataset",
bound=SupportsLenAndGetItemAndIter,
)
[docs]class EmptyDataset(Dataset[None]):
"""Dataset placeholder. Raises StopIteration if __getitem__ is called."""
def __getitem__(self, idx, /) -> None: # type: ignore
raise StopIteration
def __len__(self) -> int:
return 0
[docs]class Wrapper(Generic[T], Dataset[T]):
def __init__(self, dataset: SupportsLenAndGetItem[T]) -> None:
Dataset.__init__(self)
self.dataset = dataset
@abstractmethod
def __getitem__(self, idx, /) -> T: # type: ignore
raise NotImplementedError
@abstractmethod
def __len__(self) -> int:
raise NotImplementedError
[docs] def unwrap(self, recursive: bool = True) -> Union[SupportsLenAndGetItem, Dataset]:
dataset = self.dataset
continue_ = recursive and isinstance(dataset, Wrapper)
while continue_:
if not isinstance(dataset, (Wrapper, TorchSubset)):
break
dataset = dataset.dataset
continue_ = isinstance(dataset, Wrapper)
return dataset
def __repr__(self) -> str:
return f"{self.__class__.__name__}({repr(self.dataset)})"
[docs]class IterableWrapper(Generic[T], IterableDataset[T], Wrapper[T]):
def __init__(self, dataset: SupportsLenAndGetItem[T]) -> None:
IterableDataset.__init__(self)
Wrapper.__init__(self, dataset)
@abstractmethod
def __iter__(self) -> Iterator[T]:
raise NotImplementedError
def _get_dataset_iter(self) -> Iterator[T]:
if hasattr(self.dataset, "__iter__"):
it = iter(self.dataset)
else:
it = (self.dataset[i] for i in range(len(self.dataset)))
return it
[docs]class IterableSubset(IterableWrapper[T], Generic[T]):
def __init__(
self,
dataset: SupportsLenAndGetItem[T],
indices: Union[Iterable[int], LongTensor1D],
) -> None:
if isinstance(indices, LongTensor1D):
indices = indices.tolist()
else:
indices = list(indices)
if not all(idx >= 0 for idx in indices) or not is_sorted(indices):
msg = f"Invalid argument {indices=}. (expected a sorted list of positive integers)"
raise ValueError(msg)
super().__init__(dataset)
self._indices = indices
def __iter__(self) -> Iterator[T]:
it = super()._get_dataset_iter()
cur_idx = 0
item = next(it)
for idx in self._indices:
if cur_idx == idx:
yield item
continue
while cur_idx < idx:
cur_idx += 1
item = next(it)
yield item
return
def __len__(self) -> int:
return len(self._indices)
[docs]class Subset(Generic[T], TorchSubset[T], Wrapper[T]):
def __init__(self, dataset: SizedDatasetLike[T], indices: Iterable[int]) -> None:
indices = list(indices)
TorchSubset.__init__(self, dataset, indices) # type: ignore
Wrapper.__init__(self, dataset)