From f5c1b09c7fdb8619d2ac52edd1391b6faf23a3aa Mon Sep 17 00:00:00 2001 From: Gertjan Bisschop Date: Mon, 18 Dec 2023 16:47:30 +0000 Subject: [PATCH 1/4] initial hudson --- algorithms.py | 57 +++++++++++++++++++++++++++++++++------- tests/test_algorithms.py | 26 ++++++++++++++++++ 2 files changed, 73 insertions(+), 10 deletions(-) diff --git a/algorithms.py b/algorithms.py index 55fa82c66..7ae416e46 100644 --- a/algorithms.py +++ b/algorithms.py @@ -549,6 +549,7 @@ def __init__( gene_conversion_rate=0.0, gene_conversion_length=1, discrete_genome=True, + stop_condition=None, ): # Must be a square matrix. N = len(migration_matrix) @@ -640,6 +641,7 @@ def __init__( for time in census_times: self.modifier_events.append((time[0], self.census_event, time)) self.modifier_events.sort() + self.stop_condition = stop_condition def initialise(self, ts): root_time = np.max(self.tables.nodes.time) @@ -684,12 +686,32 @@ def initialise(self, ts): self.set_segment_mass(seg) seg = seg.next + def get_num_ancestors(self): + return sum(pop.get_num_ancestors() for pop in self.P) + def ancestors_remain(self): """ Returns True if the simulation is not finished, i.e., there is some ancestral material that has not fully coalesced. """ - return sum(pop.get_num_ancestors() for pop in self.P) != 0 + return self.get_num_ancestors() != 0 + + def assert_stop_condition(self): + """ + Returns true if the simulation is not finished given the global + stopping condition that was specified. + """ + if self.stop_condition is None: + return self.ancestors_remain() + elif self.stop_condition == "grand_mrca": + return self.get_num_ancestors() > 1 + elif self.stop_condition == "all_local_mrcas": + return any(num_anc > 1 for num_anc in self.S.values()) + elif self.stop_condition == "time": + return self.get_num_ancestors() > 1 + else: + print("Error: unknown stop condition-", self.stop_condition) + raise ValueError def change_population_size(self, pop_id, size): self.P[pop_id].set_start_size(size) @@ -835,16 +857,20 @@ def finalise(self): def simulate(self, end_time): self.verify() if self.model == "hudson": - self.hudson_simulate(end_time) + ret = self.hudson_simulate(end_time) elif self.model == "dtwf": - self.dtwf_simulate() + ret = self.dtwf_simulate() elif self.model == "fixed_pedigree": - self.pedigree_simulate() + ret = self.pedigree_simulate() elif self.model == "single_sweep": - self.single_sweep_simulate() + ret = self.single_sweep_simulate() else: print("Error: bad model specification -", self.model) raise ValueError + + if ret == 2: # _msprime.EXIT_MAX_TIME: + self.t = end_time + print(end_time) return self.finalise() def get_potential_destinations(self): @@ -897,15 +923,14 @@ def hudson_simulate(self, end_time): """ Simulates the algorithm until all loci have coalesced. """ + ret = 0 infinity = sys.float_info.max non_empty_pops = {pop.id for pop in self.P if pop.get_num_ancestors() > 0} potential_destinations = self.get_potential_destinations() # only worried about label 0 below - while len(non_empty_pops) > 0: + while self.assert_stop_condition(): self.verify() - if self.t >= end_time: - break # self.print_state() re_rate = self.get_total_recombination_rate(label=0) t_re = infinity @@ -948,6 +973,9 @@ def hudson_simulate(self, end_time): mig_dest = k min_time = min(t_re, t_ca, t_gcin, t_gc_left, t_mig) assert min_time != infinity + if self.t + min_time > end_time: + ret = 2 # _msprime.MAX_EVENT_TIME + break if self.t + min_time > self.modifier_events[0][0]: t, func, args = self.modifier_events.pop(0) self.t = t @@ -992,6 +1020,8 @@ def hudson_simulate(self, end_time): X = {pop.id for pop in self.P if pop.get_num_ancestors() > 0} assert non_empty_pops == X + return ret + def single_sweep_simulate(self): """ Does a structed coalescent until end_freq is reached, using @@ -1956,12 +1986,15 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1): j = self.S.floor_key(r_max) self.S[r_max] = self.S[j] # Update the number of extant segments. - if self.S[left] == 2: + if self.S[left] == 2 and self.stop_condition is None: self.S[left] = 0 right = self.S.succ_key(left) else: + stop_at = 2 + if self.stop_condition is not None: + stop_at = 1 right = left - while right < r_max and self.S[right] != 2: + while right < r_max and self.S[right] != stop_at: self.S[right] -= 1 right = self.S.succ_key(right) alpha = self.alloc_segment( @@ -2275,6 +2308,7 @@ def run_simulate(args): gene_conversion_rate=gc_rate, gene_conversion_length=mean_tract_length, discrete_genome=args.discrete, + stop_condition=args.stop_condition, ) ts = s.simulate(args.end_time) ts.dump(args.output_file) @@ -2373,6 +2407,9 @@ def add_simulator_arguments(parser): parser.add_argument( "--end-time", type=float, default=np.inf, help="The end for simulations." ) + parser.add_argument( + "--stop-condition", type=str, default=None, help="Global stopping condition" + ) def main(args=None): diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index ce48c4778..4c638ddea 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -431,3 +431,29 @@ def test_one_gen_pedigree(self, num_founders): tables.dump(ts_path) ts = self.run_script(f"0 --from-ts {ts_path} -r 1 --model=fixed_pedigree") assert len(ts.dump_tables().edges) == 0 + + def test_stopping_condition_time(self): + end_time = 2.5 + ts = self.run_script(f"10 --stop-condition=time --end-time={end_time}") + assert ts.max_root_time == end_time + assert ts.num_samples == 10 + assert ts.num_trees > 1 + assert not has_discrete_genome(ts) + assert ts.sequence_length == 100 + + def test_stopping_condition_all_mrcas(self): + ts = self.run_script("10 --stop-condition=all_local_mrcas") + roots = [tree.root for tree in ts.trees()] + assert len(set(roots)) > 1 + roots_time = [tree.time(tree.root) for tree in ts.trees()] + assert len(set(roots_time)) == 1 + assert ts.num_samples == 10 + assert ts.num_trees > 1 + assert not has_discrete_genome(ts) + assert ts.sequence_length == 100 + + def test_stopping_condition_grand_mrca(self): + ts = self.run_script("10 --stop-condition=grand_mrca") + assert ts.num_trees > 1 + roots = [tree.root for tree in ts.trees()] + assert len(set(roots)) == 1 From 259496457d67b442e2f9d9ad03f243f923d5a7b5 Mon Sep 17 00:00:00 2001 From: Gertjan Bisschop Date: Tue, 19 Dec 2023 11:53:31 +0000 Subject: [PATCH 2/4] update algo --- algorithms.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/algorithms.py b/algorithms.py index 7ae416e46..6e6fdb55c 100644 --- a/algorithms.py +++ b/algorithms.py @@ -1840,13 +1840,19 @@ def merge_ancestors(self, H, pop_id, label, new_node_id=-1): j = self.S.floor_key(r_max) self.S[r_max] = self.S[j] # Update the number of extant segments. - if self.S[left] == len(X): + if self.S[left] == len(X) and self.stop_condition is None: self.S[left] = 0 right = self.S.succ_key(left) else: right = left - while right < r_max and self.S[right] != len(X): - self.S[right] -= len(X) - 1 + while right < r_max: + if self.S[right] <= len(X): + if self.stop_condition is None: + break + else: + self.S[right] = 1 + else: + self.S[right] -= len(X) - 1 right = self.S.succ_key(right) alpha = self.alloc_segment(left, right, new_node_id, pop_id) # Update the heaps and make the record. @@ -1990,12 +1996,15 @@ def merge_two_ancestors(self, population_index, label, x, y, u=-1): self.S[left] = 0 right = self.S.succ_key(left) else: - stop_at = 2 - if self.stop_condition is not None: - stop_at = 1 right = left - while right < r_max and self.S[right] != stop_at: - self.S[right] -= 1 + while right < r_max: + if self.S[right] <= 2: + if self.stop_condition is None: + break + else: + self.S[right] = 1 + else: + self.S[right] -= 1 right = self.S.succ_key(right) alpha = self.alloc_segment( left=left, From 71407e3860bb15841e8004d45ca39b74123a23cc Mon Sep 17 00:00:00 2001 From: Gertjan Bisschop Date: Tue, 19 Dec 2023 12:42:40 +0000 Subject: [PATCH 3/4] pedigree --- algorithms.py | 32 ++++++++++++++++++++++++-------- tests/test_algorithms.py | 16 ++++++++++++++++ 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/algorithms.py b/algorithms.py index 6e6fdb55c..e08ef337f 100644 --- a/algorithms.py +++ b/algorithms.py @@ -709,6 +709,8 @@ def assert_stop_condition(self): return any(num_anc > 1 for num_anc in self.S.values()) elif self.stop_condition == "time": return self.get_num_ancestors() > 1 + elif self.stop_condition == "pedigree": + return True else: print("Error: unknown stop condition-", self.stop_condition) raise ValueError @@ -859,9 +861,9 @@ def simulate(self, end_time): if self.model == "hudson": ret = self.hudson_simulate(end_time) elif self.model == "dtwf": - ret = self.dtwf_simulate() + ret = self.dtwf_simulate(end_time) elif self.model == "fixed_pedigree": - ret = self.pedigree_simulate() + ret = self.pedigree_simulate(end_time) elif self.model == "single_sweep": ret = self.single_sweep_simulate() else: @@ -870,7 +872,6 @@ def simulate(self, end_time): if ret == 2: # _msprime.EXIT_MAX_TIME: self.t = end_time - print(end_time) return self.finalise() def get_potential_destinations(self): @@ -1129,21 +1130,27 @@ def single_sweep_simulate(self): self.set_labels(u, 0) self.P[0].add(tmp) - def pedigree_simulate(self): + def pedigree_simulate(self, end_time): """ Simulates through the provided pedigree, stopping at the top. """ self.pedigree = Pedigree(self.tables) - self.dtwf_climb_pedigree() + ret = self.dtwf_climb_pedigree(end_time) + return ret - def dtwf_simulate(self): + def dtwf_simulate(self, end_time): """ Simulates the algorithm until all loci have coalesced. """ + ret = 0 while self.ancestors_remain(): + if self.t + 1 > end_time: + ret = 2 # _msprime.EXIT_MAX_TIME + break self.t += 1 self.verify() self.dtwf_generation() + return ret def dtwf_generation(self): """ @@ -1281,13 +1288,14 @@ def process_pedigree_common_ancestors(self, ind, ploid): self.flush_edges() self.verify() - def dtwf_climb_pedigree(self): + def dtwf_climb_pedigree(self, end_time): """ Simulates transmission of ancestral material through a pre-specified pedigree """ assert self.num_populations == 1 # Single pop/pedigree for now pop = self.P[0] + ret = 0 # Go through the extant lineages and gather the ancestral material # into the corresponding pedigree individuals. @@ -1300,9 +1308,13 @@ def dtwf_climb_pedigree(self): # Visit pedigree individuals in time order. visit_order = sorted(self.pedigree.individuals, key=lambda x: (x.time, x.id)) for ind in visit_order: + if ind.time > end_time: + ret = 2 # _msprime.EXIT_MAX_TIME + break self.t = ind.time for ploid in range(ind.ploidy): self.process_pedigree_common_ancestors(ind, ploid) + return ret def store_arg_edges(self, segment, u=-1): if u == -1: @@ -2295,6 +2307,10 @@ def run_simulate(args): else: from_ts = tskit.load(args.from_ts) tables = from_ts.dump_tables() + if args.stop_condition == "full_pedigree": + end_time = np.max(from_ts.nodes_time) + else: + end_time = args.end_time s = Simulator( tables=tables, @@ -2319,7 +2335,7 @@ def run_simulate(args): discrete_genome=args.discrete, stop_condition=args.stop_condition, ) - ts = s.simulate(args.end_time) + ts = s.simulate(end_time) ts.dump(args.output_file) if args.verbose: s.print_state() diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 4c638ddea..624c583b9 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -457,3 +457,19 @@ def test_stopping_condition_grand_mrca(self): assert ts.num_trees > 1 roots = [tree.root for tree in ts.trees()] assert len(set(roots)) == 1 + + def test_stopping_condition_pedigree(self): + num_founders = 4 + num_generations = 10 + tables = simulate_pedigree( + num_founders=num_founders, num_generations=num_generations + ) + with tempfile.TemporaryDirectory() as tmpdir: + ts_path = pathlib.Path(tmpdir) / "pedigree.trees" + tables.dump(ts_path) + ts = self.run_script( + f"0 --from-ts {ts_path} --model=fixed_pedigree -r 0.1 \ + --stop-condition=full_pedigree" + ) + assert ts.num_trees > 1 + assert ts.max_root_time == num_generations - 1 From df0f20d02b2409e1436ddec68d38fafade8759d9 Mon Sep 17 00:00:00 2001 From: Gertjan Bisschop Date: Tue, 19 Dec 2023 15:44:54 +0000 Subject: [PATCH 4/4] dtwf --- algorithms.py | 6 +++--- tests/test_algorithms.py | 10 +++++++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/algorithms.py b/algorithms.py index e08ef337f..b2d90bdb3 100644 --- a/algorithms.py +++ b/algorithms.py @@ -708,7 +708,7 @@ def assert_stop_condition(self): elif self.stop_condition == "all_local_mrcas": return any(num_anc > 1 for num_anc in self.S.values()) elif self.stop_condition == "time": - return self.get_num_ancestors() > 1 + return True elif self.stop_condition == "pedigree": return True else: @@ -1143,7 +1143,7 @@ def dtwf_simulate(self, end_time): Simulates the algorithm until all loci have coalesced. """ ret = 0 - while self.ancestors_remain(): + while self.assert_stop_condition(): if self.t + 1 > end_time: ret = 2 # _msprime.EXIT_MAX_TIME break @@ -2307,7 +2307,7 @@ def run_simulate(args): else: from_ts = tskit.load(args.from_ts) tables = from_ts.dump_tables() - if args.stop_condition == "full_pedigree": + if args.stop_condition == "pedigree": end_time = np.max(from_ts.nodes_time) else: end_time = args.end_time diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 624c583b9..d62e1d16f 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -469,7 +469,15 @@ def test_stopping_condition_pedigree(self): tables.dump(ts_path) ts = self.run_script( f"0 --from-ts {ts_path} --model=fixed_pedigree -r 0.1 \ - --stop-condition=full_pedigree" + --stop-condition=pedigree" ) assert ts.num_trees > 1 assert ts.max_root_time == num_generations - 1 + + def test_stopping_condition_dtwf(self): + end_time = 20 + ts = self.run_script( + f"10 --model=dtwf --stop-condition=time --end-time={end_time}" + ) + assert ts.num_trees > 1 + assert ts.max_root_time == end_time