diff --git a/docs/python-api.md b/docs/python-api.md index 1ccf28d0ff..a452d7076a 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -261,6 +261,7 @@ which perform the same actions but modify the {class}`TableCollection` in place. .. autosummary:: TreeSequence.simplify TreeSequence.subset + TreeSequence.merge TreeSequence.union TreeSequence.concatenate TreeSequence.keep_intervals @@ -753,6 +754,7 @@ a functional way, returning a new tree sequence while leaving the original uncha TableCollection.delete_sites TableCollection.trim TableCollection.shift + TableCollection.merge TableCollection.union TableCollection.delete_older ``` diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 2baf0cc1c3..cacae488fc 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -7,12 +7,19 @@ - ``TreeSequence.map_to_vcf_model`` now also returns the transformed positions and contig length. (:user:`benjeffery`, :pr:`XXXX`, :issue:`3173`) +- New ``merge`` functions for tree sequences and table collections, to merge another + into the current one (:user:`hyanwong`, :pr:`3183`, :issue:`3181`) + **Bugfixes** - Fix bug in ``TreeSequence.pair_coalescence_counts`` when ``span_normalise=True`` and a window breakpoint falls within an internal missing interval. (:user:`nspope`, :pr:`3176`, :issue:`3175`) +- Change ``TreeSequence.concatenate`` to use ``merge``, as ``union`` does not + port edges, sites, or mutations from the added tree sequences if they are associated + with shared nodes (:user:`hyanwong`, :pr:`3183`, :issue:`3168`, :issue:`3182`) + -------------------- [0.6.4] - 2025-05-21 -------------------- diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index b2eb4ea6a7..51f1f362ed 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -27,6 +27,7 @@ import io import itertools import json +import platform import random import sys import unittest @@ -43,6 +44,9 @@ import tskit.provenance as provenance +IS_WINDOWS = platform.system() == "Windows" + + def simple_keep_intervals(tables, intervals, simplify=True, record_provenance=True): """ Simple Python implementation of keep_intervals. @@ -7141,18 +7145,223 @@ def test_bad_seq_len(self): ts.shift(1, sequence_length=1) +class TestMerge: + def test_empty(self): + ts = tskit.TableCollection(2).tree_sequence() + merged_ts = ts.merge(ts, node_mapping=[]) + assert merged_ts.num_nodes == 0 + assert merged_ts.num_edges == 0 + assert merged_ts.sequence_length == 2 + + def test_overlay(self): + ts1 = tskit.Tree.generate_balanced(4, span=2).tree_sequence + tables = tskit.Tree.generate_comb(4, span=2).tree_sequence.dump_tables() + tables.populations.add_row() + tables.nodes[5] = tables.nodes[5].replace( + flags=tskit.NODE_IS_SAMPLE, population=0 + ) + ts2 = tables.tree_sequence() + ts_merge = ts1.merge(ts2, node_mapping=np.full(ts1.num_nodes, tskit.NULL)) + assert ts_merge.sequence_length == ts1.sequence_length + assert ts_merge.num_samples == ts1.num_samples + ts2.num_samples + assert ts_merge.num_nodes == ts1.num_nodes + ts2.num_nodes + assert ts_merge.num_edges == ts1.num_edges + ts2.num_edges + assert ts_merge.num_trees == 1 + assert ts_merge.num_populations == 1 + assert ts_merge.first().num_roots == 2 + + def test_split_and_merge(self): + # Cut up a single tree into alternating edges and mutations, then merge + ts = tskit.Tree.generate_comb(4, span=10).tree_sequence + ts = msprime.sim_mutations(ts, rate=0.1, random_seed=1) + mut_counts = np.bincount(ts.mutations_site, minlength=ts.num_sites) + assert min(mut_counts) == 1 + assert max(mut_counts) > 1 + tables1 = ts.dump_tables() + tables1.mutations.clear() + tables2 = tables1.copy() + i = 0 + for s in ts.sites(): + for m in s.mutations: + i += 1 + if i % 2: + tables1.mutations.append(m.replace(parent=tskit.NULL)) + else: + tables2.mutations.append(m.replace(parent=tskit.NULL)) + tables1.simplify() + tables2.simplify() + assert tables1.sites.num_rows != ts.num_sites + tables1.edges.clear() + tables2.edges.clear() + for e in ts.edges(): + if e.id % 2: + tables1.edges.append(e) + else: + tables2.edges.append(e) + ts1 = tables1.tree_sequence() + ts2 = tables2.tree_sequence() + new_ts = ts1.merge(ts2, node_mapping=np.arange(ts.num_nodes)).simplify() + assert new_ts.equals(ts, ignore_provenance=True) + + def test_multi_tree(self): + ts = msprime.sim_ancestry( + 2, sequence_length=4, recombination_rate=1, random_seed=1 + ) + ts = msprime.sim_mutations(ts, rate=1, random_seed=1) + assert ts.num_trees > 3 + assert ts.num_mutations > 4 + ts1 = ts.keep_intervals([[0, 1.5]], simplify=False) + ts2 = ts.keep_intervals([[1.5, 4]], simplify=False) + new_ts = ts1.merge( + ts2, node_mapping=np.arange(ts.num_nodes), add_populations=False + ) + assert new_ts.num_trees == ts.num_trees + 1 + new_ts = new_ts.simplify() + new_ts.equals(ts, ignore_provenance=True) + + def test_new_individuals(self): + ts1 = msprime.sim_ancestry(2, sequence_length=1, random_seed=1) + ts2 = msprime.sim_ancestry(2, sequence_length=1, random_seed=2) + tables = ts2.dump_tables() + tables.edges.clear() + ts2 = tables.tree_sequence() + node_map = np.full(ts2.num_nodes, tskit.NULL) + node_map[0:2] = [0, 1] # map first two nodes to themselves + ts_merged = ts1.merge(ts2, node_mapping=node_map) + assert ts_merged.num_nodes == ts1.num_nodes + ts2.num_nodes - 2 + assert ts1.num_individuals == 2 + assert ts_merged.num_individuals == 3 + + def test_popcheck(self): + tables = tskit.TableCollection(1) + p1 = tables.populations.add_row(b"foo") + p2 = tables.populations.add_row(b"bar") + tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p1) + tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p2) + ts1 = tables.tree_sequence() + tables.populations[0] = tables.populations[0].replace(metadata=b"baz") + ts2 = tables.tree_sequence() + with pytest.raises(ValueError, match="Non-matching populations"): + ts1.merge(ts2, node_mapping=[0, 1]) + ts1.merge(ts2, node_mapping=[0, 1], check_populations=False) + # Check with add_populations=False + ts1.merge(ts2, node_mapping=[-1, 1]) # only merge the last one + with pytest.raises(ValueError, match="Non-matching populations"): + ts1.merge(ts2, node_mapping=[-1, 1], add_populations=False) + + with pytest.raises(ValueError, match="Non-matching populations"): + ts1.simplify([0]).merge(ts2, node_mapping=[-1, 1]) + + def test_isolated_mutations(self): + tables = tskit.TableCollection(1) + u = tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE) + s = tables.sites.add_row(0.5, "A") + tables.mutations.add_row(s, u, derived_state="T", time=1, metadata=b"xxx") + ts1 = tables.tree_sequence() + tables.mutations[0] = tables.mutations[0].replace(time=0.5, metadata=b"yyy") + ts2 = tables.tree_sequence() + ts_merge = ts1.merge(ts2, node_mapping=[0]) + assert ts_merge.num_sites == 1 + assert ts_merge.num_mutations == 2 + assert ts_merge.mutation(0).time == 1 + assert ts_merge.mutation(0).parent == tskit.NULL + assert ts_merge.mutation(0).metadata == b"xxx" + assert ts_merge.mutation(1).time == 0.5 + assert ts_merge.mutation(1).parent == 0 + assert ts_merge.mutation(1).metadata == b"yyy" + + def test_identity(self): + tables = tskit.TableCollection(1) + tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE) + ts = tables.tree_sequence() + ts_merge = ts.merge(ts, node_mapping=[0]) + assert ts.equals(ts_merge, ignore_provenance=True) + + @pytest.mark.skipif(IS_WINDOWS, reason="Msprime gives different result on Windows") + def test_migrations(self): + pop_configs = [msprime.PopulationConfiguration(3) for _ in range(2)] + migration_matrix = [[0, 0.001], [0.001, 0]] + ts = msprime.simulate( + population_configurations=pop_configs, + migration_matrix=migration_matrix, + record_migrations=True, + recombination_rate=2, + random_seed=42, # pick a seed that gives min(migrations.left) > 0 + end_time=100, + ) + # No migration_table.squash() function exists, so we just try to cut on the + # LHS of all the migrations + assert ts.num_migrations > 0 + assert ts.migrations_left.min() > 0 + cutpoint = ts.migrations_left.min() + ts1 = ts.keep_intervals([[0, cutpoint]], simplify=False) + ts2 = ts.keep_intervals([[cutpoint, ts.sequence_length]], simplify=False) + ts_new = ts1.merge(ts2, node_mapping=np.arange(ts.num_nodes)) + tables = ts_new.dump_tables() + tables.edges.squash() + tables.sort() + ts_new = tables.tree_sequence() + ts.tables.assert_equals(ts_new.tables, ignore_provenance=True) + + def test_provenance(self): + tables = tskit.TableCollection(1) + tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE) + ts = tables.tree_sequence() + ts_merge = ts.merge(ts, node_mapping=[0], record_provenance=False) + assert ts_merge.num_provenances == ts.num_provenances + ts_merge = ts.merge(ts, node_mapping=[0]) + assert ts_merge.num_provenances == ts.num_provenances + 1 + prov = json.loads(ts_merge.provenance(-1).record) + assert prov["parameters"]["command"] == "merge" + assert prov["parameters"]["node_mapping"] == [0] + assert prov["parameters"]["add_populations"] is True + assert prov["parameters"]["check_populations"] is True + + def test_bad_sequence_length(self): + ts1 = tskit.TableCollection(1).tree_sequence() + ts2 = tskit.TableCollection(2).tree_sequence() + with pytest.raises(ValueError, match="sequence length"): + ts1.merge(ts2, node_mapping=[]) + + def test_bad_node_mapping(self): + ts = tskit.Tree.generate_comb(5).tree_sequence + with pytest.raises(ValueError, match="node_mapping"): + ts.merge(ts, node_mapping=[0, 1, 2]) + + def test_bad_populations(self): + tables = tskit.TableCollection(1) + tables = tskit.TableCollection(1) + p1 = tables.populations.add_row() + p2 = tables.populations.add_row() + tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p1) + tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p1) + tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE, population=p2) + ts2 = tables.tree_sequence() + ts1 = ts2.simplify([0, 1]) + assert ts1.num_populations == 1 + assert ts2.num_populations == 2 + ts2.merge(ts1, [0, -1], check_populations=False, add_populations=False) + with pytest.raises(ValueError, match="population not present"): + ts1.merge(ts2, [0, -1, -1], check_populations=False, add_populations=False) + + class TestConcatenate: def test_simple(self): ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence + ts1 = msprime.sim_mutations(ts1, rate=1, random_seed=1) ts2 = tskit.Tree.generate_balanced(5, arity=3, span=3).tree_sequence + ts2 = msprime.sim_mutations(ts2, rate=1, random_seed=1) assert ts1.num_samples == ts2.num_samples assert ts1.num_nodes != ts2.num_nodes joint_ts = ts1.concatenate(ts2) assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 5 assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length assert joint_ts.num_samples == ts1.num_samples + assert joint_ts.num_sites == ts1.num_sites + ts2.num_sites + assert joint_ts.num_mutations == ts1.num_mutations + ts2.num_mutations ts3 = joint_ts.delete_intervals([[2, 5]]).rtrim() # Have to simplify here, to remove the redundant nodes + ts3.tables.assert_equals(ts1.tables, ignore_provenance=True) assert ts3.equals(ts1.simplify(), ignore_provenance=True) ts4 = joint_ts.delete_intervals([[0, 2]]).ltrim() assert ts4.equals(ts2.simplify(), ignore_provenance=True) @@ -7183,6 +7392,13 @@ def test_empty(self): assert ts.num_nodes == 0 assert ts.sequence_length == 40 + def test_check_populations(self): + ts = msprime.sim_ancestry(2) + ts1 = ts.concatenate(ts, ts, check_populations=True) + assert ts1.num_populations == 1 + ts2 = ts.concatenate(ts, ts, add_populations=True, check_populations=True) + assert ts2.num_populations == 3 + def test_samples_at_end(self): ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence ts2 = tskit.Tree.generate_balanced(5, arity=3, span=3).tree_sequence @@ -7200,22 +7416,58 @@ def test_internal_samples(self): nodes_flags[:] = tskit.NODE_IS_SAMPLE nodes_flags[-1] = 0 # Only root is not a sample tables.nodes.flags = nodes_flags - ts = tables.tree_sequence() + ts = msprime.sim_mutations(tables.tree_sequence(), rate=0.5, random_seed=1) + assert ts.num_mutations > 0 + assert ts.num_mutations > ts.num_sites joint_ts = ts.concatenate(ts) assert joint_ts.num_samples == ts.num_samples assert joint_ts.num_nodes == ts.num_nodes + 1 + assert joint_ts.num_mutations == ts.num_mutations * 2 + assert joint_ts.num_sites == ts.num_sites * 2 assert joint_ts.sequence_length == ts.sequence_length * 2 def test_some_shared_samples(self): - ts1 = tskit.Tree.generate_comb(4, span=2).tree_sequence - ts2 = tskit.Tree.generate_balanced(8, arity=3, span=3).tree_sequence - shared = np.full(ts2.num_nodes, tskit.NULL) - shared[0] = 1 - shared[1] = 0 - joint_ts = ts1.concatenate(ts2, node_mappings=[shared]) - assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length - assert joint_ts.num_samples == ts1.num_samples + ts2.num_samples - 2 - assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 2 + tables = tskit.Tree.generate_comb(5).tree_sequence.dump_tables() + tables.nodes[5] = tables.nodes[5].replace(flags=tskit.NODE_IS_SAMPLE) + ts1 = tables.tree_sequence() + tables = tskit.Tree.generate_balanced(5).tree_sequence.dump_tables() + tables.nodes[5] = tables.nodes[5].replace(flags=tskit.NODE_IS_SAMPLE) + ts2 = tables.tree_sequence() + assert ts1.num_samples == ts2.num_samples + joint_ts = ts1.concatenate(ts2) + assert joint_ts.num_samples == ts1.num_samples + assert joint_ts.num_edges == ts1.num_edges + ts2.num_edges + for tree in joint_ts.trees(): + assert tree.num_roots == 1 + + @pytest.mark.parametrize("simplify", [True, False]) + def test_wf_sim(self, simplify): + # Test that we can split & concat a wf_sim ts, which has internal samples + tables = wf.wf_sim( + 6, + 5, + seed=3, + deep_history=True, + initial_generation_samples=True, + num_loci=10, + ) + tables.sort() + tables.simplify() + ts = msprime.mutate(tables.tree_sequence(), rate=0.05, random_seed=234) + assert ts.num_trees > 2 + assert len(np.unique(ts.nodes_time[ts.samples()])) > 1 + ts1 = ts.keep_intervals([[0, 4.5]], simplify=False).trim() + ts2 = ts.keep_intervals([[4.5, ts.sequence_length]], simplify=False).trim() + if simplify: + ts1 = ts1.simplify(filter_nodes=False) + ts2, node_map = ts2.simplify(map_nodes=True) + node_mapping = np.zeros_like(node_map, shape=ts2.num_nodes) + kept = node_map != tskit.NULL + node_mapping[node_map[kept]] = np.arange(len(node_map))[kept] + else: + node_mapping = np.arange(ts.num_nodes) + ts_new = ts1.concatenate(ts2, node_mappings=[node_mapping]).simplify() + ts_new.tables.assert_equals(ts.tables, ignore_provenance=True) def test_provenance(self): ts = tskit.Tree.generate_comb(2).tree_sequence @@ -7233,9 +7485,12 @@ def test_unequal_samples(self): with pytest.raises(ValueError, match="must have the same number of samples"): ts1.concatenate(ts2) - @pytest.mark.skip( - reason="union bug: https://github.com/tskit-dev/tskit/issues/3168" - ) + def test_different_sample_numbers(self): + ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence + ts2 = tskit.Tree.generate_balanced(4, arity=3, span=3).tree_sequence + with pytest.raises(ValueError, match="must have the same number of samples"): + ts1.concatenate(ts2) + def test_duplicate_ts(self): ts1 = tskit.Tree.generate_comb(3, span=4).tree_sequence ts = ts1.keep_intervals([[0, 1]]).trim() # a quarter of the original diff --git a/python/tskit/tables.py b/python/tskit/tables.py index bc078164c0..003d41ab14 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -4161,8 +4161,11 @@ def union( portions of ``other`` to itself. To perform the node-wise union, the method relies on a ``node_mapping`` array, that maps nodes in ``other`` to its equivalent node in ``self`` or ``tskit.NULL`` if - the node is exclusive to ``other``. See :meth:`TreeSequence.union` for a more - detailed description. + the node is exclusive to ``other``. See :meth:`TreeSequence.union` + for a more detailed description. + + .. seealso:: + :meth:`.union` for a combining two tables edgewise :param TableCollection other: Another table collection. :param list node_mapping: An array of node IDs that relate nodes in @@ -4196,6 +4199,143 @@ def union( record=json.dumps(provenance.get_provenance_dict(parameters)) ) + def merge( + self, + other, + node_mapping, + *, + add_populations=None, + check_populations=None, + record_provenance=True, + ): + """ + Merge another table collection into this one, edgewise. Nodes in ``other`` + whose mapping is set to :data:`tskit.NULL` will be added as new nodes. + See :meth:`TreeSequence.merge` for a more detailed description. + + .. seealso:: + :meth:`.union` for a combining two tables nodewise + + :param TableCollection other: Another table collection. + :param list node_mapping: An array of node IDs that relate nodes in + ``other`` to nodes in ``self``: the k-th element of ``node_mapping`` + should be the index of the equivalent node in ``self``, or + :data:`tskit.NULL` if the node is not present in ``self`` (in which + case it will be added to ``self``). + :param bool record_provenance: If ``True`` (default), record details of this + call to ``merge`` in the returned tree sequence's provenance + information (Default: ``True``). + :param bool add_populations: If ``True`` (default), populations referred + to from nodes new to ``self`` will be added as a new populations. + If ``False`` new populations will not be created, and populations + with the same ID in ``self`` and the other tree sequences will be reused. + :param bool check_populations: If ``True`` (default), check that the + populations referred to from nodes in the other tree sequences + are identical to those in ``self``. + :raises ValueError: If the node mapping is not of the correct length, + or if the sequence lengths of the two tree sequences are not equal, + or if the populations referred to from nodes in the other tree sequence + do not match those in ``self``. + """ + if add_populations is None: + add_populations = True + if check_populations is None: + check_populations = True + + node_mapping = util.safe_np_int_cast(node_mapping, np.int32) + node_map = node_mapping.copy() + if node_map.shape != (other.nodes.num_rows,): + raise ValueError( + "node_mapping must be of length equal to the number of nodes in other" + ) + if self.sequence_length != other.sequence_length: + raise ValueError( + "Tree sequences must have same sequence lengths: use trim or shift to" + " adjust sequence lengths as necessary" + ) + if check_populations: + nodes_pop = other.nodes.population + if add_populations: + # Only need to check those populations used by nodes in the node_mapping + nodes_pop = nodes_pop[node_map != tskit.NULL] + for i in np.unique(nodes_pop[nodes_pop != tskit.NULL]): + if ( + i >= self.populations.num_rows + or self.populations[i] != other.populations[i] + ): + raise ValueError(f"Non-matching populations:\n self: {self.populations[i]}" + f"\n self: {other.populations[i]}") + individual_map = {} + population_map = {} + for new_node in np.where(node_map == tskit.NULL)[0]: + params = {} + node = other.nodes[new_node] + if node.individual != tskit.NULL: + if node.individual not in individual_map: + individual_map[node.individual] = self.individuals.append( + other.individuals[node.individual] + ) + params["individual"] = individual_map[node.individual] + if node.population != tskit.NULL: + if add_populations: + if node.population not in population_map: + population_map[node.population] = self.populations.append( + other.populations[node.population] + ) + params["population"] = population_map[node.population] + else: + if node.population >= self.populations.num_rows: + raise ValueError( + "One of the tree sequences to concatenate has a " + "population not present in the existing tree sequence" + ) + node_map[new_node] = self.nodes.append(node.replace(**params)) + + for e in other.edges: + self.edges.append( + e.replace(child=node_map[e.child], parent=node_map[e.parent]) + ) + site_map = {} + site_positions = {p: i for i, p in enumerate(self.sites.position)} + for site_id, site in enumerate(other.sites): + if site.position in site_positions: + site_map[site_id] = site_positions[site.position] + else: + site_map[site_id] = self.sites.append(site) + for mut in other.mutations: + self.mutations.append( + mut.replace( + node=node_map[mut.node], + site=site_map[mut.site], + parent=tskit.NULL, + ) + ) + for mig in other.migrations: + self.migrations.append( + mig.replace( + node=node_map[mig.node], + source=population_map.get(mig.source, mig.source), + dest=population_map.get(mig.dest, mig.dest), + ) + ) + self.sort() + self.build_index() + self.compute_mutation_parents() + + if record_provenance: + other_records = [prov.record for prov in other.provenances] + other_timestamps = [prov.timestamp for prov in other.provenances] + parameters = { + "command": "merge", + "other": {"timestamp": other_timestamps, "record": other_records}, + "node_mapping": node_mapping.tolist(), + "add_populations": add_populations, + "check_populations": check_populations, + } + self.provenances.add_row( + record=json.dumps(provenance.get_provenance_dict(parameters)) + ) + def ibd_segments( self, *, diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 52157d1b5a..4441d3a7cc 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7098,14 +7098,19 @@ def shift(self, value, sequence_length=None, record_provenance=True): return tables.tree_sequence() def concatenate( - self, *args, node_mappings=None, record_provenance=True, add_populations=None + self, + *args, + node_mappings=None, + record_provenance=True, + add_populations=None, + check_populations=None, ): r""" - Concatenate a set of tree sequences to the right of this one, by repeatedly - calling {meth}`union` with an (optional) - node mapping for each of the ``others``. If any node mapping is ``None`` - only map the sample nodes between the input tree sequence and this one, - based on the numerical order of sample node IDs. + Concatenate a set of tree sequences to the right of this one by shifting + them and then calling :meth:`~TreeSequence.merge`, by default merging on the + sample nodes of each tree sequence, in order of their node IDs. See the + documentation for :meth:`~TreeSequence.merge` for further details (with the + difference that ``concatenate`` defaults to ``add_populations=False``). .. note:: To add gaps between the concatenated tables, use :meth:`shift` or @@ -7113,22 +7118,30 @@ def concatenate( :param TreeSequence \*args: A list of other tree sequences to append to the right of this one. - :param Union[list, None] node_mappings: An list of node mappings for each + :param Union[list, None] node_mappings: A list of node mappings for each input tree sequence in ``args``. Each should either be an array of integers of the same length as the number of nodes in the equivalent - input tree sequence (see :meth:`union` for details), or ``None``. + input tree sequence, or ``None``. If ``None``, only sample nodes are mapped to each other. Default: ``None``, treated as ``[None] * len(args)``. - :param bool record_provenance: If True (default), record details of this + :param bool record_provenance: If ``True``, record details of this call to ``concatenate`` in the returned tree sequence's provenance - information (Default: True). - :param bool add_populations: If True (default), nodes new to ``self`` will - be assigned new population IDs (see :meth:`union`) + information (Default: ``True``). + :param bool add_populations: If ``True``, populations referred to from nodes + new to ``self`` will be added as new populations. If ``False`` new + populations will not be created, and populations with the same ID in + ``self`` and the other tree sequences will be reused. Default: ``None``, + treated as ``False``. + :param bool check_populations: If ``True``, check that the populations + referred to from nodes in the other tree sequences are identical to those + in ``self``. Default: ``None``, treated as ``True``. """ if node_mappings is None: node_mappings = [None] * len(args) if add_populations is None: - add_populations = True + add_populations = False + if check_populations is None: + check_populations = True if len(node_mappings) != len(args): raise ValueError( "You must provide the same number of node_mappings as args" @@ -7138,26 +7151,26 @@ def concatenate( tables = self.dump_tables() tables.drop_index() - for node_mapping, other in zip(node_mappings, args): - if node_mapping is None: + for node_map, other in zip(node_mappings, args): + if node_map is None: other_samples = other.samples() if len(other_samples) != len(samples): raise ValueError( "each `other` must have the same number of samples as `self`" ) - node_mapping = np.full(other.num_nodes, tskit.NULL, dtype=np.int32) - node_mapping[other_samples] = samples + node_map = np.full(other.num_nodes, tskit.NULL, dtype=np.int32) + node_map[other_samples] = samples other_tables = other.dump_tables() other_tables.shift(tables.sequence_length, record_provenance=False) tables.sequence_length = other_tables.sequence_length - # NB: should we use a different default for add_populations? - tables.union( + tables.merge( other_tables, - node_mapping=node_mapping, - check_shared_equality=False, # Else checks fail with internal samples - record_provenance=False, + node_mapping=node_map, add_populations=add_populations, + check_populations=check_populations, + record_provenance=False, ) + if record_provenance: parameters = { "command": "concatenate", @@ -7439,6 +7452,62 @@ def union( ) return tables.tree_sequence() + def merge( + self, + other, + node_mapping, + *, + add_populations=None, + check_populations=None, + record_provenance=True, + ): + """ + Merge another tree sequence into this one, edgewise. Nodes in ``other`` + whose mapping is set to :data:`tskit.NULL` will be added as new nodes. + All other nodes will remain unaltered (hence will be associated with the + same populations and individuals as in ``self``). Items that will be ported + from ``other`` into ``self`` (and given new IDs) are: + + 1. All edges in ``other`` + 2. All migrations in ``other`` + 3. All mutations in ``other`` + 4. Sites whose positions are new to ``self`` + 5. Individuals whose nodes are new to ``self``. + 6. If ``add_populations=True``, populations whose nodes are new to ``self`` + + .. seealso:: + :meth:`.union` for a combining two tree-sequences nodewise, rather than + edgewise. + + :param TableCollection other: Another table collection. + :param list node_mapping: An array of node IDs that relate nodes in + ``other`` to nodes in ``self``: the k-th element of ``node_mapping`` + should be the index of the equivalent node in ``self``, or + :data:`tskit.NULL` if the node is not present in ``self`` (in which + case it will be added to ``self``). + :param bool record_provenance: If ``True`` (default), record details of this + call to ``merge`` in the returned tree sequence's provenance + information (Default: ``True``). + :param bool add_populations: If ``True``, populations referenced by nodes new to + ``self`` will be added as a new populations. If ``False`` new populations + will not be created, and populations with the same ID in ``self`` and the + other tree sequences will be reused. Default: ``None``, treated as ``False``. + :param bool check_populations: If ``True``, check that the populations referred + to from nodes in the other tree sequences are identical to those in ``self``. + Default: ``None``, treated as ``True``. + :return: A new tree sequence containing the merged tables. + :rtype: tskit.TreeSequence + """ + tables = self.dump_tables() + tables.merge( + other.tables, + node_mapping=node_mapping, + add_populations=add_populations, + check_populations=check_populations, + record_provenance=record_provenance, + ) + return tables.tree_sequence() + def draw_svg( self, path=None,