torchoutil.nn.functional.indices module

get_inverse_perm(
indices: Tensor,
dim: int = -1,
) Tensor[source]

Return inverse permutation indices. The output will be a tensor of shape (…, N).

Args:

indices: Original permutation indices as tensor of shape (…, N). dim: Dimension of indices. defaults to -1.

Example 1

>>> x = torch.as_tensor([2, 4, 8, 10])
>>> indices = torch.randperm(len(x))
>>> x = x[indices]
>>> # x is now shuffled, to get back the original order we need the indices
>>> inv_indices = get_inverse_perm(indices)
>>> x_reordered = x[inv_indices]
>>> x_reordered
... tensor([2, 4, 8, 10])
get_perm_indices(
x1: Tensor,
x2: Tensor,
) LongTensor[source]

Find permutation between two vectors t1 and t2 which contains values from 0 to N-1.

Example 1::

>>> x1 = torch.as_tensor([0, 1, 2, 4, 3, 6, 5, 7])
>>> x2 = torch.as_tensor([0, 2, 1, 4, 3, 5, 6, 7])
>>> indices = get_perm_indices(x1, x2)
>>> torch.equal(x1, x2[indices])
True
insert_at_indices(
x: Tensor,
indices: Tensor | List | bool | int | float | complex,
values: bool | int | float | complex | Tensor,
) Tensor1D[source]

Insert value(s) in vector at specified indices.

Example 1::

>>> x = torch.as_tensor([1, 1, 2, 2, 2, 3])
>>> indices = torch.as_tensor([2, 5])
>>> values = 4
>>> insert_values(x, indices, values)
tensor([1, 1, 4, 2, 2, 2, 4, 3])
randperm_diff(
size: int,
generator: Generator | None | Literal['default'] | int = None,
device: device | None | Literal['default', 'cuda_if_available'] | str | int = None,
) LongTensor1D[source]

This function ensure that every value i cannot be the element at index i. The output will be a tensor of shape (size,).

Args:

size: The number of indices. Cannot be < 2. seed: The seed or torch.Generator used to generate permutation. device: The PyTorch device of the output indices tensor.

Example 1

>>> torch.randperm(5)
tensor([1, 4, 2, 5, 0])  # 2 is the element of index 2 !
>>> randperm_diff(5)
tensor([2, 0, 4, 1, 3])
remove_at_indices(
x: Tensor,
indices: Tensor | List | bool | int | float | complex,
) Tensor1D[source]

Remove value(s) in vector at specified indices.