Source code for torchoutil.core.packaging
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Callable, Final, Iterable, Union
import torch
from torchoutil.pyoutil.functools import identity
from torchoutil.pyoutil.importlib import is_available_package
from torchoutil.pyoutil.semver import Version
def _get_extra_version(name: str) -> str:
try:
module = __import__(name)
return str(module.__version__)
except ImportError:
return "not_installed"
except AttributeError:
return "unknown"
_EXTRAS_PACKAGES = (
"colorlog",
"h5py",
"numpy",
"omegaconf",
"pandas",
"safetensors",
"scipy",
"tensorboard",
"torchaudio",
"tqdm",
"yaml",
)
_EXTRA_AVAILABLE = {name: is_available_package(name) for name in _EXTRAS_PACKAGES}
_EXTRA_VERSION = {name: _get_extra_version(name) for name in _EXTRAS_PACKAGES}
_COLORLOG_AVAILABLE: Final[bool] = _EXTRA_AVAILABLE["colorlog"]
_H5PY_AVAILABLE: Final[bool] = _EXTRA_AVAILABLE["h5py"]
_NUMPY_AVAILABLE: Final[bool] = _EXTRA_AVAILABLE["numpy"]
_OMEGACONF_AVAILABLE: Final[bool] = _EXTRA_AVAILABLE["omegaconf"]
_PANDAS_AVAILABLE: Final[bool] = _EXTRA_AVAILABLE["pandas"]
_SAFETENSORS_AVAILABLE: Final[bool] = _EXTRA_AVAILABLE["safetensors"]
_SCIPY_AVAILABLE: Final[bool] = _EXTRA_AVAILABLE["scipy"]
_TENSORBOARD_AVAILABLE: Final[bool] = _EXTRA_AVAILABLE["tensorboard"]
_TORCHAUDIO_AVAILABLE: Final[bool] = _EXTRA_AVAILABLE["torchaudio"]
_TQDM_AVAILABLE: Final[bool] = _EXTRA_AVAILABLE["tqdm"]
_YAML_AVAILABLE: Final[bool] = _EXTRA_AVAILABLE["yaml"]
[docs]def requires_packages(packages: Union[str, Iterable[str]]) -> Callable:
if isinstance(packages, str):
packages = [packages]
else:
packages = list(packages)
missing = [pkg for pkg in packages if not is_available_package(pkg)]
if len(missing) == 0:
return identity
prefix = "\n - "
missing_str = prefix.join(missing)
msg = (
f"Cannot use/import objects because the following optionals dependencies are missing:"
f"{prefix}{missing_str}\n"
f"Please install them using `pip install torchoutil[extras]`."
)
raise ImportError(msg)
[docs]def torch_version_ge_1_13() -> bool:
version_str = str(torch.__version__)
version = Version.from_str(version_str)
return version >= Version("1.13.0")