Source code for torchoutil.extras.numpy.saving
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import warnings
from io import BytesIO
from pathlib import Path
from typing import Literal, Optional, Union
from torchoutil.pyoutil.io import _setup_path
from .definitions import np
NumpyFormat = Literal["npy", "npz"]
[docs]def dump_numpy(
obj: np.ndarray,
fpath: Union[str, Path, None] = None,
*args,
np_format: Optional[NumpyFormat] = "npy",
overwrite: bool = True,
make_parents: bool = True,
**kwargs,
) -> bytes:
fpath = _setup_path(fpath, overwrite, make_parents)
if np_format is not None:
pass
elif fpath is None or fpath.suffix == ".npy":
np_format = "npy"
elif fpath.suffix == ".npz":
np_format = "npz"
else:
NUMPY_EXTENSIONS = (".npy", ".npz")
msg = f"Unknown numpy extension '{fpath.suffix}'. (expected one of {NUMPY_EXTENSIONS})"
warnings.warn(msg)
np_format = "npy"
if np_format == "npy":
save_fn = np.save
elif np_format == "npz":
save_fn = np.savez
else:
raise ValueError(f"Invalid argument {np_format=}.")
buffer = BytesIO()
save_fn(buffer, obj, *args, **kwargs)
buffer.seek(0)
content = buffer.read()
if fpath is not None:
fpath.write_bytes(content)
return content
[docs]def load_numpy(fpath: Union[str, Path], *args, **kwargs) -> np.ndarray:
return np.load(fpath, *args, **kwargs)