Source code for torchoutil.nn.modules.numpy
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Union
from torch import Tensor
from torchoutil.nn.modules.module import Module
from torchoutil.core.make import DeviceLike, DTypeLike
from torchoutil.extras.numpy.definitions import np
from torchoutil.extras.numpy.functional import (
numpy_to_tensor,
tensor_to_numpy,
to_numpy,
)
[docs]class ToNumpy(Module):
"""
For more information, see :func:`~torchoutil.nn.functional.numpy.to_numpy`.
"""
def __init__(
self,
*,
dtype: Union[str, np.dtype, None] = None,
force: bool = False,
) -> None:
super().__init__()
self.dtype = dtype
self.force = force
[docs] def forward(self, x: Union[Tensor, np.ndarray, list]) -> np.ndarray:
return to_numpy(x, dtype=self.dtype, force=self.force)
[docs]class TensorToNumpy(Module):
"""
For more information, see :func:`~torchoutil.nn.functional.numpy.tensor_to_numpy`.
"""
def __init__(
self,
*,
dtype: Union[str, np.dtype, None] = None,
force: bool = False,
) -> None:
super().__init__()
self.dtype = dtype
self.force = force
[docs] def forward(self, x: Tensor) -> np.ndarray:
return tensor_to_numpy(x, dtype=self.dtype, force=self.force)
[docs]class NumpyToTensor(Module):
"""
For more information, see :func:`~torchoutil.nn.functional.numpy.numpy_to_tensor`.
"""
def __init__(
self,
*,
device: DeviceLike = None,
dtype: DTypeLike = None,
) -> None:
super().__init__()
self.device = device
self.dtype = dtype
[docs] def forward(self, x: np.ndarray) -> Tensor:
return numpy_to_tensor(x, dtype=self.dtype, device=self.device)