chia-blockchain/chia/wallet/wallet_transaction_store.py

443 lines
16 KiB
Python

import time
from typing import Dict, List, Optional, Tuple
import aiosqlite
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.mempool_inclusion_status import MempoolInclusionStatus
from chia.util.db_wrapper import DBWrapper
from chia.util.errors import Err
from chia.util.ints import uint8, uint32
from chia.wallet.transaction_record import TransactionRecord
from chia.wallet.util.transaction_type import TransactionType
class WalletTransactionStore:
"""
WalletTransactionStore stores transaction history for the wallet.
"""
db_connection: aiosqlite.Connection
db_wrapper: DBWrapper
tx_record_cache: Dict[bytes32, TransactionRecord]
tx_submitted: Dict[bytes32, Tuple[int, int]] # tx_id: [time submitted: count]
unconfirmed_for_wallet: Dict[int, Dict[bytes32, TransactionRecord]]
@classmethod
async def create(cls, db_wrapper: DBWrapper):
self = cls()
self.db_wrapper = db_wrapper
self.db_connection = self.db_wrapper.db
await self.db_connection.execute("pragma journal_mode=wal")
await self.db_connection.execute("pragma synchronous=2")
await self.db_connection.execute(
(
"CREATE TABLE IF NOT EXISTS transaction_record("
" transaction_record blob,"
" bundle_id text PRIMARY KEY," # NOTE: bundle_id is being stored as bytes, not hex
" confirmed_at_height bigint,"
" created_at_time bigint,"
" to_puzzle_hash text,"
" amount blob,"
" fee_amount blob,"
" confirmed int,"
" sent int,"
" wallet_id bigint,"
" trade_id text,"
" type int)"
)
)
# Useful for reorg lookups
await self.db_connection.execute(
"CREATE INDEX IF NOT EXISTS tx_confirmed_index on transaction_record(confirmed_at_height)"
)
await self.db_connection.execute(
"CREATE INDEX IF NOT EXISTS tx_created_index on transaction_record(created_at_time)"
)
await self.db_connection.execute("CREATE INDEX IF NOT EXISTS tx_confirmed on transaction_record(confirmed)")
await self.db_connection.execute("CREATE INDEX IF NOT EXISTS tx_sent on transaction_record(sent)")
await self.db_connection.execute(
"CREATE INDEX IF NOT EXISTS tx_created_time on transaction_record(created_at_time)"
)
await self.db_connection.execute("CREATE INDEX IF NOT EXISTS tx_type on transaction_record(type)")
await self.db_connection.execute(
"CREATE INDEX IF NOT EXISTS tx_to_puzzle_hash on transaction_record(to_puzzle_hash)"
)
await self.db_connection.execute("CREATE INDEX IF NOT EXISTS wallet_id on transaction_record(wallet_id)")
await self.db_connection.commit()
self.tx_record_cache = {}
self.tx_submitted = {}
self.unconfirmed_for_wallet = {}
await self.rebuild_tx_cache()
return self
async def rebuild_tx_cache(self):
# init cache here
all_records = await self.get_all_transactions()
self.tx_record_cache = {}
self.unconfirmed_for_wallet = {}
for record in all_records:
self.tx_record_cache[record.name] = record
if record.wallet_id not in self.unconfirmed_for_wallet:
self.unconfirmed_for_wallet[record.name] = {}
if not record.confirmed:
self.unconfirmed_for_wallet[record.name] = record
async def _clear_database(self):
cursor = await self.db_connection.execute("DELETE FROM transaction_record")
await cursor.close()
await self.db_connection.commit()
async def add_transaction_record(self, record: TransactionRecord, in_transaction: bool) -> None:
"""
Store TransactionRecord in DB and Cache.
"""
self.tx_record_cache[record.name] = record
if record.wallet_id not in self.unconfirmed_for_wallet:
self.unconfirmed_for_wallet[record.wallet_id] = {}
unconfirmed_dict = self.unconfirmed_for_wallet[record.wallet_id]
if record.confirmed and record.name in unconfirmed_dict:
unconfirmed_dict.pop(record.name)
if not record.confirmed:
unconfirmed_dict[record.name] = record
if not in_transaction:
await self.db_wrapper.lock.acquire()
try:
cursor = await self.db_connection.execute(
"INSERT OR REPLACE INTO transaction_record VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
bytes(record),
record.name,
record.confirmed_at_height,
record.created_at_time,
record.to_puzzle_hash.hex(),
bytes(record.amount),
bytes(record.fee_amount),
int(record.confirmed),
record.sent,
record.wallet_id,
record.trade_id,
record.type,
),
)
await cursor.close()
if not in_transaction:
await self.db_connection.commit()
except BaseException:
if not in_transaction:
await self.rebuild_tx_cache()
raise
finally:
if not in_transaction:
self.db_wrapper.lock.release()
async def set_confirmed(self, tx_id: bytes32, height: uint32):
"""
Updates transaction to be confirmed.
"""
current: Optional[TransactionRecord] = await self.get_transaction_record(tx_id)
if current is None:
return None
tx: TransactionRecord = TransactionRecord(
confirmed_at_height=height,
created_at_time=current.created_at_time,
to_puzzle_hash=current.to_puzzle_hash,
amount=current.amount,
fee_amount=current.fee_amount,
confirmed=True,
sent=current.sent,
spend_bundle=current.spend_bundle,
additions=current.additions,
removals=current.removals,
wallet_id=current.wallet_id,
sent_to=current.sent_to,
trade_id=None,
type=current.type,
name=current.name,
)
await self.add_transaction_record(tx, True)
async def increment_sent(
self,
tx_id: bytes32,
name: str,
send_status: MempoolInclusionStatus,
err: Optional[Err],
) -> bool:
"""
Updates transaction sent count (Full Node has received spend_bundle and sent ack).
"""
current: Optional[TransactionRecord] = await self.get_transaction_record(tx_id)
if current is None:
return False
sent_to = current.sent_to.copy()
current_peers = set()
err_str = err.name if err is not None else None
append_data = (name, uint8(send_status.value), err_str)
for peer_id, status, error in sent_to:
current_peers.add(peer_id)
if name in current_peers:
sent_count = uint32(current.sent)
else:
sent_count = uint32(current.sent + 1)
sent_to.append(append_data)
tx: TransactionRecord = TransactionRecord(
confirmed_at_height=current.confirmed_at_height,
created_at_time=current.created_at_time,
to_puzzle_hash=current.to_puzzle_hash,
amount=current.amount,
fee_amount=current.fee_amount,
confirmed=current.confirmed,
sent=sent_count,
spend_bundle=current.spend_bundle,
additions=current.additions,
removals=current.removals,
wallet_id=current.wallet_id,
sent_to=sent_to,
trade_id=None,
type=current.type,
name=current.name,
)
await self.add_transaction_record(tx, False)
return True
async def tx_reorged(self, record: TransactionRecord):
"""
Updates transaction sent count to 0 and resets confirmation data
"""
tx: TransactionRecord = TransactionRecord(
confirmed_at_height=uint32(0),
created_at_time=record.created_at_time,
to_puzzle_hash=record.to_puzzle_hash,
amount=record.amount,
fee_amount=record.fee_amount,
confirmed=False,
sent=uint32(0),
spend_bundle=record.spend_bundle,
additions=record.additions,
removals=record.removals,
wallet_id=record.wallet_id,
sent_to=[],
trade_id=None,
type=record.type,
name=record.name,
)
await self.add_transaction_record(tx, True)
async def get_transaction_record(self, tx_id: bytes32) -> Optional[TransactionRecord]:
"""
Checks DB and cache for TransactionRecord with id: id and returns it.
"""
if tx_id in self.tx_record_cache:
return self.tx_record_cache[tx_id]
# NOTE: bundle_id is being stored as bytes, not hex
cursor = await self.db_connection.execute("SELECT * from transaction_record WHERE bundle_id=?", (tx_id,))
row = await cursor.fetchone()
await cursor.close()
if row is not None:
record = TransactionRecord.from_bytes(row[0])
return record
return None
async def get_not_sent(self) -> List[TransactionRecord]:
"""
Returns the list of transaction that have not been received by full node yet.
"""
current_time = int(time.time())
cursor = await self.db_connection.execute(
"SELECT * from transaction_record WHERE confirmed=?",
(0,),
)
rows = await cursor.fetchall()
await cursor.close()
records = []
for row in rows:
record = TransactionRecord.from_bytes(row[0])
if record.name in self.tx_submitted:
time_submitted, count = self.tx_submitted[record.name]
if time_submitted < current_time - (60 * 10):
records.append(record)
self.tx_submitted[record.name] = current_time, 1
else:
if count < 5:
records.append(record)
self.tx_submitted[record.name] = time_submitted, (count + 1)
else:
records.append(record)
self.tx_submitted[record.name] = current_time, 1
return records
async def get_farming_rewards(self):
"""
Returns the list of all farming rewards.
"""
fee_int = TransactionType.FEE_REWARD.value
pool_int = TransactionType.COINBASE_REWARD.value
cursor = await self.db_connection.execute(
"SELECT * from transaction_record WHERE confirmed=? and (type=? or type=?)", (1, fee_int, pool_int)
)
rows = await cursor.fetchall()
await cursor.close()
records = []
for row in rows:
record = TransactionRecord.from_bytes(row[0])
records.append(record)
return records
async def get_all_unconfirmed(self) -> List[TransactionRecord]:
"""
Returns the list of all transaction that have not yet been confirmed.
"""
cursor = await self.db_connection.execute("SELECT * from transaction_record WHERE confirmed=?", (0,))
rows = await cursor.fetchall()
await cursor.close()
records = []
for row in rows:
record = TransactionRecord.from_bytes(row[0])
records.append(record)
return records
async def get_unconfirmed_for_wallet(self, wallet_id: int) -> List[TransactionRecord]:
"""
Returns the list of transaction that have not yet been confirmed.
"""
if wallet_id in self.unconfirmed_for_wallet:
return list(self.unconfirmed_for_wallet[wallet_id].values())
else:
return []
async def get_transactions_between(self, wallet_id: int, start, end) -> List[TransactionRecord]:
"""Return a list of transaction between start and end index. List is in reverse chronological order.
start = 0 is most recent transaction
"""
limit = end - start
cursor = await self.db_connection.execute(
f"SELECT * from transaction_record where wallet_id=? and confirmed_at_height not in"
f" (select confirmed_at_height from transaction_record order by confirmed_at_height"
f" ASC LIMIT {start})"
f" order by confirmed_at_height DESC LIMIT {limit}",
(wallet_id,),
)
rows = await cursor.fetchall()
await cursor.close()
records = []
for row in rows:
record = TransactionRecord.from_bytes(row[0])
records.append(record)
records.reverse()
return records
async def get_transaction_count_for_wallet(self, wallet_id) -> int:
cursor = await self.db_connection.execute(
"SELECT COUNT(*) FROM transaction_record where wallet_id=?", (wallet_id,)
)
count_result = await cursor.fetchone()
if count_result is not None:
count = count_result[0]
else:
count = 0
await cursor.close()
return count
async def get_all_transactions_for_wallet(self, wallet_id: int, type: int = None) -> List[TransactionRecord]:
"""
Returns all stored transactions.
"""
if type is None:
cursor = await self.db_connection.execute(
"SELECT * from transaction_record where wallet_id=?", (wallet_id,)
)
else:
cursor = await self.db_connection.execute(
"SELECT * from transaction_record where wallet_id=? and type=?",
(
wallet_id,
type,
),
)
rows = await cursor.fetchall()
await cursor.close()
records = []
cache_set = set()
for row in rows:
record = TransactionRecord.from_bytes(row[0])
records.append(record)
cache_set.add(record.name)
return records
async def get_all_transactions(self) -> List[TransactionRecord]:
"""
Returns all stored transactions.
"""
cursor = await self.db_connection.execute("SELECT * from transaction_record")
rows = await cursor.fetchall()
await cursor.close()
records = []
for row in rows:
record = TransactionRecord.from_bytes(row[0])
records.append(record)
return records
async def get_transaction_above(self, height: int) -> List[TransactionRecord]:
# Can be -1 (get all tx)
cursor = await self.db_connection.execute(
"SELECT * from transaction_record WHERE confirmed_at_height>?", (height,)
)
rows = await cursor.fetchall()
await cursor.close()
records = []
for row in rows:
record = TransactionRecord.from_bytes(row[0])
records.append(record)
return records
async def rollback_to_block(self, height: int):
# Delete from storage
to_delete = []
for tx in self.tx_record_cache.values():
if tx.confirmed_at_height > height:
to_delete.append(tx)
for tx in to_delete:
self.tx_record_cache.pop(tx.name)
c1 = await self.db_connection.execute("DELETE FROM transaction_record WHERE confirmed_at_height>?", (height,))
await c1.close()