Source code for torchoutil.nn.modules.container

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Callable, Dict, Generic, Iterable, List, Mapping, Optional, overload

from torch import nn
from typing_extensions import Concatenate, ParamSpec

from ._mixins import EModule  # noqa: F401
from ._mixins import ESequential  # noqa: F401
from ._mixins import (
    _DEFAULT_DEVICE_DETECT_MODE,
    ConfigModule,
    DeviceDetectMode,
    InType,
    OutType,
    OutType3,
    TypedModule,
    TypedModuleLike,
)

P = ParamSpec("P")


[docs]class EModuleList( Generic[InType, OutType3], EModule[InType, List[OutType3]], nn.ModuleList, ): """Enriched torch.nn.ModuleList with proxy device, forward typing and automatic configuration detection from attributes. Designed to work with `torchoutil.nn.EModule` instances. The default behaviour is the same than PyTorch ModuleList class, except for the forward call which returns a list containing the output of each module called separately. """ @overload def __init__( self, modules: Optional[Iterable[TypedModuleLike[InType, OutType3]]] = None, *, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: ... @overload def __init__( self, modules: Optional[Iterable[nn.Module]] = None, *, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: ... def __init__( self, modules=None, *, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: EModule.__init__( self, strict_load=strict_load, config_to_extra_repr=config_to_extra_repr, device_detect_mode=device_detect_mode, ) nn.ModuleList.__init__(self, modules)
[docs] def forward(self, *args: InType, **kwargs: InType) -> List[OutType3]: return [module(*args, **kwargs) for module in self]
[docs]class EModuleDict( Generic[InType, OutType3], EModule[InType, Dict[str, OutType3]], nn.ModuleDict, ): """Enriched torch.nn.ModuleDict with proxy device, forward typing and automatic configuration detection from attributes. Designed to work with `torchoutil.nn.EModule` instances. The default behaviour is the same than PyTorch ModuleDict class, except for the forward call which returns a dict containing the output of each module called separately. """ @overload def __init__( self, modules: Optional[Mapping[str, TypedModuleLike[InType, OutType3]]] = None, *, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: ... @overload def __init__( self, modules: Optional[Mapping[str, nn.Module]] = None, *, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: ... def __init__( self, modules=None, *, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: EModule.__init__( self, strict_load=strict_load, config_to_extra_repr=config_to_extra_repr, device_detect_mode=device_detect_mode, ) nn.ModuleDict.__init__(self, modules)
[docs] def forward(self, *args: InType, **kwargs: InType) -> Dict[str, OutType3]: return {name: module(*args, **kwargs) for name, module in self.items()}
[docs]class EModulePartial( Generic[InType, OutType], EModule[InType, OutType], ): def __init__( self, fn: Callable[Concatenate[InType, P], OutType], *args: P.args, **kwargs: P.kwargs, ) -> None: super().__init__() self.fn = fn self.args = args self.kwargs = kwargs
[docs] def forward(self, x: InType) -> OutType: # type: ignore return self.fn(x, *self.args, **self.kwargs)
[docs] def extra_repr(self) -> str: return f"{self.fn.__name__}, {ConfigModule.extra_repr(self)}"
ModuleList = EModuleList ModuleDict = EModuleDict ModulePartial = EModulePartial Sequential = ESequential def __test_typing_1() -> None: import torch from torch import Tensor class LayerA(EModule[Tensor, Tensor]): def forward(self, x: Tensor) -> Tensor: return x * x class LayerB(EModule[Tensor, int]): def forward(self, x: Tensor) -> int: return int(x.sum().item()) class LayerC(EModule[int, Tensor]): def forward(self, x: int) -> Tensor: return torch.as_tensor(x) x = torch.rand(10) xa = LayerA()(x) xb = LayerB()(x) seq = ESequential(LayerA(), LayerA(), LayerB()) xab = seq(x) seq = LayerA() | LayerA() | LayerB() xab = seq(x) seq = LayerC().chain(LayerA()) xc = seq(2) assert isinstance(xa, Tensor) assert isinstance(xb, int) assert isinstance(xab, int) assert isinstance(xc, Tensor) class LayerD(nn.Module): def forward(self, x: Tensor) -> int: return int(x.item()) class LayerE(nn.Module): def forward(self, x: bool) -> str: return str(x) seq = ESequential(LayerD(), LayerE()) y = seq(torch.rand()) assert isinstance(y, str) class LayerF(TypedModule[bool, str]): def forward(self, x): return str(x) seq = ESequential(LayerF()) y = seq(True)