#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Module versions of tensor functions that do not already exists in PyTorch."""
from typing import List, Optional, Sequence, Tuple, Union, overload
import torch
from torch import Tensor
from torch.nn import functional as F
from torch.types import Number
from torchoutil.nn.functional.make import DTypeLike, as_dtype
from torchoutil.pyoutil.collections import dump_dict
from torchoutil.pyoutil.semver import Version
from torchoutil.utils import return_types
from .module import Module
[docs]class Abs(Module):
"""
Module version of :func:`~torch.abs`.
"""
[docs] def forward(self, x: Tensor) -> Tensor:
return x.abs()
[docs]class Angle(Module):
"""
Module version of :func:`~torch.angle`.
"""
[docs] def forward(self, x: Tensor) -> Tensor:
return x.angle()
[docs]class Exp(Module):
"""
Module version of :func:`~torch.exp`.
"""
[docs] def forward(self, x: Tensor) -> Tensor:
return x.exp()
[docs]class Exp2(Module):
"""
Module version of :func:`~torch.exp2`.
"""
[docs] def forward(self, x: Tensor) -> Tensor:
return x.exp2()
[docs]class FFT(Module):
"""
Module version of :func:`~torch.fft.fft`.
"""
[docs] def forward(self, x: Tensor) -> Tensor:
return torch.fft.fft(x)
[docs]class IFFT(Module):
"""
Module version of :func:`~torch.fft.ifft`.
"""
[docs] def forward(self, x: Tensor) -> Tensor:
return torch.fft.ifft(x)
[docs]class Imag(Module):
"""
Module version of :func:`~torch.Tensor.imag`.
"""
def __init__(self, *, return_zeros: bool = False) -> None:
"""Return the imaginary part of a complex tensor.
Args:
return_zeros:
If True and the input is not a complex tensor, the module will return a tensor of same shape containing zeros.
If False and the input is not a complex tensor, raises the default RuntimError of PyTorch.
"""
super().__init__()
self.return_zeros = return_zeros
[docs] def forward(self, x: Tensor) -> Tensor:
if self.return_zeros and not x.is_complex():
return torch.zeros_like(x)
else:
return x.imag
[docs]class Interpolate(Module):
"""
Module version of :func:`~torch.nn.functional.interpolate`.
"""
def __init__(
self,
size: Union[int, Tuple[int, ...], None] = None,
scale_factor: Union[float, Tuple[float, ...], None] = None,
mode: str = "nearest",
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
antialias: bool = False,
) -> None:
super().__init__()
self.size = size
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
self.recompute_scale_factor = recompute_scale_factor
self.antialias = antialias
[docs] def forward(self, x: Tensor) -> Tensor:
kwds = {}
if Version(torch.__version__) >= Version("2.0.0"):
kwds.update(
recompute_scale_factor=self.recompute_scale_factor,
antialias=self.antialias,
)
return F.interpolate(
x,
size=self.size,
scale_factor=self.scale_factor,
mode=self.mode,
align_corners=self.align_corners,
**kwds,
)
[docs]class Log(Module):
"""
Module version of :func:`~torch.log`.
"""
[docs] def forward(self, x: Tensor) -> Tensor:
return x.log()
[docs]class Log10(Module):
"""
Module version of :func:`~torch.log10`.
"""
[docs] def forward(self, x: Tensor) -> Tensor:
return x.log10()
[docs]class Log2(Module):
"""
Module version of :func:`~torch.log2`.
"""
[docs] def forward(self, x: Tensor) -> Tensor:
return x.log2()
[docs]class Max(Module):
"""
Module version of :func:`~torch.max`.
"""
def __init__(
self,
dim: Optional[int] = None,
keepdim: bool = False,
*,
return_values: bool = True,
return_indices: Optional[bool] = None,
) -> None:
if return_indices is None:
return_indices = dim is not None
if not return_values and not return_indices:
msg = f"Invalid combinaison of arguments {return_values=} and {return_indices=}. (at least one of them must be True)"
raise ValueError(msg)
if dim is None and keepdim:
msg = f"Invalid combinaison of arguments {dim=} and {keepdim=}. (expected dim is not None or keepdim=False)"
raise ValueError(msg)
super().__init__()
self.dim = dim
self.return_values = return_values
self.return_indices = return_indices
self.keepdim = keepdim
[docs] def forward(self, x: Tensor) -> Union[Tensor, return_types.max]:
if self.dim is None:
index = x.argmax()
values_indices = return_types.max([x.flatten()[index], index])
else:
values_indices = x.max(dim=self.dim, keepdim=self.keepdim)
if self.return_values and self.return_indices:
return values_indices # type: ignore
elif self.return_values:
return values_indices.values
elif self.return_indices:
return values_indices.indices
else:
msg = f"Invalid combinaison of arguments {self.return_values=} and {self.return_indices=}. (at least one of them must be True)"
raise ValueError(msg)
[docs]class Mean(Module):
"""
Module version of :func:`~torch.mean`.
"""
def __init__(
self,
dim: Optional[int] = None,
keepdim: bool = False,
dtype: DTypeLike = None,
) -> None:
super().__init__()
self.dim = dim
self.keepdim = keepdim
self.dtype = dtype
[docs] def forward(self, x: Tensor) -> Tensor:
dtype = as_dtype(self.dtype)
if (Version(torch.__version__) >= Version("2.0.0")) or (self.dim is not None):
return x.mean(dim=self.dim, keepdim=self.keepdim, dtype=dtype) # type: ignore
# support for older torch versions
result = x.mean(dtype=dtype)
if self.keepdim:
return torch.full(x.shape, result.item(), dtype=dtype, device=x.device)
else:
return result
[docs]class Min(Module):
"""
Module version of :func:`~torch.min`.
"""
def __init__(
self,
dim: Optional[int] = None,
keepdim: bool = False,
*,
return_values: bool = True,
return_indices: Optional[bool] = None,
) -> None:
if return_indices is None:
return_indices = dim is not None
if not return_values and not return_indices:
msg = f"Invalid combinaison of arguments {return_values=} and {return_indices=}. (at least one of them must be True)"
raise ValueError(msg)
if dim is None and keepdim:
msg = f"Invalid combinaison of arguments {dim=} and {keepdim=}. (expected dim is not None or keepdim=False)"
raise ValueError(msg)
super().__init__()
self.dim = dim
self.return_values = return_values
self.return_indices = return_indices
self.keepdim = keepdim
[docs] def forward(self, x: Tensor) -> Union[Tensor, return_types.min]:
if self.dim is None:
index = x.argmin()
values_indices = return_types.min([x.flatten()[index], index])
else:
values_indices = x.min(dim=self.dim, keepdim=self.keepdim)
if self.return_values and self.return_indices:
return values_indices # type: ignore
elif self.return_values:
return values_indices.values
elif self.return_indices:
return values_indices.indices
else:
msg = f"Invalid combinaison of arguments {self.return_values=} and {self.return_indices=}. (at least one of them must be True)"
raise ValueError(msg)
[docs]class Normalize(Module):
"""
Module version of :func:`~torch.nn.functional.normalize`.
"""
def __init__(
self,
p: float = 2.0,
dim: int = 1,
eps: float = 1e-12,
) -> None:
super().__init__()
self.p = p
self.dim = dim
self.eps = eps
[docs] def forward(self, x: Tensor) -> Tensor:
return F.normalize(x, self.p, self.dim, self.eps)
[docs]class Permute(Module):
"""
Module version of :func:`~torch.permute`.
"""
def __init__(self, *args: int) -> None:
super().__init__()
self.dims = tuple(args)
[docs] def forward(self, x: Tensor) -> Tensor:
return x.permute(self.dims)
[docs]class Pow(Module):
"""
Module version of :func:`~torch.Tensor.pow`.
"""
def __init__(self, exponent: Union[Number, Tensor]) -> None:
super().__init__()
self.exponent = exponent
[docs] def forward(self, x: Tensor) -> Tensor:
return x.pow(self.exponent)
[docs]class Real(Module):
"""
Module version of :func:`~torch.Tensor.real`.
"""
[docs] def forward(self, x: Tensor) -> Tensor:
return x.real
[docs]class Repeat(Module):
"""
Module version of :func:`~torch.repeat`.
"""
def __init__(self, *repeats: int) -> None:
super().__init__()
self.repeats = repeats
[docs] def forward(self, x: Tensor) -> Tensor:
return x.repeat(self.repeats)
[docs]class RepeatInterleave(Module):
"""
Module version of :func:`~torch.repeat_interleave`.
"""
def __init__(
self,
repeats: Union[int, Tensor],
dim: int,
output_size: Optional[int] = None,
) -> None:
super().__init__()
self.repeats = repeats
self.dim = dim
self.output_size = output_size
[docs] def forward(self, x: Tensor) -> Tensor:
return x.repeat_interleave(self.repeats, self.dim, output_size=self.output_size)
[docs]class Reshape(Module):
"""
Module version of :func:`~torch.reshape`.
"""
def __init__(self, *shape: int) -> None:
super().__init__()
self.shape = shape
[docs] def forward(self, x: Tensor) -> Tensor:
return x.reshape(self.shape)
[docs]class TensorTo(Module):
"""
Module version of :func:`~torch.Tensor.to`.
"""
def __init__(self, **kwargs) -> None:
super().__init__()
self.kwargs = kwargs
[docs] def forward(self, x: Tensor) -> Tensor:
return x.to(**self.kwargs)
[docs]class ToList(Module):
"""
Module version of :func:`~torch.Tensor.tolist`.
"""
[docs] def forward(self, x: Tensor) -> List:
return x.tolist()
[docs]class Transpose(Module):
"""
Module version of :func:`~torch.transpose`.
"""
def __init__(self, dim0: int, dim1: int, copy: bool = False) -> None:
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
self.copy = copy
[docs] def forward(self, x: Tensor) -> Tensor:
if self.copy and not hasattr(torch, "transpose_copy"):
msg = f"Invalid argument {self.copy=} in torch {torch.__version__}."
raise ValueError(msg)
if self.copy:
return torch.transpose_copy(x, self.dim0, self.dim1) # type: ignore
else:
return torch.transpose(x, self.dim0, self.dim1)
[docs]class View(Module):
@overload
def __init__(self, dtype: torch.dtype, /) -> None:
...
@overload
def __init__(self, size: Sequence[int], /) -> None:
...
@overload
def __init__(self, *size: int) -> None:
...
def __init__(self, *args) -> None:
super().__init__()
self.args = args
[docs] def forward(self, x: Tensor) -> Tensor:
return x.view(*self.args)