Source code for torchoutil.extras.safetensors

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from pathlib import Path
from typing import Any, Dict, Literal, Optional, Tuple, Union, overload

from safetensors import safe_open
from safetensors.torch import save
from torch import Tensor

from torchoutil.nn import functional as F
from torchoutil.pyoutil.inspect import get_fullname
from torchoutil.pyoutil.io import _setup_path
from torchoutil.pyoutil.typing.guards import isinstance_guard
from torchoutil.pyoutil.warnings import deprecated_alias


@overload
def load_safetensors(
    fpath: Union[str, Path],
    *,
    device: str = "cpu",
    return_metadata: Literal[False] = False,
) -> Dict[str, Tensor]:
    ...


@overload
def load_safetensors(
    fpath: Union[str, Path],
    *,
    device: str = "cpu",
    return_metadata: Literal[True],
) -> Tuple[Dict[str, Tensor], Dict[str, str]]:
    ...


[docs]def load_safetensors( fpath: Union[str, Path], *, device: str = "cpu", return_metadata: bool = False, ) -> Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], Dict[str, str]]]: tensors = {} with safe_open(fpath, framework="pt", device=device) as f: for k in f.keys(): tensors[k] = f.get_tensor(k) if return_metadata: metadata = f.metadata() result = tensors, metadata else: result = tensors return result
@overload def dump_safetensors( tensors: Dict[str, Tensor], fpath: Union[str, Path, None] = None, metadata: Optional[Dict[str, str]] = None, *, overwrite: bool = True, make_parents: bool = True, convert_to_tensor: Literal[False] = False, ) -> bytes: ... @overload def dump_safetensors( tensors: Dict[str, Any], fpath: Union[str, Path, None] = None, metadata: Optional[Dict[str, str]] = None, *, overwrite: bool = True, make_parents: bool = True, convert_to_tensor: Literal[True], ) -> bytes: ...
[docs]def dump_safetensors( tensors: Dict[str, Any], fpath: Union[str, Path, None] = None, metadata: Optional[Dict[str, str]] = None, *, overwrite: bool = True, make_parents: bool = True, convert_to_tensor: bool = False, ) -> bytes: """Dump tensors to safetensors format. Requires safetensors package installed.""" if convert_to_tensor: tensors = {k: F.as_tensor(v) for k, v in tensors.items()} elif not isinstance_guard(tensors, Dict[str, Tensor]): msg = f"Invalid argument type {type(tensors)}. (expected dict[str, Tensor] but found {get_fullname(type(tensors))})" raise TypeError(msg) fpath = _setup_path(fpath, overwrite, make_parents) content = save(tensors, metadata) if fpath is not None: fpath.write_bytes(content) return content
[docs]@deprecated_alias(dump_safetensors) def to_safetensors(*args, **kwargs): ...