22
22
"""
23
23
Python implementation of the low-level supporting code for forward simulations.
24
24
"""
25
- import collections
25
+ import itertools
26
26
import random
27
27
28
28
import numpy as np
29
29
import pytest
30
30
31
31
import tskit
32
+ from tests import simplify
32
33
33
34
34
- def simplify_with_buffer (tables , parent_buffer , samples , verbose ):
35
- # Pretend this was done efficiently internally without any sorting
36
- # by creating a simplifier object and adding the ancstry for the
37
- # new parents appropriately before flushing through the rest of the
38
- # edges.
39
- for parent , edges in parent_buffer .items ():
40
- for left , right , child in edges :
35
+ class BirthBuffer :
36
+ def __init__ (self ):
37
+ self .edges = {}
38
+ self .parents = []
39
+
40
+ def add_edge (self , left , right , parent , child ):
41
+ if parent not in self .edges :
42
+ self .parents .append (parent )
43
+ self .edges [parent ] = []
44
+ self .edges [parent ].append ((child , left , right ))
45
+
46
+ def clear (self ):
47
+ self .edges = {}
48
+ self .parents = []
49
+
50
+ def __str__ (self ):
51
+ s = ""
52
+ for parent in self .parents :
53
+ for child , left , right in self .edges [parent ]:
54
+ s += f"{ parent } \t { child } \t { left :0.3f} \t { right :0.3f} \n "
55
+ return s
56
+
57
+
58
+ def add_younger_edges_to_simplifier (simplifier , t , tables , edge_offset ):
59
+ parent_edges = []
60
+ while (
61
+ edge_offset < len (tables .edges )
62
+ and tables .nodes .time [tables .edges .parent [edge_offset ]] <= t
63
+ ):
64
+ print ("edge offset = " , edge_offset )
65
+ if len (parent_edges ) == 0 :
66
+ last_parent = tables .edges .parent [edge_offset ]
67
+ else :
68
+ last_parent = parent_edges [- 1 ].parent
69
+ if last_parent == tables .edges .parent [edge_offset ]:
70
+ parent_edges .append (tables .edges [edge_offset ])
71
+ else :
72
+ print (
73
+ "Flush " , tables .nodes .time [parent_edges [- 1 ].parent ], len (parent_edges )
74
+ )
75
+ simplifier .process_parent_edges (parent_edges )
76
+ parent_edges = []
77
+ edge_offset += 1
78
+ if len (parent_edges ) > 0 :
79
+ print ("Flush " , tables .nodes .time [parent_edges [- 1 ].parent ], len (parent_edges ))
80
+ simplifier .process_parent_edges (parent_edges )
81
+ return edge_offset
82
+
83
+
84
+ def simplify_with_births (tables , births , alive , verbose ):
85
+ total_edges = len (tables .edges )
86
+ for edges in births .edges .values ():
87
+ total_edges += len (edges )
88
+ if verbose > 0 :
89
+ print ("Simplify with births" )
90
+ # print(births)
91
+ print ("total_input edges = " , total_edges )
92
+ print ("alive = " , alive )
93
+ print ("\t table edges:" , len (tables .edges ))
94
+ print ("\t table nodes:" , len (tables .nodes ))
95
+
96
+ simplifier = simplify .Simplifier (tables .tree_sequence (), alive )
97
+ nodes_time = tables .nodes .time
98
+ # This should be almost sorted, because
99
+ parent_time = nodes_time [births .parents ]
100
+ index = np .argsort (parent_time )
101
+ print (index )
102
+ offset = 0
103
+ for parent in np .array (births .parents )[index ]:
104
+ offset = add_younger_edges_to_simplifier (
105
+ simplifier , nodes_time [parent ], tables , offset
106
+ )
107
+ edges = [
108
+ tskit .Edge (left , right , parent , child )
109
+ for child , left , right in sorted (births .edges [parent ])
110
+ ]
111
+ # print("Adding parent from time", nodes_time[parent], len(edges))
112
+ # print("edges = ", edges)
113
+ simplifier .process_parent_edges (edges )
114
+ # simplifier.print_state()
115
+
116
+ # FIXME should probably reuse the add_younger_edges_to_simplifier function
117
+ # for this - doesn't quite seem to work though
118
+ for _ , edges in itertools .groupby (tables .edges [offset :], lambda e : e .parent ):
119
+ edges = list (edges )
120
+ simplifier .process_parent_edges (edges )
121
+
122
+ simplifier .check_state ()
123
+ assert simplifier .parent_edges_processed == total_edges
124
+ # if simplifier.parent_edges_processed != total_edges:
125
+ # print("HERE!!!!", total_edges)
126
+ simplifier .finalise ()
127
+
128
+ tables .nodes .replace_with (simplifier .tables .nodes )
129
+ tables .edges .replace_with (simplifier .tables .edges )
130
+
131
+ # This is needed because we call .tree_sequence here and later.
132
+ # Can be removed is we change the Simplifier to take a set of
133
+ # tables which it modifies, like the C version.
134
+ tables .drop_index ()
135
+ # Just to check
136
+ tables .tree_sequence ()
137
+
138
+ births .clear ()
139
+ # Add back all the edges with an alive parent to the buffer, so that
140
+ # we store them contiguously
141
+ keep = np .ones (len (tables .edges ), dtype = bool )
142
+ for u in alive :
143
+ u = simplifier .node_id_map [u ]
144
+ for e in np .where (tables .edges .parent == u )[0 ]:
145
+ keep [e ] = False
146
+ edge = tables .edges [e ]
147
+ # print(edge)
148
+ births .add_edge (edge .left , edge .right , edge .parent , edge .child )
149
+
150
+ if verbose > 0 :
151
+ print ("Done" )
152
+ print (births )
153
+ print ("\t table edges:" , len (tables .edges ))
154
+ print ("\t table nodes:" , len (tables .nodes ))
155
+
156
+
157
+ def simplify_with_births_easy (tables , births , alive , verbose ):
158
+ for parent , edges in births .edges .items ():
159
+ for child , left , right in edges :
41
160
tables .edges .add_row (left , right , parent , child )
42
161
tables .sort ()
43
- tables .simplify (samples )
44
- # We've exhausted the parent buffer, so clear it out. In reality we'd
45
- # do this more carefully, like KT does in the post_simplify step.
46
- parent_buffer . clear ( )
162
+ tables .simplify (alive )
163
+ births . clear ()
164
+
165
+ # print(tables.nodes.time[tables.edges.parent] )
47
166
48
167
49
168
def wright_fisher (
@@ -52,7 +171,7 @@ def wright_fisher(
52
171
rng = random .Random (seed )
53
172
tables = tskit .TableCollection (L )
54
173
alive = [tables .nodes .add_row (time = T ) for _ in range (N )]
55
- parent_buffer = collections . defaultdict ( list )
174
+ births = BirthBuffer ( )
56
175
57
176
t = T
58
177
while t > 0 :
@@ -66,12 +185,16 @@ def wright_fisher(
66
185
a = rng .randint (0 , N - 1 )
67
186
b = rng .randint (0 , N - 1 )
68
187
x = rng .uniform (0 , L )
69
- parent_buffer [alive [a ]].append ((0 , x , u ))
70
- parent_buffer [alive [b ]].append ((x , L , u ))
188
+ # TODO Possibly more natural do this like
189
+ # births.add(u, parents=[a, b], breaks=[0, x, L])
190
+ births .add_edge (0 , x , alive [a ], u )
191
+ births .add_edge (x , L , alive [b ], u )
71
192
alive = next_alive
72
193
if t % simplify_interval == 0 or t == 0 :
73
- simplify_with_buffer (tables , parent_buffer , alive , verbose = verbose )
194
+ simplify_with_births (tables , births , alive , verbose = verbose )
195
+ # simplify_with_births_easy(tables, births, alive, verbose=verbose)
74
196
alive = list (range (N ))
197
+ # print(tables.tree_sequence())
75
198
return tables .tree_sequence ()
76
199
77
200
@@ -115,3 +238,22 @@ def test_full_simulation(self):
115
238
ts = wright_fisher (N = 5 , T = 500 , death_proba = 0.9 , simplify_interval = 1000 )
116
239
for tree in ts .trees ():
117
240
assert tree .num_roots == 1
241
+
242
+
243
+ class TestSimplifyIntervals :
244
+ @pytest .mark .parametrize ("interval" , [1 , 10 , 33 , 100 ])
245
+ def test_non_overlapping_generations (self , interval ):
246
+ N = 10
247
+ ts = wright_fisher (N , T = 100 , death_proba = 1 , simplify_interval = interval )
248
+ assert ts .num_samples == N
249
+
250
+ @pytest .mark .parametrize ("interval" , [1 , 10 , 33 , 100 ])
251
+ @pytest .mark .parametrize ("death_proba" , [0.33 , 0.5 , 0.9 ])
252
+ def test_overlapping_generations (self , interval , death_proba ):
253
+ N = 4
254
+ ts = wright_fisher (
255
+ N , T = 20 , death_proba = death_proba , simplify_interval = interval , verbose = 1
256
+ )
257
+ assert ts .num_samples == N
258
+ print ()
259
+ print (ts .draw_text ())
0 commit comments