1
1
# MIT License
2
2
#
3
- # Copyright (c) 2019 Anthony Wilder Wohns
3
+ # Copyright (c) 2020 University of Oxford
4
4
#
5
5
# Permission is hereby granted, free of charge, to any person obtaining a copy
6
6
# of this software and associated documentation files (the "Software"), to deal
24
24
"""
25
25
import io
26
26
import sys
27
- import tempfile # NOQA - not currently used
28
- import pathlib # NOQA - not currently used
27
+ import tempfile
28
+ import pathlib
29
29
import unittest
30
+ from unittest import mock
30
31
31
- import tskit # NOQA - not currently used
32
- import msprime # NOQA - not currently used
33
- import numpy as np # NOQA - not currently used
32
+ import tskit
33
+ import msprime
34
+ import numpy as np
34
35
35
- import tsdate # NOQA - not currently used
36
+ import tsdate
36
37
import tsdate .cli as cli
37
38
38
39
@@ -76,10 +77,10 @@ class TestTsdateArgParser(unittest.TestCase):
76
77
77
78
def test_default_values (self ):
78
79
parser = cli .tsdate_cli_parser ()
79
- args = parser .parse_args ([self .infile , self .output ])
80
+ args = parser .parse_args ([self .infile , self .output , "1" ])
80
81
self .assertEqual (args .ts , self .infile )
81
82
self .assertEqual (args .output , self .output )
82
- self .assertEqual (args .Ne , 10000 )
83
+ self .assertEqual (args .Ne , 1 )
83
84
self .assertEqual (args .mutation_rate , None )
84
85
self .assertEqual (args .recombination_rate , None )
85
86
self .assertEqual (args .epsilon , 1e-6 )
@@ -88,66 +89,147 @@ def test_default_values(self):
88
89
self .assertEqual (args .method , 'inside_outside' )
89
90
self .assertFalse (args .progress )
90
91
91
- def test_Ne (self ):
92
- parser = cli .tsdate_cli_parser ()
93
- args = parser .parse_args ([self .infile , self .output , "-n" , "10000" ])
94
- self .assertEqual (args .Ne , 10000 )
95
- args = parser .parse_args ([self .infile , self .output , "--Ne" , "10000" ])
96
- self .assertEqual (args .Ne , 10000 )
97
-
98
92
def test_mutation_rate (self ):
99
93
parser = cli .tsdate_cli_parser ()
100
- args = parser .parse_args ([self .infile , self .output , "-m" , "1e10" ])
94
+ args = parser .parse_args ([self .infile , self .output , "10000" , " -m" , "1e10" ])
101
95
self .assertEqual (args .mutation_rate , 1e10 )
102
- args = parser .parse_args ([self .infile , self .output , "--mutation-rate" , "1e10" ])
96
+ args = parser .parse_args ([self .infile , self .output , "10000" , " --mutation-rate" , "1e10" ])
103
97
self .assertEqual (args .mutation_rate , 1e10 )
104
98
105
99
def test_recombination_rate (self ):
106
100
parser = cli .tsdate_cli_parser ()
107
- args = parser .parse_args ([self .infile , self .output , "-r" , "1e-100" ])
101
+ args = parser .parse_args ([self .infile , self .output , "10000" , " -r" , "1e-100" ])
108
102
self .assertEqual (args .recombination_rate , 1e-100 )
109
103
args = parser .parse_args (
110
- [self .infile , self .output , "--recombination-rate" , "1e-100" ])
104
+ [self .infile , self .output , "10000" , " --recombination-rate" , "1e-100" ])
111
105
self .assertEqual (args .recombination_rate , 1e-100 )
112
106
113
107
def test_epsilon (self ):
114
108
parser = cli .tsdate_cli_parser ()
115
- args = parser .parse_args ([self .infile , self .output , "-e" , "123" ])
109
+ args = parser .parse_args ([self .infile , self .output , "10000" , " -e" , "123" ])
116
110
self .assertEqual (args .epsilon , 123 )
117
- args = parser .parse_args ([self .infile , self .output , "--epsilon" , "321" ])
111
+ args = parser .parse_args ([self .infile , self .output , "10000" , " --epsilon" , "321" ])
118
112
self .assertEqual (args .epsilon , 321 )
119
113
120
114
def test_num_threads (self ):
121
115
parser = cli .tsdate_cli_parser ()
122
- args = parser .parse_args ([self .infile , self .output , "--num-threads" , "1" ])
116
+ args = parser .parse_args ([self .infile , self .output , "10000" , " --num-threads" , "1" ])
123
117
self .assertEqual (args .num_threads , 1 )
124
- args = parser .parse_args ([self .infile , self .output , "--num-threads" , "2" ])
118
+ args = parser .parse_args ([self .infile , self .output , "10000" , " --num-threads" , "2" ])
125
119
self .assertEqual (args .num_threads , 2 )
126
120
127
121
def test_probability_space (self ):
128
122
parser = cli .tsdate_cli_parser ()
129
- args = parser .parse_args ([self .infile , self .output , "--probability-space" ,
123
+ args = parser .parse_args ([self .infile , self .output , "10000" , " --probability-space" ,
130
124
"linear" ])
131
125
self .assertEqual (args .probability_space , "linear" )
132
- args = parser .parse_args ([self .infile , self .output , "--probability-space" ,
126
+ args = parser .parse_args ([self .infile , self .output , "10000" , " --probability-space" ,
133
127
"logarithmic" ])
134
128
self .assertEqual (args .probability_space , "logarithmic" )
135
129
136
130
def test_method (self ):
137
131
parser = cli .tsdate_cli_parser ()
138
- args = parser .parse_args ([self .infile , self .output , "--method" ,
132
+ args = parser .parse_args ([self .infile , self .output , "10000" , " --method" ,
139
133
"inside_outside" ])
140
134
self .assertEqual (args .method , "inside_outside" )
141
- args = parser .parse_args ([self .infile , self .output , "--method" ,
135
+ args = parser .parse_args ([self .infile , self .output , "10000" , " --method" ,
142
136
"maximization" ])
143
137
self .assertEqual (args .method , "maximization" )
144
138
145
139
def test_progress (self ):
146
140
parser = cli .tsdate_cli_parser ()
147
- args = parser .parse_args ([self .infile , self .output , "--progress" ])
141
+ args = parser .parse_args ([self .infile , self .output , "10000" , " --progress" ])
148
142
self .assertTrue (args .progress )
149
143
150
144
145
+ class TestEndToEnd (unittest .TestCase ):
146
+ """
147
+ Class to test input to CLI outputs dated tree sequences.
148
+ """
149
+ def ts_equal_except_times (self , ts1 , ts2 ):
150
+ for (t1_name , t1 ), (t2_name , t2 ) in zip (ts1 .tables , ts2 .tables ):
151
+ if isinstance (t1 , tskit .ProvenanceTable ):
152
+ # TO DO - should check that the provenance has had the "tsdate" method
153
+ # added
154
+ pass
155
+ elif isinstance (t1 , tskit .NodeTable ):
156
+ for column_name in t1 .column_names :
157
+ if column_name != 'time' :
158
+ col_t1 = getattr (t1 , column_name )
159
+ col_t2 = getattr (t2 , column_name )
160
+ self .assertTrue (np .array_equal (col_t1 , col_t2 ))
161
+ elif isinstance (t1 , tskit .EdgeTable ):
162
+ # Edges may have been re-ordered, since sortedness requirements specify
163
+ # they are sorted by parent time, and the relative order of
164
+ # (unconnected) parent nodes might have changed due to time inference
165
+ self .assertEquals (set (t1 ), set (t2 ))
166
+ else :
167
+ self .assertEquals (t1 , t2 )
168
+ # The dated and undated tree sequences should not have the same node times
169
+ self .assertTrue (not np .array_equal (ts1 .tables .nodes .time , ts2 .tables .nodes .time ))
170
+
171
+ def verify (self , input_ts , cmd ):
172
+ with tempfile .TemporaryDirectory () as tmpdir :
173
+ input_filename = pathlib .Path (tmpdir ) / "input.trees"
174
+ input_ts .dump (input_filename )
175
+ output_filename = pathlib .Path (tmpdir ) / "output.trees"
176
+ full_cmd = str (input_filename ) + f" { output_filename } " + cmd
177
+ with mock .patch ("tsdate.cli.setup_logging" ):
178
+ stdout , stderr = capture_output (cli .tsdate_main , full_cmd .split ())
179
+ self .assertEqual (len (stderr ), 0 )
180
+ self .assertEqual (len (stdout ), 0 )
181
+ output_ts = tskit .load (output_filename )
182
+ self .assertEqual (input_ts .num_samples , output_ts .num_samples )
183
+ self .ts_equal_except_times (input_ts , output_ts )
184
+ # provenance = json.loads(ts.provenance(0).record)
185
+
186
+ def test_ts (self ):
187
+ input_ts = msprime .simulate (10 , random_seed = 1 )
188
+ cmd = "1"
189
+ self .verify (input_ts , cmd )
190
+
191
+ def test_mutation_rate (self ):
192
+ input_ts = msprime .simulate (10 , random_seed = 1 )
193
+ cmd = "1 --mutation-rate 1e-8"
194
+ self .verify (input_ts , cmd )
195
+
196
+ def test_recombination_rate (self ):
197
+ input_ts = msprime .simulate (10 , random_seed = 1 )
198
+ cmd = "1 --recombination-rate 1e-8"
199
+ self .assertRaises (NotImplementedError , self .verify , input_ts , cmd )
200
+
201
+ def test_epsilon (self ):
202
+ input_ts = msprime .simulate (10 , random_seed = 1 )
203
+ cmd = "1 --epsilon 1e-3"
204
+ self .verify (input_ts , cmd )
205
+
206
+ def test_num_threads (self ):
207
+ input_ts = msprime .simulate (10 , random_seed = 1 )
208
+ cmd = "1 --num-threads 2"
209
+ self .verify (input_ts , cmd )
210
+
211
+ def test_probability_space (self ):
212
+ input_ts = msprime .simulate (10 , random_seed = 1 )
213
+ cmd = "1 --probability-space linear"
214
+ self .verify (input_ts , cmd )
215
+ cmd = "1 --probability-space logarithmic"
216
+ self .verify (input_ts , cmd )
217
+
218
+ def test_probability_space (self ):
219
+ input_ts = msprime .simulate (10 , random_seed = 1 )
220
+ cmd = "1 --probability-space linear"
221
+ self .verify (input_ts , cmd )
222
+ cmd = "1 --probability-space logarithmic"
223
+ self .verify (input_ts , cmd )
224
+
225
+ def test_method (self ):
226
+ input_ts = msprime .simulate (10 , random_seed = 1 )
227
+ cmd = "1 --method inside_outside"
228
+ self .verify (input_ts , cmd )
229
+ cmd = "1 --method maximization"
230
+ self .assertRaises (ValueError , self .verify , input_ts , cmd )
231
+
232
+
151
233
class TestCli (unittest .TestCase ):
152
234
"""
153
235
Superclass of tests that run the CLI.
0 commit comments