Skip to content

Commit a3ed3e5

Browse files
Add almost-working version with Python simplifier
1 parent 8223573 commit a3ed3e5

File tree

2 files changed

+171
-23
lines changed

2 files changed

+171
-23
lines changed

Diff for: python/tests/simplify.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# MIT License
22
#
3-
# Copyright (c) 2019-2022 Tskit Developers
3+
# Copyright (c) 2019-2023 Tskit Developers
44
# Copyright (c) 2015-2018 University of Oxford
55
#
66
# Permission is hereby granted, free of charge, to any person obtaining a copy
@@ -114,6 +114,8 @@ def __init__(
114114
filter_nodes=True,
115115
update_sample_flags=True,
116116
):
117+
# DELETE ME
118+
self.parent_edges_processed = 0
117119
self.ts = ts
118120
self.n = len(sample)
119121
self.reduce_to_site_topology = reduce_to_site_topology
@@ -397,6 +399,7 @@ def process_parent_edges(self, edges):
397399
"""
398400
Process all of the edges for a given parent.
399401
"""
402+
self.parent_edges_processed += len(edges)
400403
assert len({e.parent for e in edges}) == 1
401404
parent = edges[0].parent
402405
S = []
@@ -535,6 +538,14 @@ def insert_input_roots(self):
535538
offset += 1
536539
self.sort_offset = offset
537540

541+
def finalise(self):
542+
if self.keep_input_roots:
543+
self.insert_input_roots()
544+
self.finalise_sites()
545+
self.finalise_references()
546+
if self.sort_offset != -1:
547+
self.tables.sort(edge_start=self.sort_offset)
548+
538549
def simplify(self):
539550
if self.ts.num_edges > 0:
540551
all_edges = list(self.ts.edges())
@@ -545,12 +556,7 @@ def simplify(self):
545556
edges = []
546557
edges.append(e)
547558
self.process_parent_edges(edges)
548-
if self.keep_input_roots:
549-
self.insert_input_roots()
550-
self.finalise_sites()
551-
self.finalise_references()
552-
if self.sort_offset != -1:
553-
self.tables.sort(edge_start=self.sort_offset)
559+
self.finalise()
554560
ts = self.tables.tree_sequence()
555561
return ts, self.node_id_map
556562

Diff for: python/tests/test_forward_sims.py

+158-16
Original file line numberDiff line numberDiff line change
@@ -22,28 +22,147 @@
2222
"""
2323
Python implementation of the low-level supporting code for forward simulations.
2424
"""
25-
import collections
25+
import itertools
2626
import random
2727

2828
import numpy as np
2929
import pytest
3030

3131
import tskit
32+
from tests import simplify
3233

3334

34-
def simplify_with_buffer(tables, parent_buffer, samples, verbose):
35-
# Pretend this was done efficiently internally without any sorting
36-
# by creating a simplifier object and adding the ancstry for the
37-
# new parents appropriately before flushing through the rest of the
38-
# edges.
39-
for parent, edges in parent_buffer.items():
40-
for left, right, child in edges:
35+
class BirthBuffer:
36+
def __init__(self):
37+
self.edges = {}
38+
self.parents = []
39+
40+
def add_edge(self, left, right, parent, child):
41+
if parent not in self.edges:
42+
self.parents.append(parent)
43+
self.edges[parent] = []
44+
self.edges[parent].append((child, left, right))
45+
46+
def clear(self):
47+
self.edges = {}
48+
self.parents = []
49+
50+
def __str__(self):
51+
s = ""
52+
for parent in self.parents:
53+
for child, left, right in self.edges[parent]:
54+
s += f"{parent}\t{child}\t{left:0.3f}\t{right:0.3f}\n"
55+
return s
56+
57+
58+
def add_younger_edges_to_simplifier(simplifier, t, tables, edge_offset):
59+
parent_edges = []
60+
while (
61+
edge_offset < len(tables.edges)
62+
and tables.nodes.time[tables.edges.parent[edge_offset]] <= t
63+
):
64+
print("edge offset = ", edge_offset)
65+
if len(parent_edges) == 0:
66+
last_parent = tables.edges.parent[edge_offset]
67+
else:
68+
last_parent = parent_edges[-1].parent
69+
if last_parent == tables.edges.parent[edge_offset]:
70+
parent_edges.append(tables.edges[edge_offset])
71+
else:
72+
print(
73+
"Flush ", tables.nodes.time[parent_edges[-1].parent], len(parent_edges)
74+
)
75+
simplifier.process_parent_edges(parent_edges)
76+
parent_edges = []
77+
edge_offset += 1
78+
if len(parent_edges) > 0:
79+
print("Flush ", tables.nodes.time[parent_edges[-1].parent], len(parent_edges))
80+
simplifier.process_parent_edges(parent_edges)
81+
return edge_offset
82+
83+
84+
def simplify_with_births(tables, births, alive, verbose):
85+
total_edges = len(tables.edges)
86+
for edges in births.edges.values():
87+
total_edges += len(edges)
88+
if verbose > 0:
89+
print("Simplify with births")
90+
# print(births)
91+
print("total_input edges = ", total_edges)
92+
print("alive = ", alive)
93+
print("\ttable edges:", len(tables.edges))
94+
print("\ttable nodes:", len(tables.nodes))
95+
96+
simplifier = simplify.Simplifier(tables.tree_sequence(), alive)
97+
nodes_time = tables.nodes.time
98+
# This should be almost sorted, because
99+
parent_time = nodes_time[births.parents]
100+
index = np.argsort(parent_time)
101+
print(index)
102+
offset = 0
103+
for parent in np.array(births.parents)[index]:
104+
offset = add_younger_edges_to_simplifier(
105+
simplifier, nodes_time[parent], tables, offset
106+
)
107+
edges = [
108+
tskit.Edge(left, right, parent, child)
109+
for child, left, right in sorted(births.edges[parent])
110+
]
111+
# print("Adding parent from time", nodes_time[parent], len(edges))
112+
# print("edges = ", edges)
113+
simplifier.process_parent_edges(edges)
114+
# simplifier.print_state()
115+
116+
# FIXME should probably reuse the add_younger_edges_to_simplifier function
117+
# for this - doesn't quite seem to work though
118+
for _, edges in itertools.groupby(tables.edges[offset:], lambda e: e.parent):
119+
edges = list(edges)
120+
simplifier.process_parent_edges(edges)
121+
122+
simplifier.check_state()
123+
assert simplifier.parent_edges_processed == total_edges
124+
# if simplifier.parent_edges_processed != total_edges:
125+
# print("HERE!!!!", total_edges)
126+
simplifier.finalise()
127+
128+
tables.nodes.replace_with(simplifier.tables.nodes)
129+
tables.edges.replace_with(simplifier.tables.edges)
130+
131+
# This is needed because we call .tree_sequence here and later.
132+
# Can be removed is we change the Simplifier to take a set of
133+
# tables which it modifies, like the C version.
134+
tables.drop_index()
135+
# Just to check
136+
tables.tree_sequence()
137+
138+
births.clear()
139+
# Add back all the edges with an alive parent to the buffer, so that
140+
# we store them contiguously
141+
keep = np.ones(len(tables.edges), dtype=bool)
142+
for u in alive:
143+
u = simplifier.node_id_map[u]
144+
for e in np.where(tables.edges.parent == u)[0]:
145+
keep[e] = False
146+
edge = tables.edges[e]
147+
# print(edge)
148+
births.add_edge(edge.left, edge.right, edge.parent, edge.child)
149+
150+
if verbose > 0:
151+
print("Done")
152+
print(births)
153+
print("\ttable edges:", len(tables.edges))
154+
print("\ttable nodes:", len(tables.nodes))
155+
156+
157+
def simplify_with_births_easy(tables, births, alive, verbose):
158+
for parent, edges in births.edges.items():
159+
for child, left, right in edges:
41160
tables.edges.add_row(left, right, parent, child)
42161
tables.sort()
43-
tables.simplify(samples)
44-
# We've exhausted the parent buffer, so clear it out. In reality we'd
45-
# do this more carefully, like KT does in the post_simplify step.
46-
parent_buffer.clear()
162+
tables.simplify(alive)
163+
births.clear()
164+
165+
# print(tables.nodes.time[tables.edges.parent])
47166

48167

49168
def wright_fisher(
@@ -52,7 +171,7 @@ def wright_fisher(
52171
rng = random.Random(seed)
53172
tables = tskit.TableCollection(L)
54173
alive = [tables.nodes.add_row(time=T) for _ in range(N)]
55-
parent_buffer = collections.defaultdict(list)
174+
births = BirthBuffer()
56175

57176
t = T
58177
while t > 0:
@@ -66,12 +185,16 @@ def wright_fisher(
66185
a = rng.randint(0, N - 1)
67186
b = rng.randint(0, N - 1)
68187
x = rng.uniform(0, L)
69-
parent_buffer[alive[a]].append((0, x, u))
70-
parent_buffer[alive[b]].append((x, L, u))
188+
# TODO Possibly more natural do this like
189+
# births.add(u, parents=[a, b], breaks=[0, x, L])
190+
births.add_edge(0, x, alive[a], u)
191+
births.add_edge(x, L, alive[b], u)
71192
alive = next_alive
72193
if t % simplify_interval == 0 or t == 0:
73-
simplify_with_buffer(tables, parent_buffer, alive, verbose=verbose)
194+
simplify_with_births(tables, births, alive, verbose=verbose)
195+
# simplify_with_births_easy(tables, births, alive, verbose=verbose)
74196
alive = list(range(N))
197+
# print(tables.tree_sequence())
75198
return tables.tree_sequence()
76199

77200

@@ -115,3 +238,22 @@ def test_full_simulation(self):
115238
ts = wright_fisher(N=5, T=500, death_proba=0.9, simplify_interval=1000)
116239
for tree in ts.trees():
117240
assert tree.num_roots == 1
241+
242+
243+
class TestSimplifyIntervals:
244+
@pytest.mark.parametrize("interval", [1, 10, 33, 100])
245+
def test_non_overlapping_generations(self, interval):
246+
N = 10
247+
ts = wright_fisher(N, T=100, death_proba=1, simplify_interval=interval)
248+
assert ts.num_samples == N
249+
250+
@pytest.mark.parametrize("interval", [1, 10, 33, 100])
251+
@pytest.mark.parametrize("death_proba", [0.33, 0.5, 0.9])
252+
def test_overlapping_generations(self, interval, death_proba):
253+
N = 4
254+
ts = wright_fisher(
255+
N, T=20, death_proba=death_proba, simplify_interval=interval, verbose=1
256+
)
257+
assert ts.num_samples == N
258+
print()
259+
print(ts.draw_text())

0 commit comments

Comments
 (0)