chia-blockchain/chia/util/streamable.py

359 lines
13 KiB
Python

# flake8: noqa
# pylint: disable
from __future__ import annotations
import dataclasses
import io
import pprint
import sys
from enum import Enum
from typing import Any, BinaryIO, Dict, List, Tuple, Type, Callable, Optional, Iterator
from blspy import G1Element, G2Element, PrivateKey
from chia.types.blockchain_format.program import Program, SerializedProgram
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.byte_types import hexstr_to_bytes
from chia.util.hash import std_hash
from chia.util.ints import int64, int512, uint32, uint64, uint128
from chia.util.type_checking import is_type_List, is_type_SpecificOptional, is_type_Tuple, strictdataclass
if sys.version_info < (3, 8):
def get_args(t: Type[Any]) -> Tuple[Any, ...]:
return getattr(t, "__args__", ())
else:
from typing import get_args
pp = pprint.PrettyPrinter(indent=1, width=120, compact=True)
# TODO: Remove hack, this allows streaming these objects from binary
size_hints = {
"PrivateKey": PrivateKey.PRIVATE_KEY_SIZE,
"G1Element": G1Element.SIZE,
"G2Element": G2Element.SIZE,
"ConditionOpcode": 1,
}
unhashable_types = [
PrivateKey,
G1Element,
G2Element,
Program,
SerializedProgram,
]
# JSON does not support big ints, so these types must be serialized differently in JSON
big_ints = [uint64, int64, uint128, int512]
def dataclass_from_dict(klass, d):
"""
Converts a dictionary based on a dataclass, into an instance of that dataclass.
Recursively goes through lists, optionals, and dictionaries.
"""
if is_type_SpecificOptional(klass):
# Type is optional, data is either None, or Any
if not d:
return None
return dataclass_from_dict(get_args(klass)[0], d)
elif is_type_Tuple(klass):
# Type is tuple, can have multiple different types inside
i = 0
klass_properties = []
for item in d:
klass_properties.append(dataclass_from_dict(klass.__args__[i], item))
i = i + 1
return tuple(klass_properties)
elif dataclasses.is_dataclass(klass):
# Type is a dataclass, data is a dictionary
fieldtypes = {f.name: f.type for f in dataclasses.fields(klass)}
return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d})
elif is_type_List(klass):
# Type is a list, data is a list
return [dataclass_from_dict(get_args(klass)[0], item) for item in d]
elif issubclass(klass, bytes):
# Type is bytes, data is a hex string
return klass(hexstr_to_bytes(d))
elif klass in unhashable_types:
# Type is unhashable (bls type), so cast from hex string
return klass.from_bytes(hexstr_to_bytes(d))
else:
# Type is a primitive, cast with correct class
return klass(d)
def recurse_jsonify(d):
"""
Makes bytes objects and unhashable types into strings with 0x, and makes large ints into
strings.
"""
if isinstance(d, list) or isinstance(d, tuple):
new_list = []
for item in d:
if type(item) in unhashable_types or issubclass(type(item), bytes):
item = f"0x{bytes(item).hex()}"
if isinstance(item, dict):
item = recurse_jsonify(item)
if isinstance(item, list):
item = recurse_jsonify(item)
if isinstance(item, tuple):
item = recurse_jsonify(item)
if isinstance(item, Enum):
item = item.name
if isinstance(item, int) and type(item) in big_ints:
item = int(item)
new_list.append(item)
d = new_list
else:
for key, value in d.items():
if type(value) in unhashable_types or issubclass(type(value), bytes):
d[key] = f"0x{bytes(value).hex()}"
if isinstance(value, dict):
d[key] = recurse_jsonify(value)
if isinstance(value, list):
d[key] = recurse_jsonify(value)
if isinstance(value, tuple):
d[key] = recurse_jsonify(value)
if isinstance(value, Enum):
d[key] = value.name
if isinstance(value, int) and type(value) in big_ints:
d[key] = int(value)
return d
PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS = {}
def streamable(cls: Any):
"""
This is a decorator for class definitions. It applies the strictdataclass decorator,
which checks all types at construction. It also defines a simple serialization format,
and adds parse, from bytes, stream, and __bytes__ methods.
Serialization format:
- Each field is serialized in order, by calling from_bytes/__bytes__.
- For Lists, there is a 4 byte prefix for the list length.
- For Optionals, there is a one byte prefix, 1 iff object is present, 0 iff not.
All of the constituents must have parse/from_bytes, and stream/__bytes__ and therefore
be of fixed size. For example, int cannot be a constituent since it is not a fixed size,
whereas uint32 can be.
Furthermore, a get_hash() member is added, which performs a serialization and a sha256.
This class is used for deterministic serialization and hashing, for consensus critical
objects such as the block header.
Make sure to use the Streamable class as a parent class when using the streamable decorator,
as it will allow linters to recognize the methods that are added by the decorator. Also,
use the @dataclass(frozen=True) decorator as well, for linters to recognize constructor
arguments.
"""
cls1 = strictdataclass(cls)
t = type(cls.__name__, (cls1, Streamable), {})
parse_functions = []
try:
fields = cls1.__annotations__ # pylint: disable=no-member
except Exception:
fields = {}
for _, f_type in fields.items():
parse_functions.append(cls.function_to_parse_one_item(f_type))
PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS[t] = parse_functions
return t
def parse_bool(f: BinaryIO) -> bool:
bool_byte = f.read(1)
assert bool_byte is not None and len(bool_byte) == 1 # Checks for EOF
if bool_byte == bytes([0]):
return False
elif bool_byte == bytes([1]):
return True
else:
raise ValueError("Bool byte must be 0 or 1")
def parse_optional(f: BinaryIO, parse_inner_type_f: Callable[[BinaryIO], Any]) -> Optional[Any]:
is_present_bytes = f.read(1)
assert is_present_bytes is not None and len(is_present_bytes) == 1 # Checks for EOF
if is_present_bytes == bytes([0]):
return None
elif is_present_bytes == bytes([1]):
return parse_inner_type_f(f)
else:
raise ValueError("Optional must be 0 or 1")
def parse_bytes(f: BinaryIO) -> bytes:
list_size_bytes = f.read(4)
assert list_size_bytes is not None and len(list_size_bytes) == 4 # Checks for EOF
list_size: uint32 = uint32(int.from_bytes(list_size_bytes, "big"))
bytes_read = f.read(list_size)
assert bytes_read is not None and len(bytes_read) == list_size
return bytes_read
def parse_list(f: BinaryIO, parse_inner_type_f: Callable[[BinaryIO], Any]) -> List[Any]:
full_list: List = []
# wjb assert inner_type != get_args(List)[0]
list_size_bytes = f.read(4)
assert list_size_bytes is not None and len(list_size_bytes) == 4 # Checks for EOF
list_size = uint32(int.from_bytes(list_size_bytes, "big"))
for list_index in range(list_size):
full_list.append(parse_inner_type_f(f))
return full_list
def parse_tuple(f: BinaryIO, list_parse_inner_type_f: List[Callable[[BinaryIO], Any]]) -> Tuple[Any, ...]:
full_list = []
for parse_f in list_parse_inner_type_f:
full_list.append(parse_f(f))
return tuple(full_list)
def parse_size_hints(f: BinaryIO, f_type: Type, bytes_to_read: int) -> Any:
bytes_read = f.read(bytes_to_read)
assert bytes_read is not None and len(bytes_read) == bytes_to_read
return f_type.from_bytes(bytes_read)
def parse_str(f: BinaryIO) -> str:
str_size_bytes = f.read(4)
assert str_size_bytes is not None and len(str_size_bytes) == 4 # Checks for EOF
str_size: uint32 = uint32(int.from_bytes(str_size_bytes, "big"))
str_read_bytes = f.read(str_size)
assert str_read_bytes is not None and len(str_read_bytes) == str_size # Checks for EOF
return bytes.decode(str_read_bytes, "utf-8")
class Streamable:
@classmethod
def function_to_parse_one_item(cls: Type[cls.__name__], f_type: Type): # type: ignore
"""
This function returns a function taking one argument `f: BinaryIO` that parses
and returns a value of the given type.
"""
inner_type: Type
if f_type is bool:
return parse_bool
if is_type_SpecificOptional(f_type):
inner_type = get_args(f_type)[0]
parse_inner_type_f = cls.function_to_parse_one_item(inner_type)
return lambda f: parse_optional(f, parse_inner_type_f)
if hasattr(f_type, "parse"):
return f_type.parse
if f_type == bytes:
return parse_bytes
if is_type_List(f_type):
inner_type = get_args(f_type)[0]
parse_inner_type_f = cls.function_to_parse_one_item(inner_type)
return lambda f: parse_list(f, parse_inner_type_f)
if is_type_Tuple(f_type):
inner_types = get_args(f_type)
list_parse_inner_type_f = [cls.function_to_parse_one_item(_) for _ in inner_types]
return lambda f: parse_tuple(f, list_parse_inner_type_f)
if hasattr(f_type, "from_bytes") and f_type.__name__ in size_hints:
bytes_to_read = size_hints[f_type.__name__]
return lambda f: parse_size_hints(f, f_type, bytes_to_read)
if f_type is str:
return parse_str
raise NotImplementedError(f"Type {f_type} does not have parse")
@classmethod
def parse(cls: Type[cls.__name__], f: BinaryIO) -> cls.__name__: # type: ignore
# Create the object without calling __init__() to avoid unnecessary post-init checks in strictdataclass
obj: Streamable = object.__new__(cls)
fields: Iterator[str] = iter(getattr(cls, "__annotations__", {}))
values: Iterator = (parse_f(f) for parse_f in PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS[cls])
for field, value in zip(fields, values):
object.__setattr__(obj, field, value)
# Use -1 as a sentinel value as it's not currently serializable
if next(fields, -1) != -1:
raise ValueError("Failed to parse incomplete Streamable object")
if next(values, -1) != -1:
raise ValueError("Failed to parse unknown data in Streamable object")
return obj
def stream_one_item(self, f_type: Type, item, f: BinaryIO) -> None:
inner_type: Type
if is_type_SpecificOptional(f_type):
inner_type = get_args(f_type)[0]
if item is None:
f.write(bytes([0]))
else:
f.write(bytes([1]))
self.stream_one_item(inner_type, item, f)
elif f_type == bytes:
f.write(uint32(len(item)).to_bytes(4, "big"))
f.write(item)
elif hasattr(f_type, "stream"):
item.stream(f)
elif hasattr(f_type, "__bytes__"):
f.write(bytes(item))
elif is_type_List(f_type):
assert is_type_List(type(item))
f.write(uint32(len(item)).to_bytes(4, "big"))
inner_type = get_args(f_type)[0]
# wjb assert inner_type != get_args(List)[0] # type: ignore
for element in item:
self.stream_one_item(inner_type, element, f)
elif is_type_Tuple(f_type):
inner_types = get_args(f_type)
assert len(item) == len(inner_types)
for i in range(len(item)):
self.stream_one_item(inner_types[i], item[i], f)
elif f_type is str:
str_bytes = item.encode("utf-8")
f.write(uint32(len(str_bytes)).to_bytes(4, "big"))
f.write(str_bytes)
elif f_type is bool:
f.write(int(item).to_bytes(1, "big"))
else:
raise NotImplementedError(f"can't stream {item}, {f_type}")
def stream(self, f: BinaryIO) -> None:
try:
fields = self.__annotations__ # pylint: disable=no-member
except Exception:
fields = {}
for f_name, f_type in fields.items():
self.stream_one_item(f_type, getattr(self, f_name), f)
def get_hash(self) -> bytes32:
return bytes32(std_hash(bytes(self)))
@classmethod
def from_bytes(cls: Any, blob: bytes) -> Any:
f = io.BytesIO(blob)
parsed = cls.parse(f)
assert f.read() == b""
return parsed
def __bytes__(self: Any) -> bytes:
f = io.BytesIO()
self.stream(f)
return bytes(f.getvalue())
def __str__(self: Any) -> str:
return pp.pformat(recurse_jsonify(dataclasses.asdict(self)))
def __repr__(self: Any) -> str:
return pp.pformat(recurse_jsonify(dataclasses.asdict(self)))
def to_json_dict(self) -> Dict:
return recurse_jsonify(dataclasses.asdict(self))
@classmethod
def from_json_dict(cls: Any, json_dict: Dict) -> Any:
return dataclass_from_dict(cls, json_dict)