343 lines
13 KiB
Python
343 lines
13 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import traceback
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
import aiohttp
|
|
|
|
from chia.server.outbound_message import NodeType
|
|
from chia.server.server import ssl_context_for_server
|
|
from chia.types.peer_info import PeerInfo
|
|
from chia.util.byte_types import hexstr_to_bytes
|
|
from chia.util.ints import uint16
|
|
from chia.util.json_util import dict_to_json_str, obj_to_response
|
|
from chia.util.ws_message import create_payload, create_payload_dict, format_response, pong
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class RpcServer:
|
|
"""
|
|
Implementation of RPC server.
|
|
"""
|
|
|
|
def __init__(self, rpc_api: Any, service_name: str, stop_cb: Callable, root_path, net_config):
|
|
self.rpc_api = rpc_api
|
|
self.stop_cb: Callable = stop_cb
|
|
self.log = log
|
|
self.shut_down = False
|
|
self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None
|
|
self.service_name = service_name
|
|
self.root_path = root_path
|
|
self.net_config = net_config
|
|
self.crt_path = root_path / net_config["daemon_ssl"]["private_crt"]
|
|
self.key_path = root_path / net_config["daemon_ssl"]["private_key"]
|
|
self.ca_cert_path = root_path / net_config["private_ssl_ca"]["crt"]
|
|
self.ca_key_path = root_path / net_config["private_ssl_ca"]["key"]
|
|
self.ssl_context = ssl_context_for_server(self.ca_cert_path, self.ca_key_path, self.crt_path, self.key_path)
|
|
|
|
async def stop(self):
|
|
self.shut_down = True
|
|
if self.websocket is not None:
|
|
await self.websocket.close()
|
|
|
|
async def _state_changed(self, *args):
|
|
if self.websocket is None:
|
|
return None
|
|
payloads: List[Dict] = await self.rpc_api._state_changed(*args)
|
|
|
|
change = args[0]
|
|
if change == "add_connection" or change == "close_connection":
|
|
data = await self.get_connections({})
|
|
if data is not None:
|
|
|
|
payload = create_payload_dict(
|
|
"get_connections",
|
|
data,
|
|
self.service_name,
|
|
"wallet_ui",
|
|
)
|
|
payloads.append(payload)
|
|
for payload in payloads:
|
|
if "success" not in payload["data"]:
|
|
payload["data"]["success"] = True
|
|
try:
|
|
await self.websocket.send_str(dict_to_json_str(payload))
|
|
except Exception:
|
|
tb = traceback.format_exc()
|
|
self.log.warning(f"Sending data failed. Exception {tb}.")
|
|
|
|
def state_changed(self, *args):
|
|
if self.websocket is None:
|
|
return None
|
|
asyncio.create_task(self._state_changed(*args))
|
|
|
|
def _wrap_http_handler(self, f) -> Callable:
|
|
async def inner(request) -> aiohttp.web.Response:
|
|
request_data = await request.json()
|
|
try:
|
|
res_object = await f(request_data)
|
|
if res_object is None:
|
|
res_object = {}
|
|
if "success" not in res_object:
|
|
res_object["success"] = True
|
|
except Exception as e:
|
|
tb = traceback.format_exc()
|
|
self.log.warning(f"Error while handling message: {tb}")
|
|
if len(e.args) > 0:
|
|
res_object = {"success": False, "error": f"{e.args[0]}"}
|
|
else:
|
|
res_object = {"success": False, "error": f"{e}"}
|
|
|
|
return obj_to_response(res_object)
|
|
|
|
return inner
|
|
|
|
async def get_connections(self, request: Dict) -> Dict:
|
|
if self.rpc_api.service.server is None:
|
|
raise ValueError("Global connections is not set")
|
|
if self.rpc_api.service.server._local_type is NodeType.FULL_NODE:
|
|
# TODO add peaks for peers
|
|
connections = self.rpc_api.service.server.get_connections()
|
|
con_info = []
|
|
if self.rpc_api.service.sync_store is not None:
|
|
peak_store = self.rpc_api.service.sync_store.peer_to_peak
|
|
else:
|
|
peak_store = None
|
|
for con in connections:
|
|
if peak_store is not None and con.peer_node_id in peak_store:
|
|
peak_hash, peak_height, peak_weight = peak_store[con.peer_node_id]
|
|
else:
|
|
peak_height = None
|
|
peak_hash = None
|
|
peak_weight = None
|
|
con_dict = {
|
|
"type": con.connection_type,
|
|
"local_port": con.local_port,
|
|
"peer_host": con.peer_host,
|
|
"peer_port": con.peer_port,
|
|
"peer_server_port": con.peer_server_port,
|
|
"node_id": con.peer_node_id,
|
|
"creation_time": con.creation_time,
|
|
"bytes_read": con.bytes_read,
|
|
"bytes_written": con.bytes_written,
|
|
"last_message_time": con.last_message_time,
|
|
"peak_height": peak_height,
|
|
"peak_weight": peak_weight,
|
|
"peak_hash": peak_hash,
|
|
}
|
|
con_info.append(con_dict)
|
|
else:
|
|
connections = self.rpc_api.service.server.get_connections()
|
|
con_info = [
|
|
{
|
|
"type": con.connection_type,
|
|
"local_port": con.local_port,
|
|
"peer_host": con.peer_host,
|
|
"peer_port": con.peer_port,
|
|
"peer_server_port": con.peer_server_port,
|
|
"node_id": con.peer_node_id,
|
|
"creation_time": con.creation_time,
|
|
"bytes_read": con.bytes_read,
|
|
"bytes_written": con.bytes_written,
|
|
"last_message_time": con.last_message_time,
|
|
}
|
|
for con in connections
|
|
]
|
|
return {"connections": con_info}
|
|
|
|
async def open_connection(self, request: Dict):
|
|
host = request["host"]
|
|
port = request["port"]
|
|
target_node: PeerInfo = PeerInfo(host, uint16(int(port)))
|
|
on_connect = None
|
|
if hasattr(self.rpc_api.service, "on_connect"):
|
|
on_connect = self.rpc_api.service.on_connect
|
|
if getattr(self.rpc_api.service, "server", None) is None or not (
|
|
await self.rpc_api.service.server.start_client(target_node, on_connect)
|
|
):
|
|
raise ValueError("Start client failed, or server is not set")
|
|
return {}
|
|
|
|
async def close_connection(self, request: Dict):
|
|
node_id = hexstr_to_bytes(request["node_id"])
|
|
if self.rpc_api.service.server is None:
|
|
raise aiohttp.web.HTTPInternalServerError()
|
|
connections_to_close = [c for c in self.rpc_api.service.server.get_connections() if c.peer_node_id == node_id]
|
|
if len(connections_to_close) == 0:
|
|
raise ValueError(f"Connection with node_id {node_id.hex()} does not exist")
|
|
for connection in connections_to_close:
|
|
await connection.close()
|
|
return {}
|
|
|
|
async def stop_node(self, request):
|
|
"""
|
|
Shuts down the node.
|
|
"""
|
|
if self.stop_cb is not None:
|
|
self.stop_cb()
|
|
return {}
|
|
|
|
async def ws_api(self, message):
|
|
"""
|
|
This function gets called when new message is received via websocket.
|
|
"""
|
|
|
|
command = message["command"]
|
|
if message["ack"]:
|
|
return None
|
|
|
|
data = None
|
|
if "data" in message:
|
|
data = message["data"]
|
|
if command == "ping":
|
|
return pong()
|
|
|
|
f = getattr(self, command, None)
|
|
if f is not None:
|
|
return await f(data)
|
|
f = getattr(self.rpc_api, command, None)
|
|
if f is not None:
|
|
return await f(data)
|
|
|
|
raise ValueError(f"unknown_command {command}")
|
|
|
|
async def safe_handle(self, websocket, payload):
|
|
message = None
|
|
try:
|
|
message = json.loads(payload)
|
|
self.log.debug(f"Rpc call <- {message['command']}")
|
|
response = await self.ws_api(message)
|
|
|
|
# Only respond if we return something from api call
|
|
if response is not None:
|
|
log.debug(f"Rpc response -> {message['command']}")
|
|
# Set success to true automatically (unless it's already set)
|
|
if "success" not in response:
|
|
response["success"] = True
|
|
await websocket.send_str(format_response(message, response))
|
|
|
|
except Exception as e:
|
|
tb = traceback.format_exc()
|
|
self.log.warning(f"Error while handling message: {tb}")
|
|
if message is not None:
|
|
error = e.args[0] if e.args else e
|
|
res = {"success": False, "error": f"{error}"}
|
|
await websocket.send_str(format_response(message, res))
|
|
|
|
async def connection(self, ws):
|
|
data = {"service": self.service_name}
|
|
payload = create_payload("register_service", data, self.service_name, "daemon")
|
|
await ws.send_str(payload)
|
|
|
|
while True:
|
|
msg = await ws.receive()
|
|
if msg.type == aiohttp.WSMsgType.TEXT:
|
|
message = msg.data.strip()
|
|
# self.log.info(f"received message: {message}")
|
|
await self.safe_handle(ws, message)
|
|
elif msg.type == aiohttp.WSMsgType.BINARY:
|
|
self.log.debug("Received binary data")
|
|
elif msg.type == aiohttp.WSMsgType.PING:
|
|
self.log.debug("Ping received")
|
|
await ws.pong()
|
|
elif msg.type == aiohttp.WSMsgType.PONG:
|
|
self.log.debug("Pong received")
|
|
else:
|
|
if msg.type == aiohttp.WSMsgType.CLOSE:
|
|
self.log.debug("Closing RPC websocket")
|
|
await ws.close()
|
|
elif msg.type == aiohttp.WSMsgType.ERROR:
|
|
self.log.error("Error during receive %s" % ws.exception())
|
|
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
|
pass
|
|
|
|
break
|
|
|
|
await ws.close()
|
|
|
|
async def connect_to_daemon(self, self_hostname: str, daemon_port: uint16):
|
|
while True:
|
|
session = None
|
|
try:
|
|
if self.shut_down:
|
|
break
|
|
session = aiohttp.ClientSession()
|
|
|
|
async with session.ws_connect(
|
|
f"wss://{self_hostname}:{daemon_port}",
|
|
autoclose=True,
|
|
autoping=True,
|
|
heartbeat=60,
|
|
ssl_context=self.ssl_context,
|
|
max_msg_size=100 * 1024 * 1024,
|
|
) as ws:
|
|
self.websocket = ws
|
|
await self.connection(ws)
|
|
self.websocket = None
|
|
await session.close()
|
|
except aiohttp.ClientConnectorError:
|
|
self.log.warning(f"Cannot connect to daemon at ws://{self_hostname}:{daemon_port}")
|
|
except Exception as e:
|
|
tb = traceback.format_exc()
|
|
self.log.warning(f"Exception: {tb} {type(e)}")
|
|
finally:
|
|
if session is not None:
|
|
await session.close()
|
|
await asyncio.sleep(2)
|
|
|
|
|
|
async def start_rpc_server(
|
|
rpc_api: Any,
|
|
self_hostname: str,
|
|
daemon_port: uint16,
|
|
rpc_port: uint16,
|
|
stop_cb: Callable,
|
|
root_path: Path,
|
|
net_config,
|
|
connect_to_daemon=True,
|
|
):
|
|
"""
|
|
Starts an HTTP server with the following RPC methods, to be used by local clients to
|
|
query the node.
|
|
"""
|
|
app = aiohttp.web.Application()
|
|
rpc_server = RpcServer(rpc_api, rpc_api.service_name, stop_cb, root_path, net_config)
|
|
rpc_server.rpc_api.service._set_state_changed_callback(rpc_server.state_changed)
|
|
http_routes: Dict[str, Callable] = rpc_api.get_routes()
|
|
|
|
routes = [aiohttp.web.post(route, rpc_server._wrap_http_handler(func)) for (route, func) in http_routes.items()]
|
|
routes += [
|
|
aiohttp.web.post(
|
|
"/get_connections",
|
|
rpc_server._wrap_http_handler(rpc_server.get_connections),
|
|
),
|
|
aiohttp.web.post(
|
|
"/open_connection",
|
|
rpc_server._wrap_http_handler(rpc_server.open_connection),
|
|
),
|
|
aiohttp.web.post(
|
|
"/close_connection",
|
|
rpc_server._wrap_http_handler(rpc_server.close_connection),
|
|
),
|
|
aiohttp.web.post("/stop_node", rpc_server._wrap_http_handler(rpc_server.stop_node)),
|
|
]
|
|
|
|
app.add_routes(routes)
|
|
if connect_to_daemon:
|
|
daemon_connection = asyncio.create_task(rpc_server.connect_to_daemon(self_hostname, daemon_port))
|
|
runner = aiohttp.web.AppRunner(app, access_log=None)
|
|
await runner.setup()
|
|
site = aiohttp.web.TCPSite(runner, self_hostname, int(rpc_port), ssl_context=rpc_server.ssl_context)
|
|
await site.start()
|
|
|
|
async def cleanup():
|
|
await rpc_server.stop()
|
|
await runner.cleanup()
|
|
if connect_to_daemon:
|
|
await daemon_connection
|
|
|
|
return cleanup
|