Source code for torchoutil.nn.modules.mask

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Iterable, Union

from torch import Tensor

from torchoutil.nn.functional.mask import masked_mean, masked_sum
from torchoutil.pyoutil.collections import dump_dict

from .module import Module


[docs]class MaskedMean(Module): """ For more information, see :func:`~torchoutil.nn.functional.mask.masked_mean`. """ def __init__(self, dim: Union[None, int, Iterable[int]] = None) -> None: super().__init__() self.dim = dim
[docs] def forward(self, tensor: Tensor, non_pad_mask: Tensor) -> Tensor: reduced = masked_mean(tensor, non_pad_mask, dim=self.dim) return reduced
[docs] def extra_repr(self) -> str: return dump_dict( dict( dim=self.dim, ), ignore_lst=(None,), )
[docs]class MaskedSum(Module): """ For more information, see :func:`~torchoutil.nn.functional.mask.masked_sum`. """ def __init__(self, dim: Union[None, int, Iterable[int]] = None) -> None: super().__init__() self.dim = dim
[docs] def forward(self, tensor: Tensor, non_pad_mask: Tensor) -> Tensor: reduced = masked_sum(tensor, non_pad_mask, dim=self.dim) return reduced
[docs] def extra_repr(self) -> str: return dump_dict( dict( dim=self.dim, ), ignore_lst=(None,), )