chia-blockchain/chia/wallet/wallet_node.py

905 lines
39 KiB
Python

import asyncio
import json
import logging
import socket
import time
import traceback
from pathlib import Path
from typing import Callable, Dict, List, Optional, Set, Tuple, Union, Any
from blspy import PrivateKey
from chia.consensus.block_record import BlockRecord
from chia.consensus.constants import ConsensusConstants
from chia.consensus.multiprocess_validation import PreValidationResult
from chia.protocols import wallet_protocol
from chia.protocols.full_node_protocol import RequestProofOfWeight, RespondProofOfWeight
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.wallet_protocol import (
RejectAdditionsRequest,
RejectRemovalsRequest,
RequestAdditions,
RequestHeaderBlocks,
RespondAdditions,
RespondBlockHeader,
RespondHeaderBlocks,
RespondRemovals,
)
from chia.server.node_discovery import WalletPeers
from chia.server.outbound_message import Message, NodeType, make_msg
from chia.server.server import ChiaServer
from chia.server.ws_connection import WSChiaConnection
from chia.types.blockchain_format.coin import Coin, hash_coin_list
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.header_block import HeaderBlock
from chia.types.peer_info import PeerInfo
from chia.util.byte_types import hexstr_to_bytes
from chia.util.errors import Err, ValidationError
from chia.util.ints import uint32, uint128
from chia.util.keychain import Keychain
from chia.util.lru_cache import LRUCache
from chia.util.merkle_set import MerkleSet, confirm_included_already_hashed, confirm_not_included_already_hashed
from chia.util.path import mkdir, path_from_root
from chia.wallet.block_record import HeaderBlockRecord
from chia.wallet.derivation_record import DerivationRecord
from chia.wallet.settings.settings_objects import BackupInitialized
from chia.wallet.transaction_record import TransactionRecord
from chia.wallet.util.backup_utils import open_backup_file
from chia.wallet.util.wallet_types import WalletType
from chia.wallet.wallet_action import WalletAction
from chia.wallet.wallet_blockchain import ReceiveBlockResult
from chia.wallet.wallet_state_manager import WalletStateManager
from chia.util.profiler import profile_task
class WalletNode:
key_config: Dict
config: Dict
constants: ConsensusConstants
server: Optional[ChiaServer]
log: logging.Logger
wallet_peers: WalletPeers
# Maintains the state of the wallet (blockchain and transactions), handles DB connections
wallet_state_manager: Optional[WalletStateManager]
# How far away from LCA we must be to perform a full sync. Before then, do a short sync,
# which is consecutive requests for the previous block
short_sync_threshold: int
_shut_down: bool
root_path: Path
state_changed_callback: Optional[Callable]
syncing: bool
full_node_peer: Optional[PeerInfo]
peer_task: Optional[asyncio.Task]
logged_in: bool
def __init__(
self,
config: Dict,
keychain: Keychain,
root_path: Path,
consensus_constants: ConsensusConstants,
name: str = None,
):
self.config = config
self.constants = consensus_constants
self.root_path = root_path
if name:
self.log = logging.getLogger(name)
else:
self.log = logging.getLogger(__name__)
# Normal operation data
self.cached_blocks: Dict = {}
self.future_block_hashes: Dict = {}
self.keychain = keychain
# Sync data
self._shut_down = False
self.proof_hashes: List = []
self.header_hashes: List = []
self.header_hashes_error = False
self.short_sync_threshold = 15 # Change the test when changing this
self.potential_blocks_received: Dict = {}
self.potential_header_hashes: Dict = {}
self.state_changed_callback = None
self.wallet_state_manager = None
self.backup_initialized = False # Delay first launch sync after user imports backup info or decides to skip
self.server = None
self.wsm_close_task = None
self.sync_task: Optional[asyncio.Task] = None
self.new_peak_lock: Optional[asyncio.Lock] = None
self.logged_in_fingerprint: Optional[int] = None
self.peer_task = None
self.logged_in = False
self.last_new_peak_messages = LRUCache(5)
def get_key_for_fingerprint(self, fingerprint: Optional[int]):
private_keys = self.keychain.get_all_private_keys()
if len(private_keys) == 0:
self.log.warning("No keys present. Create keys with the UI, or with the 'chia keys' program.")
return None
private_key: Optional[PrivateKey] = None
if fingerprint is not None:
for sk, _ in private_keys:
if sk.get_g1().get_fingerprint() == fingerprint:
private_key = sk
break
else:
private_key = private_keys[0][0]
return private_key
async def _start(
self,
fingerprint: Optional[int] = None,
new_wallet: bool = False,
backup_file: Optional[Path] = None,
skip_backup_import: bool = False,
) -> bool:
private_key = self.get_key_for_fingerprint(fingerprint)
if private_key is None:
self.logged_in = False
return False
if self.config.get("enable_profiler", False):
asyncio.create_task(profile_task(self.root_path, "wallet", self.log))
db_path_key_suffix = str(private_key.get_g1().get_fingerprint())
db_path_replaced: str = (
self.config["database_path"]
.replace("CHALLENGE", self.config["selected_network"])
.replace("KEY", db_path_key_suffix)
)
path = path_from_root(self.root_path, db_path_replaced)
mkdir(path.parent)
assert self.server is not None
self.wallet_state_manager = await WalletStateManager.create(
private_key, self.config, path, self.constants, self.server
)
self.wsm_close_task = None
assert self.wallet_state_manager is not None
backup_settings: BackupInitialized = self.wallet_state_manager.user_settings.get_backup_settings()
if backup_settings.user_initialized is False:
if new_wallet is True:
await self.wallet_state_manager.user_settings.user_created_new_wallet()
self.wallet_state_manager.new_wallet = True
elif skip_backup_import is True:
await self.wallet_state_manager.user_settings.user_skipped_backup_import()
elif backup_file is not None:
await self.wallet_state_manager.import_backup_info(backup_file)
else:
self.backup_initialized = False
await self.wallet_state_manager.close_all_stores()
self.wallet_state_manager = None
self.logged_in = False
return False
self.backup_initialized = True
if backup_file is not None:
json_dict = open_backup_file(backup_file, self.wallet_state_manager.private_key)
if "start_height" in json_dict["data"]:
start_height = json_dict["data"]["start_height"]
self.config["starting_height"] = max(0, start_height - self.config["start_height_buffer"])
else:
self.config["starting_height"] = 0
else:
self.config["starting_height"] = 0
if self.state_changed_callback is not None:
self.wallet_state_manager.set_callback(self.state_changed_callback)
self.wallet_state_manager.set_pending_callback(self._pending_tx_handler)
self._shut_down = False
self.peer_task = asyncio.create_task(self._periodically_check_full_node())
self.sync_event = asyncio.Event()
self.sync_task = asyncio.create_task(self.sync_job())
self.logged_in_fingerprint = fingerprint
self.logged_in = True
return True
def _close(self):
self.log.info("self._close")
self.logged_in_fingerprint = None
self._shut_down = True
async def _await_closed(self):
self.log.info("self._await_closed")
await self.server.close_all_connections()
asyncio.create_task(self.wallet_peers.ensure_is_closed())
if self.wallet_state_manager is not None:
await self.wallet_state_manager.close_all_stores()
self.wallet_state_manager = None
if self.sync_task is not None:
self.sync_task.cancel()
self.sync_task = None
if self.peer_task is not None:
self.peer_task.cancel()
self.peer_task = None
self.logged_in = False
def _set_state_changed_callback(self, callback: Callable):
self.state_changed_callback = callback
if self.wallet_state_manager is not None:
self.wallet_state_manager.set_callback(self.state_changed_callback)
self.wallet_state_manager.set_pending_callback(self._pending_tx_handler)
def _pending_tx_handler(self):
if self.wallet_state_manager is None or self.backup_initialized is False:
return None
asyncio.create_task(self._resend_queue())
async def _action_messages(self) -> List[Message]:
if self.wallet_state_manager is None or self.backup_initialized is False:
return []
actions: List[WalletAction] = await self.wallet_state_manager.action_store.get_all_pending_actions()
result: List[Message] = []
for action in actions:
data = json.loads(action.data)
action_data = data["data"]["action_data"]
if action.name == "request_puzzle_solution":
coin_name = bytes32(hexstr_to_bytes(action_data["coin_name"]))
height = uint32(action_data["height"])
msg = make_msg(
ProtocolMessageTypes.request_puzzle_solution,
wallet_protocol.RequestPuzzleSolution(coin_name, height),
)
result.append(msg)
return result
async def _resend_queue(self):
if (
self._shut_down
or self.server is None
or self.wallet_state_manager is None
or self.backup_initialized is None
):
return None
for msg, sent_peers in await self._messages_to_resend():
if (
self._shut_down
or self.server is None
or self.wallet_state_manager is None
or self.backup_initialized is None
):
return None
full_nodes = self.server.get_full_node_connections()
for peer in full_nodes:
if peer.peer_node_id in sent_peers:
continue
await peer.send_message(msg)
for msg in await self._action_messages():
if (
self._shut_down
or self.server is None
or self.wallet_state_manager is None
or self.backup_initialized is None
):
return None
await self.server.send_to_all([msg], NodeType.FULL_NODE)
async def _messages_to_resend(self) -> List[Tuple[Message, Set[bytes32]]]:
if self.wallet_state_manager is None or self.backup_initialized is False or self._shut_down:
return []
messages: List[Tuple[Message, Set[bytes32]]] = []
records: List[TransactionRecord] = await self.wallet_state_manager.tx_store.get_not_sent()
for record in records:
if record.spend_bundle is None:
continue
msg = make_msg(
ProtocolMessageTypes.send_transaction,
wallet_protocol.SendTransaction(record.spend_bundle),
)
already_sent = set()
for peer, status, _ in record.sent_to:
already_sent.add(hexstr_to_bytes(peer))
messages.append((msg, already_sent))
return messages
def set_server(self, server: ChiaServer):
self.server = server
# TODO: perhaps use a different set of DNS seeders for wallets, to split the traffic.
self.wallet_peers = WalletPeers(
self.server,
self.root_path,
self.config["target_peer_count"],
self.config["wallet_peers_path"],
self.config["introducer_peer"],
[],
self.config["peer_connect_interval"],
self.config["selected_network"],
self.log,
)
asyncio.create_task(self.wallet_peers.start())
async def on_connect(self, peer: WSChiaConnection):
if self.wallet_state_manager is None or self.backup_initialized is False:
return None
messages_peer_ids = await self._messages_to_resend()
for msg, peer_ids in messages_peer_ids:
if peer.peer_node_id in peer_ids:
continue
await peer.send_message(msg)
if not self.has_full_node() and self.wallet_peers is not None:
asyncio.create_task(self.wallet_peers.on_connect(peer))
async def _periodically_check_full_node(self) -> None:
tries = 0
while not self._shut_down and tries < 5:
if self.has_full_node():
await self.wallet_peers.ensure_is_closed()
break
tries += 1
await asyncio.sleep(self.config["peer_connect_interval"])
def has_full_node(self) -> bool:
if self.server is None:
return False
if "full_node_peer" in self.config:
full_node_peer = PeerInfo(
self.config["full_node_peer"]["host"],
self.config["full_node_peer"]["port"],
)
peers = [c.get_peer_info() for c in self.server.get_full_node_connections()]
full_node_resolved = PeerInfo(socket.gethostbyname(full_node_peer.host), full_node_peer.port)
if full_node_peer in peers or full_node_resolved in peers:
self.log.info(f"Will not attempt to connect to other nodes, already connected to {full_node_peer}")
for connection in self.server.get_full_node_connections():
if (
connection.get_peer_info() != full_node_peer
and connection.get_peer_info() != full_node_resolved
):
self.log.info(f"Closing unnecessary connection to {connection.get_peer_info()}.")
asyncio.create_task(connection.close())
return True
return False
async def complete_blocks(self, header_blocks: List[HeaderBlock], peer: WSChiaConnection):
if self.wallet_state_manager is None:
return None
header_block_records: List[HeaderBlockRecord] = []
assert self.server
trusted = self.server.is_trusted_peer(peer, self.config["trusted_peers"])
async with self.wallet_state_manager.blockchain.lock:
for block in header_blocks:
if block.is_transaction_block:
# Find additions and removals
(additions, removals,) = await self.wallet_state_manager.get_filter_additions_removals(
block, block.transactions_filter, None
)
# Get Additions
added_coins = await self.get_additions(peer, block, additions)
if added_coins is None:
raise ValueError("Failed to fetch additions")
# Get removals
removed_coins = await self.get_removals(peer, block, added_coins, removals)
if removed_coins is None:
raise ValueError("Failed to fetch removals")
hbr = HeaderBlockRecord(block, added_coins, removed_coins)
else:
hbr = HeaderBlockRecord(block, [], [])
header_block_records.append(hbr)
(
result,
error,
fork_h,
) = await self.wallet_state_manager.blockchain.receive_block(hbr, trusted=trusted)
if result == ReceiveBlockResult.NEW_PEAK:
if not self.wallet_state_manager.sync_mode:
self.wallet_state_manager.blockchain.clean_block_records()
self.wallet_state_manager.state_changed("new_block")
self.wallet_state_manager.state_changed("sync_changed")
elif result == ReceiveBlockResult.INVALID_BLOCK:
self.log.info(f"Invalid block from peer: {peer.get_peer_info()} {error}")
await peer.close()
return None
else:
self.log.debug(f"Result: {result}")
async def new_peak_wallet(self, peak: wallet_protocol.NewPeakWallet, peer: WSChiaConnection):
if self.wallet_state_manager is None:
return None
curr_peak = self.wallet_state_manager.blockchain.get_peak()
if curr_peak is not None and curr_peak.weight >= peak.weight:
return None
if self.new_peak_lock is None:
self.new_peak_lock = asyncio.Lock()
async with self.new_peak_lock:
request = wallet_protocol.RequestBlockHeader(peak.height)
response: Optional[RespondBlockHeader] = await peer.request_block_header(request)
if response is None or not isinstance(response, RespondBlockHeader) or response.header_block is None:
return None
header_block = response.header_block
if (curr_peak is None and header_block.height < self.constants.WEIGHT_PROOF_RECENT_BLOCKS) or (
curr_peak is not None and curr_peak.height > header_block.height - 200
):
top = header_block
blocks = [top]
# Fetch blocks backwards until we hit the one that we have,
# then complete them with additions / removals going forward
while not self.wallet_state_manager.blockchain.contains_block(top.prev_header_hash) and top.height > 0:
request_prev = wallet_protocol.RequestBlockHeader(top.height - 1)
response_prev: Optional[RespondBlockHeader] = await peer.request_block_header(request_prev)
if response_prev is None:
return None
if not isinstance(response_prev, RespondBlockHeader):
return None
prev_head = response_prev.header_block
blocks.append(prev_head)
top = prev_head
blocks.reverse()
await self.complete_blocks(blocks, peer)
await self.wallet_state_manager.create_more_puzzle_hashes()
elif header_block.height >= self.constants.WEIGHT_PROOF_RECENT_BLOCKS:
# Request weight proof
# Sync if PoW validates
if self.wallet_state_manager.sync_mode:
self.last_new_peak_messages.put(peer, peak)
return None
weight_request = RequestProofOfWeight(header_block.height, header_block.header_hash)
weight_proof_response: RespondProofOfWeight = await peer.request_proof_of_weight(
weight_request, timeout=360
)
if weight_proof_response is None:
return None
weight_proof = weight_proof_response.wp
if self.wallet_state_manager is None:
return None
if self.server is not None and self.server.is_trusted_peer(peer, self.config["trusted_peers"]):
valid, fork_point = self.wallet_state_manager.weight_proof_handler.get_fork_point_no_validations(
weight_proof
)
else:
valid, fork_point, _ = await self.wallet_state_manager.weight_proof_handler.validate_weight_proof(
weight_proof
)
if not valid:
self.log.error(
f"invalid weight proof, num of epochs {len(weight_proof.sub_epochs)}"
f" recent blocks num ,{len(weight_proof.recent_chain_data)}"
)
self.log.debug(f"{weight_proof}")
return None
self.log.info(f"Validated, fork point is {fork_point}")
self.wallet_state_manager.sync_store.add_potential_fork_point(
header_block.header_hash, uint32(fork_point)
)
self.wallet_state_manager.sync_store.add_potential_peak(header_block)
self.start_sync()
def start_sync(self) -> None:
self.log.info("self.sync_event.set()")
self.sync_event.set()
async def check_new_peak(self) -> None:
if self.wallet_state_manager is None:
return None
current_peak: Optional[BlockRecord] = self.wallet_state_manager.blockchain.get_peak()
if current_peak is None:
return None
potential_peaks: List[
Tuple[bytes32, HeaderBlock]
] = self.wallet_state_manager.sync_store.get_potential_peaks_tuples()
for _, block in potential_peaks:
if current_peak.weight < block.weight:
await asyncio.sleep(5)
self.start_sync()
return None
async def sync_job(self) -> None:
while True:
self.log.info("Loop start in sync job")
if self._shut_down is True:
break
asyncio.create_task(self.check_new_peak())
await self.sync_event.wait()
self.last_new_peak_messages = LRUCache(5)
self.sync_event.clear()
if self._shut_down is True:
break
try:
assert self.wallet_state_manager is not None
self.wallet_state_manager.set_sync_mode(True)
await self._sync()
except Exception as e:
tb = traceback.format_exc()
self.log.error(f"Loop exception in sync {e}. {tb}")
finally:
if self.wallet_state_manager is not None:
self.wallet_state_manager.set_sync_mode(False)
for peer, peak in self.last_new_peak_messages.cache.items():
asyncio.create_task(self.new_peak_wallet(peak, peer))
self.log.info("Loop end in sync job")
async def _sync(self) -> None:
"""
Wallet has fallen far behind (or is starting up for the first time), and must be synced
up to the LCA of the blockchain.
"""
if self.wallet_state_manager is None or self.backup_initialized is False or self.server is None:
return None
highest_weight: uint128 = uint128(0)
peak_height: uint32 = uint32(0)
peak: Optional[HeaderBlock] = None
potential_peaks: List[
Tuple[bytes32, HeaderBlock]
] = self.wallet_state_manager.sync_store.get_potential_peaks_tuples()
self.log.info(f"Have collected {len(potential_peaks)} potential peaks")
for header_hash, potential_peak_block in potential_peaks:
if potential_peak_block.weight > highest_weight:
highest_weight = potential_peak_block.weight
peak_height = potential_peak_block.height
peak = potential_peak_block
if peak_height is None or peak_height == 0:
return None
if self.wallet_state_manager.peak is not None and highest_weight <= self.wallet_state_manager.peak.weight:
self.log.info("Not performing sync, already caught up.")
return None
peers: List[WSChiaConnection] = self.server.get_full_node_connections()
if len(peers) == 0:
self.log.info("No peers to sync to")
return None
async with self.wallet_state_manager.blockchain.lock:
fork_height = None
if peak is not None:
fork_height = self.wallet_state_manager.sync_store.get_potential_fork_point(peak.header_hash)
our_peak_height = self.wallet_state_manager.blockchain.get_peak_height()
ses_heigths = self.wallet_state_manager.blockchain.get_ses_heights()
if len(ses_heigths) > 2 and our_peak_height is not None:
ses_heigths.sort()
max_fork_ses_height = ses_heigths[-3]
# This is fork point in SES in case where fork was not detected
if (
self.wallet_state_manager.blockchain.get_peak_height() is not None
and fork_height == max_fork_ses_height
):
peers = self.server.get_full_node_connections()
for peer in peers:
# Grab a block at peak + 1 and check if fork point is actually our current height
potential_height = uint32(our_peak_height + 1)
block_response: Optional[Any] = await peer.request_header_blocks(
wallet_protocol.RequestHeaderBlocks(potential_height, potential_height)
)
if block_response is not None and isinstance(
block_response, wallet_protocol.RespondHeaderBlocks
):
our_peak = self.wallet_state_manager.blockchain.get_peak()
if (
our_peak is not None
and block_response.header_blocks[0].prev_header_hash == our_peak.header_hash
):
fork_height = our_peak_height
break
if fork_height is None:
fork_height = uint32(0)
await self.wallet_state_manager.blockchain.warmup(fork_height)
batch_size = self.constants.MAX_BLOCK_COUNT_PER_REQUESTS
advanced_peak = False
for i in range(max(0, fork_height - 1), peak_height, batch_size):
start_height = i
end_height = min(peak_height, start_height + batch_size)
peers = self.server.get_full_node_connections()
added = False
for peer in peers:
try:
added, advanced_peak = await self.fetch_blocks_and_validate(
peer, uint32(start_height), uint32(end_height), None if advanced_peak else fork_height
)
if added:
break
except Exception as e:
await peer.close()
exc = traceback.format_exc()
self.log.error(f"Error while trying to fetch from peer:{e} {exc}")
if not added:
raise RuntimeError(f"Was not able to add blocks {start_height}-{end_height}")
peak = self.wallet_state_manager.blockchain.get_peak()
assert peak is not None
self.wallet_state_manager.blockchain.clean_block_record(
min(
end_height - self.constants.BLOCKS_CACHE_SIZE,
peak.height - self.constants.BLOCKS_CACHE_SIZE,
)
)
async def fetch_blocks_and_validate(
self,
peer: WSChiaConnection,
height_start: uint32,
height_end: uint32,
fork_point_with_peak: Optional[uint32],
) -> Tuple[bool, bool]:
"""
Returns whether the blocks validated, and whether the peak was advanced
"""
if self.wallet_state_manager is None:
return False, False
self.log.info(f"Requesting blocks {height_start}-{height_end}")
request = RequestHeaderBlocks(uint32(height_start), uint32(height_end))
res: Optional[RespondHeaderBlocks] = await peer.request_header_blocks(request)
if res is None or not isinstance(res, RespondHeaderBlocks):
raise ValueError("Peer returned no response")
header_blocks: List[HeaderBlock] = res.header_blocks
advanced_peak = False
if header_blocks is None:
raise ValueError(f"No response from peer {peer}")
if (
self.full_node_peer is not None
and peer.peer_host == self.full_node_peer.host
or peer.peer_host == "127.0.0.1"
):
trusted = True
pre_validation_results: Optional[List[PreValidationResult]] = None
else:
trusted = False
pre_validation_results = await self.wallet_state_manager.blockchain.pre_validate_blocks_multiprocessing(
header_blocks
)
if pre_validation_results is None:
return False, advanced_peak
assert len(header_blocks) == len(pre_validation_results)
for i in range(len(header_blocks)):
header_block = header_blocks[i]
if not trusted and pre_validation_results is not None and pre_validation_results[i].error is not None:
raise ValidationError(Err(pre_validation_results[i].error))
fork_point_with_old_peak = None if advanced_peak else fork_point_with_peak
if header_block.is_transaction_block:
# Find additions and removals
(additions, removals,) = await self.wallet_state_manager.get_filter_additions_removals(
header_block, header_block.transactions_filter, fork_point_with_old_peak
)
# Get Additions
added_coins = await self.get_additions(peer, header_block, additions)
if added_coins is None:
raise ValueError("Failed to fetch additions")
# Get removals
removed_coins = await self.get_removals(peer, header_block, added_coins, removals)
if removed_coins is None:
raise ValueError("Failed to fetch removals")
header_block_record = HeaderBlockRecord(header_block, added_coins, removed_coins)
else:
header_block_record = HeaderBlockRecord(header_block, [], [])
start_t = time.time()
if trusted:
(result, error, fork_h,) = await self.wallet_state_manager.blockchain.receive_block(
header_block_record, None, trusted, fork_point_with_old_peak
)
else:
assert pre_validation_results is not None
(result, error, fork_h,) = await self.wallet_state_manager.blockchain.receive_block(
header_block_record, pre_validation_results[i], trusted, fork_point_with_old_peak
)
self.log.debug(
f"Time taken to validate {header_block.height} with fork "
f"{fork_point_with_old_peak}: {time.time() - start_t}"
)
if result == ReceiveBlockResult.NEW_PEAK:
advanced_peak = True
self.wallet_state_manager.state_changed("new_block")
elif result == ReceiveBlockResult.INVALID_BLOCK:
raise ValueError("Value error peer sent us invalid block")
if advanced_peak:
await self.wallet_state_manager.create_more_puzzle_hashes()
return True, advanced_peak
def validate_additions(
self,
coins: List[Tuple[bytes32, List[Coin]]],
proofs: Optional[List[Tuple[bytes32, bytes, Optional[bytes]]]],
root,
):
if proofs is None:
# Verify root
additions_merkle_set = MerkleSet()
# Addition Merkle set contains puzzlehash and hash of all coins with that puzzlehash
for puzzle_hash, coins_l in coins:
additions_merkle_set.add_already_hashed(puzzle_hash)
additions_merkle_set.add_already_hashed(hash_coin_list(coins_l))
additions_root = additions_merkle_set.get_root()
if root != additions_root:
return False
else:
for i in range(len(coins)):
assert coins[i][0] == proofs[i][0]
coin_list_1: List[Coin] = coins[i][1]
puzzle_hash_proof: bytes32 = proofs[i][1]
coin_list_proof: Optional[bytes32] = proofs[i][2]
if len(coin_list_1) == 0:
# Verify exclusion proof for puzzle hash
not_included = confirm_not_included_already_hashed(
root,
coins[i][0],
puzzle_hash_proof,
)
if not_included is False:
return False
else:
try:
# Verify inclusion proof for coin list
included = confirm_included_already_hashed(
root,
hash_coin_list(coin_list_1),
coin_list_proof,
)
if included is False:
return False
except AssertionError:
return False
try:
# Verify inclusion proof for puzzle hash
included = confirm_included_already_hashed(
root,
coins[i][0],
puzzle_hash_proof,
)
if included is False:
return False
except AssertionError:
return False
return True
def validate_removals(self, coins, proofs, root):
if proofs is None:
# If there are no proofs, it means all removals were returned in the response.
# we must find the ones relevant to our wallets.
# Verify removals root
removals_merkle_set = MerkleSet()
for name_coin in coins:
# TODO review all verification
name, coin = name_coin
if coin is not None:
removals_merkle_set.add_already_hashed(coin.name())
removals_root = removals_merkle_set.get_root()
if root != removals_root:
return False
else:
# This means the full node has responded only with the relevant removals
# for our wallet. Each merkle proof must be verified.
if len(coins) != len(proofs):
return False
for i in range(len(coins)):
# Coins are in the same order as proofs
if coins[i][0] != proofs[i][0]:
return False
coin = coins[i][1]
if coin is None:
# Verifies merkle proof of exclusion
not_included = confirm_not_included_already_hashed(
root,
coins[i][0],
proofs[i][1],
)
if not_included is False:
return False
else:
# Verifies merkle proof of inclusion of coin name
if coins[i][0] != coin.name():
return False
included = confirm_included_already_hashed(
root,
coin.name(),
proofs[i][1],
)
if included is False:
return False
return True
async def get_additions(self, peer: WSChiaConnection, block_i, additions) -> Optional[List[Coin]]:
if len(additions) > 0:
additions_request = RequestAdditions(block_i.height, block_i.header_hash, additions)
additions_res: Optional[Union[RespondAdditions, RejectAdditionsRequest]] = await peer.request_additions(
additions_request
)
if additions_res is None:
await peer.close()
return None
elif isinstance(additions_res, RespondAdditions):
validated = self.validate_additions(
additions_res.coins,
additions_res.proofs,
block_i.foliage_transaction_block.additions_root,
)
if not validated:
await peer.close()
return None
added_coins = []
for ph_coins in additions_res.coins:
ph, coins = ph_coins
added_coins.extend(coins)
return added_coins
elif isinstance(additions_res, RejectRemovalsRequest):
await peer.close()
return None
return None
else:
added_coins = []
return added_coins
async def get_removals(self, peer: WSChiaConnection, block_i, additions, removals) -> Optional[List[Coin]]:
assert self.wallet_state_manager is not None
request_all_removals = False
# Check if we need all removals
for coin in additions:
puzzle_store = self.wallet_state_manager.puzzle_store
record_info: Optional[DerivationRecord] = await puzzle_store.get_derivation_record_for_puzzle_hash(
coin.puzzle_hash.hex()
)
if record_info is not None and record_info.wallet_type == WalletType.COLOURED_COIN:
# TODO why ?
request_all_removals = True
break
if record_info is not None and record_info.wallet_type == WalletType.DISTRIBUTED_ID:
request_all_removals = True
break
if len(removals) > 0 or request_all_removals:
if request_all_removals:
removals_request = wallet_protocol.RequestRemovals(block_i.height, block_i.header_hash, None)
else:
removals_request = wallet_protocol.RequestRemovals(block_i.height, block_i.header_hash, removals)
removals_res: Optional[Union[RespondRemovals, RejectRemovalsRequest]] = await peer.request_removals(
removals_request
)
if removals_res is None:
return None
elif isinstance(removals_res, RespondRemovals):
validated = self.validate_removals(
removals_res.coins,
removals_res.proofs,
block_i.foliage_transaction_block.removals_root,
)
if validated is False:
await peer.close()
return None
removed_coins = []
for _, coins_l in removals_res.coins:
if coins_l is not None:
removed_coins.append(coins_l)
return removed_coins
elif isinstance(removals_res, RejectRemovalsRequest):
return None
else:
return None
else:
return []