Skip to content

Commit 6944250

Browse files
committed
Create a new merge() function, and use that for concatenate
It appears as if `union` wasn't really the right underlying function to use for `concatenate`, e.g. it doesn't deal as expected with sites above samples, so I had to roll my own `merge` utility.
1 parent 9d2ff8e commit 6944250

File tree

5 files changed

+509
-37
lines changed

5 files changed

+509
-37
lines changed

docs/python-api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ which perform the same actions but modify the {class}`TableCollection` in place.
261261
.. autosummary::
262262
TreeSequence.simplify
263263
TreeSequence.subset
264+
TreeSequence.merge
264265
TreeSequence.union
265266
TreeSequence.concatenate
266267
TreeSequence.keep_intervals
@@ -753,6 +754,7 @@ a functional way, returning a new tree sequence while leaving the original uncha
753754
TableCollection.delete_sites
754755
TableCollection.trim
755756
TableCollection.shift
757+
TableCollection.merge
756758
TableCollection.union
757759
TableCollection.delete_older
758760
```

python/CHANGELOG.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,19 @@
77
- ``TreeSequence.map_to_vcf_model`` now also returns the transformed positions and
88
contig length. (:user:`benjeffery`, :pr:`XXXX`, :issue:`3173`)
99

10+
- New ``merge`` functions for tree sequences and table collections, to merge another
11+
into the current one (:user:`hyanwong`, :pr:`3183`, :issue:`3181`)
12+
1013
**Bugfixes**
1114

1215
- Fix bug in ``TreeSequence.pair_coalescence_counts`` when ``span_normalise=True``
1316
and a window breakpoint falls within an internal missing interval.
1417
(:user:`nspope`, :pr:`3176`, :issue:`3175`)
1518

19+
- Change ``TreeSequence.concatenate`` to use ``merge``, as ``union`` does not
20+
port edges, sites, or mutations from the added tree sequences if they are associated
21+
with shared nodes (:user:`hyanwong`, :pr:`3183`, :issue:`3168`, :issue:`3182`)
22+
1623
--------------------
1724
[0.6.4] - 2025-05-21
1825
--------------------

python/tests/test_topology.py

Lines changed: 268 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import io
2828
import itertools
2929
import json
30+
import platform
3031
import random
3132
import sys
3233
import unittest
@@ -43,6 +44,9 @@
4344
import tskit.provenance as provenance
4445

4546

47+
IS_WINDOWS = platform.system() == "Windows"
48+
49+
4650
def simple_keep_intervals(tables, intervals, simplify=True, record_provenance=True):
4751
"""
4852
Simple Python implementation of keep_intervals.
@@ -7141,18 +7145,223 @@ def test_bad_seq_len(self):
71417145
ts.shift(1, sequence_length=1)
71427146

71437147

7148+
class TestMerge:
7149+
def test_empty(self):
7150+
ts = tskit.TableCollection(2).tree_sequence()
7151+
merged_ts = ts.merge(ts, node_mapping=[])
7152+
assert merged_ts.num_nodes == 0
7153+
assert merged_ts.num_edges == 0
7154+
assert merged_ts.sequence_length == 2
7155+
7156+
def test_overlay(self):
7157+
ts1 = tskit.Tree.generate_balanced(4, span=2).tree_sequence
7158+
tables = tskit.Tree.generate_comb(4, span=2).tree_sequence.dump_tables()
7159+
tables.populations.add_row()
7160+
tables.nodes[5] = tables.nodes[5].replace(
7161+
flags=tskit.NODE_IS_SAMPLE, population=0
7162+
)
7163+
ts2 = tables.tree_sequence()
7164+
ts_merge = ts1.merge(ts2, node_mapping=np.full(ts1.num_nodes, tskit.NULL))
7165+
assert ts_merge.sequence_length == ts1.sequence_length
7166+
assert ts_merge.num_samples == ts1.num_samples + ts2.num_samples
7167+
assert ts_merge.num_nodes == ts1.num_nodes + ts2.num_nodes
7168+
assert ts_merge.num_edges == ts1.num_edges + ts2.num_edges
7169+
assert ts_merge.num_trees == 1
7170+
assert ts_merge.num_populations == 1
7171+
assert ts_merge.first().num_roots == 2
7172+
7173+
def test_split_and_merge(self):
7174+
# Cut up a single tree into alternating edges and mutations, then merge
7175+
ts = tskit.Tree.generate_comb(4, span=10).tree_sequence
7176+
ts = msprime.sim_mutations(ts, rate=0.1, random_seed=1)
7177+
mut_counts = np.bincount(ts.mutations_site, minlength=ts.num_sites)
7178+
assert min(mut_counts) == 1
7179+
assert max(mut_counts) > 1
7180+
tables1 = ts.dump_tables()
7181+
tables1.mutations.clear()
7182+
tables2 = tables1.copy()
7183+
i = 0
7184+
for s in ts.sites():
7185+
for m in s.mutations:
7186+
i += 1
7187+
if i % 2:
7188+
tables1.mutations.append(m.replace(parent=tskit.NULL))
7189+
else:
7190+
tables2.mutations.append(m.replace(parent=tskit.NULL))
7191+
tables1.simplify()
7192+
tables2.simplify()
7193+
assert tables1.sites.num_rows != ts.num_sites
7194+
tables1.edges.clear()
7195+
tables2.edges.clear()
7196+
for e in ts.edges():
7197+
if e.id % 2:
7198+
tables1.edges.append(e)
7199+
else:
7200+
tables2.edges.append(e)
7201+
ts1 = tables1.tree_sequence()
7202+
ts2 = tables2.tree_sequence()
7203+
new_ts = ts1.merge(ts2, node_mapping=np.arange(ts.num_nodes)).simplify()
7204+
assert new_ts.equals(ts, ignore_provenance=True)
7205+
7206+
def test_multi_tree(self):
7207+
ts = msprime.sim_ancestry(
7208+
2, sequence_length=4, recombination_rate=1, random_seed=1
7209+
)
7210+
ts = msprime.sim_mutations(ts, rate=1, random_seed=1)
7211+
assert ts.num_trees > 3
7212+
assert ts.num_mutations > 4
7213+
ts1 = ts.keep_intervals([[0, 1.5]], simplify=False)
7214+
ts2 = ts.keep_intervals([[1.5, 4]], simplify=False)
7215+
new_ts = ts1.merge(
7216+
ts2, node_mapping=np.arange(ts.num_nodes), add_populations=False
7217+
)
7218+
assert new_ts.num_trees == ts.num_trees + 1
7219+
new_ts = new_ts.simplify()
7220+
new_ts.equals(ts, ignore_provenance=True)
7221+
7222+
def test_new_individuals(self):
7223+
ts1 = msprime.sim_ancestry(2, sequence_length=1, random_seed=1)
7224+
ts2 = msprime.sim_ancestry(2, sequence_length=1, random_seed=2)
7225+
tables = ts2.dump_tables()
7226+
tables.edges.clear()
7227+
ts2 = tables.tree_sequence()
7228+
node_map = np.full(ts2.num_nodes, tskit.NULL)
7229+
node_map[0:2] = [0, 1] # map first two nodes to themselves
7230+
ts_merged = ts1.merge(ts2, node_mapping=node_map)
7231+
assert ts_merged.num_nodes == ts1.num_nodes + ts2.num_nodes - 2
7232+
assert ts1.num_individuals == 2
7233+
assert ts_merged.num_individuals == 3
7234+
7235+
def test_popcheck(self):
7236+
tables = tskit.TableCollection(1)
7237+
p1 = tables.populations.add_row(b"foo")
7238+
p2 = tables.populations.add_row(b"bar")
7239+
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p1)
7240+
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p2)
7241+
ts1 = tables.tree_sequence()
7242+
tables.populations[0] = tables.populations[0].replace(metadata=b"baz")
7243+
ts2 = tables.tree_sequence()
7244+
with pytest.raises(ValueError, match="Non-matching populations"):
7245+
ts1.merge(ts2, node_mapping=[0, 1])
7246+
ts1.merge(ts2, node_mapping=[0, 1], check_populations=False)
7247+
# Check with add_populations=False
7248+
ts1.merge(ts2, node_mapping=[-1, 1]) # only merge the last one
7249+
with pytest.raises(ValueError, match="Non-matching populations"):
7250+
ts1.merge(ts2, node_mapping=[-1, 1], add_populations=False)
7251+
7252+
with pytest.raises(ValueError, match="Non-matching populations"):
7253+
ts1.simplify([0]).merge(ts2, node_mapping=[-1, 1])
7254+
7255+
def test_isolated_mutations(self):
7256+
tables = tskit.TableCollection(1)
7257+
u = tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE)
7258+
s = tables.sites.add_row(0.5, "A")
7259+
tables.mutations.add_row(s, u, derived_state="T", time=1, metadata=b"xxx")
7260+
ts1 = tables.tree_sequence()
7261+
tables.mutations[0] = tables.mutations[0].replace(time=0.5, metadata=b"yyy")
7262+
ts2 = tables.tree_sequence()
7263+
ts_merge = ts1.merge(ts2, node_mapping=[0])
7264+
assert ts_merge.num_sites == 1
7265+
assert ts_merge.num_mutations == 2
7266+
assert ts_merge.mutation(0).time == 1
7267+
assert ts_merge.mutation(0).parent == tskit.NULL
7268+
assert ts_merge.mutation(0).metadata == b"xxx"
7269+
assert ts_merge.mutation(1).time == 0.5
7270+
assert ts_merge.mutation(1).parent == 0
7271+
assert ts_merge.mutation(1).metadata == b"yyy"
7272+
7273+
def test_identity(self):
7274+
tables = tskit.TableCollection(1)
7275+
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE)
7276+
ts = tables.tree_sequence()
7277+
ts_merge = ts.merge(ts, node_mapping=[0])
7278+
assert ts.equals(ts_merge, ignore_provenance=True)
7279+
7280+
@pytest.mark.skipif(IS_WINDOWS, reason="Msprime gives different result on Windows")
7281+
def test_migrations(self):
7282+
pop_configs = [msprime.PopulationConfiguration(3) for _ in range(2)]
7283+
migration_matrix = [[0, 0.001], [0.001, 0]]
7284+
ts = msprime.simulate(
7285+
population_configurations=pop_configs,
7286+
migration_matrix=migration_matrix,
7287+
record_migrations=True,
7288+
recombination_rate=2,
7289+
random_seed=42, # pick a seed that gives min(migrations.left) > 0
7290+
end_time=100,
7291+
)
7292+
# No migration_table.squash() function exists, so we just try to cut on the
7293+
# LHS of all the migrations
7294+
assert ts.num_migrations > 0
7295+
assert ts.migrations_left.min() > 0
7296+
cutpoint = ts.migrations_left.min()
7297+
ts1 = ts.keep_intervals([[0, cutpoint]], simplify=False)
7298+
ts2 = ts.keep_intervals([[cutpoint, ts.sequence_length]], simplify=False)
7299+
ts_new = ts1.merge(ts2, node_mapping=np.arange(ts.num_nodes))
7300+
tables = ts_new.dump_tables()
7301+
tables.edges.squash()
7302+
tables.sort()
7303+
ts_new = tables.tree_sequence()
7304+
ts.tables.assert_equals(ts_new.tables, ignore_provenance=True)
7305+
7306+
def test_provenance(self):
7307+
tables = tskit.TableCollection(1)
7308+
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE)
7309+
ts = tables.tree_sequence()
7310+
ts_merge = ts.merge(ts, node_mapping=[0], record_provenance=False)
7311+
assert ts_merge.num_provenances == ts.num_provenances
7312+
ts_merge = ts.merge(ts, node_mapping=[0])
7313+
assert ts_merge.num_provenances == ts.num_provenances + 1
7314+
prov = json.loads(ts_merge.provenance(-1).record)
7315+
assert prov["parameters"]["command"] == "merge"
7316+
assert prov["parameters"]["node_mapping"] == [0]
7317+
assert prov["parameters"]["add_populations"] is True
7318+
assert prov["parameters"]["check_populations"] is True
7319+
7320+
def test_bad_sequence_length(self):
7321+
ts1 = tskit.TableCollection(1).tree_sequence()
7322+
ts2 = tskit.TableCollection(2).tree_sequence()
7323+
with pytest.raises(ValueError, match="sequence length"):
7324+
ts1.merge(ts2, node_mapping=[])
7325+
7326+
def test_bad_node_mapping(self):
7327+
ts = tskit.Tree.generate_comb(5).tree_sequence
7328+
with pytest.raises(ValueError, match="node_mapping"):
7329+
ts.merge(ts, node_mapping=[0, 1, 2])
7330+
7331+
def test_bad_populations(self):
7332+
tables = tskit.TableCollection(1)
7333+
tables = tskit.TableCollection(1)
7334+
p1 = tables.populations.add_row()
7335+
p2 = tables.populations.add_row()
7336+
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p1)
7337+
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p1)
7338+
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p2)
7339+
ts2 = tables.tree_sequence()
7340+
ts1 = ts2.simplify([0, 1])
7341+
assert ts1.num_populations == 1
7342+
assert ts2.num_populations == 2
7343+
ts2.merge(ts1, [0, -1], check_populations=False, add_populations=False)
7344+
with pytest.raises(ValueError, match="population not present"):
7345+
ts1.merge(ts2, [0, -1, -1], check_populations=False, add_populations=False)
7346+
7347+
71447348
class TestConcatenate:
71457349
def test_simple(self):
71467350
ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence
7351+
ts1 = msprime.sim_mutations(ts1, rate=1, random_seed=1)
71477352
ts2 = tskit.Tree.generate_balanced(5, arity=3, span=3).tree_sequence
7353+
ts2 = msprime.sim_mutations(ts2, rate=1, random_seed=1)
71487354
assert ts1.num_samples == ts2.num_samples
71497355
assert ts1.num_nodes != ts2.num_nodes
71507356
joint_ts = ts1.concatenate(ts2)
71517357
assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 5
71527358
assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length
71537359
assert joint_ts.num_samples == ts1.num_samples
7360+
assert joint_ts.num_sites == ts1.num_sites + ts2.num_sites
7361+
assert joint_ts.num_mutations == ts1.num_mutations + ts2.num_mutations
71547362
ts3 = joint_ts.delete_intervals([[2, 5]]).rtrim()
71557363
# Have to simplify here, to remove the redundant nodes
7364+
ts3.tables.assert_equals(ts1.tables, ignore_provenance=True)
71567365
assert ts3.equals(ts1.simplify(), ignore_provenance=True)
71577366
ts4 = joint_ts.delete_intervals([[0, 2]]).ltrim()
71587367
assert ts4.equals(ts2.simplify(), ignore_provenance=True)
@@ -7183,6 +7392,13 @@ def test_empty(self):
71837392
assert ts.num_nodes == 0
71847393
assert ts.sequence_length == 40
71857394

7395+
def test_check_populations(self):
7396+
ts = msprime.sim_ancestry(2)
7397+
ts1 = ts.concatenate(ts, ts, check_populations=True)
7398+
assert ts1.num_populations == 1
7399+
ts2 = ts.concatenate(ts, ts, add_populations=True, check_populations=True)
7400+
assert ts2.num_populations == 3
7401+
71867402
def test_samples_at_end(self):
71877403
ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence
71887404
ts2 = tskit.Tree.generate_balanced(5, arity=3, span=3).tree_sequence
@@ -7200,22 +7416,58 @@ def test_internal_samples(self):
72007416
nodes_flags[:] = tskit.NODE_IS_SAMPLE
72017417
nodes_flags[-1] = 0 # Only root is not a sample
72027418
tables.nodes.flags = nodes_flags
7203-
ts = tables.tree_sequence()
7419+
ts = msprime.sim_mutations(tables.tree_sequence(), rate=0.5, random_seed=1)
7420+
assert ts.num_mutations > 0
7421+
assert ts.num_mutations > ts.num_sites
72047422
joint_ts = ts.concatenate(ts)
72057423
assert joint_ts.num_samples == ts.num_samples
72067424
assert joint_ts.num_nodes == ts.num_nodes + 1
7425+
assert joint_ts.num_mutations == ts.num_mutations * 2
7426+
assert joint_ts.num_sites == ts.num_sites * 2
72077427
assert joint_ts.sequence_length == ts.sequence_length * 2
72087428

72097429
def test_some_shared_samples(self):
7210-
ts1 = tskit.Tree.generate_comb(4, span=2).tree_sequence
7211-
ts2 = tskit.Tree.generate_balanced(8, arity=3, span=3).tree_sequence
7212-
shared = np.full(ts2.num_nodes, tskit.NULL)
7213-
shared[0] = 1
7214-
shared[1] = 0
7215-
joint_ts = ts1.concatenate(ts2, node_mappings=[shared])
7216-
assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length
7217-
assert joint_ts.num_samples == ts1.num_samples + ts2.num_samples - 2
7218-
assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 2
7430+
tables = tskit.Tree.generate_comb(5).tree_sequence.dump_tables()
7431+
tables.nodes[5] = tables.nodes[5].replace(flags=tskit.NODE_IS_SAMPLE)
7432+
ts1 = tables.tree_sequence()
7433+
tables = tskit.Tree.generate_balanced(5).tree_sequence.dump_tables()
7434+
tables.nodes[5] = tables.nodes[5].replace(flags=tskit.NODE_IS_SAMPLE)
7435+
ts2 = tables.tree_sequence()
7436+
assert ts1.num_samples == ts2.num_samples
7437+
joint_ts = ts1.concatenate(ts2)
7438+
assert joint_ts.num_samples == ts1.num_samples
7439+
assert joint_ts.num_edges == ts1.num_edges + ts2.num_edges
7440+
for tree in joint_ts.trees():
7441+
assert tree.num_roots == 1
7442+
7443+
@pytest.mark.parametrize("simplify", [True, False])
7444+
def test_wf_sim(self, simplify):
7445+
# Test that we can split & concat a wf_sim ts, which has internal samples
7446+
tables = wf.wf_sim(
7447+
6,
7448+
5,
7449+
seed=3,
7450+
deep_history=True,
7451+
initial_generation_samples=True,
7452+
num_loci=10,
7453+
)
7454+
tables.sort()
7455+
tables.simplify()
7456+
ts = msprime.mutate(tables.tree_sequence(), rate=0.05, random_seed=234)
7457+
assert ts.num_trees > 2
7458+
assert len(np.unique(ts.nodes_time[ts.samples()])) > 1
7459+
ts1 = ts.keep_intervals([[0, 4.5]], simplify=False).trim()
7460+
ts2 = ts.keep_intervals([[4.5, ts.sequence_length]], simplify=False).trim()
7461+
if simplify:
7462+
ts1 = ts1.simplify(filter_nodes=False)
7463+
ts2, node_map = ts2.simplify(map_nodes=True)
7464+
node_mapping = np.zeros_like(node_map, shape=ts2.num_nodes)
7465+
kept = node_map != tskit.NULL
7466+
node_mapping[node_map[kept]] = np.arange(len(node_map))[kept]
7467+
else:
7468+
node_mapping = np.arange(ts.num_nodes)
7469+
ts_new = ts1.concatenate(ts2, node_mappings=[node_mapping]).simplify()
7470+
ts_new.tables.assert_equals(ts.tables, ignore_provenance=True)
72197471

72207472
def test_provenance(self):
72217473
ts = tskit.Tree.generate_comb(2).tree_sequence
@@ -7233,9 +7485,12 @@ def test_unequal_samples(self):
72337485
with pytest.raises(ValueError, match="must have the same number of samples"):
72347486
ts1.concatenate(ts2)
72357487

7236-
@pytest.mark.skip(
7237-
reason="union bug: https://github.com/tskit-dev/tskit/issues/3168"
7238-
)
7488+
def test_different_sample_numbers(self):
7489+
ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence
7490+
ts2 = tskit.Tree.generate_balanced(4, arity=3, span=3).tree_sequence
7491+
with pytest.raises(ValueError, match="must have the same number of samples"):
7492+
ts1.concatenate(ts2)
7493+
72397494
def test_duplicate_ts(self):
72407495
ts1 = tskit.Tree.generate_comb(3, span=4).tree_sequence
72417496
ts = ts1.keep_intervals([[0, 1]]).trim() # a quarter of the original

0 commit comments

Comments
 (0)