#!/usr/bin/env python
# -*- coding: utf-8 -*-
from dataclasses import dataclass
from typing import Any, Generic, Iterable, Tuple, TypeVar, Union
import torch
from torch import Tensor
import torchoutil as to
from torchoutil import pyoutil as po
from torchoutil.pyoutil import BuiltinScalar, get_current_fn_name
from .definitions import ACCEPTED_NUMPY_DTYPES, np
T_Invalid = TypeVar("T_Invalid", covariant=True)
T_EmptyNp = TypeVar("T_EmptyNp", covariant=True)
T_EmptyTorch = TypeVar("T_EmptyTorch", covariant=True)
[docs]class InvalidTorchDType(metaclass=po.Singleton):
"""Default return type for torch_dtype when an invalid data is passed as argument of scan_torch_dtype function. (like str for example)"""
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
[docs]@dataclass(frozen=True)
class ShapeDTypeInfo(Generic[T_Invalid, T_EmptyTorch, T_EmptyNp]):
shape: Tuple[int, ...]
torch_dtype: Union[torch.dtype, T_Invalid, T_EmptyTorch]
numpy_dtype: Union[np.dtype, T_EmptyNp]
valid_shape: bool
@property
def fill_value(self) -> BuiltinScalar:
return numpy_dtype_to_fill_value(self.numpy_dtype)
@property
def get_ndim(self) -> int:
return len(self.shape)
@property
def kind(self) -> str:
if isinstance(self.numpy_dtype, np.dtype):
return self.numpy_dtype.kind
else:
return "V"
[docs]def scan_shape_dtypes(
x: Any,
*,
accept_heterogeneous_shape: bool = False,
empty_torch: T_EmptyTorch = None,
empty_np: T_EmptyNp = np.dtype("V"),
) -> ShapeDTypeInfo[InvalidTorchDType, T_EmptyTorch, T_EmptyNp]:
"""Returns the shape and the hdf_dtype for an input."""
valid_shape, shape = to.get_shape(x, return_valid=True)
if not accept_heterogeneous_shape and not valid_shape:
msg = f"Invalid argument {x} for {get_current_fn_name()}. (cannot compute shape for heterogeneous data)"
raise ValueError(msg)
torch_dtype = scan_torch_dtype(x, empty=empty_torch)
numpy_dtype = scan_numpy_dtype(x, empty=empty_np)
info = ShapeDTypeInfo[InvalidTorchDType, T_EmptyTorch, T_EmptyNp](
shape,
torch_dtype,
numpy_dtype,
valid_shape,
)
return info
[docs]def scan_torch_dtype(
x: Any,
*,
invalid: T_Invalid = InvalidTorchDType(),
empty: T_EmptyTorch = None,
) -> Union[torch.dtype, T_Invalid, T_EmptyTorch]:
"""Returns torch dtype of an arbitrary object. Works recursively on tuples and lists. An instance of InvalidTorchDType can be returned if a str is passed."""
if isinstance(x, (int, float, bool, complex)):
torch_dtype = torch.as_tensor(x).dtype
return torch_dtype
if isinstance(x, Tensor):
torch_dtype = x.dtype
return torch_dtype
if isinstance(x, (np.ndarray, np.generic)):
torch_dtype = numpy_dtype_to_torch_dtype(x.dtype, invalid=invalid)
return torch_dtype
if isinstance(x, (str, bytes, bytearray)):
return invalid
if isinstance(x, (list, tuple)):
if len(x) == 0:
return empty
torch_dtypes = [scan_torch_dtype(xi, invalid=invalid, empty=empty) for xi in x]
torch_dtype = merge_torch_dtypes(torch_dtypes, invalid=invalid, empty=empty)
return torch_dtype
msg = f"Unsupported type {x.__class__.__name__} in function {po.get_current_fn_name()}."
raise TypeError(msg)
[docs]def scan_numpy_dtype(
x: Any,
*,
empty: T_EmptyNp = np.dtype("V"),
) -> Union[np.dtype, T_EmptyNp]:
if isinstance(x, (int, float, bool, complex)):
numpy_dtype = np.array(x).dtype
return numpy_dtype
if isinstance(x, Tensor):
numpy_dtype = torch_dtype_to_numpy_dtype(x.dtype)
return numpy_dtype
if isinstance(x, (np.ndarray, np.generic)):
numpy_dtype = x.dtype
return numpy_dtype
if isinstance(x, (str, bytes, bytearray)):
numpy_dtype = np.array(x).dtype
return numpy_dtype
if isinstance(x, (list, tuple)):
if len(x) == 0:
return empty
numpy_dtypes = [scan_numpy_dtype(xi, empty=empty) for xi in x]
numpy_dtype = merge_numpy_dtypes(numpy_dtypes, empty=empty)
return numpy_dtype
msg = f"Unsupported type {x.__class__.__name__} in function {po.get_current_fn_name()}."
raise TypeError(msg)
[docs]def torch_dtype_to_numpy_dtype(dtype: torch.dtype) -> np.dtype:
x = torch.empty((0,), dtype=dtype)
x = to.tensor_to_numpy(x)
return x.dtype
[docs]def numpy_dtype_to_torch_dtype(
dtype: np.dtype,
*,
invalid: T_Invalid = InvalidTorchDType(),
) -> Union[torch.dtype, T_Invalid]:
if dtype in ACCEPTED_NUMPY_DTYPES:
x = np.empty((0,), dtype=dtype)
x = to.numpy_to_tensor(x)
return x.dtype
else:
return invalid
[docs]def numpy_dtype_to_fill_value(dtype: Any) -> BuiltinScalar:
if not isinstance(dtype, np.dtype):
return None
kind = dtype.kind
if kind in ("b",):
return False
elif kind in ("u", "i"):
return 0
elif kind in ("f",):
return 0.0
elif kind in ("c",):
return 0j
elif kind in ("U", "S"):
return ""
else:
KINDS = ("b", "u", "i", "f", "c", "U", "S")
msg = f"Invalid argument {dtype=}. (expected dtype.kind in {KINDS})"
raise ValueError(msg)
[docs]def merge_numpy_dtypes(
dtypes: Iterable[Union[np.dtype, T_EmptyNp]],
*,
empty: T_EmptyNp = np.dtype("V"),
) -> Union[np.dtype, T_EmptyNp]:
dtypes = list(dict.fromkeys(dtypes))
dtypes = [dtype for dtype in dtypes if dtype != empty]
if len(dtypes) == 0:
return empty
dummy_arrays = [np.empty((0,), dtype=dtype) for dtype in dtypes] # type: ignore
dtype = np.stack(dummy_arrays).dtype
return dtype
[docs]def merge_torch_dtypes(
dtypes: Iterable[Union[torch.dtype, T_Invalid, T_EmptyNp]],
*,
invalid: T_Invalid = InvalidTorchDType(),
empty: T_EmptyNp = None,
) -> Union[torch.dtype, T_Invalid, T_EmptyNp]:
dtypes = list(dict.fromkeys(dtypes))
dtypes = [dtype for dtype in dtypes if dtype != empty]
if len(dtypes) == 0:
return empty
if any(dtype == invalid for dtype in dtypes):
return invalid
dummy_tensors = [torch.empty((0,), dtype=dtype) for dtype in dtypes] # type: ignore
dtype = torch.stack(dummy_tensors).dtype
return dtype
[docs]def get_default_numpy_dtype() -> np.dtype:
return np.empty((0,)).dtype