Source code for torchoutil.nn.modules.pad

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Iterable, Union

from torch import Tensor
from torch.types import Number

from torchoutil.core.make import DeviceLike, DTypeLike, GeneratorLike
from torchoutil.nn.functional.padding import (
    PadAlign,
    PadMode,
    PadValue,
    pad_and_stack_rec,
    pad_dim,
    pad_dims,
)
from torchoutil.pyoutil.collections import dump_dict

from .module import Module


[docs]class PadDim(Module): """ For more information, see :func:`~torchoutil.nn.functional.pad.pad_dim`. """ def __init__( self, target_length: int, *, dim: int = -1, align: PadAlign = "left", pad_value: PadValue = 0.0, mode: PadMode = "constant", generator: GeneratorLike = None, ) -> None: super().__init__() self.target_length = target_length self.dim = dim self.align: PadAlign = align self.pad_value = pad_value self.mode: PadMode = mode self.generator: GeneratorLike = generator
[docs] def forward( self, x: Tensor, ) -> Tensor: return pad_dim( x, target_length=self.target_length, dim=self.dim, align=self.align, pad_value=self.pad_value, mode=self.mode, generator=self.generator, )
[docs] def extra_repr(self) -> str: return dump_dict( dict( target_length=self.target_length, dim=self.dim, align=self.align, pad_value=self.pad_value, mode=self.mode, ) )
[docs]class PadDims(Module): """ For more information, see :func:`~torchoutil.nn.functional.pad.pad_dims`. """ def __init__( self, target_lengths: Iterable[int], *, dims: Iterable[int] = (-1,), aligns: Iterable[PadAlign] = ("left",), pad_value: PadValue = 0.0, mode: PadMode = "constant", generator: GeneratorLike = None, ) -> None: super().__init__() self.target_lengths = target_lengths self.aligns = aligns self.pad_value = pad_value self.dims = dims self.mode: PadMode = mode self.generator: GeneratorLike = generator
[docs] def forward( self, x: Tensor, ) -> Tensor: return pad_dims( x, target_lengths=self.target_lengths, dims=self.dims, aligns=self.aligns, pad_value=self.pad_value, mode=self.mode, generator=self.generator, )
[docs] def extra_repr(self) -> str: return dump_dict( dict( target_lengths=self.target_lengths, dims=self.dims, aligns=self.aligns, pad_value=self.pad_value, mode=self.mode, ) )
[docs]class PadAndStackRec(Module): """ For more information, see :func:`~torchoutil.nn.functional.pad.pad_and_stack_rec`. """ def __init__( self, pad_value: Number = 0, *, align: PadAlign = "left", device: DeviceLike = None, dtype: DTypeLike = None, ) -> None: super().__init__() self.pad_value = pad_value self.align: PadAlign = align self.device = device self.dtype = dtype
[docs] def forward( self, sequence: Union[Tensor, int, float, tuple, list], ) -> Tensor: return pad_and_stack_rec( sequence, self.pad_value, align=self.align, device=self.device, dtype=self.dtype, )
[docs] def extra_repr(self) -> str: return dump_dict( dict( pad_value=self.pad_value, align=self.align, device=self.device, dtype=self.dtype, ) )