Skip to content

Commit ebe8339

Browse files
committed
implement the TxsGraph class
1 parent 8df40ad commit ebe8339

File tree

4 files changed

+159
-87
lines changed

4 files changed

+159
-87
lines changed

py/parse_simulation_data.py

+30-49
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,7 @@
77

88
from datatypes import BTC, TXID, btc_to_sat
99
from paths import LN
10-
from txs_graph import build_txs_graph, get_downstream
11-
12-
13-
def get_htlcs_claimed_by_timeout(graph: DiGraph, commitment_txid: TXID) -> List[TXID]:
14-
"""
15-
return a list of txids that claimed an HTLC output from the given
16-
commitment transaction, using timeout-claim (as opposed to success-claim)
17-
"""
18-
19-
# from the bolt:
20-
# " HTLC-Timeout and HTLC-Success Transactions... are almost identical,
21-
# except the HTLC-timeout transaction is timelocked "
22-
#
23-
# i.e. if the child_tx has a non-zero locktime, it is an HTLC-timeout
24-
# TODO: what about claiming of local/remote outputs? are they locked? check it
25-
26-
return [
27-
child_tx
28-
for _, child_tx in graph.out_edges(commitment_txid)
29-
if graph.nodes[child_tx]["tx"]["locktime"] > 0
30-
]
10+
from txs_graph.txs_graph import TxsGraph
3111

3212

3313
# ---------- txid-to-label functions ----------
@@ -110,19 +90,15 @@ def extract_bitcoin_funding_txids(simulation_outfile: str) -> Set[TXID]:
11090
return txids
11191

11292

113-
def get_all_direct_children(txid, graph: DiGraph) -> Set[TXID]:
114-
return {txid for _, txid in graph.out_edges(txid)}
115-
116-
11793
def flatten(s: Iterable[Iterable[Any]]) -> List[Any]:
11894
return list(itertools.chain.from_iterable(s))
11995

12096

121-
def find_commitments(simulation_outfile: str, graph: DiGraph) -> List[TXID]:
97+
def find_commitments(simulation_outfile: str, graph: TxsGraph) -> List[TXID]:
12298
bitcoin_fundings = extract_bitcoin_funding_txids(simulation_outfile=simulation_outfile)
12399

124100
ln_channel_fundings = flatten(
125-
get_all_direct_children(txid, graph=graph)
101+
graph.get_all_direct_children(txid)
126102
for txid in bitcoin_fundings
127103
)
128104

@@ -132,7 +108,7 @@ def find_commitments(simulation_outfile: str, graph: DiGraph) -> List[TXID]:
132108
list(filter(
133109
# only keep those with the expected balance
134110
lambda child_txid: graph.edges[(channel_funding_txid, child_txid)]["value"] == LN_CHANNEL_BALANCE,
135-
get_all_direct_children(txid=channel_funding_txid, graph=graph)
111+
graph.get_all_direct_children(txid=channel_funding_txid)
136112
))
137113
for channel_funding_txid in ln_channel_fundings
138114
)
@@ -151,24 +127,28 @@ def find_commitments(simulation_outfile: str, graph: DiGraph) -> List[TXID]:
151127
f"Failed to find commitments. "
152128
f"txid {commitment_txid[-4:]} expected to have exactly 1 input, but has {num_inputs}"
153129
)
154-
130+
155131
return commitments
156132

157133

158-
def is_replaceable_by_fee(txid, graph: DiGraph) -> bool:
159-
for input_dict in graph.nodes[txid]["tx"]["vin"]:
160-
if input_dict['sequence'] < (0xffffffff - 1):
161-
return True
162-
return False
163-
164-
165-
def print_nsequence(txids: Iterable[TXID], graph: DiGraph):
166-
for txid in txids:
167-
print(f"txid {txid_to_short_txid(txid)}:")
168-
for i, input_dict in enumerate(graph.nodes[txid]["tx"]["vin"]):
169-
sequence: int = input_dict['sequence']
170-
sequence_str = format(sequence, 'x')
171-
print(f"input {i}: sequence={sequence_str}")
134+
def get_htlcs_claimed_by_timeout(graph: TxsGraph, commitment_txid: TXID) -> List[TXID]:
135+
"""
136+
return a list of txids that claimed an HTLC output from the given
137+
commitment transaction, using timeout-claim (as opposed to success-claim)
138+
"""
139+
140+
# from the bolt:
141+
# " HTLC-Timeout and HTLC-Success Transactions... are almost identical,
142+
# except the HTLC-timeout transaction is timelocked "
143+
#
144+
# i.e. if the child_tx has a non-zero locktime, it is an HTLC-timeout
145+
# TODO: what about claiming of local/remote outputs? are they locked? check it
146+
147+
return [
148+
child_tx
149+
for _, child_tx in graph.out_edges(commitment_txid)
150+
if graph.nodes[child_tx]["tx"]["locktime"] > 0
151+
]
172152

173153

174154
def main(simulation_name):
@@ -177,26 +157,27 @@ def main(simulation_name):
177157
outfile = os.path.join(LN, "simulations", f"{simulation_name}.out")
178158
dotfile = os.path.join(LN, f"{simulation_name}.dot")
179159
jpgfile = os.path.join(LN, f"{simulation_name}.jpg")
180-
181-
txs_graph = build_txs_graph(datadir)
160+
161+
txs_graph = TxsGraph.from_datadir(datadir)
182162

183163
commitments = find_commitments(simulation_outfile=outfile, graph=txs_graph)
184164

185165
for commitment_txid in commitments:
186166
short_txid = txid_to_short_txid(commitment_txid)
187167
num_outputs = len(txs_graph.nodes[commitment_txid]["tx"]["vout"])
188-
replaceable = is_replaceable_by_fee(txid=commitment_txid, graph=txs_graph)
189-
num_htlcs_stolen = len(get_htlcs_claimed_by_timeout(commitment_txid=commitment_txid, graph=txs_graph))
168+
replaceable = txs_graph.is_replaceable_by_fee(txid=commitment_txid)
169+
nsequence = txs_graph.get_minimal_nsequence(commitment_txid)
170+
num_htlcs_stolen = len(get_htlcs_claimed_by_timeout(graph=txs_graph, commitment_txid=commitment_txid))
190171
print(
191172
f"commitment: {short_txid:<5} "
192173
f"num-outputs: {num_outputs:<4} "
193174
f"replaceable: {str(replaceable):<5} "
175+
f"nsequence: {nsequence:x} "
194176
f"htlcs-stolen: {num_htlcs_stolen}"
195177
)
196178

197179
# this graph includes only interesting txs (no coinbase and other junk)
198-
restricted_txs_graph = get_downstream(
199-
graph=txs_graph,
180+
restricted_txs_graph = txs_graph.get_downstream(
200181
sources=extract_bitcoin_funding_txids(simulation_outfile=outfile),
201182
)
202183

py/tests/txs_graph_test.py

+36-21
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,69 @@
11
import unittest
22

3-
from networkx import DiGraph
4-
5-
from txs_graph import get_downstream
3+
from txs_graph.txs_graph import TxsGraph
64

75

86
class TxsGraphTest(unittest.TestCase):
9-
def test_downstream_1(self):
10-
graph = DiGraph()
7+
8+
def get_test_graph_1(self) -> TxsGraph:
9+
graph = TxsGraph()
1110
graph.add_edge("a", "b")
1211
graph.add_edge("b", "c")
1312
graph.add_edge("c", "e")
1413
graph.add_edge("e", "d")
1514
graph.add_edge("d", "b")
1615
graph.add_edge("e", "f")
17-
18-
downstream = get_downstream(graph, sources={"a"})
16+
return graph
17+
18+
def test_downstream_1_1(self):
19+
graph = self.get_test_graph_1()
20+
downstream = graph.get_downstream(sources={"a"})
1921
downstream_nodes = set(downstream.nodes)
2022
self.assertSetEqual(downstream_nodes, {"a", "b", "c", "d", "e", "f"})
21-
22-
downstream = get_downstream(graph, sources={"b", "e"})
23+
24+
def test_downstream_1_2(self):
25+
graph = self.get_test_graph_1()
26+
downstream = graph.get_downstream(sources={"b", "e"})
2327
downstream_nodes = set(downstream.nodes)
2428
self.assertSetEqual(downstream_nodes, {"b", "c", "d", "e", "f"})
25-
26-
downstream = get_downstream(graph, sources={"e"})
29+
30+
def test_downstream_1_3(self):
31+
graph = self.get_test_graph_1()
32+
downstream = graph.get_downstream(sources={"e"})
2733
downstream_nodes = set(downstream.nodes)
2834
self.assertSetEqual(downstream_nodes, {"b", "c", "d", "e", "f"})
29-
30-
downstream = get_downstream(graph, sources={"f"})
35+
36+
def test_downstream_1_4(self):
37+
graph = self.get_test_graph_1()
38+
downstream = graph.get_downstream(sources={"f"})
3139
downstream_nodes = set(downstream.nodes)
3240
self.assertSetEqual(downstream_nodes, {"f"})
3341

34-
def test_downstream_2(self):
35-
graph = DiGraph()
42+
def get_test_graph_2(self) -> TxsGraph:
43+
graph = TxsGraph()
3644
graph.add_edge(1, 3)
3745
graph.add_edge(1, 4)
3846
graph.add_edge(2, 4)
3947
graph.add_edge(2, 5)
4048
graph.add_edge(4, 6)
4149
graph.add_edge(5, 6)
42-
43-
downstream = get_downstream(graph, sources={1, 2})
50+
return graph
51+
52+
def test_downstream_2_1(self):
53+
graph = self.get_test_graph_2()
54+
downstream = graph.get_downstream(sources={1, 2})
4455
downstream_nodes = set(downstream.nodes)
4556
self.assertSetEqual(downstream_nodes, {1, 2, 3, 4, 5, 6})
46-
47-
downstream = get_downstream(graph, sources={2})
57+
58+
def test_downstream_2_2(self):
59+
graph = self.get_test_graph_2()
60+
downstream = graph.get_downstream(sources={2})
4861
downstream_nodes = set(downstream.nodes)
4962
self.assertSetEqual(downstream_nodes, {2, 4, 5, 6})
50-
51-
downstream = get_downstream(graph, sources={3, 4})
63+
64+
def test_downstream_2_3(self):
65+
graph = self.get_test_graph_2()
66+
downstream = graph.get_downstream(sources={3, 4})
5267
downstream_nodes = set(downstream.nodes)
5368
self.assertSetEqual(downstream_nodes, {3, 4, 6})
5469

py/txs_graph/txs_graph.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from typing import Any, Iterable, List
2+
3+
from networkx.algorithms.traversal.breadth_first_search import bfs_edges
4+
from networkx.classes.digraph import DiGraph
5+
6+
from datatypes import TXID
7+
from txs_graph.txs_graph_utils import find_tx_fee, load_blocks, load_txs
8+
9+
10+
class TxsGraph(DiGraph):
11+
12+
@staticmethod
13+
def from_datadir(datadir: str) -> "TxsGraph":
14+
"""
15+
read all block and transaction files in the given datadir and construct
16+
a full transaction graph.
17+
18+
Each node represents a transaction. the node's id is the txid and it has
19+
the following attributes:
20+
- "tx" - the full tx json, as returned by bitcoind
21+
- "fee" - the tx fee
22+
- "height" - the block height in which the tx was included
23+
24+
Each edge has the following attributes:
25+
- "value" the value in BTC of the output represented by this edge
26+
27+
"""
28+
blocks = load_blocks(datadir)
29+
txs = load_txs(datadir)
30+
31+
txid_to_fee = {txid: find_tx_fee(txid, txs) for txid in txs.keys()}
32+
txid_to_height = {
33+
txid: block["height"]
34+
for block in blocks.values()
35+
for txid in block["tx"]
36+
}
37+
38+
graph = TxsGraph()
39+
40+
# add all transactions
41+
for txid in txs.keys():
42+
graph.add_node(txid, tx=txs[txid], fee=txid_to_fee[txid], height=txid_to_height[txid])
43+
44+
# add edges between transactions
45+
for dest_txid, dest_tx in txs.items():
46+
for entry in dest_tx["vin"]:
47+
if "coinbase" in entry:
48+
continue # coinbase transaction. no src
49+
src_txid = entry["txid"]
50+
index = entry["vout"]
51+
value = txs[src_txid]["vout"][index]["value"]
52+
graph.add_edge(src_txid, dest_txid, value=value)
53+
54+
return graph
55+
56+
def get_all_direct_children(self, txid: TXID) -> List[TXID]:
57+
return [txid for _, txid in self.out_edges(txid)]
58+
59+
def get_minimal_nsequence(self, txid: TXID) -> int:
60+
"""
61+
return the minimal nsequence of an input in the given txid
62+
"""
63+
return min(map(
64+
lambda input_dict: input_dict['sequence'],
65+
self.nodes[txid]["tx"]["vin"]
66+
))
67+
68+
def is_replaceable_by_fee(self, txid: TXID) -> bool:
69+
return self.get_minimal_nsequence(txid) < (0xffffffff - 1)
70+
71+
def get_downstream(self, sources: Iterable[Any]) -> "TxsGraph":
72+
"""
73+
return the downstream of sources in the given graph.
74+
sources must be an iterable of existing node ids in the graph
75+
"""
76+
sources = list(sources) # sources may be iterable only once (e.g. map), so we copy
77+
78+
downstream = TxsGraph()
79+
downstream.add_nodes_from(
80+
map(lambda src: (src, self.nodes[src]), sources)
81+
)
82+
for src in sources:
83+
for v, u in bfs_edges(self, source=src):
84+
downstream.add_node(v, **self.nodes[v])
85+
downstream.add_node(u, **self.nodes[u])
86+
downstream.add_edge(v, u, **self.edges[(v, u)])
87+
88+
return downstream

py/txs_graph.py py/txs_graph/txs_graph_utils.py

+5-17
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import json
22
import os
3-
from typing import Any, Dict, Iterable
3+
from typing import Dict
44

5-
from networkx.algorithms.traversal.breadth_first_search import bfs_edges
65
from networkx.classes.digraph import DiGraph
76

87
from datatypes import BTC, Block, BlockHash, TX, TXID
98

9+
"""
10+
A collection of helper functions to build a complete TxsGraph
11+
"""
12+
1013

1114
def load_blocks(datadir: str) -> Dict[BlockHash, Block]:
1215
blocks = {}
@@ -97,18 +100,3 @@ def build_txs_graph(datadir: str) -> DiGraph:
97100
graph.add_edge(src_txid, dest_txid, value=value)
98101

99102
return graph
100-
101-
102-
def get_downstream(graph: DiGraph, sources: Iterable[Any]) -> DiGraph:
103-
"""
104-
return the downstream of sources in the given graph.
105-
sources must be an iterable of existing node ids in the graph
106-
"""
107-
sources = list(sources) # sources may be iterable only once (e.g. map), so we copy
108-
nodes_to_include = set(sources)
109-
110-
# for each source, compute the set of its ancestors using bfs
111-
nodes_to_include.update(
112-
*({u for v, u in bfs_edges(graph, source=src)} for src in sources)
113-
)
114-
return graph.subgraph(nodes=nodes_to_include)

0 commit comments

Comments
 (0)