Skip to content

Commit c8dad24

Browse files
committed
Create a new merge() function, and use that for concatenate
It appears as if `union` wasn't really the right underlying function to use for `concatenate`, e.g. it doesn't deal as expected with sites above samples, so I had to roll my own `merge` utility.
1 parent c86884d commit c8dad24

File tree

3 files changed

+390
-35
lines changed

3 files changed

+390
-35
lines changed

python/tests/test_topology.py

Lines changed: 185 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7141,18 +7141,180 @@ def test_bad_seq_len(self):
71417141
ts.shift(1, sequence_length=1)
71427142

71437143

7144+
class TestMerge:
7145+
def test_empty(self):
7146+
ts = tskit.TableCollection(2).tree_sequence()
7147+
merged_ts = ts.merge(ts, node_mapping=[])
7148+
assert merged_ts.num_nodes == 0
7149+
assert merged_ts.num_edges == 0
7150+
assert merged_ts.sequence_length == 2
7151+
7152+
def test_simple(self):
7153+
# Cut up a single tree into alternating edges and mutations, then merge
7154+
ts = tskit.Tree.generate_comb(4, span=10).tree_sequence
7155+
ts = msprime.sim_mutations(ts, rate=0.1, random_seed=1)
7156+
mut_counts = np.bincount(ts.mutations_site, minlength=ts.num_sites)
7157+
assert min(mut_counts) == 1
7158+
assert max(mut_counts) > 1
7159+
tables1 = ts.dump_tables()
7160+
tables1.mutations.clear()
7161+
tables2 = tables1.copy()
7162+
i = 0
7163+
for s in ts.sites():
7164+
for m in s.mutations:
7165+
i += 1
7166+
if i % 2:
7167+
tables1.mutations.append(m.replace(parent=tskit.NULL))
7168+
else:
7169+
tables2.mutations.append(m.replace(parent=tskit.NULL))
7170+
tables1.simplify()
7171+
tables2.simplify()
7172+
assert tables1.sites.num_rows != ts.num_sites
7173+
tables1.edges.clear()
7174+
tables2.edges.clear()
7175+
for e in ts.edges():
7176+
if e.id % 2:
7177+
tables1.edges.append(e)
7178+
else:
7179+
tables2.edges.append(e)
7180+
ts1 = tables1.tree_sequence()
7181+
ts2 = tables2.tree_sequence()
7182+
new_ts = ts1.merge(ts2, node_mapping=np.arange(ts.num_nodes)).simplify()
7183+
assert new_ts.equals(ts, ignore_provenance=True)
7184+
7185+
def test_multi_tree(self):
7186+
ts = msprime.sim_ancestry(
7187+
2, sequence_length=4, recombination_rate=1, random_seed=1
7188+
)
7189+
ts = msprime.sim_mutations(ts, rate=1, random_seed=1)
7190+
assert ts.num_trees > 3
7191+
assert ts.num_mutations > 4
7192+
ts1 = ts.keep_intervals([[0, 1.5]], simplify=False)
7193+
ts2 = ts.keep_intervals([[1.5, 4]], simplify=False)
7194+
new_ts = ts1.merge(
7195+
ts2, node_mapping=np.arange(ts.num_nodes), add_populations=False
7196+
)
7197+
assert new_ts.num_trees == ts.num_trees + 1
7198+
new_ts = new_ts.simplify()
7199+
new_ts.equals(ts, ignore_provenance=True)
7200+
7201+
def test_new_individuals(self):
7202+
ts1 = msprime.sim_ancestry(2, sequence_length=1, random_seed=1)
7203+
ts2 = msprime.sim_ancestry(2, sequence_length=1, random_seed=2)
7204+
tables = ts2.dump_tables()
7205+
tables.edges.clear()
7206+
ts2 = tables.tree_sequence()
7207+
node_map = np.full(ts2.num_nodes, tskit.NULL)
7208+
node_map[0:2] = [0, 1] # map first two nodes to themselves
7209+
ts_merged = ts1.merge(ts2, node_mapping=node_map)
7210+
assert ts_merged.num_nodes == ts1.num_nodes + ts2.num_nodes - 2
7211+
assert ts1.num_individuals == 2
7212+
assert ts_merged.num_individuals == 3
7213+
7214+
def test_popcheck(self):
7215+
tables = tskit.TableCollection(1)
7216+
p1 = tables.populations.add_row(b"foo")
7217+
p2 = tables.populations.add_row(b"bar")
7218+
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p1)
7219+
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p2)
7220+
ts1 = tables.tree_sequence()
7221+
tables.populations[0] = tables.populations[0].replace(metadata=b"baz")
7222+
ts2 = tables.tree_sequence()
7223+
with pytest.raises(ValueError, match="Non-matching populations"):
7224+
ts1.merge(ts2, node_mapping=[0, 1])
7225+
ts1.merge(ts2, node_mapping=[0, 1], check_populations=False)
7226+
# Check with add_populations=False
7227+
ts1.merge(ts2, node_mapping=[-1, 1]) # only merge the last one
7228+
with pytest.raises(ValueError, match="Non-matching populations"):
7229+
ts1.merge(ts2, node_mapping=[-1, 1], add_populations=False)
7230+
7231+
with pytest.raises(ValueError, match="Non-matching populations"):
7232+
ts1.simplify([0]).merge(ts2, node_mapping=[-1, 1])
7233+
7234+
def test_isolated_mutations(self):
7235+
tables = tskit.TableCollection(1)
7236+
u = tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE)
7237+
s = tables.sites.add_row(0.5, "A")
7238+
tables.mutations.add_row(s, u, derived_state="T", time=1, metadata=b"xxx")
7239+
ts1 = tables.tree_sequence()
7240+
tables.mutations[0] = tables.mutations[0].replace(time=0.5, metadata=b"yyy")
7241+
ts2 = tables.tree_sequence()
7242+
ts_merge = ts1.merge(ts2, node_mapping=[0])
7243+
assert ts_merge.num_sites == 1
7244+
assert ts_merge.num_mutations == 2
7245+
assert ts_merge.mutation(0).time == 1
7246+
assert ts_merge.mutation(0).parent == tskit.NULL
7247+
assert ts_merge.mutation(0).metadata == b"xxx"
7248+
assert ts_merge.mutation(1).time == 0.5
7249+
assert ts_merge.mutation(1).parent == 0
7250+
assert ts_merge.mutation(1).metadata == b"yyy"
7251+
7252+
def test_identity(self):
7253+
tables = tskit.TableCollection(1)
7254+
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE)
7255+
ts = tables.tree_sequence()
7256+
ts_merge = ts.merge(ts, node_mapping=[0])
7257+
assert ts.equals(ts_merge, ignore_provenance=True)
7258+
7259+
def test_provenance(self):
7260+
tables = tskit.TableCollection(1)
7261+
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE)
7262+
ts = tables.tree_sequence()
7263+
ts_merge = ts.merge(ts, node_mapping=[0], record_provenance=False)
7264+
assert ts_merge.num_provenances == ts.num_provenances
7265+
ts_merge = ts.merge(ts, node_mapping=[0])
7266+
assert ts_merge.num_provenances == ts.num_provenances + 1
7267+
prov = json.loads(ts_merge.provenance(-1).record)
7268+
assert prov["parameters"]["command"] == "merge"
7269+
assert prov["parameters"]["node_mapping"] == [0]
7270+
assert prov["parameters"]["add_populations"] is True
7271+
assert prov["parameters"]["check_populations"] is True
7272+
7273+
def test_bad_sequence_length(self):
7274+
ts1 = tskit.TableCollection(1).tree_sequence()
7275+
ts2 = tskit.TableCollection(2).tree_sequence()
7276+
with pytest.raises(ValueError, match="sequence length"):
7277+
ts1.merge(ts2, node_mapping=[])
7278+
7279+
def test_bad_node_mapping(self):
7280+
ts = tskit.Tree.generate_comb(5).tree_sequence
7281+
with pytest.raises(ValueError, match="node_mapping"):
7282+
ts.merge(ts, node_mapping=[0, 1, 2])
7283+
7284+
def test_bad_populations(self):
7285+
tables = tskit.TableCollection(1)
7286+
tables = tskit.TableCollection(1)
7287+
p1 = tables.populations.add_row()
7288+
p2 = tables.populations.add_row()
7289+
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p1)
7290+
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p1)
7291+
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p2)
7292+
ts2 = tables.tree_sequence()
7293+
ts1 = ts2.simplify([0, 1])
7294+
assert ts1.num_populations == 1
7295+
assert ts2.num_populations == 2
7296+
ts2.merge(ts1, [0, -1], check_populations=False, add_populations=False)
7297+
with pytest.raises(ValueError, match="population not present"):
7298+
ts1.merge(ts2, [0, -1, -1], check_populations=False, add_populations=False)
7299+
7300+
71447301
class TestConcatenate:
71457302
def test_simple(self):
71467303
ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence
7304+
ts1 = msprime.sim_mutations(ts1, rate=1, random_seed=1)
71477305
ts2 = tskit.Tree.generate_balanced(5, arity=3, span=3).tree_sequence
7306+
ts2 = msprime.sim_mutations(ts2, rate=1, random_seed=1)
71487307
assert ts1.num_samples == ts2.num_samples
71497308
assert ts1.num_nodes != ts2.num_nodes
71507309
joint_ts = ts1.concatenate(ts2)
71517310
assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 5
71527311
assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length
71537312
assert joint_ts.num_samples == ts1.num_samples
7313+
assert joint_ts.num_sites == ts1.num_sites + ts2.num_sites
7314+
assert joint_ts.num_mutations == ts1.num_mutations + ts2.num_mutations
71547315
ts3 = joint_ts.delete_intervals([[2, 5]]).rtrim()
71557316
# Have to simplify here, to remove the redundant nodes
7317+
ts3.tables.assert_equals(ts1.tables, ignore_provenance=True)
71567318
assert ts3.equals(ts1.simplify(), ignore_provenance=True)
71577319
ts4 = joint_ts.delete_intervals([[0, 2]]).ltrim()
71587320
assert ts4.equals(ts2.simplify(), ignore_provenance=True)
@@ -7200,22 +7362,29 @@ def test_internal_samples(self):
72007362
nodes_flags[:] = tskit.NODE_IS_SAMPLE
72017363
nodes_flags[-1] = 0 # Only root is not a sample
72027364
tables.nodes.flags = nodes_flags
7203-
ts = tables.tree_sequence()
7365+
ts = msprime.sim_mutations(tables.tree_sequence(), rate=0.5, random_seed=1)
7366+
assert ts.num_mutations > 0
7367+
assert ts.num_mutations > ts.num_sites
72047368
joint_ts = ts.concatenate(ts)
72057369
assert joint_ts.num_samples == ts.num_samples
72067370
assert joint_ts.num_nodes == ts.num_nodes + 1
7371+
assert joint_ts.num_mutations == ts.num_mutations * 2
7372+
assert joint_ts.num_sites == ts.num_sites * 2
72077373
assert joint_ts.sequence_length == ts.sequence_length * 2
72087374

72097375
def test_some_shared_samples(self):
7210-
ts1 = tskit.Tree.generate_comb(4, span=2).tree_sequence
7211-
ts2 = tskit.Tree.generate_balanced(8, arity=3, span=3).tree_sequence
7212-
shared = np.full(ts2.num_nodes, tskit.NULL)
7213-
shared[0] = 1
7214-
shared[1] = 0
7215-
joint_ts = ts1.concatenate(ts2, node_mappings=[shared])
7216-
assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length
7217-
assert joint_ts.num_samples == ts1.num_samples + ts2.num_samples - 2
7218-
assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 2
7376+
tables = tskit.Tree.generate_comb(5).tree_sequence.dump_tables()
7377+
tables.nodes[5] = tables.nodes[5].replace(flags=tskit.NODE_IS_SAMPLE)
7378+
ts1 = tables.tree_sequence()
7379+
tables = tskit.Tree.generate_balanced(5).tree_sequence.dump_tables()
7380+
tables.nodes[5] = tables.nodes[5].replace(flags=tskit.NODE_IS_SAMPLE)
7381+
ts2 = tables.tree_sequence()
7382+
assert ts1.num_samples == ts2.num_samples
7383+
joint_ts = ts1.concatenate(ts2)
7384+
assert joint_ts.num_samples == ts1.num_samples
7385+
assert joint_ts.num_edges == ts1.num_edges + ts2.num_edges
7386+
for tree in joint_ts.trees():
7387+
assert tree.num_roots == 1
72197388

72207389
def test_provenance(self):
72217390
ts = tskit.Tree.generate_comb(2).tree_sequence
@@ -7233,9 +7402,12 @@ def test_unequal_samples(self):
72337402
with pytest.raises(ValueError, match="must have the same number of samples"):
72347403
ts1.concatenate(ts2)
72357404

7236-
@pytest.mark.skip(
7237-
reason="union bug: https://github.com/tskit-dev/tskit/issues/3168"
7238-
)
7405+
def test_different_sample_numbers(self):
7406+
ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence
7407+
ts2 = tskit.Tree.generate_balanced(4, arity=3, span=3).tree_sequence
7408+
with pytest.raises(ValueError, match="must have the same number of samples"):
7409+
ts1.concatenate(ts2)
7410+
72397411
def test_duplicate_ts(self):
72407412
ts1 = tskit.Tree.generate_comb(3, span=4).tree_sequence
72417413
ts = ts1.keep_intervals([[0, 1]]).trim() # a quarter of the original

python/tskit/tables.py

Lines changed: 138 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4161,8 +4161,11 @@ def union(
41614161
portions of ``other`` to itself. To perform the node-wise union,
41624162
the method relies on a ``node_mapping`` array, that maps nodes in
41634163
``other`` to its equivalent node in ``self`` or ``tskit.NULL`` if
4164-
the node is exclusive to ``other``. See :meth:`TreeSequence.union` for a more
4165-
detailed description.
4164+
the node is exclusive to ``other``. See :meth:`TreeSequence.union`
4165+
for a more detailed description.
4166+
4167+
.. seealso::
4168+
:meth:`.union` for a combining two tables edgewise
41664169
41674170
:param TableCollection other: Another table collection.
41684171
:param list node_mapping: An array of node IDs that relate nodes in
@@ -4196,6 +4199,139 @@ def union(
41964199
record=json.dumps(provenance.get_provenance_dict(parameters))
41974200
)
41984201

4202+
def merge(
4203+
self,
4204+
other,
4205+
node_mapping,
4206+
*,
4207+
add_populations=None,
4208+
check_populations=None,
4209+
record_provenance=True,
4210+
):
4211+
"""
4212+
Merge another table collection into this one, edgewise. Nodes in ``other``
4213+
whose mapping is set to :data:`tskit.NULL` will be added as new nodes.
4214+
See :meth:`TreeSequence.merge` for a more detailed description.
4215+
4216+
.. seealso::
4217+
:meth:`.union` for a combining two tables nodewise
4218+
4219+
:param TableCollection other: Another table collection.
4220+
:param list node_mapping: An array of node IDs that relate nodes in
4221+
``other`` to nodes in ``self``: the k-th element of ``node_mapping``
4222+
should be the index of the equivalent node in ``self``, or
4223+
:data:`tskit.NULL` if the node is not present in ``self`` (in which
4224+
case it will be added to ``self``).
4225+
:param bool record_provenance: If ``True`` (default), record details of this
4226+
call to ``merge`` in the returned tree sequence's provenance
4227+
information (Default: ``True``).
4228+
:param bool add_populations: If ``True`` (default), populations referred
4229+
to from nodes new to ``self`` will be added as a new populations.
4230+
If ``False`` new populations will not be created, and populations
4231+
with the same ID in ``self`` and the other tree sequences will be reused.
4232+
:param bool check_populations: If ``True`` (default), check that the
4233+
populations referred to from nodes in the other tree sequences
4234+
are identical to those in ``self``.
4235+
4236+
"""
4237+
if add_populations is None:
4238+
add_populations = True
4239+
if check_populations is None:
4240+
check_populations = True
4241+
4242+
node_mapping = util.safe_np_int_cast(node_mapping, np.int32)
4243+
node_map = node_mapping.copy()
4244+
if node_map.shape != (other.nodes.num_rows,):
4245+
raise ValueError(
4246+
"node_mapping must be of length equal to the number of nodes in other"
4247+
)
4248+
if self.sequence_length != other.sequence_length:
4249+
raise ValueError(
4250+
"Tree sequences must have same sequence lengths: use trim or shift to"
4251+
" adjust sequence lengths as necessary"
4252+
)
4253+
if check_populations:
4254+
nodes_pop = other.nodes.population
4255+
if add_populations:
4256+
# Only need to check those populations used by nodes in the node_mapping
4257+
nodes_pop = nodes_pop[node_map != tskit.NULL]
4258+
for i in np.unique(nodes_pop[nodes_pop != tskit.NULL]):
4259+
if (
4260+
i >= self.populations.num_rows
4261+
or self.populations[i] != other.populations[i]
4262+
):
4263+
raise ValueError("Non-matching populations")
4264+
individual_map = {}
4265+
population_map = {}
4266+
for new_node in np.where(node_map == tskit.NULL)[0]:
4267+
params = {}
4268+
node = other.nodes[new_node]
4269+
if node.individual != tskit.NULL:
4270+
if node.individual not in individual_map:
4271+
individual_map[node.individual] = self.individuals.append(
4272+
other.individuals[node.individual]
4273+
)
4274+
params["individual"] = individual_map[node.individual]
4275+
if node.population != tskit.NULL:
4276+
if add_populations:
4277+
if node.population not in population_map:
4278+
population_map[node.population] = self.populations.append(
4279+
other.populations[node.population]
4280+
)
4281+
params["population"] = population_map[node.population]
4282+
else:
4283+
if node.population >= self.populations.num_rows:
4284+
raise ValueError(
4285+
"One of the tree sequences to concatenate has a "
4286+
"population not present in the existing tree sequence"
4287+
)
4288+
node_map[new_node] = self.nodes.append(node.replace(**params))
4289+
4290+
for e in other.edges:
4291+
self.edges.append(
4292+
e.replace(child=node_map[e.child], parent=node_map[e.parent])
4293+
)
4294+
site_map = {}
4295+
site_positions = {p: i for i, p in enumerate(self.sites.position)}
4296+
for site_id, site in enumerate(other.sites):
4297+
if site.position in site_positions:
4298+
site_map[site_id] = site_positions[site.position]
4299+
else:
4300+
site_map[site_id] = self.sites.append(site)
4301+
for mut in other.mutations:
4302+
self.mutations.append(
4303+
mut.replace(
4304+
node=node_map[mut.node],
4305+
site=site_map[mut.site],
4306+
parent=tskit.NULL,
4307+
)
4308+
)
4309+
for mig in other.migrations:
4310+
self.migrations.append(
4311+
mig.replace(
4312+
node=node_map[mig.node],
4313+
source=population_map.get(mig.source, mig.source),
4314+
dest=population_map.get(mig.dest, mig.dest),
4315+
)
4316+
)
4317+
self.sort()
4318+
self.build_index()
4319+
self.compute_mutation_parents()
4320+
4321+
if record_provenance:
4322+
other_records = [prov.record for prov in other.provenances]
4323+
other_timestamps = [prov.timestamp for prov in other.provenances]
4324+
parameters = {
4325+
"command": "merge",
4326+
"other": {"timestamp": other_timestamps, "record": other_records},
4327+
"node_mapping": node_mapping.tolist(),
4328+
"add_populations": add_populations,
4329+
"check_populations": check_populations,
4330+
}
4331+
self.provenances.add_row(
4332+
record=json.dumps(provenance.get_provenance_dict(parameters))
4333+
)
4334+
41994335
def ibd_segments(
42004336
self,
42014337
*,

0 commit comments

Comments
 (0)