chia-blockchain/chia/server/ws_connection.py

455 lines
18 KiB
Python

import asyncio
import logging
import time
import traceback
from typing import Any, Callable, Dict, List, Optional
from aiohttp import WSCloseCode, WSMessage, WSMsgType
from chia.cmds.init_funcs import chia_full_version_str
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.shared_protocol import Capability, Handshake
from chia.server.outbound_message import Message, NodeType, make_msg
from chia.server.rate_limits import RateLimiter
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.peer_info import PeerInfo
from chia.util.errors import Err, ProtocolError
from chia.util.ints import uint8, uint16
# Each message is prepended with LENGTH_BYTES bytes specifying the length
from chia.util.network import class_for_type, is_localhost
# Max size 2^(8*4) which is around 4GiB
LENGTH_BYTES: int = 4
class WSChiaConnection:
"""
Represents a connection to another node. Local host and port are ours, while peer host and
port are the host and port of the peer that we are connected to. Node_id and connection_type are
set after the handshake is performed in this connection.
"""
def __init__(
self,
local_type: NodeType,
ws: Any, # Websocket
server_port: int,
log: logging.Logger,
is_outbound: bool,
is_feeler: bool, # Special type of connection, that disconnects after the handshake.
peer_host,
incoming_queue,
close_callback: Callable,
peer_id,
inbound_rate_limit_percent: int,
outbound_rate_limit_percent: int,
close_event=None,
session=None,
):
# Local properties
self.ws: Any = ws
self.local_type = local_type
self.local_port = server_port
# Remote properties
self.peer_host = peer_host
peername = self.ws._writer.transport.get_extra_info("peername")
if peername is None:
raise ValueError(f"Was not able to get peername from {self.ws_witer} at {self.peer_host}")
connection_port = peername[1]
self.peer_port = connection_port
self.peer_server_port: Optional[uint16] = None
self.peer_node_id = peer_id
self.log = log
# connection properties
self.is_outbound = is_outbound
self.is_feeler = is_feeler
# ChiaConnection metrics
self.creation_time = time.time()
self.bytes_read = 0
self.bytes_written = 0
self.last_message_time: float = 0
# Messaging
self.incoming_queue: asyncio.Queue = incoming_queue
self.outgoing_queue: asyncio.Queue = asyncio.Queue()
self.inbound_task: Optional[asyncio.Task] = None
self.outbound_task: Optional[asyncio.Task] = None
self.active: bool = False # once handshake is successful this will be changed to True
self.close_event: asyncio.Event = close_event
self.session = session
self.close_callback = close_callback
self.pending_requests: Dict[bytes32, asyncio.Event] = {}
self.pending_timeouts: Dict[bytes32, asyncio.Task] = {}
self.request_results: Dict[bytes32, Message] = {}
self.closed = False
self.connection_type: Optional[NodeType] = None
if is_outbound:
self.request_nonce: uint16 = uint16(0)
else:
# Different nonce to reduce chances of overlap. Each peer will increment the nonce by one for each
# request. The receiving peer (not is_outbound), will use 2^15 to 2^16 - 1
self.request_nonce = uint16(2 ** 15)
# This means that even if the other peer's boundaries for each minute are not aligned, we will not
# disconnect. Also it allows a little flexibility.
self.outbound_rate_limiter = RateLimiter(incoming=False, percentage_of_limit=outbound_rate_limit_percent)
self.inbound_rate_limiter = RateLimiter(incoming=True, percentage_of_limit=inbound_rate_limit_percent)
async def perform_handshake(self, network_id: str, protocol_version: str, server_port: int, local_type: NodeType):
if self.is_outbound:
outbound_handshake = make_msg(
ProtocolMessageTypes.handshake,
Handshake(
network_id,
protocol_version,
chia_full_version_str(),
uint16(server_port),
uint8(local_type.value),
[(uint16(Capability.BASE.value), "1")],
),
)
assert outbound_handshake is not None
await self._send_message(outbound_handshake)
inbound_handshake_msg = await self._read_one_message()
if inbound_handshake_msg is None:
raise ProtocolError(Err.INVALID_HANDSHAKE)
inbound_handshake = Handshake.from_bytes(inbound_handshake_msg.data)
if ProtocolMessageTypes(inbound_handshake_msg.type) != ProtocolMessageTypes.handshake:
raise ProtocolError(Err.INVALID_HANDSHAKE)
if inbound_handshake.network_id != network_id:
raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID)
self.peer_server_port = inbound_handshake.server_port
self.connection_type = NodeType(inbound_handshake.node_type)
else:
try:
message = await self._read_one_message()
except Exception:
raise ProtocolError(Err.INVALID_HANDSHAKE)
if message is None:
raise ProtocolError(Err.INVALID_HANDSHAKE)
inbound_handshake = Handshake.from_bytes(message.data)
if ProtocolMessageTypes(message.type) != ProtocolMessageTypes.handshake:
raise ProtocolError(Err.INVALID_HANDSHAKE)
if inbound_handshake.network_id != network_id:
raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID)
outbound_handshake = make_msg(
ProtocolMessageTypes.handshake,
Handshake(
network_id,
protocol_version,
chia_full_version_str(),
uint16(server_port),
uint8(local_type.value),
[(uint16(Capability.BASE.value), "1")],
),
)
await self._send_message(outbound_handshake)
self.peer_server_port = inbound_handshake.server_port
self.connection_type = NodeType(inbound_handshake.node_type)
self.outbound_task = asyncio.create_task(self.outbound_handler())
self.inbound_task = asyncio.create_task(self.inbound_handler())
return True
async def close(self, ban_time: int = 0, ws_close_code: WSCloseCode = WSCloseCode.OK, error: Optional[Err] = None):
"""
Closes the connection, and finally calls the close_callback on the server, so the connections gets removed
from the global list.
"""
if self.closed:
return None
self.closed = True
if error is None:
message = b""
else:
message = str(int(error.value)).encode("utf-8")
try:
if self.inbound_task is not None:
self.inbound_task.cancel()
if self.outbound_task is not None:
self.outbound_task.cancel()
if self.ws is not None and self.ws._closed is False:
await self.ws.close(code=ws_close_code, message=message)
if self.session is not None:
await self.session.close()
if self.close_event is not None:
self.close_event.set()
self.cancel_pending_timeouts()
except Exception:
error_stack = traceback.format_exc()
self.log.warning(f"Exception closing socket: {error_stack}")
self.close_callback(self, ban_time)
raise
self.close_callback(self, ban_time)
def cancel_pending_timeouts(self):
for _, task in self.pending_timeouts.items():
task.cancel()
async def outbound_handler(self):
try:
while not self.closed:
msg = await self.outgoing_queue.get()
if msg is not None:
await self._send_message(msg)
except asyncio.CancelledError:
pass
except BrokenPipeError as e:
self.log.warning(f"{e} {self.peer_host}")
except ConnectionResetError as e:
self.log.warning(f"{e} {self.peer_host}")
except Exception as e:
error_stack = traceback.format_exc()
self.log.error(f"Exception: {e} with {self.peer_host}")
self.log.error(f"Exception Stack: {error_stack}")
async def inbound_handler(self):
try:
while not self.closed:
message: Message = await self._read_one_message()
if message is not None:
if message.id in self.pending_requests:
self.request_results[message.id] = message
event = self.pending_requests[message.id]
event.set()
else:
await self.incoming_queue.put((message, self))
else:
continue
except asyncio.CancelledError:
self.log.debug("Inbound_handler task cancelled")
except Exception as e:
error_stack = traceback.format_exc()
self.log.error(f"Exception: {e}")
self.log.error(f"Exception Stack: {error_stack}")
async def send_message(self, message: Message):
"""Send message sends a message with no tracking / callback."""
if self.closed:
return None
await self.outgoing_queue.put(message)
def __getattr__(self, attr_name: str):
# TODO KWARGS
async def invoke(*args, **kwargs):
timeout = 60
if "timeout" in kwargs:
timeout = kwargs["timeout"]
attribute = getattr(class_for_type(self.connection_type), attr_name, None)
if attribute is None:
raise AttributeError(f"Node type {self.connection_type} does not have method {attr_name}")
msg = Message(uint8(getattr(ProtocolMessageTypes, attr_name).value), None, args[0])
request_start_t = time.time()
result = await self.create_request(msg, timeout)
self.log.debug(
f"Time for request {attr_name}: {self.get_peer_info()} = {time.time() - request_start_t}, "
f"None? {result is None}"
)
if result is not None:
ret_attr = getattr(class_for_type(self.local_type), ProtocolMessageTypes(result.type).name, None)
req_annotations = ret_attr.__annotations__
req = None
for key in req_annotations:
if key == "return" or key == "peer":
continue
else:
req = req_annotations[key]
assert req is not None
result = req.from_bytes(result.data)
return result
return invoke
async def create_request(self, message_no_id: Message, timeout: int) -> Optional[Message]:
"""Sends a message and waits for a response."""
if self.closed:
return None
# We will wait for this event, it will be set either by the response, or the timeout
event = asyncio.Event()
# The request nonce is an integer between 0 and 2**16 - 1, which is used to match requests to responses
# If is_outbound, 0 <= nonce < 2^15, else 2^15 <= nonce < 2^16
request_id = self.request_nonce
if self.is_outbound:
self.request_nonce = uint16(self.request_nonce + 1) if self.request_nonce != (2 ** 15 - 1) else uint16(0)
else:
self.request_nonce = (
uint16(self.request_nonce + 1) if self.request_nonce != (2 ** 16 - 1) else uint16(2 ** 15)
)
message = Message(message_no_id.type, request_id, message_no_id.data)
self.pending_requests[message.id] = event
await self.outgoing_queue.put(message)
# If the timeout passes, we set the event
async def time_out(req_id, req_timeout):
try:
await asyncio.sleep(req_timeout)
if req_id in self.pending_requests:
self.pending_requests[req_id].set()
except asyncio.CancelledError:
if req_id in self.pending_requests:
self.pending_requests[req_id].set()
raise
timeout_task = asyncio.create_task(time_out(message.id, timeout))
self.pending_timeouts[message.id] = timeout_task
await event.wait()
self.pending_requests.pop(message.id)
result: Optional[Message] = None
if message.id in self.request_results:
result = self.request_results[message.id]
assert result is not None
self.log.debug(f"<- {ProtocolMessageTypes(result.type).name} from: {self.peer_host}:{self.peer_port}")
self.request_results.pop(result.id)
return result
async def reply_to_request(self, response: Message):
if self.closed:
return None
await self.outgoing_queue.put(response)
async def send_messages(self, messages: List[Message]):
if self.closed:
return None
for message in messages:
await self.outgoing_queue.put(message)
async def _wait_and_retry(self, msg: Message, queue: asyncio.Queue):
try:
await asyncio.sleep(1)
await queue.put(msg)
except Exception as e:
self.log.debug(f"Exception {e} while waiting to retry sending rate limited message")
return None
async def _send_message(self, message: Message):
encoded: bytes = bytes(message)
size = len(encoded)
assert len(encoded) < (2 ** (LENGTH_BYTES * 8))
if not self.outbound_rate_limiter.process_msg_and_check(message):
if not is_localhost(self.peer_host):
self.log.debug(
f"Rate limiting ourselves. message type: {ProtocolMessageTypes(message.type).name}, "
f"peer: {self.peer_host}"
)
# TODO: fix this special case. This function has rate limits which are too low.
if ProtocolMessageTypes(message.type) != ProtocolMessageTypes.respond_peers:
asyncio.create_task(self._wait_and_retry(message, self.outgoing_queue))
return None
else:
self.log.debug(
f"Not rate limiting ourselves. message type: {ProtocolMessageTypes(message.type).name}, "
f"peer: {self.peer_host}"
)
await self.ws.send_bytes(encoded)
self.log.debug(f"-> {ProtocolMessageTypes(message.type).name} to peer {self.peer_host} {self.peer_node_id}")
self.bytes_written += size
async def _read_one_message(self) -> Optional[Message]:
try:
message: WSMessage = await self.ws.receive(30)
except asyncio.TimeoutError:
# self.ws._closed if we didn't receive a ping / pong
if self.ws._closed:
asyncio.create_task(self.close())
await asyncio.sleep(3)
return None
return None
if self.connection_type is not None:
connection_type_str = NodeType(self.connection_type).name.lower()
else:
connection_type_str = ""
if message.type == WSMsgType.CLOSING:
self.log.debug(
f"Closing connection to {connection_type_str} {self.peer_host}:"
f"{self.peer_server_port}/"
f"{self.peer_port}"
)
asyncio.create_task(self.close())
await asyncio.sleep(3)
elif message.type == WSMsgType.CLOSE:
self.log.debug(
f"Peer closed connection {connection_type_str} {self.peer_host}:"
f"{self.peer_server_port}/"
f"{self.peer_port}"
)
asyncio.create_task(self.close())
await asyncio.sleep(3)
elif message.type == WSMsgType.CLOSED:
if not self.closed:
asyncio.create_task(self.close())
await asyncio.sleep(3)
return None
elif message.type == WSMsgType.BINARY:
data = message.data
full_message_loaded: Message = Message.from_bytes(data)
self.bytes_read += len(data)
self.last_message_time = time.time()
try:
message_type = ProtocolMessageTypes(full_message_loaded.type).name
except Exception:
message_type = "Unknown"
if not self.inbound_rate_limiter.process_msg_and_check(full_message_loaded):
if self.local_type == NodeType.FULL_NODE and not is_localhost(self.peer_host):
self.log.error(
f"Peer has been rate limited and will be disconnected: {self.peer_host}, "
f"message: {message_type}"
)
# Only full node disconnects peers, to prevent abuse and crashing timelords, farmers, etc
asyncio.create_task(self.close(300))
await asyncio.sleep(3)
return None
else:
self.log.warning(
f"Peer surpassed rate limit {self.peer_host}, message: {message_type}, "
f"port {self.peer_port} but not disconnecting"
)
return full_message_loaded
return full_message_loaded
elif message.type == WSMsgType.ERROR:
self.log.error(f"WebSocket Error: {message}")
if message.data.code == WSCloseCode.MESSAGE_TOO_BIG:
asyncio.create_task(self.close(300))
else:
asyncio.create_task(self.close())
await asyncio.sleep(3)
else:
self.log.error(f"Unexpected WebSocket message type: {message}")
asyncio.create_task(self.close())
await asyncio.sleep(3)
return None
def get_peer_info(self) -> Optional[PeerInfo]:
result = self.ws._writer.transport.get_extra_info("peername")
if result is None:
return None
connection_host = result[0]
port = self.peer_server_port if self.peer_server_port is not None else self.peer_port
return PeerInfo(connection_host, port)