Source code for torchoutil.nn.modules.mask
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Iterable, Union
from torch import Tensor
from torchoutil.nn.functional.mask import masked_mean, masked_sum
from torchoutil.pyoutil.collections import dump_dict
from .module import Module
[docs]class MaskedMean(Module):
"""
For more information, see :func:`~torchoutil.nn.functional.mask.masked_mean`.
"""
def __init__(self, dim: Union[None, int, Iterable[int]] = None) -> None:
super().__init__()
self.dim = dim
[docs] def forward(self, tensor: Tensor, non_pad_mask: Tensor) -> Tensor:
reduced = masked_mean(tensor, non_pad_mask, dim=self.dim)
return reduced
[docs]class MaskedSum(Module):
"""
For more information, see :func:`~torchoutil.nn.functional.mask.masked_sum`.
"""
def __init__(self, dim: Union[None, int, Iterable[int]] = None) -> None:
super().__init__()
self.dim = dim
[docs] def forward(self, tensor: Tensor, non_pad_mask: Tensor) -> Tensor:
reduced = masked_sum(tensor, non_pad_mask, dim=self.dim)
return reduced