torchoutil.hub.registry module

class RegistryEntry(
*args,
**kwargs,
)[source]

Bases: dict

architecture: typing_extensions.NotRequired[str]
fname: str
hash_type: typing_extensions.NotRequired[Literal['sha256', 'md5']]
hash_value: typing_extensions.NotRequired[str]
state_dict_key: typing_extensions.NotRequired[str]
url: str
class RegistryHub(
infos: Mapping[T_Hashable, RegistryEntry],
register_root: str | Path = '~/.cache/torch/hub/checkpoints',
)[source]

Bases: Generic[T_Hashable]

download_file(
name: T_Hashable,
force: bool = False,
check_hash: bool = True,
verbose: int = 0,
) Tuple[Path, bool][source]

Download checkpoint file.

classmethod from_file(
path: str | Path,
) RegistryHub[source]

Load register info from JSON file.

get_path(
name: T_Hashable,
) Path[source]

Returns the expected filepath of an element.

property infos: Dict[T_Hashable, RegistryEntry]
is_valid_hash(
name: T_Hashable,
) bool[source]

Returns True if target file hash is valid. If no hash is provided in infos, this function also returns True.

load_state_dict(name_or_path: ~torchoutil.hub.registry.T_Hashable | str | ~pathlib.Path, *, device: ~torch.device | None | ~typing.Literal['default', 'cuda_if_available'] | str | int = None, offline: bool = False, load_fn: ~typing.Callable[[~typing.Any], ~torchoutil.serialization.load_fn.T] | ~typing.Literal['csv', 'json', 'h5py', 'numpy', 'pickle', 'safetensors', 'torch', 'torchaudio', 'yaml'] = <function load_torch>, load_kwds: ~typing.Dict[str, ~typing.Any] | None = None, verbose: int = 0) Dict[str, Tensor][source]

Load state_dict weights.

Args:

model_name_or_path: Model name (case sensitive) or path to checkpoint file. device: Device of checkpoint weights. (deprecated) offline: If False, the checkpoint from a model name will be automatically downloaded. load_fn: Load function backend. defaults to torch.load. load_kwds: Optional keywords arguments passed to load_fn. defaults to None. verbose: Verbose level. defaults to 0.

Returns:

Loaded file content.

property names: List[T_Hashable]
property paths: List[Path]
property register_root: Path
remove_file(
name: T_Hashable,
) None[source]
save(
path: str | Path,
) None[source]

Save info to JSON file.

get_default_register_root() Path[source]

Default register root path is ~/.cache/torch/hub/checkpoints, which is based on torch.hub.get_dir.