Source code for torchoutil.serialization.common

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
import re
from argparse import Namespace
from dataclasses import asdict
from enum import Enum
from pathlib import Path
from re import Pattern
from typing import (
    Any,
    Dict,
    Iterable,
    List,
    Literal,
    Mapping,
    Optional,
    TypeVar,
    Union,
    overload,
)

import torch
from torch import Tensor

from torchoutil.core.packaging import (
    _H5PY_AVAILABLE,
    _NUMPY_AVAILABLE,
    _OMEGACONF_AVAILABLE,
    _PANDAS_AVAILABLE,
    _SAFETENSORS_AVAILABLE,
    _TORCHAUDIO_AVAILABLE,
    _YAML_AVAILABLE,
)
from torchoutil.pyoutil.importlib import Placeholder
from torchoutil.pyoutil.typing import (
    BuiltinNumber,
    DataclassInstance,
    NamedTupleInstance,
    T_BuiltinScalar,
    is_builtin_scalar,
    is_dataclass_instance,
    is_namedtuple_instance,
)

if _OMEGACONF_AVAILABLE:
    from omegaconf import DictConfig, ListConfig, OmegaConf  # type: ignore

if _PANDAS_AVAILABLE:
    import pandas as pd  # type: ignore

    DataFrame = pd.DataFrame  # type: ignore
else:

    class DataFrame(Placeholder):
        ...


pylog = logging.getLogger(__name__)

T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")

UnkMode = Literal["identity", "error"]
SavingBackend = Literal[
    "csv",
    "json",
    "h5py",
    "numpy",
    "pickle",
    "safetensors",
    "torch",
    "torchaudio",
    "yaml",
]

# Note: order matter here: last extension of a backend is the default/recommanded one
PATTERN_TO_BACKEND: Dict[str, SavingBackend] = {
    r"^.+\.tsv$": "csv",
    r"^.+\.csv$": "csv",
    r"^.+\.json$": "json",
    r"^.+\.pkl$": "pickle",
    r"^.+\.pickle$": "pickle",
    r"^.+\.torch$": "torch",
    r"^.+\.ckpt$": "torch",
    r"^.+\.pt$": "torch",
}

if _H5PY_AVAILABLE:
    PATTERN_TO_BACKEND.update(
        {
            r"^.+\.h5$": "h5py",
            r"^.+\.hdf$": "h5py",
        }
    )


if _NUMPY_AVAILABLE:
    import numpy as np

    PATTERN_TO_BACKEND.update(
        {
            r"^.+\.npz$": "numpy",
            r"^.+\.npy$": "numpy",
        }
    )

if _SAFETENSORS_AVAILABLE:
    PATTERN_TO_BACKEND.update(
        {
            r"^.+\.safetensors$": "safetensors",
        }
    )

if _TORCHAUDIO_AVAILABLE:
    PATTERN_TO_BACKEND.update(
        {
            r"^.+\.mp3$": "torchaudio",
            r"^.+\.wav$": "torchaudio",
            r"^.+\.aac$": "torchaudio",
            r"^.+\.ogg$": "torchaudio",
            r"^.+\.flac$": "torchaudio",
        }
    )

if _YAML_AVAILABLE:
    PATTERN_TO_BACKEND.update(
        {
            r".+\.yml$": "yaml",
            r".+\.yaml$": "yaml",
        }
    )


def _fpath_to_saving_backend(
    fpath: Union[str, Path],
    verbose: int = 0,
) -> SavingBackend:
    fname = Path(fpath).name

    saving_backend: Optional[SavingBackend] = None
    for pattern, backend in PATTERN_TO_BACKEND.items():
        if re.match(pattern, fname) is not None:
            saving_backend = backend
            break

    if saving_backend is None:
        msg = f"Unknown file pattern '{fname}'. (expected one of {tuple(PATTERN_TO_BACKEND.keys())} or specify the backend argument with `to.load(..., saving_backend=\"backend\")`)"
        raise ValueError(msg)

    if verbose >= 2:
        pylog.debug(f"Loading file '{str(fpath)}' using {saving_backend=}.")
    return saving_backend


BACKEND_TO_PATTERN: Dict[SavingBackend, str] = {
    backend: ext for ext, backend in PATTERN_TO_BACKEND.items()
}
UNK_MODES = ("identity", "error")


@overload
def to_builtin(
    x: Enum,
    *,
    unk_mode: UnkMode = "error",
) -> str:
    ...


@overload
def to_builtin(
    x: Path,
    *,
    unk_mode: UnkMode = "error",
) -> str:
    ...


@overload
def to_builtin(
    x: Pattern,
    *,
    unk_mode: UnkMode = "error",
) -> str:
    ...


@overload
def to_builtin(
    x: Namespace,
    *,
    unk_mode: UnkMode = "error",
) -> Dict[str, Any]:
    ...


@overload
def to_builtin(
    x: Tensor,
    *,
    unk_mode: UnkMode = "error",
) -> Union[List, BuiltinNumber]:
    ...


@overload
def to_builtin(
    x: Mapping[K, V],
    *,
    unk_mode: UnkMode = "error",
) -> Dict[K, V]:
    ...


@overload
def to_builtin(
    x: DataclassInstance,
    *,
    unk_mode: UnkMode = "error",
) -> Dict[str, Any]:
    ...


@overload
def to_builtin(
    x: NamedTupleInstance,
    *,
    unk_mode: UnkMode = "error",
) -> Dict[str, Any]:
    ...


@overload
def to_builtin(
    x: T_BuiltinScalar,
    *,
    unk_mode: UnkMode = "error",
) -> T_BuiltinScalar:
    ...


@overload
def to_builtin(
    x: Any,
    *,
    unk_mode: UnkMode = "error",
) -> Any:
    ...


[docs]def to_builtin( x: Any, *, unk_mode: UnkMode = "error", ) -> Any: """Convert an object to a python builtin equivalent. This function can be used to sanitize data before saving to YAML or CSV file. Args: x: Object to convert to built-in equivalent. unk_mode: When an object type is not recognized, unk_mode defines the behaviour. If unk_mode == "identity", the object is returned unchanged. If unk_mode == "error", a TypeError is raised. """ # Terminal cases if is_builtin_scalar(x, strict=True): return x if isinstance(x, Enum): return x.name if isinstance(x, Path): return str(x) if isinstance(x, Pattern): return x.pattern if isinstance(x, Tensor): return x.tolist() if isinstance(x, torch.dtype): return str(x) if _NUMPY_AVAILABLE and isinstance(x, np.ndarray): # type: ignore return x.tolist() if _NUMPY_AVAILABLE and isinstance(x, np.generic): # type: ignore return x.item() if _NUMPY_AVAILABLE and isinstance(x, np.dtype): # type: ignore return str(x) # Non-terminal cases (iterables and mappings) if _OMEGACONF_AVAILABLE and isinstance(x, (DictConfig, ListConfig)): # type: ignore return to_builtin(OmegaConf.to_container(x, resolve=False, enum_to_str=True)) # type: ignore if _PANDAS_AVAILABLE and isinstance(x, DataFrame): # type: ignore return to_builtin(x.to_dict("list")) if isinstance(x, Namespace): return to_builtin(x.__dict__) if is_dataclass_instance(x): return to_builtin(asdict(x)) if is_namedtuple_instance(x): return to_builtin(x._asdict()) if isinstance(x, Mapping): return {to_builtin(k): to_builtin(v) for k, v in x.items()} if isinstance(x, Iterable): return [to_builtin(xi) for xi in x] if unk_mode == "identity": return x elif unk_mode == "error": raise TypeError(f"Unsupported argument type {type(x)}.") else: raise ValueError(f"Invalid argument {unk_mode=}. (expected one of {UNK_MODES})")