Skip to content

Commit

Permalink
Merge pull request #391 from jeromekelleher/samples-numpy
Browse files Browse the repository at this point in the history
Change ts.samples to return numpy array.
  • Loading branch information
jeromekelleher authored Feb 2, 2018
2 parents 937e6a7 + 088f7be commit dc70090
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 120 deletions.
20 changes: 11 additions & 9 deletions msprime/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -1884,7 +1884,8 @@ def trees(
flags |= _msprime.SAMPLE_LISTS
kwargs = {"flags": flags}
if tracked_samples is not None:
kwargs["tracked_samples"] = tracked_samples
# TODO remove this when we allow numpy arrays in the low-level API.
kwargs["tracked_samples"] = list(tracked_samples)
ll_sparse_tree = _msprime.SparseTree(self._ll_tree_sequence, **kwargs)
iterator = _msprime.SparseTreeIterator(ll_sparse_tree)
sparse_tree = SparseTree(ll_sparse_tree, self)
Expand Down Expand Up @@ -2003,9 +2004,7 @@ def pairwise_diversity(self, samples=None):
"""
if samples is None:
samples = self.samples()
else:
samples = list(samples)
return self._ll_tree_sequence.get_pairwise_diversity(samples)
return self._ll_tree_sequence.get_pairwise_diversity(list(samples))

def node(self, id_):
"""
Expand Down Expand Up @@ -2054,25 +2053,28 @@ def get_samples(self, population_id=None):

def samples(self, population=None, population_id=None):
"""
Returns the samples matching the specified population ID.
Returns an array of the sample node IDs in this tree sequence. If the
``population`` parameter is specified, only return sample IDs from this
population.
:param int population: The population of interest. If None,
return all samples.
:param int population_id: Deprecated alias for ``population``.
:return: The ID of the population we wish to find samples from.
If None, return samples from all populations.
:rtype: list
:return: A numpy array of the node IDs for the samples of interest.
:rtype: numpy.ndarray (dtype=np.int32)
"""
if population is not None and population_id is not None:
raise ValueError(
"population_id and population are aliases. Cannot specify both")
if population_id is not None:
population = population_id
# TODO the low-level tree sequence should perform this operation natively
# and return a numpy array.
samples = self._ll_tree_sequence.get_samples()
if population is not None:
samples = [
u for u in samples if self.get_population(u) == population]
return samples
return np.array(samples, dtype=np.int32)

def write_vcf(self, output, ploidy=1, contig_id="1"):
"""
Expand Down
153 changes: 83 additions & 70 deletions tests/test_demography.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import tempfile
import unittest

import numpy as np

import msprime


Expand Down Expand Up @@ -481,16 +483,16 @@ def test_two_pops_single_sample(self):
demographic_events=demographic_events,
random_seed=1)
tree = next(ts.trees())
self.assertEqual(tree.get_root(), 2)
self.assertGreater(tree.get_time(2), t)
self.assertEqual(tree.get_population(0), 0)
self.assertEqual(tree.get_population(1), 1)
self.assertEqual(tree.get_population(2), 2)
self.assertEqual(ts.get_population(0), 0)
self.assertEqual(ts.get_population(1), 1)
self.assertEqual(ts.get_samples(), [0, 1])
self.assertEqual(ts.get_samples(0), [0])
self.assertEqual(ts.get_samples(1), [1])
self.assertEqual(tree.root, 2)
self.assertGreater(tree.time(2), t)
self.assertEqual(tree.population(0), 0)
self.assertEqual(tree.population(1), 1)
self.assertEqual(tree.population(2), 2)
self.assertEqual(ts.node(0).population, 0)
self.assertEqual(ts.node(1).population, 1)
self.assertEqual(list(ts.samples()), [0, 1])
self.assertEqual(list(ts.samples(0)), [0])
self.assertEqual(list(ts.samples(1)), [1])

def test_two_pops_multiple_samples(self):
# Made absolutely sure that all samples have coalesced within
Expand All @@ -511,17 +513,25 @@ def test_two_pops_multiple_samples(self):
demographic_events=demographic_events,
random_seed=1)
tree = next(ts.trees())
self.assertEqual(tree.get_root(), 2 * n - 2)
self.assertGreater(tree.get_time(tree.get_root()), t)
self.assertEqual(tree.root, 2 * n - 2)
self.assertGreater(tree.time(tree.root), t)
for j in range(n // 2):
self.assertEqual(tree.get_population(j), 0)
self.assertEqual(tree.get_population(n // 2 + j), 1)
self.assertEqual(tree.population(j), 0)
self.assertEqual(tree.population(n // 2 + j), 1)
self.assertEqual(ts.get_population(j), 0)
self.assertEqual(ts.get_population(n // 2 + j), 1)
self.assertEqual(tree.get_population(tree.get_root()), 2)
self.assertEqual(ts.get_samples(0), list(range(n // 2)))
self.assertEqual(ts.get_samples(1), list(range(n // 2, n)))
self.assertEqual(ts.get_samples(2), [])
self.assertEqual(tree.population(tree.root), 2)

self.assertTrue(np.array_equal(
ts.samples(0), np.arange(n // 2, dtype=np.int32)))
self.assertTrue(np.array_equal(
ts.samples(1), np.arange(n // 2, n, dtype=np.int32)))
self.assertTrue(np.array_equal(
ts.samples(2), np.array([], dtype=np.int32)))

# self.assertEqual(ts.samples(0), list(range(n // 2)))
# self.assertEqual(ts.samples(1), list(range(n // 2, n)))
# self.assertEqual(ts.samples(2), [])

def test_three_pops_migration(self):
n = 9
Expand All @@ -541,29 +551,32 @@ def test_three_pops_migration(self):
demographic_events=demographic_events,
random_seed=1)
tree = next(ts.trees())
self.assertEqual(tree.get_root(), 2 * n - 2)
self.assertGreater(tree.get_time(tree.get_root()), t)
self.assertEqual(tree.root, 2 * n - 2)
self.assertGreater(tree.time(tree.root), t)
for j in range(n // 3):
self.assertEqual(tree.get_population(j), 0)
self.assertEqual(tree.get_population(n // 3 + j), 1)
self.assertEqual(tree.get_population(2 * (n // 3) + j), 2)
self.assertEqual(tree.population(j), 0)
self.assertEqual(tree.population(n // 3 + j), 1)
self.assertEqual(tree.population(2 * (n // 3) + j), 2)
self.assertEqual(ts.get_population(j), 0)
self.assertEqual(ts.get_population(n // 3 + j), 1)
self.assertEqual(ts.get_population(2 * (n // 3) + j), 2)
# The MRCAs of 0, 1 and 3 must have occured in deme 0
self.assertEqual(tree.get_population(tree.get_mrca(0, n // 3)), 0)
self.assertEqual(tree.population(tree.get_mrca(0, n // 3)), 0)
self.assertEqual(
tree.get_population(tree.get_mrca(0, 2 * (n // 3))), 0)
tree.population(tree.get_mrca(0, 2 * (n // 3))), 0)
# The MRCAs of all the samples within each deme must have
# occured within that deme
for k in range(3):
deme_samples = range(k * (n // 3), (k + 1) * (n // 3))
for u, v in itertools.combinations(deme_samples, 2):
mrca_pop = tree.get_population(tree.get_mrca(u, v))
mrca_pop = tree.population(tree.get_mrca(u, v))
self.assertEqual(k, mrca_pop)
self.assertEqual(ts.get_samples(0), list(range(n // 3)))
self.assertEqual(ts.get_samples(1), list(range(n // 3, 2 * (n // 3))))
self.assertEqual(ts.get_samples(2), list(range(2 * (n // 3), n)))
self.assertTrue(np.array_equal(
ts.samples(0), np.arange(n // 3, dtype=np.int32)))
self.assertTrue(np.array_equal(
ts.samples(1), np.arange(n // 3, 2 * (n // 3), dtype=np.int32)))
self.assertTrue(np.array_equal(
ts.samples(2), np.arange(2 * (n // 3), n, dtype=np.int32)))

def test_four_pops_three_mass_migrations(self):
t1 = 1
Expand All @@ -588,25 +601,25 @@ def test_four_pops_three_mass_migrations(self):
tree = next(ts.trees())
# Check the leaves have the correct population.
for j in range(4):
self.assertEqual(tree.get_population(j), j)
self.assertEqual(tree.population(j), j)
self.assertEqual(ts.get_population(j), j)
self.assertEqual(ts.get_samples(j), [j])
self.assertEqual(ts.samples(j), [j])
# The MRCA of 0 and 1 should happen in 1 at time > t1, and < t2
u = tree.get_mrca(0, 1)
self.assertEqual(u, 4)
self.assertEqual(tree.get_population(u), 1)
g = tree.get_time(u) * 4
self.assertEqual(tree.population(u), 1)
g = tree.time(u) * 4
self.assertTrue(t1 < g < t2)
# The MRCA of 0, 1 and 2 should happen in 2 at time > t2 and < t3
u = tree.get_mrca(0, 2)
self.assertEqual(u, 5)
self.assertEqual(tree.get_population(u), 2)
self.assertTrue(t2 < tree.get_time(u) < t3)
self.assertEqual(tree.population(u), 2)
self.assertTrue(t2 < tree.time(u) < t3)
# The MRCA of 0, 1, 2 and 3 should happen in 3 at time > t3
u = tree.get_mrca(0, 3)
self.assertEqual(u, 6)
self.assertEqual(tree.get_population(u), 3)
self.assertGreater(tree.get_time(u), t3)
self.assertEqual(tree.population(u), 3)
self.assertGreater(tree.time(u), t3)

def test_empty_demes(self):
t1 = 1
Expand All @@ -630,19 +643,19 @@ def test_empty_demes(self):
random_seed=1)
tree = next(ts.trees())
# Check the leaves have the correct population.
self.assertEqual(tree.get_population(0), 0)
self.assertEqual(tree.get_population(1), 3)
self.assertEqual(ts.get_population(0), 0)
self.assertEqual(ts.get_population(1), 3)
self.assertEqual(ts.get_samples(0), [0])
self.assertEqual(ts.get_samples(1), [])
self.assertEqual(ts.get_samples(2), [])
self.assertEqual(ts.get_samples(3), [1])
self.assertEqual(tree.population(0), 0)
self.assertEqual(tree.population(1), 3)
self.assertEqual(ts.node(0).population, 0)
self.assertEqual(ts.node(1).population, 3)
self.assertEqual(list(ts.samples(0)), [0])
self.assertEqual(list(ts.samples(1)), [])
self.assertEqual(list(ts.samples(2)), [])
self.assertEqual(list(ts.samples(3)), [1])
# The MRCA of 0, 1 in 3 at time > t3
u = tree.get_mrca(0, 1)
self.assertEqual(u, 2)
self.assertEqual(tree.get_population(u), 3)
g = tree.get_time(u) * 4
self.assertEqual(tree.population(u), 3)
g = tree.time(u) * 4
self.assertGreater(g, t3)

def test_migration_rate_directionality(self):
Expand All @@ -661,16 +674,16 @@ def test_migration_rate_directionality(self):
demographic_events=demographic_events,
random_seed=1)
tree = next(ts.trees())
self.assertEqual(tree.get_root(), 2)
self.assertGreater(tree.get_time(2), t / 4)
self.assertEqual(tree.get_population(0), 0)
self.assertEqual(tree.get_population(1), 1)
self.assertEqual(tree.get_population(2), 2)
self.assertEqual(ts.get_population(0), 0)
self.assertEqual(ts.get_population(1), 1)
self.assertEqual(ts.get_samples(), [0, 1])
self.assertEqual(ts.get_samples(0), [0])
self.assertEqual(ts.get_samples(1), [1])
self.assertEqual(tree.root, 2)
self.assertGreater(tree.time(2), t / 4)
self.assertEqual(tree.population(0), 0)
self.assertEqual(tree.population(1), 1)
self.assertEqual(tree.population(2), 2)
self.assertEqual(ts.node(0).population, 0)
self.assertEqual(ts.node(1).population, 1)
self.assertEqual(list(ts.samples()), [0, 1])
self.assertEqual(list(ts.samples(0)), [0])
self.assertEqual(list(ts.samples(1)), [1])

def test_many_demes(self):
num_demes = 300
Expand All @@ -687,13 +700,13 @@ def test_many_demes(self):
demographic_events=demographic_events,
random_seed=1)
tree = next(ts.trees())
self.assertEqual(tree.get_root(), 2)
self.assertGreater(tree.get_time(2), t)
self.assertEqual(tree.get_population(0), 0)
self.assertEqual(tree.get_population(1), num_demes - 1)
self.assertEqual(tree.get_population(2), num_demes - 1)
self.assertEqual(ts.get_population(0), 0)
self.assertEqual(ts.get_population(1), num_demes - 1)
self.assertEqual(tree.root, 2)
self.assertGreater(tree.time(2), t)
self.assertEqual(tree.population(0), 0)
self.assertEqual(tree.population(1), num_demes - 1)
self.assertEqual(tree.population(2), num_demes - 1)
self.assertEqual(ts.node(0).population, 0)
self.assertEqual(ts.node(1).population, num_demes - 1)

def test_instantaneous_bottleneck_locations(self):
population_configurations = [
Expand All @@ -718,14 +731,14 @@ def test_instantaneous_bottleneck_locations(self):
demographic_events=demographic_events,
random_seed=1)
tree = next(ts.trees())
self.assertGreater(tree.get_time(tree.get_root()), t4)
self.assertEqual(tree.get_population(tree.get_root()), 0)
self.assertGreater(tree.time(tree.root), t4)
self.assertEqual(tree.population(tree.root), 0)
# The parent of all the samples from each deme should be in that deme.
for pop in range(3):
parents = [
tree.get_parent(u) for u in ts.samples(population=pop)]
for v in parents:
self.assertEqual(tree.get_population(v), pop)
self.assertEqual(tree.population(v), pop)


class TestMigrationRecords(unittest.TestCase):
Expand Down Expand Up @@ -887,7 +900,7 @@ def test_coalescence_after_growth_rate_change(self):
tree = next(ts.trees())
u = tree.get_mrca(0, 1)
self.assertEqual(u, 2)
self.assertAlmostEqual(g, tree.get_time(u), places=1)
self.assertAlmostEqual(g, tree.time(u), places=1)

def test_coalescence_after_size_change(self):
Ne = 20000
Expand All @@ -913,7 +926,7 @@ def test_coalescence_after_size_change(self):
tree = next(ts.trees())
u = tree.get_mrca(0, 1)
self.assertEqual(u, 2)
self.assertAlmostEqual(g, tree.get_time(u), places=1)
self.assertAlmostEqual(g, tree.time(u), places=1)

def test_instantaneous_bottleneck(self):
Ne = 0.5
Expand All @@ -937,7 +950,7 @@ def test_instantaneous_bottleneck(self):
random_seed=1, num_replicates=10)
for ts in reps:
tree = next(ts.trees())
self.assertAlmostEqual(t, tree.get_time(tree.get_root()), places=5)
self.assertAlmostEqual(t, tree.time(tree.root), places=5)


class TestLowLevelConversions(unittest.TestCase):
Expand Down
Loading

0 comments on commit dc70090

Please sign in to comment.