Source code for upathlib.serializer

import gc
import json
import pickle
import threading
import zlib
from contextlib import contextmanager
from typing import Protocol, TypeVar

import zstandard

# zstandard has good compression ratio and also quite fast.
# It is very "balanced".
# lz4 has lower compression ratio than zstandard but is much faster.
#
# See:
#   https://gregoryszorc.com/blog/2017/03/07/better-compression-with-zstandard/
#   https://stackoverflow.com/questions/67537111/how-do-i-decide-between-lz4-and-snappy-compression
#   https://gist.github.com/oldcai/7230548

T = TypeVar("T")


MEGABYTE = 1048576  # 1024 * 1024
ZLIB_LEVEL = 3  # official default is 6
ZSTD_LEVEL = 3  # official default is 3
LZ4_LEVEL = (
    0  # official default is 0; high-compression value is 3, much slower at compressing
)
PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL


@contextmanager
def _gc(data):
    turnedoff = False
    if len(data) >= MEGABYTE * 10 and gc.isenabled():
        gc.disable()
        turnedoff = True
    try:
        yield
    finally:
        if turnedoff:
            gc.enable()


[docs] class Serializer(Protocol):
[docs] @classmethod def serialize(cls, x: T, **kwargs) -> bytes: ...
[docs] @classmethod def deserialize(cls, y: bytes, **kwargs) -> T: ...
[docs] @classmethod def dump(cls, x: T, file, *, overwrite: bool = False, **kwargs) -> None: # `file` is a `Upath` object. y = cls.serialize(x, **kwargs) file.write_bytes(y, overwrite=overwrite)
[docs] @classmethod def load(cls, file, **kwargs) -> T: # `file` is a `Upath` object. y = file.read_bytes() return cls.deserialize(y, **kwargs)
[docs] class JsonSerializer(Serializer):
[docs] @classmethod def serialize(cls, x, *, encoding=None, errors=None, **kwargs) -> bytes: return json.dumps(x, **kwargs).encode( encoding=encoding or "utf-8", errors=errors or "strict" )
[docs] @classmethod def deserialize(cls, y, *, encoding=None, errors=None, **kwargs): with _gc(y): return json.loads( y.decode(encoding=encoding or "utf-8", errors=errors or "strict"), **kwargs, )
[docs] class PickleSerializer(Serializer):
[docs] @classmethod def serialize(cls, x, *, protocol=None, **kwargs) -> bytes: return pickle.dumps(x, protocol=protocol or PICKLE_PROTOCOL, **kwargs)
[docs] @classmethod def deserialize(cls, y, **kwargs): with _gc(y): return pickle.loads(y, **kwargs)
[docs] class ZPickleSerializer(PickleSerializer): # In general, this is not the best choice of compression. # Use `zstandard` or `lz4 instead.
[docs] @classmethod def serialize(cls, x, *, level=ZLIB_LEVEL, **kwargs) -> bytes: y = super().serialize(x, **kwargs) return zlib.compress(y, level=level)
[docs] @classmethod def deserialize(cls, y, **kwargs): y = zlib.decompress(y) return super().deserialize(y, **kwargs)
[docs] class ZstdCompressor(threading.local): # See doc string in ``cpython / Lib / _threading_local.py``. # See doc on `ZstdCompressor` and `ZstdDecompressor` in # (github python-zstandard) `zstandard / backend_cffi.py`. # The `ZstdCompressor` and `ZstdDecompressor` objects can't be pickled. # If there are issues related to forking, check out ``os.register_at_fork``.
[docs] def __init__(self): self._compressor: dict[tuple[int, int], zstandard.ZstdCompressor] = {} self._decompressor: zstandard.ZstdDecompressor = None
[docs] def compress(self, x, *, level=ZSTD_LEVEL, threads=0): """ Parameters ---------- threads Number of threads to use to compress data concurrently. When set, compression operations are performed on multiple threads. The default value (0) disables multi-threaded compression. A value of ``-1`` means to set the number of threads to the number of detected logical CPUs. """ c = self._compressor.get((level, threads)) if c is None: c = zstandard.ZstdCompressor(level=level, threads=threads) self._compressor[(level, threads)] = c return c.compress(x)
[docs] def decompress(self, y): if self._decompressor is None: self._decompressor = zstandard.ZstdDecompressor() return self._decompressor.decompress(y)
[docs] class ZstdPickleSerializer(PickleSerializer): _compressor = ZstdCompressor()
[docs] @classmethod def serialize(cls, x, *, level=ZSTD_LEVEL, threads=0, **kwargs) -> bytes: y = super().serialize(x, **kwargs) return cls._compressor.compress(y, level=level, threads=threads)
[docs] @classmethod def deserialize(cls, y, **kwargs): y = cls._compressor.decompress(y) return super().deserialize(y, **kwargs)
try: # To use this, just install the Python package `lz4`. import lz4.frame except ImportError: pass else: class Lz4PickleSerializer(PickleSerializer): @classmethod def serialize(cls, x, *, level=LZ4_LEVEL, **kwargs) -> bytes: y = super().serialize(x, **kwargs) return lz4.frame.compress(y, compression_level=level) @classmethod def deserialize(cls, y, **kwargs): y = lz4.frame.decompress(y) return super().deserialize(y, **kwargs) try: # To use this, just install the Python package `orjson`. import orjson except ImportError: pass else: class OrjsonSerializer(Serializer): @classmethod def serialize(cls, x, **kwargs) -> bytes: return orjson.dumps(x, **kwargs) @classmethod def deserialize(cls, y: bytes, **kwargs): return orjson.loads(y, **kwargs) class ZOrjsonSerializer(OrjsonSerializer): # In general, this is not the best choice of compression. # Use `zstandard` or `lz4 instead. @classmethod def serialize(cls, x, *, level=ZLIB_LEVEL, **kwargs) -> bytes: y = super().serialize(x, **kwargs) return zlib.compress(y, level=level) @classmethod def deserialize(cls, y, **kwargs): y = zlib.decompress(y) return super().deserialize(y, **kwargs) class ZstdOrjsonSerializer(OrjsonSerializer): _compressor = ZstdCompressor() @classmethod def serialize(cls, x, *, level=ZSTD_LEVEL, threads=0, **kwargs) -> bytes: y = super().serialize(x, **kwargs) return cls._compressor.compress(y, level=level, threads=threads) @classmethod def deserialize(cls, y, **kwargs): y = cls._compressor.decompress(y) return super().deserialize(y, **kwargs) try: # To use this, just install the Python package `lz4`. import lz4.frame except ImportError: pass else: class Lz4OrjsonSerializer(OrjsonSerializer): @classmethod def serialize(cls, x, *, level=LZ4_LEVEL, **kwargs) -> bytes: y = super().serialize(x, **kwargs) return lz4.frame.compress(y, compression_level=level) @classmethod def deserialize(cls, y, **kwargs): y = lz4.frame.decompress(y) return super().deserialize(y, **kwargs)