Source code for torchoutil.nn.modules.activation
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Iterable, Union
from torch import Tensor
from torchoutil.nn.functional.activation import log_softmax_multidim, softmax_multidim
from torchoutil.pyoutil.collections import dump_dict
from .module import Module
[docs]class SoftmaxMultidim(Module):
"""
For more information, see :func:`~torchoutil.nn.functional.activation.softmax_multidim`.
"""
def __init__(
self,
dims: Union[Iterable[int], None] = (-1,),
) -> None:
super().__init__()
self.dims = dims
[docs] def forward(
self,
input: Tensor,
) -> Tensor:
return softmax_multidim(input, dims=self.dims)
[docs]class LogSoftmaxMultidim(Module):
"""
For more information, see :func:`~torchoutil.nn.functional.activation.softmax_multidim`.
"""
def __init__(
self,
dims: Union[Iterable[int], None] = (-1,),
) -> None:
super().__init__()
self.dims = dims
[docs] def forward(
self,
input: Tensor,
) -> Tensor:
return log_softmax_multidim(input, dims=self.dims)