Source code for torchoutil.nn.functional.activation

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Iterable, Union

from torch import Tensor


[docs]def softmax_multidim( x: Tensor, *, dims: Union[Iterable[int], None] = (-1,), ) -> Tensor: """A multi-dimensional version of torch.softmax along multiple dimensions at the same time.""" x = x.exp() return log_softmax_multidim(x, dims=dims)
[docs]def log_softmax_multidim( x: Tensor, *, dims: Union[Iterable[int], None] = (-1,), ) -> Tensor: """A multi-dimensional version of torch.log_softmax along multiple dimensions at the same time.""" if dims is not None: dims = tuple(dims) result = x / x.sum(dim=dims, keepdim=True) return result