#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Generic, List, Mapping, Optional, Sequence, Union
import torch
from torch import Tensor, nn
from torchoutil.core.make import DeviceLike, DTypeLike
from torchoutil.nn.functional.multiclass import (
T_Name,
index_to_name,
index_to_onehot,
name_to_index,
name_to_onehot,
onehot_to_index,
onehot_to_name,
probs_to_index,
probs_to_name,
probs_to_onehot,
)
from torchoutil.pyoutil.collections import dump_dict
from .module import Module
[docs]class IndexToOnehot(Module):
"""
For more information, see :func:`~torchoutil.nn.functional.multiclass.index_to_onehot`.
"""
def __init__(
self,
num_classes: int,
*,
padding_idx: Optional[int] = None,
device: DeviceLike = None,
dtype: DTypeLike = torch.bool,
) -> None:
super().__init__()
self.num_classes = num_classes
self.padding_idx = padding_idx
self.device = device
self.dtype = dtype
[docs] def forward(
self,
index: Union[List[int], Tensor],
) -> Tensor:
onehot = index_to_onehot(
index,
self.num_classes,
padding_idx=self.padding_idx,
device=self.device,
dtype=self.dtype,
)
return onehot
[docs]class IndexToName(Generic[T_Name], nn.Module):
"""
For more information, see :func:`~torchoutil.nn.functional.multiclass.index_to_name`.
"""
def __init__(
self,
idx_to_name: Union[Mapping[int, T_Name], Sequence[T_Name]],
) -> None:
super().__init__()
self.idx_to_name = idx_to_name
[docs] def forward(
self,
index: Union[List[int], Tensor],
) -> List[T_Name]:
name = index_to_name(index, self.idx_to_name)
return name
[docs]class OnehotToIndex(Module):
"""
For more information, see :func:`~torchoutil.nn.functional.multiclass.onehot_to_index`.
"""
def __init__(self, dim: int = -1) -> None:
super().__init__()
self.dim = dim
[docs] def forward(
self,
onehot: Tensor,
) -> Tensor:
index = onehot_to_index(onehot, dim=self.dim)
return index
[docs]class OnehotToName(Generic[T_Name], nn.Module):
"""
For more information, see :func:`~torchoutil.nn.functional.multiclass.onehot_to_name`.
"""
def __init__(
self,
idx_to_name: Union[Mapping[int, T_Name], Sequence[T_Name]],
dim: int = -1,
) -> None:
super().__init__()
self.idx_to_name = idx_to_name
self.dim = dim
[docs] def forward(
self,
onehot: Tensor,
) -> List[T_Name]:
name = onehot_to_name(onehot, self.idx_to_name, dim=self.dim)
return name
[docs]class NameToIndex(Generic[T_Name], nn.Module):
"""
For more information, see :func:`~torchoutil.nn.functional.multiclass.name_to_index`.
"""
def __init__(
self,
idx_to_name: Union[Mapping[int, T_Name], Sequence[T_Name]],
) -> None:
super().__init__()
self.idx_to_name = idx_to_name
[docs] def forward(
self,
name: List[T_Name],
) -> Tensor:
index = name_to_index(name, self.idx_to_name)
return index
[docs]class NameToOnehot(Generic[T_Name], nn.Module):
"""
For more information, see :func:`~torchoutil.nn.functional.multiclass.name_to_onehot`.
"""
def __init__(
self,
idx_to_name: Union[Mapping[int, T_Name], Sequence[T_Name]],
*,
device: DeviceLike = None,
dtype: DTypeLike = torch.bool,
) -> None:
super().__init__()
self.idx_to_name = idx_to_name
self.device = device
self.dtype = dtype
[docs] def forward(
self,
name: List[T_Name],
) -> Tensor:
onehot = name_to_onehot(
name,
self.idx_to_name,
device=self.device,
dtype=self.dtype,
)
return onehot
[docs]class ProbsToIndex(Module):
"""
For more information, see :func:`~torchoutil.nn.functional.multiclass.probs_to_index`.
"""
def __init__(self, dim: int = -1) -> None:
super().__init__()
self.dim = dim
[docs] def forward(
self,
probs: Tensor,
) -> Tensor:
index = probs_to_index(probs, dim=self.dim)
return index
[docs]class ProbsToOnehot(Module):
"""
For more information, see :func:`~torchoutil.nn.functional.multiclass.probs_to_onehot`.
"""
def __init__(
self,
*,
dim: int = -1,
device: DeviceLike = None,
dtype: DTypeLike = torch.bool,
) -> None:
super().__init__()
self.dim = dim
self.device = device
self.dtype = dtype
[docs] def forward(
self,
probs: Tensor,
) -> Tensor:
onehot = probs_to_onehot(
probs,
dim=self.dim,
device=self.device,
dtype=self.dtype,
)
return onehot
[docs]class ProbsToName(Generic[T_Name], nn.Module):
"""
For more information, see :func:`~torchoutil.nn.functional.multiclass.probs_to_name`.
"""
def __init__(
self,
idx_to_name: Union[Mapping[int, T_Name], Sequence[T_Name]],
dim: int = -1,
) -> None:
super().__init__()
self.idx_to_name = idx_to_name
self.dim = dim
[docs] def forward(
self,
probs: Tensor,
) -> List[T_Name]:
name = probs_to_name(probs, self.idx_to_name, dim=self.dim)
return name