#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Sized,
Tuple,
Union,
get_args,
)
import torch
from torch import Size, Tensor
from torch.types import Number
from typing_extensions import TypeAlias
from torchoutil.core.make import (
DeviceLike,
DTypeLike,
GeneratorLike,
as_device,
as_dtype,
as_generator,
)
from torchoutil.nn import functional as F
from torchoutil.types import is_number_like
PadAlign: TypeAlias = Literal["left", "right", "center", "random"]
PadValue: TypeAlias = Union[Number, Callable[[Tensor], Number]]
PadMode: TypeAlias = Literal["constant", "reflect", "replicate", "circular"]
[docs]def pad_dim(
x: Tensor,
target_length: int,
*,
dim: int = -1,
align: PadAlign = "left",
pad_value: PadValue = 0.0,
mode: PadMode = "constant",
generator: GeneratorLike = None,
) -> Tensor:
"""Generic function for pad a single dimension."""
return pad_dims(
x,
target_lengths=[target_length],
dims=[dim],
aligns=[align],
pad_value=pad_value,
mode=mode,
generator=generator,
)
[docs]def pad_dims(
x: Tensor,
target_lengths: Iterable[int],
*,
dims: Iterable[int] = (-1,),
aligns: Union[PadAlign, Iterable[PadAlign]] = ("left",),
pad_value: PadValue = 0.0,
mode: PadMode = "constant",
generator: GeneratorLike = None,
) -> Tensor:
"""Generic function to pad multiple dimensions."""
if isinstance(pad_value, Callable):
pad_value = pad_value(x)
if isinstance(aligns, str):
aligns = [aligns]
pad_seq = __generate_pad_seq(
x.shape,
target_lengths=target_lengths,
dims=dims,
aligns=aligns,
generator=generator,
)
x = F.pad(x, pad_seq, mode=mode, value=pad_value)
return x
[docs]def pad_and_stack_rec(
sequence: Union[Tensor, int, float, tuple, list],
pad_value: Number = 0,
*,
align: PadAlign = "left",
device: DeviceLike = None,
dtype: DTypeLike = None,
) -> Tensor:
"""Recursive version of torch.nn.utils.rnn.pad_sequence, with padding of Tensors.
Args:
sequence: The sequence to pad. Must be convertable to tensor by having the correct number of dims in all sublists.
pad_value: The pad value used.
device: The device of the output Tensor. defaults to None.
dtype: The dtype of the output Tensor. defaults to None.
Example 1::
-----------
>>> sequence = [[1, 2], [3], [], [4, 5]]
>>> output = pad_sequence_rec(sequence, 0)
tensor([[1, 2], [3, 0], [0, 0], [4, 5]])
Example 2::
-----------
>>> invalid_sequence = [[1, 2, 3], 3]
>>> output = pad_sequence_rec(invalid_sequence, 0)
ValueError : Cannot pad sequence of tensors of differents number of dims.
"""
device = as_device(device)
dtype = as_dtype(dtype)
def _impl(sequence: Union[Tensor, int, float, tuple, list]) -> Tensor:
if isinstance(sequence, Tensor):
return sequence.to(dtype=dtype, device=device)
elif is_number_like(sequence) or (
isinstance(sequence, Sized) and len(sequence) == 0
):
return torch.as_tensor(sequence, dtype=dtype, device=device) # type: ignore
elif isinstance(sequence, (list, tuple)):
tensors = [_impl(elt) for elt in sequence]
if F.is_stackable(tensors):
return torch.stack(tensors)
shapes = [elt.shape for elt in tensors]
shape0 = shapes[0]
if not all(len(shape) == len(shape0) for shape in shapes[1:]):
msg = f"Cannot pad sequence of tensors of differents number of dims. (with {shapes=})"
raise ValueError(msg)
max_lens = [max(shape[i] for shape in shapes) for i in range(len(shape0))]
tensors = [
pad_dims(
xi,
target_lengths=max_lens,
pad_value=pad_value,
aligns=[align] * xi.ndim, # type: ignore
dims=range(xi.ndim),
)
for xi in tensors
]
result = torch.stack(tensors)
return result
else:
msg = f"Invalid type {type(sequence)}. (expected Tensor, int, float, list or tuple)"
raise TypeError(msg)
return _impl(sequence)
[docs]def cat_padded_batch(
x1: Tensor,
x1_lens: Tensor,
x2: Tensor,
x2_lens: Tensor,
seq_dim: int = -1,
batch_dim: int = 0,
) -> Tuple[Tensor, Tensor]:
"""Concatenate padded batched of sequences.
Args:
x1: First batch with D dims of shape (batch_size, ..., N1, ...)
x1_lens: First lengths of each element in sequence dim of shape (batch_size,).
x2: Second batch with D dims of shape (batch_size, ..., N2, ...)
The shape must be the same than x1 unless for the dimension N2.
x2_lens: Second lengths of each element in sequence dim of shape (batch_size,).
seq_dim: Dimension index of sequence. defaults to -1.
batch_dim: Batch dimension index. defaults to 0.
"""
_check_cat_padded_batch(x1, x1_lens, x2, x2_lens, seq_dim, batch_dim)
x12_lens = x1_lens + x2_lens
sum_size_12 = x1.shape[seq_dim] + x2.shape[seq_dim]
x12 = pad_dim(x1, sum_size_12, dim=seq_dim)
kwd: Dict[str, Any] = dict(device=x1.device, dtype=torch.long)
indices = torch.arange(x2_lens.max().item(), **kwd)
unsq_x1_lens = x1_lens
ndim = x1.ndim
for i in range(ndim):
if i != (seq_dim % ndim):
indices = indices.unsqueeze(dim=i)
if i != (batch_dim % ndim):
unsq_x1_lens = unsq_x1_lens.unsqueeze(dim=i)
expand_size = list(x2.shape)
expand_size[seq_dim] = -1
indices = indices.expand(*expand_size)
indices = indices + unsq_x1_lens
x12.scatter_(seq_dim, indices, x2)
max_size_12 = int(x12_lens.max().item())
if max_size_12 < sum_size_12:
slices = [slice(None) for _ in range(ndim)]
slices[seq_dim] = slice(max_size_12)
x12 = x12[slices]
return x12, x12_lens
def __generate_pad_seq(
x_shape: Union[Size, Tuple[int, ...]],
target_lengths: Iterable[int],
*,
dims: Iterable[int] = (-1,),
aligns: Iterable[PadAlign] = ("left",),
generator: GeneratorLike = None,
) -> List[int]:
"""Generate pad sequence for torch.nn.functional.pad from target lengths, dims, aligns and generator.
Args:
x_shape: Shape of the tensor to pad.
target_lengths: Expected lengths of each dim in the tensor.
dims: Specified dims indices. defaults to (-1,).
aligns: Alignment for each dim. defaults to ("left",).
generator: Optional generator when align="random". defaults to None.
"""
target_lengths = list(target_lengths)
aligns = list(aligns)
dims = list(dims)
generator = as_generator(generator)
if len(dims) == 0:
msg = f"Invalid argument {dims=}. (cannot use an empty list of dimensions)"
raise ValueError(msg)
if len(target_lengths) != len(dims):
msg = f"Invalid number of targets lengths ({len(target_lengths)}) with the number of dimensions ({len(dims)})."
raise ValueError(msg)
if len(aligns) != len(dims):
msg = f"Invalid number of aligns ({len(aligns)}) with the number of dimensions ({len(dims)})."
raise ValueError(msg)
pad_seq = [0 for _ in range(len(x_shape) * 2)]
for target_length, dim, align in zip(target_lengths, dims, aligns):
missing = max(target_length - x_shape[dim], 0)
if align == "left":
missing_left = 0
missing_right = missing
elif align == "right":
missing_left = missing
missing_right = 0
elif align == "center":
missing_left = missing // 2 + missing % 2
missing_right = missing // 2
elif align == "random":
missing_left = int(
torch.randint(
low=0,
high=missing + 1,
size=(),
generator=generator,
).item()
)
missing_right = missing - missing_left
else:
msg = f"Invalid argument {align=}. (expected one of {get_args(PadAlign)})"
raise ValueError(msg)
# Note: pad_seq : [pad_left_dim_-1, pad_right_dim_-1, pad_left_dim_-2, pad_right_dim_-2, ...)
idx = len(x_shape) - (dim % len(x_shape)) - 1
assert pad_seq[idx * 2] == 0 and pad_seq[idx * 2 + 1] == 0
pad_seq[idx * 2] = missing_left
pad_seq[idx * 2 + 1] = missing_right
return pad_seq
def _check_cat_padded_batch(
x1: Tensor,
x1_lens: Tensor,
x2: Tensor,
x2_lens: Tensor,
seq_dim: int,
batch_dim: int,
) -> None:
if x1.ndim != x2.ndim:
raise ValueError(f"Invalid arguments ndims. (found {x1.ndim=} != {x2.ndim=})")
if x1.ndim < 2:
raise ValueError(f"Invalid arguments ndims. (found {x1.ndim=} < 2)")
batch_size = x1.shape[batch_dim]
if not (x1_lens.shape == x2_lens.shape == Size((batch_size,))):
msg = f"Invalid arguments shape. (with {x1_lens.shape=} and {x2_lens.shape=})"
raise ValueError(msg)
x1_shape = torch.as_tensor(x1.shape)
x2_shape = torch.as_tensor(x2.shape)
eq_mask = x1_shape.eq(x2_shape)
eq_mask[seq_dim] = True
if not eq_mask.all():
raise ValueError(f"Invalid arguments shape. (with {x1.shape=} and {x2.shape=})")