Source code for torchoutil.hub.registry

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
import os
import os.path as osp
from pathlib import Path
from typing import (
    Any,
    Dict,
    Generic,
    Hashable,
    List,
    Mapping,
    Optional,
    Tuple,
    TypedDict,
    TypeVar,
    Union,
)

import torch
from torch import Tensor
from typing_extensions import NotRequired

from torchoutil.core.make import DeviceLike, as_device
from torchoutil.pyoutil.hashlib import HashName, hash_file
from torchoutil.pyoutil.warnings import deprecated_function, warn_once
from torchoutil.serialization.json import dump_json, load_json
from torchoutil.serialization.load_fn import LOAD_FNS, LoadFnLike, load_torch

from .paths import get_cache_dir

T_Hashable = TypeVar("T_Hashable", bound=Hashable)

pylog = logging.getLogger(__name__)


[docs]class RegistryEntry(TypedDict): url: str fname: str hash_value: NotRequired[str] hash_type: NotRequired[HashName] state_dict_key: NotRequired[str] architecture: NotRequired[str]
[docs]class RegistryHub(Generic[T_Hashable]): def __init__( self, infos: Mapping[T_Hashable, RegistryEntry], register_root: Union[str, Path] = "~/.cache/torch/hub/checkpoints", ) -> None: """ Args: infos: Maps model_name to their checkpoint information, with download url, filename, hash value, hash type and state_dict key. register_root: Directory where checkpoints are saved. defaults to `~/.cache/torch/hub/checkpoints`. """ infos = dict(infos.items()) if register_root is None: register_root = get_cache_dir().joinpath("checkpoints") else: register_root = Path(register_root).resolve().expanduser() super().__init__() self._infos = infos self._register_root = register_root @property def infos(self) -> Dict[T_Hashable, RegistryEntry]: return self._infos @property def register_root(self) -> Path: return self._register_root.resolve() @property def names(self) -> List[T_Hashable]: return list(self._infos.keys()) @property def paths(self) -> List[Path]: return [self.get_path(model_name) for model_name in self.names]
[docs] def get_path(self, name: T_Hashable) -> Path: """Returns the expected filepath of an element.""" if name not in self.names: msg = f"Invalid argument {name=}. (expected one of {self.names})" raise ValueError(msg) fname = self._infos[name]["fname"] fpath = self.register_root.joinpath(fname) return fpath
[docs] def load_state_dict( self, name_or_path: Union[T_Hashable, str, Path], *, device: DeviceLike = None, offline: bool = False, load_fn: LoadFnLike = load_torch, load_kwds: Optional[Dict[str, Any]] = None, verbose: int = 0, ) -> Dict[str, Tensor]: """Load state_dict weights. Args: model_name_or_path: Model name (case sensitive) or path to checkpoint file. device: Device of checkpoint weights. (deprecated) offline: If False, the checkpoint from a model name will be automatically downloaded. load_fn: Load function backend. defaults to torch.load. load_kwds: Optional keywords arguments passed to load_fn. defaults to None. verbose: Verbose level. defaults to 0. Returns: Loaded file content. """ if isinstance(load_fn, str): if load_fn not in LOAD_FNS: msg = f"Invalid argument {load_fn=}. (expected one of {tuple(LOAD_FNS.keys())})" raise ValueError(msg) load_fn = LOAD_FNS[load_fn] if load_kwds is None: load_kwds = {} if device is not None: src_device = device device = as_device(device) msg = f"Deprecated argument device={src_device}. Use `load_kwds=dict(map_location={device})` with function torch.load instead." warn_once(msg) if device is not None: load_kwds["map_location"] = device if isinstance(name_or_path, (str, Path)) and osp.isfile(name_or_path): path = Path(name_or_path) name = self._get_name(path) else: name = name_or_path try: path = self.get_path(name_or_path) # type: ignore except ValueError: msg = f"Invalid argument {name_or_path=}. (expected a path to a checkpoint file or a model name in {self.names})" raise ValueError(msg) if path.is_file(): pass elif offline: msg = f"Cannot find checkpoint model file in '{path}' for model '{name}' with mode {offline=}." raise FileNotFoundError(msg) else: self.download_file(name, verbose=verbose) # type: ignore del name_or_path info = self._infos.get(name, {}) # type: ignore state_dict_key = info.get("state_dict_key", None) data = load_fn(path, **load_kwds) if state_dict_key is None: result = data else: result = data[state_dict_key] if verbose >= 1: msg = f"Loading encoder weights from '{path}'..." pylog.info(msg) return result
[docs] def download_file( self, name: T_Hashable, force: bool = False, check_hash: bool = True, verbose: int = 0, ) -> Tuple[Path, bool]: """Download checkpoint file.""" model_path = self.get_path(name) exists = model_path.exists() if exists and not force: return model_path, False if exists and force: os.remove(model_path) model_path.parent.mkdir(parents=True, exist_ok=True) url = self._infos[name]["url"] torch.hub.download_url_to_file(url, str(model_path), progress=verbose >= 1) if not check_hash: return model_path, True valid = self.is_valid_hash(name) if valid: return model_path, True else: raise ValueError(f"Invalid hash for file '{model_path}'.")
[docs] def remove_file( self, name: T_Hashable, ) -> None: path = self.get_path(name) if path.is_file(): os.remove(path) elif path.exists(): msg = f"Invalid argument {name=}, which redirect to a non-file {path=}." raise ValueError(msg)
[docs] def is_valid_hash( self, name: T_Hashable, ) -> bool: """Returns True if target file hash is valid. If no hash is provided in infos, this function also returns True.""" info = self.infos[name] if "hash_type" not in info or "hash_value" not in info: msg = f"Cannot check hash for {name}. (cannot find any expected hash value or type)" pylog.warning(msg) return True hash_type = info["hash_type"] expected_hash_value = info["hash_value"] model_path = self.get_path(name) hash_value = hash_file(model_path, hash_type) valid = hash_value == expected_hash_value return valid
[docs] def save(self, path: Union[str, Path]) -> None: """Save info to JSON file.""" args = { "infos": self._infos, "register_root": str(self._register_root), } dump_json(args, path)
[docs] @classmethod def from_file(cls, path: Union[str, Path]) -> "RegistryHub": """Load register info from JSON file.""" args = load_json(path) return RegistryHub(**args)
def _get_name(self, path: Union[str, Path]) -> Optional[T_Hashable]: path_to_name = { path_i.resolve().expanduser(): name_i for path_i, name_i in zip(self.paths, self.names) } path = Path(path).resolve().expanduser() name = path_to_name.get(path, None) return name
[docs]@deprecated_function() def get_default_register_root() -> Path: """Default register root path is `~/.cache/torch/hub/checkpoints`, which is based on `torch.hub.get_dir`.""" path = torch.hub.get_dir() path = Path(path) path = path.joinpath("checkpoints") return path