Source code for torchoutil.utils.data.slicer

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from abc import ABC, abstractmethod
from typing import Any, Generic, Iterable, List, Tuple, TypeVar, Union, final, overload

import torch
from torch.utils.data.dataset import Dataset
from typing_extensions import TypeAlias

from torchoutil.extras.numpy.functional import is_numpy_bool_array
from torchoutil.nn.functional.transform import as_tensor
from torchoutil.pyoutil.typing import isinstance_guard
from torchoutil.pyoutil.typing.classes import SupportsLenAndGetItem
from torchoutil.types._typing import BoolTensor1D, Tensor1D, TensorOrArray
from torchoutil.types.guards import is_number_like, is_tensor_or_array
from torchoutil.utils.data.dataset import Wrapper

T = TypeVar("T", covariant=False)
U = TypeVar("U", covariant=False)

Indices: TypeAlias = Union[Iterable[bool], Iterable[int], None, slice, Tensor1D]


[docs]class DatasetSlicer(Generic[T], ABC, Dataset[T]): def __init__( self, *, add_slice_support: bool = True, add_indices_support: bool = True, add_mask_support: bool = True, add_none_support: bool = True, ) -> None: Dataset.__init__(self) self._add_slice_support = add_slice_support self._add_indices_support = add_indices_support self._add_mask_support = add_mask_support self._add_none_support = add_none_support @abstractmethod def __len__(self) -> int: raise NotImplementedError
[docs] @abstractmethod def get_item(self, idx, /, *args, **kwargs) -> Any: raise NotImplementedError
@overload @final def __getitem__(self, idx: int, /) -> T: # type: ignore ... @overload @final def __getitem__(self, idx: Indices, /) -> List[T]: # type: ignore ... @overload @final def __getitem__(self, idx: Tuple[Any, ...], /) -> Any: # type: ignore ... @final def __getitem__(self, idx) -> Any: if isinstance(idx, tuple) and len(idx) > 1: idx, *args = idx else: args = () if is_number_like(idx): return self.get_item(idx, *args) elif isinstance(idx, slice): return self.get_items_slice(idx, *args) elif ( isinstance_guard(idx, Iterable[bool]) or isinstance(idx, BoolTensor1D) or (is_numpy_bool_array(idx) and idx.ndim == 1) ): return self.get_items_mask(idx, *args) elif isinstance_guard(idx, Iterable[int]) or is_tensor_or_array(idx): return self.get_items_indices(idx, *args) elif idx is None: return self.get_items_none(idx, *args) else: raise TypeError(f"Invalid argument type {type(idx)=} with {args=}.") @final def __getitems__( self, indices: Indices, *args, ) -> List[T]: return self.__getitem__(indices, *args) def __repr__(self) -> str: return f"{self.__class__.__name__}()"
[docs] def get_items_indices( self, indices: Union[Iterable[int], TensorOrArray], *args, ) -> List[T]: if self._add_indices_support: return [self.get_item(idx, *args) for idx in indices] else: return self.get_item(indices, *args)
[docs] def get_items_mask( self, mask: Union[Iterable[bool], TensorOrArray], *args, ) -> List[T]: if self._add_mask_support: mask = as_tensor(mask, dtype=torch.bool) if len(mask) > 0 and len(mask) != len(self): # type: ignore msg = f"Invalid mask size {len(mask)}. (expected {len(self)})" raise ValueError(msg) indices = torch.where(mask)[0] return self.get_items_indices(indices, *args) else: return self.get_item(mask, *args)
[docs] def get_items_slice( self, slice_: slice, *args, ) -> List[T]: if self._add_slice_support: return self.get_items_indices(range(len(self))[slice_], *args) else: return self.get_item(slice_, *args)
[docs] def get_items_none( self, none: None, *args, ) -> List[T]: if self._add_none_support: return self.get_items_slice(slice(None), *args) else: return self.get_item(none, *args)
[docs]class DatasetSlicerWrapper(Generic[T], DatasetSlicer[T], Wrapper[T]): def __init__( self, dataset: SupportsLenAndGetItem[T], *, add_slice_support: bool = True, add_indices_support: bool = True, add_mask_support: bool = True, ) -> None: """Wrap a sequence to support slice, indices and mask arguments types.""" DatasetSlicer.__init__( self, add_slice_support=add_slice_support, add_indices_support=add_indices_support, add_mask_support=add_mask_support, ) Wrapper.__init__(self, dataset) def __len__(self) -> int: return len(self.dataset)
[docs] def get_item(self, idx: int, *args) -> T: # note: we need to split calls here, because self.dataset[idx] give an int as argument while self.dataset[idx, *args] always gives a tuple even if args == () if len(args) == 0: return self.dataset[idx] else: # equivalent to self.dataset[idx, *args], but only in recent python versions return self.dataset.__getitem__((idx,) + args)