Source code for torchoutil.nn.functional.cropping
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Iterable, List, Literal, Union, get_args
import torch
from torch import Tensor
from torchoutil.core.make import GeneratorLike, as_generator
CropAlign = Literal["left", "right", "center", "random"]
[docs]def crop_dim(
x: Tensor,
target_length: int,
*,
dim: int = -1,
align: CropAlign = "left",
generator: GeneratorLike = None,
) -> Tensor:
"""Generic function to crop a single dimension."""
return crop_dims(
x,
[target_length],
dims=[dim],
aligns=[align],
generator=generator,
)
[docs]def crop_dims(
x: Tensor,
target_lengths: Iterable[int],
*,
dims: Union[Iterable[int], Literal["auto"]] = "auto",
aligns: Union[CropAlign, Iterable[CropAlign]] = "left",
generator: GeneratorLike = None,
) -> Tensor:
"""Generic function to crop multiple dimensions."""
target_lengths = list(target_lengths)
aligns_lst: List[CropAlign]
if isinstance(aligns, str):
aligns_lst = [aligns] * len(target_lengths)
else:
aligns_lst = list(aligns)
del aligns
if dims == "auto":
dims = list(range(-len(target_lengths), 0))
else:
dims = list(dims)
generator = as_generator(generator)
if len(target_lengths) != len(dims):
msg = f"Invalid number of targets lengths ({len(target_lengths)}) with the number of dimensions ({len(dims)})."
raise ValueError(msg)
if len(aligns_lst) != len(dims):
msg = f"Invalid number of aligns ({len(aligns_lst)}) with the number of dimensions ({len(dims)})."
raise ValueError(msg)
slices = [slice(None)] * len(x.shape)
for target_length, dim, align in zip(target_lengths, dims, aligns_lst):
if x.shape[dim] <= target_length:
continue
if align == "left":
start = 0
end = target_length
elif align == "right":
start = x.shape[dim] - target_length
end = None
elif align == "center":
diff = x.shape[dim] - target_length
start = diff // 2 + diff % 2
end = start + target_length
elif align == "random":
diff = x.shape[dim] - target_length
start = torch.randint(low=0, high=diff, size=(), generator=generator).item()
end = start + target_length
else:
msg = f"Invalid argument {align=}. (expected one of {get_args(CropAlign)})"
raise ValueError(msg)
slices[dim] = slice(start, end)
x = x[slices]
return x