Source code for torchoutil.nn.modules.tensor

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""Module versions of tensor functions that do not already exists in PyTorch."""

from typing import List, Optional, Sequence, Tuple, Union, overload

import torch
from torch import Tensor
from torch.nn import functional as F
from torch.types import Number

from torchoutil.nn.functional.make import DTypeLike, as_dtype
from torchoutil.pyoutil.collections import dump_dict
from torchoutil.pyoutil.semver import Version
from torchoutil.utils import return_types

from .module import Module


[docs]class Abs(Module): """ Module version of :func:`~torch.abs`. """
[docs] def forward(self, x: Tensor) -> Tensor: return x.abs()
[docs]class Angle(Module): """ Module version of :func:`~torch.angle`. """
[docs] def forward(self, x: Tensor) -> Tensor: return x.angle()
[docs]class Exp(Module): """ Module version of :func:`~torch.exp`. """
[docs] def forward(self, x: Tensor) -> Tensor: return x.exp()
[docs]class Exp2(Module): """ Module version of :func:`~torch.exp2`. """
[docs] def forward(self, x: Tensor) -> Tensor: return x.exp2()
[docs]class FFT(Module): """ Module version of :func:`~torch.fft.fft`. """
[docs] def forward(self, x: Tensor) -> Tensor: return torch.fft.fft(x)
[docs]class IFFT(Module): """ Module version of :func:`~torch.fft.ifft`. """
[docs] def forward(self, x: Tensor) -> Tensor: return torch.fft.ifft(x)
[docs]class Imag(Module): """ Module version of :func:`~torch.Tensor.imag`. """ def __init__(self, *, return_zeros: bool = False) -> None: """Return the imaginary part of a complex tensor. Args: return_zeros: If True and the input is not a complex tensor, the module will return a tensor of same shape containing zeros. If False and the input is not a complex tensor, raises the default RuntimError of PyTorch. """ super().__init__() self.return_zeros = return_zeros
[docs] def forward(self, x: Tensor) -> Tensor: if self.return_zeros and not x.is_complex(): return torch.zeros_like(x) else: return x.imag
[docs]class Interpolate(Module): """ Module version of :func:`~torch.nn.functional.interpolate`. """ def __init__( self, size: Union[int, Tuple[int, ...], None] = None, scale_factor: Union[float, Tuple[float, ...], None] = None, mode: str = "nearest", align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False, ) -> None: super().__init__() self.size = size self.scale_factor = scale_factor self.mode = mode self.align_corners = align_corners self.recompute_scale_factor = recompute_scale_factor self.antialias = antialias
[docs] def forward(self, x: Tensor) -> Tensor: kwds = {} if Version(torch.__version__) >= Version("2.0.0"): kwds.update( recompute_scale_factor=self.recompute_scale_factor, antialias=self.antialias, ) return F.interpolate( x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners, **kwds, )
[docs]class Log(Module): """ Module version of :func:`~torch.log`. """
[docs] def forward(self, x: Tensor) -> Tensor: return x.log()
[docs]class Log10(Module): """ Module version of :func:`~torch.log10`. """
[docs] def forward(self, x: Tensor) -> Tensor: return x.log10()
[docs]class Log2(Module): """ Module version of :func:`~torch.log2`. """
[docs] def forward(self, x: Tensor) -> Tensor: return x.log2()
[docs]class Max(Module): """ Module version of :func:`~torch.max`. """ def __init__( self, dim: Optional[int] = None, keepdim: bool = False, *, return_values: bool = True, return_indices: Optional[bool] = None, ) -> None: if return_indices is None: return_indices = dim is not None if not return_values and not return_indices: msg = f"Invalid combinaison of arguments {return_values=} and {return_indices=}. (at least one of them must be True)" raise ValueError(msg) if dim is None and keepdim: msg = f"Invalid combinaison of arguments {dim=} and {keepdim=}. (expected dim is not None or keepdim=False)" raise ValueError(msg) super().__init__() self.dim = dim self.return_values = return_values self.return_indices = return_indices self.keepdim = keepdim
[docs] def forward(self, x: Tensor) -> Union[Tensor, return_types.max]: if self.dim is None: index = x.argmax() values_indices = return_types.max([x.flatten()[index], index]) else: values_indices = x.max(dim=self.dim, keepdim=self.keepdim) if self.return_values and self.return_indices: return values_indices # type: ignore elif self.return_values: return values_indices.values elif self.return_indices: return values_indices.indices else: msg = f"Invalid combinaison of arguments {self.return_values=} and {self.return_indices=}. (at least one of them must be True)" raise ValueError(msg)
[docs] def extra_repr(self) -> str: return dump_dict( dict( dim=self.dim, return_values=self.return_values, return_indices=self.return_indices, keepdim=self.keepdim, ), )
[docs]class Mean(Module): """ Module version of :func:`~torch.mean`. """ def __init__( self, dim: Optional[int] = None, keepdim: bool = False, dtype: DTypeLike = None, ) -> None: super().__init__() self.dim = dim self.keepdim = keepdim self.dtype = dtype
[docs] def forward(self, x: Tensor) -> Tensor: dtype = as_dtype(self.dtype) if (Version(torch.__version__) >= Version("2.0.0")) or (self.dim is not None): return x.mean(dim=self.dim, keepdim=self.keepdim, dtype=dtype) # type: ignore # support for older torch versions result = x.mean(dtype=dtype) if self.keepdim: return torch.full(x.shape, result.item(), dtype=dtype, device=x.device) else: return result
[docs] def extra_repr(self) -> str: return dump_dict( dict( dim=self.dim, keepdim=self.keepdim, dtype=self.dtype, ), ignore_lst=(None,), )
[docs]class Min(Module): """ Module version of :func:`~torch.min`. """ def __init__( self, dim: Optional[int] = None, keepdim: bool = False, *, return_values: bool = True, return_indices: Optional[bool] = None, ) -> None: if return_indices is None: return_indices = dim is not None if not return_values and not return_indices: msg = f"Invalid combinaison of arguments {return_values=} and {return_indices=}. (at least one of them must be True)" raise ValueError(msg) if dim is None and keepdim: msg = f"Invalid combinaison of arguments {dim=} and {keepdim=}. (expected dim is not None or keepdim=False)" raise ValueError(msg) super().__init__() self.dim = dim self.return_values = return_values self.return_indices = return_indices self.keepdim = keepdim
[docs] def forward(self, x: Tensor) -> Union[Tensor, return_types.min]: if self.dim is None: index = x.argmin() values_indices = return_types.min([x.flatten()[index], index]) else: values_indices = x.min(dim=self.dim, keepdim=self.keepdim) if self.return_values and self.return_indices: return values_indices # type: ignore elif self.return_values: return values_indices.values elif self.return_indices: return values_indices.indices else: msg = f"Invalid combinaison of arguments {self.return_values=} and {self.return_indices=}. (at least one of them must be True)" raise ValueError(msg)
[docs] def extra_repr(self) -> str: return dump_dict( dict( dim=self.dim, return_values=self.return_values, return_indices=self.return_indices, keepdim=self.keepdim, ), )
[docs]class Normalize(Module): """ Module version of :func:`~torch.nn.functional.normalize`. """ def __init__( self, p: float = 2.0, dim: int = 1, eps: float = 1e-12, ) -> None: super().__init__() self.p = p self.dim = dim self.eps = eps
[docs] def forward(self, x: Tensor) -> Tensor: return F.normalize(x, self.p, self.dim, self.eps)
[docs] def extra_repr(self) -> str: return dump_dict( dict( p=self.p, dim=self.dim, eps=self.eps, ) )
[docs]class Permute(Module): """ Module version of :func:`~torch.permute`. """ def __init__(self, *args: int) -> None: super().__init__() self.dims = tuple(args)
[docs] def forward(self, x: Tensor) -> Tensor: return x.permute(self.dims)
[docs] def extra_repr(self) -> str: return dump_dict( dict( dims=self.dims, ), fmt="{value}", )
[docs]class Pow(Module): """ Module version of :func:`~torch.Tensor.pow`. """ def __init__(self, exponent: Union[Number, Tensor]) -> None: super().__init__() self.exponent = exponent
[docs] def forward(self, x: Tensor) -> Tensor: return x.pow(self.exponent)
[docs] def extra_repr(self) -> str: return dump_dict(exponent=self.exponent)
[docs]class Real(Module): """ Module version of :func:`~torch.Tensor.real`. """
[docs] def forward(self, x: Tensor) -> Tensor: return x.real
[docs]class Repeat(Module): """ Module version of :func:`~torch.repeat`. """ def __init__(self, *repeats: int) -> None: super().__init__() self.repeats = repeats
[docs] def forward(self, x: Tensor) -> Tensor: return x.repeat(self.repeats)
[docs] def extra_repr(self) -> str: return dump_dict(repeats=self.repeats)
[docs]class RepeatInterleave(Module): """ Module version of :func:`~torch.repeat_interleave`. """ def __init__( self, repeats: Union[int, Tensor], dim: int, output_size: Optional[int] = None, ) -> None: super().__init__() self.repeats = repeats self.dim = dim self.output_size = output_size
[docs] def forward(self, x: Tensor) -> Tensor: return x.repeat_interleave(self.repeats, self.dim, output_size=self.output_size)
[docs] def extra_repr(self) -> str: return dump_dict( dict( repeats=self.repeats, dim=self.dim, output_size=self.output_size, ), ignore_lst=(None,), )
[docs]class Reshape(Module): """ Module version of :func:`~torch.reshape`. """ def __init__(self, *shape: int) -> None: super().__init__() self.shape = shape
[docs] def forward(self, x: Tensor) -> Tensor: return x.reshape(self.shape)
[docs] def extra_repr(self) -> str: return dump_dict( dict( shape=self.shape, ), )
[docs]class TensorTo(Module): """ Module version of :func:`~torch.Tensor.to`. """ def __init__(self, **kwargs) -> None: super().__init__() self.kwargs = kwargs
[docs] def forward(self, x: Tensor) -> Tensor: return x.to(**self.kwargs)
[docs] def extra_repr(self) -> str: return dump_dict(self.kwargs)
[docs]class ToList(Module): """ Module version of :func:`~torch.Tensor.tolist`. """
[docs] def forward(self, x: Tensor) -> List: return x.tolist()
[docs]class Transpose(Module): """ Module version of :func:`~torch.transpose`. """ def __init__(self, dim0: int, dim1: int, copy: bool = False) -> None: super().__init__() self.dim0 = dim0 self.dim1 = dim1 self.copy = copy
[docs] def forward(self, x: Tensor) -> Tensor: if self.copy and not hasattr(torch, "transpose_copy"): msg = f"Invalid argument {self.copy=} in torch {torch.__version__}." raise ValueError(msg) if self.copy: return torch.transpose_copy(x, self.dim0, self.dim1) # type: ignore else: return torch.transpose(x, self.dim0, self.dim1)
[docs] def extra_repr(self) -> str: return dump_dict( dict( dim0=self.dim0, dim1=self.dim1, ), fmt="{value}", )
[docs]class View(Module): @overload def __init__(self, dtype: torch.dtype, /) -> None: ... @overload def __init__(self, size: Sequence[int], /) -> None: ... @overload def __init__(self, *size: int) -> None: ... def __init__(self, *args) -> None: super().__init__() self.args = args
[docs] def forward(self, x: Tensor) -> Tensor: return x.view(*self.args)