#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Any, Dict, Iterable, List, Tuple
import torch
from torch import Tensor
from typing_extensions import TypeGuard, TypeIs
from torchoutil.extras.numpy import is_numpy_number_like, is_numpy_scalar_like, np
from torchoutil.pyoutil.typing import (
is_builtin_number,
is_builtin_scalar,
isinstance_guard,
)
from torchoutil.pyoutil.warnings import deprecated_alias, deprecated_function
from ._typing import (
BoolTensor,
BoolTensor1D,
ComplexFloatingTensor,
FloatingTensor,
IntegralTensor,
IntegralTensor1D,
NumberLike,
ScalarLike,
Tensor0D,
TensorOrArray,
)
[docs]def is_number_like(x: Any) -> TypeGuard[NumberLike]:
"""Returns True if input is a scalar number.
Accepted numbers-like objects are:
- Python numbers (int, float, bool, complex)
- Numpy zero-dimensional arrays
- Numpy numbers
- PyTorch zero-dimensional tensors
"""
return is_builtin_number(x) or is_numpy_number_like(x) or isinstance(x, Tensor0D)
[docs]def is_scalar_like(x: Any) -> TypeGuard[ScalarLike]:
"""Returns True if input is a scalar number.
Accepted scalar-like objects are:
- Python scalars like (int, float, bool, complex, None, str, bytes)
- Numpy zero-dimensional arrays
- Numpy generic
- PyTorch zero-dimensional tensors
"""
return is_builtin_scalar(x) or is_numpy_scalar_like(x) or isinstance(x, Tensor0D)
[docs]def is_tensor_or_array(x: Any) -> TypeIs[TensorOrArray]:
return isinstance(x, (Tensor, np.ndarray))
[docs]@deprecated_alias(is_tensor_or_array)
def is_tensor_like(*args, **kwargs):
...
[docs]def is_integral_dtype(dtype: torch.dtype) -> bool:
return is_integral_tensor(torch.empty((0,), dtype=dtype))
[docs]@deprecated_function("{fn_name}, use `isinstance(x, to.BoolTensor)` instead.")
def is_bool_tensor(x: Any) -> TypeIs[BoolTensor]:
return isinstance(x, BoolTensor)
[docs]@deprecated_function("{fn_name}, use `isinstance(x, to.BoolTensor1D)` instead.")
def is_bool_tensor1d(x: Any) -> TypeIs[BoolTensor1D]:
return isinstance(x, BoolTensor1D)
[docs]@deprecated_function(
"{fn_name}, use `isinstance(x, to.ComplexFloatingTensor)` instead."
)
def is_complex_tensor(x: Any) -> TypeIs[ComplexFloatingTensor]:
return isinstance(x, ComplexFloatingTensor)
[docs]@deprecated_function(
"{fn_name}, use `to.isinstance_guard(x, Dict[str, Tensor])` instead."
)
def is_dict_str_tensor(x: Any) -> TypeIs[Dict[str, Tensor]]:
return isinstance_guard(x, Dict[str, Tensor])
[docs]@deprecated_function("{fn_name}, use `isinstance(x, to.FloatingTensor)` instead.")
def is_floating_tensor(x: Any) -> TypeIs[FloatingTensor]:
return isinstance(x, FloatingTensor)
[docs]@deprecated_function("{fn_name}, use `isinstance(x, to.IntegralTensor)` instead.")
def is_integral_tensor(x: Any) -> TypeIs[IntegralTensor]:
return isinstance(x, IntegralTensor)
[docs]@deprecated_function("{fn_name}, use `isinstance(x, to.IntegralTensor1D)` instead.")
def is_integral_tensor1d(x: Any) -> TypeIs[IntegralTensor1D]:
return isinstance(x, IntegralTensor1D)
[docs]@deprecated_function(
"{fn_name}, use `to.isinstance_guard(x, Iterable[Tensor])` instead."
)
def is_iterable_tensor(x: Any) -> TypeIs[Iterable[Tensor]]:
return isinstance_guard(x, Iterable[Tensor])
[docs]@deprecated_function("{fn_name}, use `to.isinstance_guard(x, List[Tensor])` instead.")
def is_list_tensor(x: Any) -> TypeIs[List[Tensor]]:
return isinstance_guard(x, List[Tensor])
[docs]@deprecated_function("{fn_name}, use `isinstance(x, Tensor0D)` instead.")
def is_tensor0d(x: Any) -> TypeIs[Tensor0D]:
return isinstance(x, Tensor0D)
[docs]@deprecated_function(
"{fn_name}, use `to.isinstance_guard(x, Tuple[Tensor, ...])` instead."
)
def is_tuple_tensor(x: Any) -> TypeIs[Tuple[Tensor, ...]]:
return isinstance_guard(x, Tuple[Tensor, ...])