Skip to content

Commit 9cb050d

Browse files
committed
test_cli, date changes
1 parent acc8dec commit 9cb050d

File tree

9 files changed

+151
-70
lines changed

9 files changed

+151
-70
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2019 Anthony Wilder Wohns
3+
Copyright (c) 2020 University of Oxford
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

docs/index.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
.. currentmodule:: tsdate
22
.. tsdate documentation master file, created by Wilder Wohns
33
4-
Welcome to tsdate's documentation!
4+
Welcome to tsdate's documentation
55
==================================
66

7-
``Tsdate`` is a scalable method for estimating times in the past for ancestral
8-
nodes in a :ref:`tree sequence <sec_python_api>`, in other words, *dating* a
9-
tree sequence. It assumes a prior distribution of node times given by neutral
10-
coalescent theory, and updates the times on the basis of the number of mutations
11-
along each edge or branch of the tree sequence (i.e. using the "molecular clock").
7+
Contents:
128

139
.. toctree::
1410
:maxdepth: 2
1511
:caption: Contents:
1612

13+
introduction
14+
installation
15+
tutorial
16+
1717

1818

1919
Indices and tables

evaluation/evaluate_accuracy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,16 +337,16 @@ def plot_tsdate_accuracy(all_results, parameter, parameter_arr, prior_distr, inf
337337

338338
if parameter == 'Mut Rate':
339339
plt.suptitle("Evaluating " + parameter + ": " + inferred + " " + node_mut +
340-
" vs. True " + node_mut + ". \n Inside-Outside Algorithm and Maximization. \n" + prior_distr + " Prior, n=250, Length=1Mb, Rec Rate=1e-8", y=0.99, size=21)
340+
" vs. True " + node_mut + ". \n Inside-Outside Algorithm and Maximization. \n" + prior_distr + " Prior, n=100, Length=1Mb, Rec Rate=1e-8", y=0.99, size=21)
341341
elif parameter == 'Sample Size':
342342
plt.suptitle("Evaluating " + parameter + ": " + inferred + " " + node_mut +
343343
" vs. True " + node_mut + ". \n Inside-Outside Algorithm and Maximization. \n" + prior_distr + " Prior, Length=1Mb, Mut Rate=1e-8, Rec Rate=1e-8", y=0.99, size=21)
344344
elif parameter == 'Length':
345345
plt.suptitle("Evaluating " + parameter + ": " + inferred + " " + node_mut +
346-
" vs. True " + node_mut + ". \n Inside-Outside Algorithm and Maximization. \n" + prior_distr + " Prior, n=250, Mut Rate=1e-8, Rec Rate=1e-8", y=0.99, size=21)
346+
" vs. True " + node_mut + ". \n Inside-Outside Algorithm and Maximization. \n" + prior_distr + " Prior, n=100, Mut Rate=1e-8, Rec Rate=1e-8", y=0.99, size=21)
347347
elif parameter == 'Timepoints':
348348
plt.suptitle("Evaluating " + parameter + ": " + inferred + " " + node_mut +
349-
" vs. True " + node_mut + ". \n Inside-Outside Algorithm and Maximization. \n" + prior_distr + " Prior, n=250, length=1Mb, Mut Rate=1e-8, Rec Rate=1e-8", y=0.99, size=21)
349+
" vs. True " + node_mut + ". \n Inside-Outside Algorithm and Maximization. \n" + prior_distr + " Prior, n=100, length=1Mb, Mut Rate=1e-8, Rec Rate=1e-8", y=0.99, size=21)
350350
# plt.tight_layout()
351351
plt.savefig("evaluation/" + parameter + "_" + inferred + "_" + node_mut + "_" + prior_distr +
352352
"_accuracy", dpi=300, bbox_inches='tight')

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"numpy>=1.17.0",
2222
"tskit>=0.2.3",
2323
"scipy>1.2.3",
24+
"numba>=0.46",
2425
"tqdm"
2526
],
2627
project_urls={

tests/test_cli.py

Lines changed: 111 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# MIT License
22
#
3-
# Copyright (c) 2019 Anthony Wilder Wohns
3+
# Copyright (c) 2020 University of Oxford
44
#
55
# Permission is hereby granted, free of charge, to any person obtaining a copy
66
# of this software and associated documentation files (the "Software"), to deal
@@ -24,15 +24,16 @@
2424
"""
2525
import io
2626
import sys
27-
import tempfile # NOQA - not currently used
28-
import pathlib # NOQA - not currently used
27+
import tempfile
28+
import pathlib
2929
import unittest
30+
from unittest import mock
3031

31-
import tskit # NOQA - not currently used
32-
import msprime # NOQA - not currently used
33-
import numpy as np # NOQA - not currently used
32+
import tskit
33+
import msprime
34+
import numpy as np
3435

35-
import tsdate # NOQA - not currently used
36+
import tsdate
3637
import tsdate.cli as cli
3738

3839

@@ -76,10 +77,10 @@ class TestTsdateArgParser(unittest.TestCase):
7677

7778
def test_default_values(self):
7879
parser = cli.tsdate_cli_parser()
79-
args = parser.parse_args([self.infile, self.output])
80+
args = parser.parse_args([self.infile, self.output, "1"])
8081
self.assertEqual(args.ts, self.infile)
8182
self.assertEqual(args.output, self.output)
82-
self.assertEqual(args.Ne, 10000)
83+
self.assertEqual(args.Ne, 1)
8384
self.assertEqual(args.mutation_rate, None)
8485
self.assertEqual(args.recombination_rate, None)
8586
self.assertEqual(args.epsilon, 1e-6)
@@ -88,66 +89,147 @@ def test_default_values(self):
8889
self.assertEqual(args.method, 'inside_outside')
8990
self.assertFalse(args.progress)
9091

91-
def test_Ne(self):
92-
parser = cli.tsdate_cli_parser()
93-
args = parser.parse_args([self.infile, self.output, "-n", "10000"])
94-
self.assertEqual(args.Ne, 10000)
95-
args = parser.parse_args([self.infile, self.output, "--Ne", "10000"])
96-
self.assertEqual(args.Ne, 10000)
97-
9892
def test_mutation_rate(self):
9993
parser = cli.tsdate_cli_parser()
100-
args = parser.parse_args([self.infile, self.output, "-m", "1e10"])
94+
args = parser.parse_args([self.infile, self.output, "10000", "-m", "1e10"])
10195
self.assertEqual(args.mutation_rate, 1e10)
102-
args = parser.parse_args([self.infile, self.output, "--mutation-rate", "1e10"])
96+
args = parser.parse_args([self.infile, self.output, "10000", "--mutation-rate", "1e10"])
10397
self.assertEqual(args.mutation_rate, 1e10)
10498

10599
def test_recombination_rate(self):
106100
parser = cli.tsdate_cli_parser()
107-
args = parser.parse_args([self.infile, self.output, "-r", "1e-100"])
101+
args = parser.parse_args([self.infile, self.output, "10000", "-r", "1e-100"])
108102
self.assertEqual(args.recombination_rate, 1e-100)
109103
args = parser.parse_args(
110-
[self.infile, self.output, "--recombination-rate", "1e-100"])
104+
[self.infile, self.output, "10000", "--recombination-rate", "1e-100"])
111105
self.assertEqual(args.recombination_rate, 1e-100)
112106

113107
def test_epsilon(self):
114108
parser = cli.tsdate_cli_parser()
115-
args = parser.parse_args([self.infile, self.output, "-e", "123"])
109+
args = parser.parse_args([self.infile, self.output, "10000", "-e", "123"])
116110
self.assertEqual(args.epsilon, 123)
117-
args = parser.parse_args([self.infile, self.output, "--epsilon", "321"])
111+
args = parser.parse_args([self.infile, self.output, "10000", "--epsilon", "321"])
118112
self.assertEqual(args.epsilon, 321)
119113

120114
def test_num_threads(self):
121115
parser = cli.tsdate_cli_parser()
122-
args = parser.parse_args([self.infile, self.output, "--num-threads", "1"])
116+
args = parser.parse_args([self.infile, self.output, "10000", "--num-threads", "1"])
123117
self.assertEqual(args.num_threads, 1)
124-
args = parser.parse_args([self.infile, self.output, "--num-threads", "2"])
118+
args = parser.parse_args([self.infile, self.output, "10000", "--num-threads", "2"])
125119
self.assertEqual(args.num_threads, 2)
126120

127121
def test_probability_space(self):
128122
parser = cli.tsdate_cli_parser()
129-
args = parser.parse_args([self.infile, self.output, "--probability-space",
123+
args = parser.parse_args([self.infile, self.output, "10000", "--probability-space",
130124
"linear"])
131125
self.assertEqual(args.probability_space, "linear")
132-
args = parser.parse_args([self.infile, self.output, "--probability-space",
126+
args = parser.parse_args([self.infile, self.output, "10000", "--probability-space",
133127
"logarithmic"])
134128
self.assertEqual(args.probability_space, "logarithmic")
135129

136130
def test_method(self):
137131
parser = cli.tsdate_cli_parser()
138-
args = parser.parse_args([self.infile, self.output, "--method",
132+
args = parser.parse_args([self.infile, self.output, "10000", "--method",
139133
"inside_outside"])
140134
self.assertEqual(args.method, "inside_outside")
141-
args = parser.parse_args([self.infile, self.output, "--method",
135+
args = parser.parse_args([self.infile, self.output, "10000", "--method",
142136
"maximization"])
143137
self.assertEqual(args.method, "maximization")
144138

145139
def test_progress(self):
146140
parser = cli.tsdate_cli_parser()
147-
args = parser.parse_args([self.infile, self.output, "--progress"])
141+
args = parser.parse_args([self.infile, self.output, "10000", "--progress"])
148142
self.assertTrue(args.progress)
149143

150144

145+
class TestEndToEnd(unittest.TestCase):
146+
"""
147+
Class to test input to CLI outputs dated tree sequences.
148+
"""
149+
def ts_equal_except_times(self, ts1, ts2):
150+
for (t1_name, t1), (t2_name, t2) in zip(ts1.tables, ts2.tables):
151+
if isinstance(t1, tskit.ProvenanceTable):
152+
# TO DO - should check that the provenance has had the "tsdate" method
153+
# added
154+
pass
155+
elif isinstance(t1, tskit.NodeTable):
156+
for column_name in t1.column_names:
157+
if column_name != 'time':
158+
col_t1 = getattr(t1, column_name)
159+
col_t2 = getattr(t2, column_name)
160+
self.assertTrue(np.array_equal(col_t1, col_t2))
161+
elif isinstance(t1, tskit.EdgeTable):
162+
# Edges may have been re-ordered, since sortedness requirements specify
163+
# they are sorted by parent time, and the relative order of
164+
# (unconnected) parent nodes might have changed due to time inference
165+
self.assertEquals(set(t1), set(t2))
166+
else:
167+
self.assertEquals(t1, t2)
168+
# The dated and undated tree sequences should not have the same node times
169+
self.assertTrue(not np.array_equal(ts1.tables.nodes.time, ts2.tables.nodes.time))
170+
171+
def verify(self, input_ts, cmd):
172+
with tempfile.TemporaryDirectory() as tmpdir:
173+
input_filename = pathlib.Path(tmpdir) / "input.trees"
174+
input_ts.dump(input_filename)
175+
output_filename = pathlib.Path(tmpdir) / "output.trees"
176+
full_cmd = str(input_filename) + f" {output_filename} " + cmd
177+
with mock.patch("tsdate.cli.setup_logging"):
178+
stdout, stderr = capture_output(cli.tsdate_main, full_cmd.split())
179+
self.assertEqual(len(stderr), 0)
180+
self.assertEqual(len(stdout), 0)
181+
output_ts = tskit.load(output_filename)
182+
self.assertEqual(input_ts.num_samples, output_ts.num_samples)
183+
self.ts_equal_except_times(input_ts, output_ts)
184+
# provenance = json.loads(ts.provenance(0).record)
185+
186+
def test_ts(self):
187+
input_ts = msprime.simulate(10, random_seed=1)
188+
cmd = "1"
189+
self.verify(input_ts, cmd)
190+
191+
def test_mutation_rate(self):
192+
input_ts = msprime.simulate(10, random_seed=1)
193+
cmd = "1 --mutation-rate 1e-8"
194+
self.verify(input_ts, cmd)
195+
196+
def test_recombination_rate(self):
197+
input_ts = msprime.simulate(10, random_seed=1)
198+
cmd = "1 --recombination-rate 1e-8"
199+
self.assertRaises(NotImplementedError, self.verify, input_ts, cmd)
200+
201+
def test_epsilon(self):
202+
input_ts = msprime.simulate(10, random_seed=1)
203+
cmd = "1 --epsilon 1e-3"
204+
self.verify(input_ts, cmd)
205+
206+
def test_num_threads(self):
207+
input_ts = msprime.simulate(10, random_seed=1)
208+
cmd = "1 --num-threads 2"
209+
self.verify(input_ts, cmd)
210+
211+
def test_probability_space(self):
212+
input_ts = msprime.simulate(10, random_seed=1)
213+
cmd = "1 --probability-space linear"
214+
self.verify(input_ts, cmd)
215+
cmd = "1 --probability-space logarithmic"
216+
self.verify(input_ts, cmd)
217+
218+
def test_probability_space(self):
219+
input_ts = msprime.simulate(10, random_seed=1)
220+
cmd = "1 --probability-space linear"
221+
self.verify(input_ts, cmd)
222+
cmd = "1 --probability-space logarithmic"
223+
self.verify(input_ts, cmd)
224+
225+
def test_method(self):
226+
input_ts = msprime.simulate(10, random_seed=1)
227+
cmd = "1 --method inside_outside"
228+
self.verify(input_ts, cmd)
229+
cmd = "1 --method maximization"
230+
self.assertRaises(ValueError, self.verify, input_ts, cmd)
231+
232+
151233
class TestCli(unittest.TestCase):
152234
"""
153235
Superclass of tests that run the CLI.

tsdate/cli.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@
2525
import argparse
2626
import logging
2727
import sys
28-
# sys.path.insert(1, '../tsdate')
29-
import tskit
3028

29+
import tskit
3130
import tsdate
3231

3332
logger = logging.getLogger(__name__)
@@ -61,7 +60,7 @@ def tsdate_cli_parser():
6160
help="Tree sequence from which we estimate age")
6261
parser.add_argument('output',
6362
help="path and name of output file")
64-
parser.add_argument('-n', '--Ne', type=float, default=10000,
63+
parser.add_argument('Ne', type=float,
6564
help="effective population size")
6665
parser.add_argument('-m', '--mutation-rate', type=float, default=None,
6766
help="mutation rate")
@@ -101,11 +100,11 @@ def run_date(args):
101100

102101

103102
def main(args):
104-
# Load tree sequence
105103
run_date(args)
106104

107105

108106
def tsdate_main(arg_list=None):
109107
parser = tsdate_cli_parser()
110108
args = parser.parse_args(arg_list)
109+
setup_logging(args)
111110
main(args)

0 commit comments

Comments
 (0)