chia-blockchain/chia/wallet/wallet_coin_store.py

253 lines
11 KiB
Python

from typing import Dict, List, Optional, Set
import aiosqlite
import sqlite3
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.db_wrapper import DBWrapper
from chia.util.ints import uint32, uint64
from chia.wallet.util.wallet_types import WalletType
from chia.wallet.wallet_coin_record import WalletCoinRecord
class WalletCoinStore:
"""
This object handles CoinRecords in DB used by wallet.
"""
db_connection: aiosqlite.Connection
# coin_record_cache keeps ALL coin records in memory. [record_name: record]
coin_record_cache: Dict[bytes32, WalletCoinRecord]
# unspent_coin_wallet_cache keeps ALL unspent coin records for wallet in memory [wallet_id: [record_name: record]]
unspent_coin_wallet_cache: Dict[int, Dict[bytes32, WalletCoinRecord]]
db_wrapper: DBWrapper
@classmethod
async def create(cls, wrapper: DBWrapper):
self = cls()
self.db_connection = wrapper.db
self.db_wrapper = wrapper
await self.db_connection.execute("pragma journal_mode=wal")
await self.db_connection.execute("pragma synchronous=2")
await self.db_connection.execute(
(
"CREATE TABLE IF NOT EXISTS coin_record("
"coin_name text PRIMARY KEY,"
" confirmed_height bigint,"
" spent_height bigint,"
" spent int,"
" coinbase int,"
" puzzle_hash text,"
" coin_parent text,"
" amount blob,"
" wallet_type int,"
" wallet_id int)"
)
)
# Useful for reorg lookups
await self.db_connection.execute(
"CREATE INDEX IF NOT EXISTS coin_confirmed_height on coin_record(confirmed_height)"
)
await self.db_connection.execute("CREATE INDEX IF NOT EXISTS coin_spent_height on coin_record(spent_height)")
await self.db_connection.execute("CREATE INDEX IF NOT EXISTS coin_spent on coin_record(spent)")
await self.db_connection.execute("CREATE INDEX IF NOT EXISTS coin_puzzlehash on coin_record(puzzle_hash)")
await self.db_connection.execute("CREATE INDEX IF NOT EXISTS wallet_type on coin_record(wallet_type)")
await self.db_connection.execute("CREATE INDEX IF NOT EXISTS wallet_id on coin_record(wallet_id)")
await self.db_connection.commit()
self.coin_record_cache = {}
self.unspent_coin_wallet_cache = {}
await self.rebuild_wallet_cache()
return self
async def _clear_database(self):
cursor = await self.db_connection.execute("DELETE FROM coin_record")
await cursor.close()
await self.db_connection.commit()
async def rebuild_wallet_cache(self):
# First update all coins that were reorged, then re-add coin_records
all_coins = await self.get_all_coins()
self.unspent_coin_wallet_cache = {}
self.coin_record_cache = {}
for coin_record in all_coins:
name = coin_record.name()
self.coin_record_cache[name] = coin_record
if coin_record.spent is False:
if coin_record.wallet_id not in self.unspent_coin_wallet_cache:
self.unspent_coin_wallet_cache[coin_record.wallet_id] = {}
self.unspent_coin_wallet_cache[coin_record.wallet_id][name] = coin_record
# Store CoinRecord in DB and ram cache
async def add_coin_record(self, record: WalletCoinRecord) -> None:
# update wallet cache
name = record.name()
self.coin_record_cache[name] = record
if record.wallet_id in self.unspent_coin_wallet_cache:
if record.spent and name in self.unspent_coin_wallet_cache[record.wallet_id]:
self.unspent_coin_wallet_cache[record.wallet_id].pop(name)
if not record.spent:
self.unspent_coin_wallet_cache[record.wallet_id][name] = record
else:
if not record.spent:
self.unspent_coin_wallet_cache[record.wallet_id] = {}
self.unspent_coin_wallet_cache[record.wallet_id][name] = record
cursor = await self.db_connection.execute(
"INSERT OR REPLACE INTO coin_record VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
name.hex(),
record.confirmed_block_height,
record.spent_block_height,
int(record.spent),
int(record.coinbase),
str(record.coin.puzzle_hash.hex()),
str(record.coin.parent_coin_info.hex()),
bytes(record.coin.amount),
record.wallet_type,
record.wallet_id,
),
)
await cursor.close()
# Update coin_record to be spent in DB
async def set_spent(self, coin_name: bytes32, height: uint32) -> WalletCoinRecord:
current: Optional[WalletCoinRecord] = await self.get_coin_record(coin_name)
assert current is not None
# assert current.spent is False
spent: WalletCoinRecord = WalletCoinRecord(
current.coin,
current.confirmed_block_height,
height,
True,
current.coinbase,
current.wallet_type,
current.wallet_id,
)
await self.add_coin_record(spent)
return spent
def coin_record_from_row(self, row: sqlite3.Row) -> WalletCoinRecord:
coin = Coin(bytes32(bytes.fromhex(row[6])), bytes32(bytes.fromhex(row[5])), uint64.from_bytes(row[7]))
return WalletCoinRecord(
coin, uint32(row[1]), uint32(row[2]), bool(row[3]), bool(row[4]), WalletType(row[8]), row[9]
)
async def get_coin_record(self, coin_name: bytes32) -> Optional[WalletCoinRecord]:
""" Returns CoinRecord with specified coin id. """
if coin_name in self.coin_record_cache:
return self.coin_record_cache[coin_name]
cursor = await self.db_connection.execute("SELECT * from coin_record WHERE coin_name=?", (coin_name.hex(),))
row = await cursor.fetchone()
await cursor.close()
if row is None:
return None
return self.coin_record_from_row(row)
async def get_first_coin_height(self) -> Optional[uint32]:
""" Returns height of first confirmed coin"""
cursor = await self.db_connection.execute("SELECT MIN(confirmed_height) FROM coin_record;")
row = await cursor.fetchone()
await cursor.close()
if row is not None and row[0] is not None:
return uint32(row[0])
return None
async def get_unspent_coins_at_height(self, height: Optional[uint32] = None) -> Set[WalletCoinRecord]:
"""
Returns set of CoinRecords that have not been spent yet. If a height is specified,
We can also return coins that were unspent at this height (but maybe spent later).
Finally, the coins must be confirmed at the height or less.
"""
if height is None:
all_unspent = set()
for name, coin_record in self.coin_record_cache.items():
if coin_record.spent is False:
all_unspent.add(coin_record)
return all_unspent
else:
all_unspent = set()
for name, coin_record in self.coin_record_cache.items():
if (
coin_record.spent is False
or coin_record.spent_block_height > height >= coin_record.confirmed_block_height
):
all_unspent.add(coin_record)
return all_unspent
async def get_unspent_coins_for_wallet(self, wallet_id: int) -> Set[WalletCoinRecord]:
""" Returns set of CoinRecords that have not been spent yet for a wallet. """
if wallet_id in self.unspent_coin_wallet_cache:
wallet_coins: Dict[bytes32, WalletCoinRecord] = self.unspent_coin_wallet_cache[wallet_id]
return set(wallet_coins.values())
else:
return set()
async def get_all_coins(self) -> Set[WalletCoinRecord]:
""" Returns set of all CoinRecords."""
cursor = await self.db_connection.execute("SELECT * from coin_record")
rows = await cursor.fetchall()
await cursor.close()
return set(self.coin_record_from_row(row) for row in rows)
# Checks DB and DiffStores for CoinRecords with puzzle_hash and returns them
async def get_coin_records_by_puzzle_hash(self, puzzle_hash: bytes32) -> List[WalletCoinRecord]:
"""Returns a list of all coin records with the given puzzle hash"""
cursor = await self.db_connection.execute("SELECT * from coin_record WHERE puzzle_hash=?", (puzzle_hash.hex(),))
rows = await cursor.fetchall()
await cursor.close()
return [self.coin_record_from_row(row) for row in rows]
async def rollback_to_block(self, height: int):
"""
Rolls back the blockchain to block_index. All blocks confirmed after this point
are removed from the LCA. All coins confirmed after this point are removed.
All coins spent after this point are set to unspent. Can be -1 (rollback all)
"""
# Delete from storage
delete_queue: List[WalletCoinRecord] = []
for coin_name, coin_record in self.coin_record_cache.items():
if coin_record.spent_block_height > height:
new_record = WalletCoinRecord(
coin_record.coin,
coin_record.confirmed_block_height,
uint32(0),
False,
coin_record.coinbase,
coin_record.wallet_type,
coin_record.wallet_id,
)
self.coin_record_cache[coin_record.coin.name()] = new_record
self.unspent_coin_wallet_cache[coin_record.wallet_id][coin_record.coin.name()] = new_record
if coin_record.confirmed_block_height > height:
delete_queue.append(coin_record)
for coin_record in delete_queue:
self.coin_record_cache.pop(coin_record.coin.name())
if coin_record.wallet_id in self.unspent_coin_wallet_cache:
coin_cache = self.unspent_coin_wallet_cache[coin_record.wallet_id]
if coin_record.coin.name() in coin_cache:
coin_cache.pop(coin_record.coin.name())
c1 = await self.db_connection.execute("DELETE FROM coin_record WHERE confirmed_height>?", (height,))
await c1.close()
c2 = await self.db_connection.execute(
"UPDATE coin_record SET spent_height = 0, spent = 0 WHERE spent_height>?",
(height,),
)
await c2.close()