Source code for torchoutil.nn.functional.checksum

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import functools
import itertools
import math
import pickle
import re
import struct
import zlib
from dataclasses import asdict
from functools import lru_cache
from types import FunctionType, MethodType
from typing import Callable, Iterable, Literal, Mapping, Union, get_args

import torch
from torch import Tensor, nn
from typing_extensions import TypeAlias

from torchoutil.core.packaging import _NUMPY_AVAILABLE, _PANDAS_AVAILABLE
from torchoutil.extras.numpy import np
from torchoutil.nn.functional.predicate import is_complex, is_floating_point
from torchoutil.pyoutil.functools import function_alias
from torchoutil.pyoutil.importlib import Placeholder
from torchoutil.pyoutil.inspect import get_fullname
from torchoutil.pyoutil.typing import (
    BuiltinNumber,
    BuiltinScalar,
    DataclassInstance,
    NamedTupleInstance,
    NoneType,
)

if _PANDAS_AVAILABLE:
    import pandas as pd

    DataFrame = pd.DataFrame  # type: ignore
else:

    class DataFrame(Placeholder):
        ...


Checksumable: TypeAlias = Union[
    int,
    bool,
    complex,
    float,
    NoneType,
    str,
    bytes,
    bytearray,
    re.Pattern,
    nn.Module,
    Tensor,
    np.ndarray,
    np.generic,
    NamedTupleInstance,
    DataclassInstance,
    Mapping,
    Iterable,
    MethodType,
    FunctionType,
    functools.partial,
    type,
    slice,
]
UnkMode = Literal["pickle", "error"]


# Recursive functions for union of types
[docs]def checksum_any( x: Checksumable, *, unk_mode: UnkMode = "error", allow_protocol: bool = True, **kwargs, ) -> int: """Compute checksum of an arbitrary python object. The property of a checksum is: for all any supported objects a and b, `(a == b) => (checksum(a) == checksum(b))`. This function is deterministic across executions by default. Args: x: Object to checksum. unk_mode: Defines behaviour when x is not a supported type OR contains elements that are not supported. "error": raises a TypeError. "pickle": convert object to bytes using pickle module. However, this conversion depends of the object implementation and might be not deterministic. defaults to "error". allow_protocol: Whether or not accept to use duck typing to detect NamedTuples, Dataclasses, Mappings or Iterables. defaults to True. **kwargs: Optional arguments to customize object checksum. """ kwargs.update( dict( unk_mode=unk_mode, allow_protocol=allow_protocol, ) ) if isinstance(x, (int, bool, complex, float)): return checksum_builtin_number(x, **kwargs) elif x is None: return checksum_none(x, **kwargs) elif isinstance(x, str): return checksum_str(x, **kwargs) elif isinstance(x, bytes): return checksum_bytes(x, **kwargs) elif isinstance(x, bytearray): return checksum_bytearray(x, **kwargs) elif isinstance(x, slice): return checksum_slice(x, **kwargs) elif isinstance(x, re.Pattern): return checksum_pattern(x, **kwargs) elif isinstance(x, nn.Module): return checksum_module(x, **kwargs) elif isinstance(x, Tensor): return checksum_tensor(x, **kwargs) elif _NUMPY_AVAILABLE and isinstance(x, (np.ndarray, np.generic)): return checksum_ndarray(x, **kwargs) elif isinstance(x, torch.dtype) or (_NUMPY_AVAILABLE and isinstance(x, np.dtype)): return checksum_dtype(x, **kwargs) elif _PANDAS_AVAILABLE and isinstance(x, DataFrame): return checksum_dataframe(x, **kwargs) elif allow_protocol and isinstance(x, NamedTupleInstance): return checksum_namedtuple(x, **kwargs) elif allow_protocol and isinstance(x, DataclassInstance): return checksum_dataclass(x, **kwargs) elif (allow_protocol and isinstance(x, Mapping)) or isinstance(x, dict): return checksum_mapping(x, **kwargs) elif (allow_protocol and isinstance(x, Iterable)) or isinstance( x, (list, tuple, set, frozenset, range) ): return checksum_iterable(x, **kwargs) elif isinstance(x, MethodType): return checksum_method(x, **kwargs) elif isinstance(x, FunctionType): return checksum_function(x, **kwargs) elif isinstance(x, functools.partial): return checksum_partial(x, **kwargs) elif isinstance(x, type): return checksum_type(x, **kwargs) elif unk_mode == "pickle": return checksum_bytes(pickle.dumps(x), **kwargs) elif unk_mode == "error": msg = f"Invalid argument type {type(x)}. (expected one of {get_args(Checksumable)})" raise TypeError(msg) else: msg = f"Invalid argument {unk_mode=}. (expected one of {get_args(UnkMode)})" raise ValueError(msg)
[docs]@function_alias(checksum_any) def checksum(*args, **kwargs): ...
[docs]def checksum_dataclass(x: DataclassInstance, **kwargs) -> int: kwargs["accumulator"] = kwargs.get("accumulator", 0) + __cached_checksum_str( get_fullname(x) ) return checksum_mapping(asdict(x), **kwargs)
[docs]def checksum_dataframe(x: DataFrame, **kwargs) -> int: if not _PANDAS_AVAILABLE: msg = "Cannot call function 'checksum_dataframe' because optional dependency 'pandas' is not installed. Please install it using 'pip install torchoutil[extras]'" raise NotImplementedError(msg) # hash_value = hashlib.sha1(pd.util.hash_pandas_object(x).values).hexdigest() # csum = checksum_str(hash_value, **kwargs) kwargs["accumulator"] = kwargs.get("accumulator", 0) + __cached_checksum_str( get_fullname(x) ) x = x.to_dict() return checksum_mapping(x, **kwargs) # type: ignore
[docs]def checksum_dtype(x: Union[torch.dtype, np.dtype], **kwargs) -> int: kwargs["accumulator"] = kwargs.get("accumulator", 0) + __cached_checksum_str( get_fullname(x) ) xstr = str(x) return checksum_str(xstr, **kwargs)
[docs]def checksum_iterable(x: Iterable, **kwargs) -> int: accumulator = kwargs.pop("accumulator", 0) + __cached_checksum_str(get_fullname(x)) csum = sum( checksum_any(xi, accumulator=accumulator + (i + 1), **kwargs) * (i + 1) for i, xi in enumerate(x) ) return csum + accumulator
[docs]def checksum_mapping(x: Mapping, **kwargs) -> int: kwargs["accumulator"] = kwargs.get("accumulator", 0) + __cached_checksum_str( get_fullname(x) ) return checksum_iterable(x.items(), **kwargs)
[docs]def checksum_method(x: MethodType, **kwargs) -> int: fn = getattr(x.__self__, x.__name__) checksums = [ checksum_any(x.__self__, **kwargs), # type: ignore checksum_function(fn, **kwargs), ] return checksum_iterable(checksums, **kwargs)
[docs]def checksum_module( x: nn.Module, *, only_trainable: bool = False, with_names: bool = False, buffers: bool = False, training: bool = False, **kwargs, ) -> int: """Compute a simple checksum over module parameters.""" training = x.training x.train(training) if with_names: params_it = ( (n, p) for n, p in x.named_parameters() if not only_trainable or p.requires_grad ) else: params_it = ( param for param in x.parameters() if not only_trainable or param.requires_grad ) if not buffers: iterator = params_it elif with_names: buffers_it = (name_buffer for name_buffer in x.named_buffers()) iterator = itertools.chain(params_it, buffers_it) else: buffers_it = (buffer for buffer in x.buffers()) iterator = itertools.chain(params_it, buffers_it) csum = checksum_iterable(iterator, **kwargs) x.train(training) return csum
[docs]def checksum_namedtuple(x: NamedTupleInstance, **kwargs) -> int: kwargs["accumulator"] = kwargs.get("accumulator", 0) + __cached_checksum_str( get_fullname(x) ) return checksum_mapping(x._asdict(), **kwargs)
[docs]def checksum_partial(x: functools.partial, **kwargs) -> int: kwargs["accumulator"] = kwargs.get("accumulator", 0) + __cached_checksum_str( get_fullname(x) ) return checksum_iterable((x.func, x.args, x.keywords), **kwargs)
[docs]def checksum_pattern(x: re.Pattern, **kwargs) -> int: kwargs["accumulator"] = kwargs.get("accumulator", 0) + __cached_checksum_str( get_fullname(x) ) return checksum_str(str(x), **kwargs)
[docs]def checksum_slice(x: slice, **kwargs) -> int: kwargs["accumulator"] = kwargs.get("accumulator", 0) + __cached_checksum_str( get_fullname(x) ) return checksum_iterable((x.start, x.stop, x.step), **kwargs)
# Intermediate functions
[docs]def checksum_builtin_number(x: BuiltinNumber, **kwargs) -> int: """Compute a simple checksum of a builtin scalar number.""" # Note: instance check must follow this order: bool, int, float, complex, because isinstance(True, int) returns True ! if isinstance(x, bool): return checksum_bool(x, **kwargs) elif isinstance(x, int): return checksum_int(x, **kwargs) elif isinstance(x, float): return checksum_float(x, **kwargs) elif isinstance(x, complex): return checksum_complex(x, **kwargs) else: msg = f"Invalid argument type {type(x)}. (expected one of {get_args(BuiltinNumber)})" raise TypeError(msg)
[docs]def checksum_builtin_scalar(x: BuiltinScalar, **kwargs) -> int: if isinstance(x, BuiltinNumber): return checksum_builtin_number(x, **kwargs) elif isinstance(x, bytes): return checksum_bytes(x, **kwargs) elif x is None: return checksum_none(x, **kwargs) elif isinstance(x, str): return checksum_str(x, **kwargs) else: msg = f"Invalid argument type {type(x)}. (expected one of {get_args(BuiltinScalar)})" raise TypeError(msg)
[docs]def checksum_bytearray(x: bytearray, **kwargs) -> int: kwargs["accumulator"] = kwargs.get("accumulator", 0) + __cached_checksum_str( get_fullname(x) ) return _checksum_bytes_bytearray(x, **kwargs)
[docs]def checksum_bytes(x: Union[bytes, bytearray], **kwargs) -> int: return _checksum_bytes_bytearray(x, **kwargs)
[docs]def checksum_complex(x: complex, **kwargs) -> int: kwargs["accumulator"] = kwargs.get("accumulator", 0) + __cached_checksum_str( get_fullname(x) ) return checksum_tensor(torch.as_tensor([x.real, x.imag]), **kwargs)
[docs]def checksum_function(x: FunctionType, **kwargs) -> int: kwargs["accumulator"] = kwargs.get("accumulator", 0) + __cached_checksum_str( get_fullname(x) ) return checksum_str(x.__qualname__, **kwargs)
[docs]def checksum_none(x: None, **kwargs) -> int: kwargs["accumulator"] = kwargs.get("accumulator", 0) + __cached_checksum_str( get_fullname(x) ) return checksum_type(x.__class__, **kwargs) + kwargs.get("accumulator", 0)
[docs]def checksum_ndarray(x: Union[np.ndarray, np.generic], **kwargs) -> int: if not _NUMPY_AVAILABLE: msg = "Cannot call function 'checksum_ndarray' because optional dependency 'numpy' is not installed. Please install it using 'pip install torchoutil[extras]'" raise NotImplementedError(msg) # Supports non-numeric numpy arrays (byte string, unicode string, object, void) if x.dtype.kind in ("S", "U", "O", "V"): return checksum_any(x.tolist(), **kwargs) return _checksum_tensor_array_like( x, nan_to_num_fn=np.nan_to_num, **kwargs, )
[docs]def checksum_str(x: str, **kwargs) -> int: kwargs["accumulator"] = kwargs.get("accumulator", 0) + __cached_checksum_str( get_fullname(x) ) return checksum_bytes(x.encode(), **kwargs)
[docs]@torch.inference_mode() def checksum_tensor(x: Tensor, **kwargs) -> int: """Compute a simple checksum of a tensor. Order of values matter for the checksum.""" return _checksum_tensor_array_like( x, nan_to_num_fn=torch.nan_to_num, **kwargs, )
[docs]def checksum_type(x: type, **kwargs) -> int: return checksum_str(x.__qualname__, **kwargs)
# Terminate functions
[docs]def checksum_bool(x: bool, **kwargs) -> int: xint = int(x) return __terminate_checksum( xint, get_fullname(x), **kwargs, )
[docs]def checksum_float(x: float, **kwargs) -> int: xint = _interpret_float_as_int(x) return __terminate_checksum( xint, get_fullname(x), **kwargs, )
[docs]def checksum_int(x: int, **kwargs) -> int: xint = x return __terminate_checksum( xint, get_fullname(x), **kwargs, )
def _checksum_bytes_bytearray(x: Union[bytes, bytearray], **kwargs) -> int: xint = zlib.crc32(x) % (1 << 32) return __terminate_checksum( xint, get_fullname(x), **kwargs, ) def __terminate_checksum(x: int, fullname: str, **kwargs) -> int: return x + __cached_checksum_str(fullname) + kwargs.get("accumulator", 0) @lru_cache(maxsize=None) def __cached_checksum_str(x: str) -> int: return zlib.crc32(x.encode()) % (1 << 32) def _interpret_float_as_int(x: float) -> int: xbytes = struct.pack(">d", x) xint = struct.unpack(">q", xbytes)[0] return xint def _checksum_tensor_array_like( x: Union[Tensor, np.ndarray, np.generic], *, nan_to_num_fn: Callable, **kwargs, ) -> int: if is_floating_point(x) or is_complex(x): nan_csum = checksum_float(math.nan, **kwargs) neginf_csum = checksum_float(-math.inf, **kwargs) posinf_csum = checksum_float(math.inf, **kwargs) x = nan_to_num_fn( x, nan=nan_csum, neginf=neginf_csum, posinf=posinf_csum, ) # Ensure that accumulator exists kwargs["accumulator"] = kwargs.get("accumulator", 0) kwargs["accumulator"] += checksum_dtype(x.dtype, **kwargs) kwargs["accumulator"] += checksum_iterable(x.shape, **kwargs) kwargs["accumulator"] += __cached_checksum_str(get_fullname(x)) if isinstance(x, np.ndarray): xbytes = x.tobytes() csum = checksum_bytes(xbytes, **kwargs) elif isinstance(x, Tensor): if _NUMPY_AVAILABLE: xbytes = x.cpu().numpy().tobytes() else: xbytes = _serialize_tensor_to_bytes(x) csum = checksum_bytes(xbytes, **kwargs) else: msg = f"invalid argument type {type(x)}. (expected ndarray or Tensor)" raise TypeError(msg) return csum def _serialize_tensor_to_bytes(x: Tensor) -> bytes: """Convert tensor data to bytes, but very slow compare to numpy' tobytes() method.""" x = x.view(torch.int8).view(-1) xbytes = struct.pack(f"{len(x)}b", *x) return xbytes