diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..5ee45eb --- /dev/null +++ b/.travis.yml @@ -0,0 +1,24 @@ +notifications: + email: false + +language: python + +matrix: + include: + - os: linux + python: 3.7 + dist: xenial + sudo: true + +install: + - wget https://github.com/andre-martins/AD3/archive/2.2.1.tar.gz + - tar zxvf 2.2.1.tar.gz + - cd AD3-2.2.1; make; cd .. + - pip install pytest numpy ad3==2.2.1 cython + - pip install https://download.pytorch.org/whl/cpu/torch-1.0.1.post2-cp37-cp37m-linux_x86_64.whl + - AD3_DIR=AD3-2.2.1/ python setup.py bdist_wheel + - pip install --pre --no-index --find-links dist/ sparsemap + +script: + - echo "Running tests" + - mkdir empty_folder; cd empty_folder; pytest -vs --pyargs sparsemap; cd .. diff --git a/README.md b/README.md index e457279..e7a976f 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,8 @@ ![SparseMAP cartoon](sparsemap.png?raw=true "SparseMAP cartoon") +[![Build Status](https://travis-ci.org/vene/sparsemap.svg?branch=master)](https://travis-ci.org/vene/sparsemap) + SparseMAP is a new method for **sparse structured inference,** able to automatically select only a few global structures: it is situated between MAP inference, which picks a single structure, @@ -29,28 +31,27 @@ to the `cpp` folder for an implementation, and see our paper, ## Current state of the codebase We are working to slowly provide useful implementations. At the moment, -the codebase provides a generic pytorch layer supporting version 0.2, -as well as particular instantiations for sequence, matching, and tree layers. +the codebase provides a generic pytorch 1.0 layer, as well as particular +instantiations for sequence, matching, and tree layers. Dynet custom layers, as well as the SparseMAP loss, are on the way. ## Python Setup -Requirements: numpy, scipy, Cython, pytorch=0.2, and ad3 >= 2.2 +Requirements: numpy, scipy, Cython, pytorch>=1.0, and ad3 >= 2.2 1. Set the `AD3_DIR` environment variable to point to the - [AD3](https://github.com/andre-martins/ad3) source directory. + [AD3](https://github.com/andre-martins/ad3) source directory, + where you have compiled AD3. -2. Inside the `python` dir, run `python setup.py build_ext --inplace`. +2. Run `pip install .` (optionally with the `-e` flag). ### Notes on testing -The implemented layers pass numerical tests. However, the pytorch -gradcheck (as of version 0.2) has a very strict "reentrant" test, which we fail -due to tiny numerical differences. To reliably check gradients, please comment -out the `if not reentrant: ...` part of pytorch's gradcheck.py. +Because of slight numerical differences, we had to relax the reentrancy +test from pytorch's gradcheck. ## Dynet (c++) setup: diff --git a/python/setup.py b/python/setup.py deleted file mode 100644 index 407f0bf..0000000 --- a/python/setup.py +++ /dev/null @@ -1,46 +0,0 @@ -import os - -import numpy - -from setuptools import setup -from setuptools.extension import Extension -from Cython.Build import cythonize - -ad3_dir = os.environ.get('AD3_DIR') -if not ad3_dir: - print("Warning: please set the AD3_DIR environment variable to point" - "to the path where you have downloaded the AD3 library.") - exit(1) - - -ext_args = dict( - # below is a hack so we can include ./ad3 as well as ../ad3 - libraries=['ad3'], - extra_compile_args=["-std=c++11"], - language="c++") - - -setup(name="sparsemap", - version="0.1.dev0", - author="Vlad Niculae", - author_email="vlad@vene.ro", - ext_modules=cythonize([ - Extension( - "sparsemap._sparsemap", - ["sparsemap/_sparsemap.pyx"], - include_dirs=["../src", ad3_dir, os.path.join(ad3_dir, 'python'), - numpy.get_include()], - library_dirs=[os.path.join(ad3_dir, 'ad3')], - **ext_args), - Extension( - "sparsemap._factors", - ["sparsemap/_factors.pyx", - "../src/lapjv/lapjv.cpp", - os.path.join(ad3_dir, 'examples', 'cpp', 'parsing', 'FactorTree.cpp') - ], - include_dirs=["../src", ad3_dir, os.path.join(ad3_dir, 'python'), - numpy.get_include()], - library_dirs=[os.path.join(ad3_dir, 'ad3')], - **ext_args), - ]) -) diff --git a/python/sparsemap/__init__.py b/python/sparsemap/__init__.py index 531d089..a7cf63d 100644 --- a/python/sparsemap/__init__.py +++ b/python/sparsemap/__init__.py @@ -1 +1,3 @@ from ._sparsemap import sparsemap +from . import layers_pt +from . import fw_solver diff --git a/python/sparsemap/_sparsemap.pyx b/python/sparsemap/_sparsemap.pyx index 150d291..87a392e 100644 --- a/python/sparsemap/_sparsemap.pyx +++ b/python/sparsemap/_sparsemap.pyx @@ -46,7 +46,7 @@ cpdef sparsemap(PGenericFactor f, vector[Configuration] active_set_c vector[double] distribution vector[double] inverse_A - vector[double] M, Madd + vector[double] M,N GenericFactor* gf @@ -69,8 +69,8 @@ cpdef sparsemap(PGenericFactor f, active_set_c = gf.GetQPActiveSet() distribution = gf.GetQPDistribution() - inverse_A = gf.GetQPInvA() - gf.GetCorrespondence(&M, &Madd) + inverse = gf.GetQPInvA() + gf.GetCorrespondence(&M, &N) n_active = active_set_c.size() n_add = post_additionals.size() @@ -78,18 +78,18 @@ cpdef sparsemap(PGenericFactor f, post_unaries_np = asfloatvec(post_unaries.data(), n_var) post_additionals_np = asfloatvec(post_additionals.data(), n_add) distribution_np = asfloatvec(distribution.data(), n_active) - invA_np = asfloatarray(inverse_A.data(), 1 + n_active, 1 + n_active) + inv_np = asfloatarray(inverse.data(), 1 + n_active, 1 + n_active) M_np = asfloatarray(M.data(), n_active, n_var) - Madd_np = asfloatarray(Madd.data(), n_active, n_add) + N_np = asfloatarray(N.data(), n_active, n_add) active_set_py = [f._cast_configuration(x) for x in active_set_c] solver_data = { 'active_set': active_set_py, 'distribution': distribution_np, - 'inverse_A': invA_np, + 'inverse': inv_np, 'M': M_np, - 'Madd': Madd_np + 'N': N_np } return post_unaries_np, post_additionals_np, solver_data diff --git a/python/sparsemap/fw_solver.py b/python/sparsemap/fw_solver.py index 2f66ab1..4e0af47 100644 --- a/python/sparsemap/fw_solver.py +++ b/python/sparsemap/fw_solver.py @@ -10,9 +10,7 @@ from collections import defaultdict import numpy as np -from numpy.testing import assert_allclose -import pytest class SparseMAPFW(object): @@ -213,145 +211,3 @@ def solve(self, eta_u, eta_v, full_path=False): return u, v, active_set, objs, size else: return u, v, active_set - - -@pytest.mark.parametrize('variant', ('vanilla', 'pairwise', 'away-step')) -def test_pairwise_factor(variant): - - class PairwiseFactor(object): - """A factor with two binary variables and a coupling between them.""" - - def vertex(self, y): - - # y is a tuple (0, 0), (0, 1), (1, 0) or (1, 1) - u = np.array(y, dtype=np.float) - v = np.atleast_1d(np.prod(u)) - return u, v - - def map_oracle(self, eta_u, eta_v): - - best_score = -np.inf - best_y = None - for x1 in (0, 1): - for x2 in (0, 1): - y = (x1, x2) - u, v = self.vertex(y) - - score = np.dot(u, eta_u) + np.dot(v, eta_v) - if score > best_score: - best_score = score - best_y = y - return best_y - - def qp(self, eta_u, eta_v): - """Prop 6.5 in Andre Martins' thesis""" - - c1, c2, c12 = eta_u[0], eta_u[1], eta_v[0] - - flip_sign = False - if c12 < 0: - flip_sign = True - c1, c2, c12 = c1 + c12, 1 - c2, -c12 - - if c1 > c2 + c12: - u = [c1, c2 + c12] - elif c2 > c1 + c12: - u = [c1 + c12, c2] - else: - uu = (c1 + c2 + c12) / 2 - u = [uu, uu] - - u = np.clip(np.array(u), 0, 1) - v = np.atleast_1d(np.min(u)) - - if flip_sign: - u[1] = 1 - u[1] - v[0] = u[0] - v[0] - - return u, v - - pw = PairwiseFactor() - fw = SparseMAPFW(pw, max_iter=10000, tol=1e-12, variant=variant) - - params = [ - (np.array([0, 0]), np.array([0])), - (np.array([100, 0]), np.array([0])), - (np.array([0, 100]), np.array([0])), - (np.array([100, 0]), np.array([-100])), - (np.array([0, 100]), np.array([-100])) - ] - - rng = np.random.RandomState(0) - for _ in range(20): - eta_u = rng.randn(2) - eta_v = rng.randn(1) - params.append((eta_u, eta_v)) - - for eta_u, eta_v in params: - - u, v, active_set = fw.solve(eta_u, eta_v) - ustar, vstar = pw.qp(eta_u, eta_v) - - uv = np.concatenate([u, v]) - uvstar = np.concatenate([ustar, vstar]) - - assert_allclose(uv, uvstar, atol=1e-10) - - -@pytest.mark.parametrize('variant', ('vanilla', 'pairwise', 'away-step')) -@pytest.mark.parametrize('k', (1, 4, 20)) -def test_xor(variant, k): - class XORFactor(object): - """A one-of-K factor""" - - def __init__(self, k): - self.k = k - - def vertex(self, y): - # y is an integer between 0 and k-1 - u = np.zeros(k) - u[y] = 1 - v = np.array(()) - - return u, v - - def map_oracle(self, eta_u, eta_v): - return np.argmax(eta_u) - - def qp(self, eta_u, eta_v): - """Projection onto the simplex""" - z = 1 - v = np.array(eta_u) - n_features = v.shape[0] - u = np.sort(v)[::-1] - cssv = np.cumsum(u) - z - ind = np.arange(n_features) + 1 - cond = u - cssv / ind > 0 - rho = ind[cond][-1] - theta = cssv[cond][-1] / float(rho) - uu = np.maximum(v - theta, 0) - vv = np.array(()) - return uu, vv - - xor = XORFactor(k) - fw = SparseMAPFW(xor, max_iter=10000, tol=1e-12, variant=variant) - - params = [np.zeros(k), np.ones(k), np.full(k, -1)] - - rng = np.random.RandomState(0) - for _ in range(20): - eta_u = rng.randn(k) - params.append(eta_u) - - for eta_u in params: - - # try different ways of supplying empty eta_v - for eta_v in (np.array(()), [], 0, None): - - u, v, active_set = fw.solve(eta_u, eta_v) - ustar, vstar = xor.qp(eta_u, eta_v) - - uv = np.concatenate([u, v]) - uvstar = np.concatenate([ustar, vstar]) - - assert_allclose(uv, uvstar, atol=1e-10) diff --git a/python/sparsemap/layers_pt/__init__.py b/python/sparsemap/layers_pt/__init__.py index e69de29..087cc82 100644 --- a/python/sparsemap/layers_pt/__init__.py +++ b/python/sparsemap/layers_pt/__init__.py @@ -0,0 +1,6 @@ +# from .matching_layer import MatchingSparseMarginals +# from .seq_layer import SequenceSparseMarginals +# from .seq_layer import SequenceDistanceSparseMarginals +# from .seq_layer import StationarySequencePotentials +# from .tree_layer import TreeSparseMarginals +# from .tree_layer import TreeSparseMarginalsFast diff --git a/python/sparsemap/layers_pt/base.py b/python/sparsemap/layers_pt/base.py index d2911c3..ab0d6e1 100644 --- a/python/sparsemap/layers_pt/base.py +++ b/python/sparsemap/layers_pt/base.py @@ -1,115 +1,137 @@ import numpy as np import torch -from torch.autograd import Variable, Function from .. import sparsemap -from ..utils import S_from_Ainv -class _BaseSparseMarginals(Function): +def _Z_from_inv(inv): + """ - def __init__(self, max_iter=10, verbose=0): - self.max_iter = max_iter - self.verbose = verbose + active set maintains the inverse : - def forward(self, unaries): + inv = [0, 1.T; 1, MtM] ^ -1 - cuda_device = None - if unaries.is_cuda: - cuda_device = unaries.get_device() - unaries = unaries.cpu() + we recover Z = (MtM)^{-1} by Sherman-Morrison-Woodbury + """ - factor = self.build_factor() - u, _, status = sparsemap(factor, unaries, [], - max_iter=self.max_iter, - verbose=self.verbose) - self.status = status + Z = inv[1:, 1:] + k = inv[0, 0] + b = inv[0, 1:].unsqueeze(0) - out = torch.from_numpy(u) - if cuda_device is not None: - out = out.cuda(cuda_device) - return out + Z -= (1 / k) * (b * b.t()) + return Z - def _d_vbar(self, M, dy): - Ainv = torch.from_numpy(self.status['inverse_A']) - S = S_from_Ainv(Ainv) +def _d_vbar(M, dy, inv): - if M.is_cuda: - S = S.cuda() - # B = S11t / 1S1t - # dvbar = (I - B) S M dy + Z = _Z_from_inv(inv) - # we first compute S M dy - first_term = S @ (M @ dy) - # then, BSMt dy = B * first_term. Optimized: - # 1S1t = S.sum() - # S11tx = (S1) (1t * x) - second_term = (first_term.sum() * S.sum(0)) / S.sum() - d_vbar = first_term - second_term - return d_vbar + # B = S11t / 1S1t + # dvbar = (I - B) S M dy + # we first compute S M dy + first_term = Z @ (M @ dy) + # then, BSMt dy = B * first_term. Optimized: + # 1S1t = S.sum() + # S11tx = (S1) (1t * x) + second_term = (first_term.sum() * Z.sum(0)) / Z.sum() + d_vbar = first_term - second_term + return d_vbar - def backward(self, dy): - cuda_device = None - if dy.is_cuda: - cuda_device = dy.get_device() - dy = dy.cpu() +def _from_np_like(X_np, Y_pt): + X = torch.from_numpy(X_np) + return torch.as_tensor(X, dtype=Y_pt.dtype, device=Y_pt.device) - M = torch.from_numpy(self.status['M']) - - d_vbar = self._d_vbar(M, dy) - d_unary = M.t() @ d_vbar - if cuda_device is not None: - d_unary = d_unary.cuda(cuda_device) +class _BaseSparseMAP(torch.nn.Module): - return d_unary + def __init__(self, max_iter=20, verbose=0): + self.max_iter = max_iter + self.verbose = verbose + super(_BaseSparseMAP, self).__init__() + def sparsemap(self, unaries, factor, additionals=None): + if additionals is not None: + return _SparseMAPAdd.apply( + unaries, + additionals, + factor, + self, + self.max_iter, + self.verbose) -class _BaseSparseMarginalsAdditionals(_BaseSparseMarginals): + else: + return _SparseMAP.apply( + unaries, + factor, + self, + self.max_iter, + self.verbose) - def forward(self, unaries, additionals): - cuda_device = None - if unaries.is_cuda: - cuda_device = unaries.get_device() - unaries = unaries.cpu() - additionals = additionals.cpu() +class _SparseMAP(torch.autograd.Function): - factor = self.build_factor() - u, uadd, status = sparsemap(factor, unaries, additionals, - max_iter=self.max_iter, - verbose=self.verbose) + @staticmethod + def forward(ctx, unaries, factor, caller=None, max_iter=20, verbose=0): - self.status = status + u, v, status = sparsemap(factor, unaries.cpu(), [], + max_iter=max_iter, + verbose=verbose) - out = torch.from_numpy(u) - if cuda_device is not None: - out = out.cuda(cuda_device) - return out + u = _from_np_like(u, unaries) + inv = _from_np_like(status['inverse'], unaries) + M = _from_np_like(status['M'], unaries) - def backward(self, dy): - cuda_device = None + if caller is not None: + caller.distribution = status['distribution'] + caller.configurations = status['active_set'] - if dy.is_cuda: - cuda_device = dy.get_device() - dy = dy.cpu() + ctx.save_for_backward(inv, M) + return u - M = torch.from_numpy(self.status['M']) - Madd = torch.from_numpy(self.status['Madd']) - if dy.is_cuda: - M = M.cuda() - Madd = Madd.cuda() + @staticmethod + def backward(ctx, dy): - d_vbar = self._d_vbar(M, dy) + inv, M = ctx.saved_tensors + d_vbar = _d_vbar(M, dy, inv) d_unary = M.t() @ d_vbar - d_additionals = Madd.t() @ d_vbar - if cuda_device is not None: - d_unary = d_unary.cuda(cuda_device) - d_additionals = d_additionals.cuda(cuda_device) + return d_unary, None, None, None, None + + +class _SparseMAPAdd(torch.autograd.Function): + """SparseMAP with additional inputs/outputs, as for a linear chain CRF""" + + @staticmethod + def forward(ctx, unaries, additionals, factor, caller=None, max_iter=20, + verbose=0): + + u, _, status = sparsemap(factor, + unaries.cpu(), + additionals.cpu(), + max_iter=max_iter, + verbose=verbose) + + u = _from_np_like(u, unaries) + inv = _from_np_like(status['inverse'], unaries) + M = _from_np_like(status['M'], unaries) + N = _from_np_like(status['N'], unaries) + + if caller is not None: + caller.distribution = status['distribution'] + caller.configurations = status['active_set'] + + ctx.save_for_backward(inv, M, N) + return u + + @staticmethod + def backward(ctx, dy): + + inv, M, N = ctx.saved_tensors + d_vbar = _d_vbar(M, dy, inv) - return d_unary, d_additionals + d_u = M.t() @ d_vbar + d_v = N.t() @ d_vbar + return d_u, d_v, None, None, None, None diff --git a/python/sparsemap/layers_pt/matching_layer.py b/python/sparsemap/layers_pt/matching_layer.py index fee4ab6..ed61241 100644 --- a/python/sparsemap/layers_pt/matching_layer.py +++ b/python/sparsemap/layers_pt/matching_layer.py @@ -1,41 +1,36 @@ from ad3 import PFactorGraph import torch -from torch.autograd import Variable, Function +from torch.autograd import Function -from .base import _BaseSparseMarginals +from .base import _SparseMAP, _BaseSparseMAP from .._factors import PFactorMatching -class MatchingSparseMarginals(_BaseSparseMarginals): +class Matching(_BaseSparseMAP): + + def forward(self, unaries): + self.n_rows, self.n_cols = unaries.size() - def build_factor(self): match = PFactorMatching() match.initialize(self.n_rows, self.n_cols) - return match - def forward(self, unaries): - self.n_rows, self.n_cols = unaries.size() - u = super().forward(unaries.view(-1)) - return u.view_as(unaries) + u = self.sparsemap(unaries.view(-1), match) - def backward(self, dy): - dy = dy.contiguous().view(-1) - da = super().backward(dy) - return da.view(self.n_rows, self.n_cols) + return u.view_as(unaries) if __name__ == '__main__': n_rows = 5 n_cols = 3 - scores = torch.randn(n_rows, n_cols) - scores = Variable(scores, requires_grad=True) + scores = torch.randn(n_rows, n_cols, dtype=torch.double, requires_grad=True) - matcher = MatchingSparseMarginals() + matcher = Matching(max_iter=1000) matching = matcher(scores) - print(matching) - matching.sum().backward() + print(torch.autograd.grad(matching[0, 0], scores, retain_graph=True)) + print(torch.autograd.grad(matching[0, 0], scores)) - print("dpost_dunary", scores.grad) + from torch.autograd import gradcheck + print(gradcheck(matcher, scores, eps=1e-4, atol=1e-3)) diff --git a/python/sparsemap/layers_pt/seq_layer.py b/python/sparsemap/layers_pt/seq_layer.py index efde0b9..392cc33 100644 --- a/python/sparsemap/layers_pt/seq_layer.py +++ b/python/sparsemap/layers_pt/seq_layer.py @@ -3,27 +3,23 @@ from ad3.extensions import PFactorSequence import torch -from torch.autograd import Variable, Function -from torch import nn -from .base import _BaseSparseMarginalsAdditionals +from .base import _BaseSparseMAP from .._factors import PFactorSequenceDistance -class StationarySequencePotentials(nn.Module): +class StationarySequencePotentials(torch.nn.Module): def forward(self, transition, n_variables, start=None, end=None): n_states, n_states_ = transition.size() assert n_states == n_states_ if start is None: - start = Variable(transition.data.new(n_states)) - start.data.zero_() + start = transition.zero_(n_states) else: assert start.dim() == 1 and start.size()[0] == n_states if end is None: - end = Variable(transition.data.new(n_states)) - end.data.zero_() + end = transition.zero_(n_states) else: assert end.dim() == 1 and end.size()[0] == n_states @@ -32,7 +28,7 @@ def forward(self, transition, n_variables, start=None, end=None): end]) -class SequenceSparseMarginals(_BaseSparseMarginalsAdditionals): +class Sequence(_BaseSparseMAP): def forward(self, unaries, additionals): """Returns a weighted sum of the most likely posterior assignments. @@ -57,30 +53,23 @@ def forward(self, unaries, additionals): """ self.n_variables, self.n_states = unaries.size() - u = super().forward(unaries.view(-1), additionals) - return u.view_as(unaries) - - def backward(self, dy): - dy = dy.contiguous().view(-1) - da, dadd = super().backward(dy) - return da.view(self.n_variables, self.n_states), dadd - - def build_factor(self): seq = PFactorSequence() seq.initialize([self.n_states] * self.n_variables) - return seq + u = self.sparsemap(unaries.view(-1), seq, additionals) + return u.view_as(unaries) -class SequenceDistanceSparseMarginals(SequenceSparseMarginals): - def __init__(self, bandwidth, max_iter=10, verbose=False): +class SequenceDistance(Sequence): + def __init__(self, bandwidth, max_iter=20, verbose=False): self.bandwidth = bandwidth - self.max_iter = max_iter - self.verbose = verbose + super(SequenceDistance, self).__init__(max_iter, verbose) - def build_factor(self): + def forward(self, unaries, additionals): + self.n_variables, self.n_states = unaries.size() seq = PFactorSequenceDistance() seq.initialize(self.n_variables, self.n_states, self.bandwidth) - return seq + u = self.sparsemap(unaries.view(-1), seq, additionals) + return u.view_as(unaries) if __name__ == '__main__': @@ -89,13 +78,13 @@ def build_factor(self): n_states = 3 torch.manual_seed(12) - unary = Variable(torch.randn(n_variables, n_states), requires_grad=True) - start = Variable(torch.randn(n_states), requires_grad=True) - end = Variable(torch.randn(n_states), requires_grad=True) - transition = Variable(torch.randn(n_states, n_states), requires_grad=True) + unary = torch.randn(n_variables, n_states, requires_grad=True) + start = torch.randn(n_states, requires_grad=True) + end = torch.randn(n_states, requires_grad=True) + transition = torch.randn(n_states, n_states, requires_grad=True) stationary_seq = StationarySequencePotentials() - seq_marginals = SequenceSparseMarginals() + seq_marginals = Sequence() additionals = stationary_seq(transition, n_variables, @@ -104,7 +93,7 @@ def build_factor(self): posterior = seq_marginals(unary, additionals) print(posterior) - posterior.sum().backward() + posterior[0, 0].backward() print("dpost_dunary", unary.grad) print("dstart", start.grad) @@ -114,10 +103,10 @@ def build_factor(self): print("With distance-based parametrization") bw = 3 - dist_additional = Variable(torch.randn(1 + 4 * bw), requires_grad=True) - seq_dist_marg = SequenceDistanceSparseMarginals(bw) + dist_additional = torch.randn(1 + 4 * bw, requires_grad=True) + seq_dist_marg = SequenceDistance(bw) posterior = seq_dist_marg(unary, dist_additional) print(posterior) - ((posterior - 0.5)**2).sum().backward() + posterior[0, 0].backward() print("dpost_dunary", unary.grad) print("dpost_dadd", dist_additional.grad) diff --git a/python/sparsemap/layers_pt/tests/custom_gradcheck.py b/python/sparsemap/layers_pt/tests/custom_gradcheck.py new file mode 100644 index 0000000..e8b5ec0 --- /dev/null +++ b/python/sparsemap/layers_pt/tests/custom_gradcheck.py @@ -0,0 +1,238 @@ +"""relaxed reentrancy check. + +Based on https://github.com/pytorch/pytorch/blob/v1.0.1/torch/autograd/gradcheck.py + +By the pytorch authors, released under the same license. +""" + +import torch +from torch._six import container_abcs +import torch.testing +import sys +from itertools import product +import warnings + + +def zero_gradients(x): + if isinstance(x, torch.Tensor): + if x.grad is not None: + x.grad.detach_() + x.grad.data.zero_() + elif isinstance(x, container_abcs.Iterable): + for elem in x: + zero_gradients(elem) + + +def make_jacobian(input, num_out): + if isinstance(input, torch.Tensor): + if not input.is_floating_point(): + return None + if not input.requires_grad: + return None + return torch.zeros(input.nelement(), num_out, dtype=input.dtype) + elif isinstance(input, container_abcs.Iterable) and not isinstance(input, str): + jacobians = list(filter( + lambda x: x is not None, (make_jacobian(elem, num_out) for elem in input))) + if not jacobians: + return None + return type(input)(jacobians) + else: + return None + + +def iter_tensors(x, only_requiring_grad=False): + if isinstance(x, torch.Tensor): + if x.requires_grad or not only_requiring_grad: + yield x + elif isinstance(x, container_abcs.Iterable) and not isinstance(x, str): + for elem in x: + for result in iter_tensors(elem, only_requiring_grad): + yield result + + +# `input` is input to `fn` +# `target` is the Tensors wrt whom Jacobians are calculated (default=`input`) +# +# Note that `target` may not even be part of `input` to `fn`, so please be +# **very careful** in this to not clone `target`. +def get_numerical_jacobian(fn, input, target=None, eps=1e-3): + if target is None: + target = input + output_size = fn(input).numel() + jacobian = make_jacobian(target, output_size) + + # It's much easier to iterate over flattened lists of tensors. + # These are reference to the same objects in jacobian, so any changes + # will be reflected in it as well. + x_tensors = [t for t in iter_tensors(target, True)] + j_tensors = [t for t in iter_tensors(jacobian)] + + # TODO: compare structure + for x_tensor, d_tensor in zip(x_tensors, j_tensors): + # need data here to get around the version check because without .data, + # the following code updates version but doesn't change content + x_tensor = x_tensor.data + for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])): + orig = x_tensor[x_idx].item() + x_tensor[x_idx] = orig - eps + outa = fn(input).clone() + x_tensor[x_idx] = orig + eps + outb = fn(input).clone() + x_tensor[x_idx] = orig + + r = (outb - outa) / (2 * eps) + d_tensor[d_idx] = r.detach().reshape(-1) + + return jacobian + + +def get_analytical_jacobian(input, output, reentrance_tol=1e-4): + diff_input_list = list(iter_tensors(input, True)) + jacobian = make_jacobian(input, output.numel()) + jacobian_reentrant = make_jacobian(input, output.numel()) + grad_output = torch.zeros_like(output) + flat_grad_output = grad_output.view(-1) + reentrant = True + correct_grad_sizes = True + + for i in range(flat_grad_output.numel()): + flat_grad_output.zero_() + flat_grad_output[i] = 1 + for jacobian_c in (jacobian, jacobian_reentrant): + grads_input = torch.autograd.grad(output, diff_input_list, grad_output, + retain_graph=True, allow_unused=True) + for jacobian_x, d_x, x in zip(jacobian_c, grads_input, diff_input_list): + if d_x is not None and d_x.size() != x.size(): + correct_grad_sizes = False + elif jacobian_x.numel() != 0: + if d_x is None: + jacobian_x[:, i].zero_() + else: + d_x_dense = d_x.to_dense() if d_x.is_sparse else d_x + assert jacobian_x[:, i].numel() == d_x_dense.numel() + jacobian_x[:, i] = d_x_dense.contiguous().view(-1) + + for jacobian_x, jacobian_reentrant_x in zip(jacobian, jacobian_reentrant): + if (jacobian_x.numel() != 0 and + (jacobian_x - jacobian_reentrant_x).abs().max() > reentrance_tol): + reentrant = False + + return jacobian, reentrant, correct_grad_sizes + + +def _as_tuple(x): + if isinstance(x, tuple): + return x + elif isinstance(x, list): + return tuple(x) + else: + return x, + + +def _differentiable_outputs(x): + return tuple(o for o in _as_tuple(x) if o.requires_grad) + + +def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True): + r"""Check gradients computed via small finite differences against analytical + gradients w.r.t. tensors in :attr:`inputs` that are of floating point type + and with ``requires_grad=True``. + + The check between numerical and analytical gradients uses :func:`~torch.allclose`. + + .. note:: + The default values are designed for :attr:`input` of double precision. + This check will likely fail if :attr:`input` is of less precision, e.g., + ``FloatTensor``. + + .. warning:: + If any checked tensor in :attr:`input` has overlapping memory, i.e., + different indices pointing to the same memory address (e.g., from + :func:`torch.expand`), this check will likely fail because the numerical + gradients computed by point perturbation at such indices will change + values at all other indices that share the same memory address. + + Args: + func (function): a Python function that takes Tensor inputs and returns + a Tensor or a tuple of Tensors + inputs (tuple of Tensor or Tensor): inputs to the function + eps (float, optional): perturbation for finite differences + atol (float, optional): absolute tolerance + rtol (float, optional): relative tolerance + raise_exception (bool, optional): indicating whether to raise an exception if + the check fails. The exception gives more information about the + exact nature of the failure. This is helpful when debugging gradchecks. + + Returns: + True if all differences satisfy allclose condition + """ + tupled_inputs = _as_tuple(inputs) + + # Make sure that gradients are saved for all inputs + any_input_requiring_grad = False + for inp in tupled_inputs: + if isinstance(inp, torch.Tensor): + if inp.requires_grad: + if inp.dtype != torch.float64: + warnings.warn( + 'At least one of the inputs that requires gradient ' + 'is not of double precision floating point. ' + 'This check will likely fail if all the inputs are ' + 'not of double precision floating point. ') + any_input_requiring_grad = True + inp.retain_grad() + if not any_input_requiring_grad: + raise ValueError( + 'gradcheck expects at least one input tensor to require gradient, ' + 'but none of the them have requires_grad=True.') + + output = _differentiable_outputs(func(*tupled_inputs)) + + def fail_test(msg): + if raise_exception: + raise RuntimeError(msg) + return False + + for i, o in enumerate(output): + if not o.requires_grad: + continue + + def fn(input): + return _as_tuple(func(*input))[i] + + analytical, reentrant, correct_grad_sizes = get_analytical_jacobian(tupled_inputs, o, reentrance_tol=atol) + numerical = get_numerical_jacobian(fn, tupled_inputs, eps=eps) + + if not correct_grad_sizes: + return fail_test('Analytical gradient has incorrect size') + + for j, (a, n) in enumerate(zip(analytical, numerical)): + if a.numel() != 0 or n.numel() != 0: + if not torch.allclose(a, n, rtol, atol): + return fail_test('Jacobian mismatch for output %d with respect to input %d,\n' + 'numerical:%s\nanalytical:%s\n' % (i, j, n, a)) + + if not reentrant: + return fail_test('Backward is not reentrant, i.e., running backward with same ' + 'input and grad_output multiple times gives different values, ' + 'although analytical gradient matches numerical gradient') + + # check if the backward multiplies by grad_output + output = _differentiable_outputs(func(*tupled_inputs)) + if any([o.requires_grad for o in output]): + diff_input_list = list(iter_tensors(tupled_inputs, True)) + if not diff_input_list: + raise RuntimeError("no Tensors requiring grad found in input") + grads_input = torch.autograd.grad(output, diff_input_list, [torch.zeros_like(o) for o in output], + allow_unused=True) + for gi, i in zip(grads_input, diff_input_list): + if gi is None: + continue + if not gi.eq(0).all(): + return fail_test('backward not multiplied by grad_output') + if gi.type() != i.type(): + return fail_test("grad is incorrect type") + if gi.size() != i.size(): + return fail_test('grad is incorrect size') + + return True diff --git a/python/sparsemap/layers_pt/tests/test_matching_layer.py b/python/sparsemap/layers_pt/tests/test_matching_layer.py index 11dec85..914410d 100644 --- a/python/sparsemap/layers_pt/tests/test_matching_layer.py +++ b/python/sparsemap/layers_pt/tests/test_matching_layer.py @@ -1,7 +1,7 @@ -from .. import matching_layer - import torch -from torch.autograd import gradcheck, Variable + +from .custom_gradcheck import gradcheck +from .. import matching_layer def test_matching_sparse_decode(): @@ -10,9 +10,7 @@ def test_matching_sparse_decode(): n_cols = 4 for _ in range(20): - matcher = matching_layer.MatchingSparseMarginals(max_iter=100) - W = torch.randn(n_rows, n_cols) - W = Variable(W, requires_grad=True) - res = gradcheck(matcher, (W,), eps=1e-3, - atol=1e-3) + matcher = matching_layer.Matching(max_iter=100) + W = torch.randn(n_rows, n_cols, dtype=torch.double, requires_grad=True) + res = gradcheck(matcher, (W,), eps=1e-3, atol=1e-5) assert res diff --git a/python/sparsemap/layers_pt/tests/test_seq_layer.py b/python/sparsemap/layers_pt/tests/test_seq_layer.py index 70d462c..a5bc021 100644 --- a/python/sparsemap/layers_pt/tests/test_seq_layer.py +++ b/python/sparsemap/layers_pt/tests/test_seq_layer.py @@ -1,7 +1,7 @@ -from .. import seq_layer - import torch -from torch.autograd import gradcheck, Variable + +from .custom_gradcheck import gradcheck +from .. import seq_layer def test_seq_sparse_decode(): @@ -9,13 +9,12 @@ def test_seq_sparse_decode(): n_vars = 4 n_states = 3 for _ in range(20): - sequence_smap = seq_layer.SequenceSparseMarginals(max_iter=1000) - unary = Variable(torch.randn(n_vars, n_states), requires_grad=True) - additionals = Variable(torch.randn(2 * n_states + - (n_vars - 1) * n_states ** 2), - requires_grad=True) - res = gradcheck(sequence_smap, (unary, additionals), eps=1e-4, atol=1e-3) - print(res) + seq = seq_layer.Sequence(max_iter=1000) + unary = torch.randn(n_vars, n_states, dtype=torch.double, requires_grad=True) + additionals = torch.randn(2 * n_states + (n_vars - 1) * n_states ** 2, + dtype=torch.double, + requires_grad=True) + res = gradcheck(seq, (unary, additionals), eps=1e-4, atol=1e-4) assert res @@ -25,10 +24,12 @@ def test_seq_dist_sparse_decode(): n_states = 3 bandwidth = 3 for _ in range(20): - seq_dist_smap = seq_layer.SequenceDistanceSparseMarginals(bandwidth) - unary = Variable(torch.randn(n_vars, n_states), requires_grad=True) - additionals = Variable(torch.randn(1 + 4 * bandwidth), - requires_grad=True) - res = gradcheck(seq_dist_smap, (unary, additionals), eps=1e-4, atol=1e-3) - print(res) + seq = seq_layer.SequenceDistance(bandwidth, max_iter=1000) + unary = torch.randn(n_vars, n_states, + dtype=torch.double, + requires_grad=True) + additionals = torch.randn(1 + 4 * bandwidth, + dtype=torch.double, + requires_grad=True) + res = gradcheck(seq, (unary, additionals), eps=1e-4, atol=1e-3) assert res diff --git a/python/sparsemap/layers_pt/tests/test_tree_layer.py b/python/sparsemap/layers_pt/tests/test_tree_layer.py index 4110f76..bb29a1f 100644 --- a/python/sparsemap/layers_pt/tests/test_tree_layer.py +++ b/python/sparsemap/layers_pt/tests/test_tree_layer.py @@ -1,31 +1,28 @@ -from ..tree_layer import TreeSparseMarginalsFast - import torch -from torch.autograd import gradcheck, Variable +from .custom_gradcheck import gradcheck +from ..tree_layer import DependencyTreeFast def test_fasttree_sparse_decode(): torch.manual_seed(42) - n_nodes = 5 - tsm = TreeSparseMarginalsFast(n_nodes, max_iter=1000) + n_nodes = 6 + tree = DependencyTreeFast(n_nodes, max_iter=100) for _ in range(20): - W = torch.randn(n_nodes, n_nodes + 1).view(-1) - W = Variable(W, requires_grad=True) - res = gradcheck(tsm, (W,), eps=1e-4, - atol=1e-3) - print(res) + W = torch.randn(n_nodes * (n_nodes + 1), + dtype=torch.double, + requires_grad=True) + res = gradcheck(tree, (W,), eps=1e-5, atol=1e-2) assert res def test_meaning_sparse_decode(): n_nodes = 4 - w = torch.zeros(n_nodes, n_nodes + 1) + w = torch.zeros(n_nodes, n_nodes + 1, dtype=torch.double) w[2, 1] = 100 - w = Variable(w) - tsm = TreeSparseMarginalsFast(n_nodes, verbose=3) - u = tsm(w.view(-1)) - for config in tsm.status['active_set']: + tree = DependencyTreeFast(n_nodes) + u = tree(w.view(-1)) + for config in tree.configurations: assert config[1 + 2] == 1 @@ -34,9 +31,8 @@ def test_fast_tree_ignores_diag(): # w = torch.zeros(n_nodes, n_nodes + 1) w_init = torch.randn(n_nodes * (n_nodes + 1)) - w = Variable(w_init) - tsm = TreeSparseMarginalsFast(n_nodes) - u = tsm(w.view(-1)) + tree = DependencyTreeFast(n_nodes) + u = tree(w_init.view(-1)) k = 0 for m in range(1, n_nodes + 1): @@ -45,8 +41,7 @@ def test_fast_tree_ignores_diag(): w_init[k] = 0 k += 1 - w = Variable(w_init) - tsm = TreeSparseMarginalsFast(n_nodes) - u_zeroed = tsm(w.view(-1)) + tree = DependencyTreeFast(n_nodes) + u_zeroed = tree(w_init.view(-1)) assert (u_zeroed - u).data.norm() < 1e-12 diff --git a/python/sparsemap/layers_pt/tree_layer.py b/python/sparsemap/layers_pt/tree_layer.py index c93d33e..8331936 100644 --- a/python/sparsemap/layers_pt/tree_layer.py +++ b/python/sparsemap/layers_pt/tree_layer.py @@ -2,81 +2,72 @@ import numpy as np import torch -from torch.autograd import Variable, Function -from torch import nn from ad3 import PFactorGraph from ad3.extensions import PFactorTree -from .base import _BaseSparseMarginals +from .base import _BaseSparseMAP from .._factors import PFactorTreeFast -class TreeSparseMarginals(_BaseSparseMarginals): +class DependencyTree(_BaseSparseMAP): - def __init__(self, n_nodes=None, max_iter=10, verbose=0): + def __init__(self, n_nodes, max_iter=20, verbose=0): self.n_nodes = n_nodes - self.max_iter = max_iter - self.verbose = verbose + super(DependencyTree, self).__init__(max_iter, verbose) + + def forward(self, unaries): - def build_factor(self): n_nodes = self.n_nodes + arcs = [(h, m) for m in range(1, n_nodes + 1) + for h in range(n_nodes + 1) + if h != m] g = PFactorGraph() - self.arcs = [(h, m) - for m in range(1, n_nodes + 1) - for h in range(n_nodes + 1) - if h != m] - arc_vars = [g.create_binary_variable() for _ in self.arcs] + arc_vars = [g.create_binary_variable() for _ in arcs] tree = PFactorTree() g.declare_factor(tree, arc_vars) - tree.initialize(n_nodes + 1, self.arcs) - return tree + tree.initialize(n_nodes + 1, arcs) + return self.sparsemap(unaries.view(-1), tree) -class TreeSparseMarginalsFast(_BaseSparseMarginals): - def __init__(self, n_nodes=None, max_iter=10, verbose=0): - self.n_nodes = n_nodes - self.max_iter = max_iter - self.verbose = verbose +class DependencyTreeFast(DependencyTree): - def build_factor(self): + def forward(self, unaries): n_nodes = self.n_nodes + arcs = [(h, m) for m in range(1, n_nodes + 1) + for h in range(n_nodes + 1) + if h != m] g = PFactorGraph() - self.arcs = [(h, m) - for m in range(1, n_nodes + 1) - for h in range(n_nodes + 1) - if h != m] - arc_vars = [g.create_binary_variable() for _ in self.arcs] + arc_vars = [g.create_binary_variable() for _ in arcs] tree = PFactorTreeFast() g.declare_factor(tree, arc_vars) - tree.initialize(n_nodes + 1) - return tree + tree.initialize(self.n_nodes + 1) + return self.sparsemap(unaries.view(-1), tree) if __name__ == '__main__': n_nodes = 3 - Wt = torch.randn((n_nodes + 1) * n_nodes) - W = Variable(Wt, requires_grad=True) + W = torch.randn((n_nodes + 1) * n_nodes, requires_grad=True) Wskip_a = [] k = 0 for m in range(1, n_nodes + 1): for h in range(n_nodes + 1): if h != m: - Wskip_a.append(Wt[k]) + Wskip_a.append(W.data[k]) k += 1 - Wskip_a = np.array(Wskip_a) + Wskip_a = np.array(Wskip_a, dtype=np.double) - Wskip_t = torch.from_numpy(Wskip_a) - Wskip = Variable(Wskip_t, requires_grad=True) + Wskip = torch.from_numpy(Wskip_a).requires_grad_() - tsm_slow = TreeSparseMarginals(n_nodes) - posteriors = tsm_slow(Wskip) + tree_slow = DependencyTree(n_nodes) + posteriors = tree_slow(Wskip) print("posteriors slow", posteriors) - tsm = TreeSparseMarginalsFast(n_nodes) - posteriors = tsm(W) + tree_fast = DependencyTreeFast(n_nodes) + posteriors = tree_fast(W) print("posteriors fast", posteriors) - posteriors.sum().backward() + + posteriors[0].backward() print("dposteriors_dW", W.grad) diff --git a/python/sparsemap/tests/__init__.py b/python/sparsemap/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/sparsemap/tests/test_fw_solver.py b/python/sparsemap/tests/test_fw_solver.py new file mode 100644 index 0000000..9fbbf9b --- /dev/null +++ b/python/sparsemap/tests/test_fw_solver.py @@ -0,0 +1,151 @@ +# Author: Vlad Niculae +# License: BSD 3-clause + + +import numpy as np +from numpy.testing import assert_allclose +import pytest + +from ..fw_solver import SparseMAPFW + + +@pytest.mark.parametrize('variant', ('vanilla', 'pairwise', 'away-step')) +def test_pairwise_factor(variant): + + class PairwiseFactor(object): + """A factor with two binary variables and a coupling between them.""" + + def vertex(self, y): + + # y is a tuple (0, 0), (0, 1), (1, 0) or (1, 1) + u = np.array(y, dtype=np.float) + v = np.atleast_1d(np.prod(u)) + return u, v + + def map_oracle(self, eta_u, eta_v): + + best_score = -np.inf + best_y = None + for x1 in (0, 1): + for x2 in (0, 1): + y = (x1, x2) + u, v = self.vertex(y) + + score = np.dot(u, eta_u) + np.dot(v, eta_v) + if score > best_score: + best_score = score + best_y = y + return best_y + + def qp(self, eta_u, eta_v): + """Prop 6.5 in Andre Martins' thesis""" + + c1, c2, c12 = eta_u[0], eta_u[1], eta_v[0] + + flip_sign = False + if c12 < 0: + flip_sign = True + c1, c2, c12 = c1 + c12, 1 - c2, -c12 + + if c1 > c2 + c12: + u = [c1, c2 + c12] + elif c2 > c1 + c12: + u = [c1 + c12, c2] + else: + uu = (c1 + c2 + c12) / 2 + u = [uu, uu] + + u = np.clip(np.array(u), 0, 1) + v = np.atleast_1d(np.min(u)) + + if flip_sign: + u[1] = 1 - u[1] + v[0] = u[0] - v[0] + + return u, v + + pw = PairwiseFactor() + fw = SparseMAPFW(pw, max_iter=10000, tol=1e-12, variant=variant) + + params = [ + (np.array([0, 0]), np.array([0])), + (np.array([100, 0]), np.array([0])), + (np.array([0, 100]), np.array([0])), + (np.array([100, 0]), np.array([-100])), + (np.array([0, 100]), np.array([-100])) + ] + + rng = np.random.RandomState(0) + for _ in range(20): + eta_u = rng.randn(2) + eta_v = rng.randn(1) + params.append((eta_u, eta_v)) + + for eta_u, eta_v in params: + + u, v, active_set = fw.solve(eta_u, eta_v) + ustar, vstar = pw.qp(eta_u, eta_v) + + uv = np.concatenate([u, v]) + uvstar = np.concatenate([ustar, vstar]) + + assert_allclose(uv, uvstar, atol=1e-10) + + +@pytest.mark.parametrize('variant', ('vanilla', 'pairwise', 'away-step')) +@pytest.mark.parametrize('k', (1, 4, 20)) +def test_xor(variant, k): + class XORFactor(object): + """A one-of-K factor""" + + def __init__(self, k): + self.k = k + + def vertex(self, y): + # y is an integer between 0 and k-1 + u = np.zeros(k) + u[y] = 1 + v = np.array(()) + + return u, v + + def map_oracle(self, eta_u, eta_v): + return np.argmax(eta_u) + + def qp(self, eta_u, eta_v): + """Projection onto the simplex""" + z = 1 + v = np.array(eta_u) + n_features = v.shape[0] + u = np.sort(v)[::-1] + cssv = np.cumsum(u) - z + ind = np.arange(n_features) + 1 + cond = u - cssv / ind > 0 + rho = ind[cond][-1] + theta = cssv[cond][-1] / float(rho) + uu = np.maximum(v - theta, 0) + vv = np.array(()) + return uu, vv + + xor = XORFactor(k) + fw = SparseMAPFW(xor, max_iter=10000, tol=1e-12, variant=variant) + + params = [np.zeros(k), np.ones(k), np.full(k, -1)] + + rng = np.random.RandomState(0) + for _ in range(20): + eta_u = rng.randn(k) + params.append(eta_u) + + for eta_u in params: + + # try different ways of supplying empty eta_v + for eta_v in (np.array(()), [], 0, None): + + u, v, active_set = fw.solve(eta_u, eta_v) + ustar, vstar = xor.qp(eta_u, eta_v) + + uv = np.concatenate([u, v]) + uvstar = np.concatenate([ustar, vstar]) + + assert_allclose(uv, uvstar, atol=1e-10) diff --git a/python/sparsemap/utils.py b/python/sparsemap/utils.py deleted file mode 100644 index 4daed32..0000000 --- a/python/sparsemap/utils.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import division - -import math -import numpy as np - -import torch -from torch.autograd import Variable - - -def batch_slices(n_samples, batch_size=32): - n_batches = math.ceil(n_samples / batch_size) - batches = [slice(ix * batch_size, (ix + 1) * batch_size) - for ix in range(n_batches)] - return batches - - -def S_from_Ainv(Ainv): - """See footnote in notes.pdf""" - - # Ainv = torch.FloatTensor(Ainv).view(1 + n_active, 1 + n_active) - S = Ainv[1:, 1:] - k = Ainv[0, 0] - b = Ainv[0, 1:].unsqueeze(0) - - S -= (1 / k) * (b * b.t()) - return S - - -def expand_with_zeros(x, rows, cols): - orig_rows, orig_cols = x.size() - - ret = x - if orig_cols < cols: - horiz = Variable(x.data.new(orig_rows, cols - orig_cols).zero_()) - ret = torch.cat([ret, horiz], dim=-1) - - if orig_rows < rows: - vert = Variable(x.data.new(rows - orig_rows, cols).zero_()) - ret = torch.cat([ret, vert], dim=0) - - return ret - - -def zeros_like(torch_var): - data = torch_var.data.new(torch_var.size()).zero_() - return Variable(data) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..e78f276 --- /dev/null +++ b/setup.py @@ -0,0 +1,58 @@ +import os + +import numpy + +from setuptools import setup +from setuptools.extension import Extension +from Cython.Build import cythonize + +AD3_DIR = os.environ.get('AD3_DIR') +if not AD3_DIR: + print("Warning: please set the AD3_DIR environment variable to point " + "to the path where you have downloaded the AD3 library.") + exit(1) + +# PROJ_ROOT = os.path.dirname(os.path.abspath(__file__)) + + +ext_args = dict( + # below is a hack so we can include ./ad3 as well as ../ad3 + libraries=['ad3'], + extra_compile_args=["-std=c++11"], + language="c++") + + +package_dir = {'sparsemap': 'python/sparsemap'} + +setup(name="sparsemap", + version="0.1.dev0", + author="Vlad Niculae", + author_email="vlad@vene.ro", + package_dir=package_dir, + packages=['sparsemap', 'sparsemap.layers_pt', + 'sparsemap.layers_pt.tests'], + include_package_data=True, + ext_modules=cythonize([ + Extension( + "sparsemap._sparsemap", + ["python/sparsemap/_sparsemap.pyx"], + include_dirs=["src", + AD3_DIR, + os.path.join(AD3_DIR, 'python'), + numpy.get_include()], + library_dirs=[os.path.join(AD3_DIR, 'ad3')], + **ext_args), + Extension( + "sparsemap._factors", + ["python/sparsemap/_factors.pyx", + os.path.join("src", "lapjv", "lapjv.cpp"), + os.path.join("src", "FactorTree.cpp"), + ], + include_dirs=["src", + AD3_DIR, + os.path.join(AD3_DIR, 'python'), + numpy.get_include()], + library_dirs=[os.path.join(AD3_DIR, 'ad3')], + **ext_args), + ]) +) diff --git a/src/FactorTree.cpp b/src/FactorTree.cpp new file mode 100644 index 0000000..35c6354 --- /dev/null +++ b/src/FactorTree.cpp @@ -0,0 +1,253 @@ +// Copyright (c) 2012 Andre Martins +// All Rights Reserved. +// +// This file is part of AD3 2.1. +// +// AD3 2.1 is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// AD3 2.1 is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with AD3 2.1. If not, see . + +#ifndef FACTOR_TREE_H_ +#define FACTOR_TREE_H_ + +#include "FactorTree.h" +#include +#include +#include +#include +#include +#include +#include + +using namespace std; + +namespace AD3 { + +// Decoder for the basic model; it finds a maximum weighted arborescence +// using Chu-Liu-Edmonds' algorithm. +void FactorTree::RunCLE(const vector& scores, + vector *heads, + double *value) { + + + // Done once. + vector > candidate_heads(length_); + vector > candidate_scores(length_); + vector disabled(length_, false); + for (int m = 1; m < length_; ++m) { + for (int h = 0; h < length_; ++h) { + int r = index_arcs_[h][m]; + if (r < 0) continue; + candidate_heads[m].push_back(h); + candidate_scores[m].push_back(scores[r]); + } + } + + RunChuLiuEdmondsIteration(&disabled, &candidate_heads, + &candidate_scores, heads, + value); + + *value = 0; + (*heads)[0] = -1; + for (int m = 1; m < length_; ++m) { + int h = (*heads)[m]; + assert(h >= 0 && h < length_); + int r = index_arcs_[h][m]; + assert(r >= 0); + *value += scores[r]; + } +} + +void FactorTree::RunChuLiuEdmondsIteration(vector *disabled, + vector > *candidate_heads, + vector > + *candidate_scores, + vector *heads, + double *value) { + // Original number of nodes (including the root). + int length = disabled->size(); + + // Pick the best incoming arc for each node. + heads->resize(length); + vector best_scores(length); + for (int m = 1; m < length; ++m) { + if ((*disabled)[m]) continue; + int best = -1; + for (int k = 0; k < (*candidate_heads)[m].size(); ++k) { + if (best < 0 || + (*candidate_scores)[m][k] > (*candidate_scores)[m][best]) { + best = k; + } + } + if (best < 0) { + // No spanning tree exists. Assign the parent of this node + // to the root, and give it a minus infinity score. + (*heads)[m] = 0; + best_scores[m] = -std::numeric_limits::infinity(); + } else { + (*heads)[m] = (*candidate_heads)[m][best]; //best; + best_scores[m] = (*candidate_scores)[m][best]; //best; + } + } + + // Look for cycles. Return after the first cycle is found. + vector cycle; + vector visited(length, 0); + for (int m = 1; m < length; ++m) { + if ((*disabled)[m]) continue; + // Examine all the ancestors of m until the root or a cycle is found. + int h = m; + while (h != 0) { + // If already visited, break and check if it is part of a cycle. + // If visited[h] < m, the node was visited earlier and seen not + // to be part of a cycle. + if (visited[h]) break; + visited[h] = m; + h = (*heads)[h]; + } + + // Found a cycle to which h belongs. + // Obtain the full cycle. + if (visited[h] == m) { + m = h; + do { + cycle.push_back(m); + m = (*heads)[m]; + } while (m != h); + break; + } + } + + // If there are no cycles, then this is a well formed tree. + if (cycle.empty()) { + *value = 0.0; + for (int m = 1; m < length; ++m) { + *value += best_scores[m]; + } + return; + } + + // Build a cycle membership vector for constant-time querying and compute the + // score of the cycle. + // Nominate a representative node for the cycle and disable all the others. + double cycle_score = 0.0; + vector in_cycle(length, false); + int representative = cycle[0]; + for (int k = 0; k < cycle.size(); ++k) { + int m = cycle[k]; + in_cycle[m] = true; + cycle_score += best_scores[m]; + if (m != representative) (*disabled)[m] = true; + } + + // Contract the cycle. + // 1) Update the score of each child to the maximum score achieved by a parent + // node in the cycle. + vector best_heads_cycle(length); + for (int m = 1; m < length; ++m) { + if ((*disabled)[m] || m == representative) continue; + double best_score; + // If the list of candidate parents of m is shorter than the length of + // the cycle, use that. Otherwise, loop through the cycle. + int best = -1; + for (int k = 0; k < (*candidate_heads)[m].size(); ++k) { + if (!in_cycle[(*candidate_heads)[m][k]]) continue; + if (best < 0 || (*candidate_scores)[m][k] > best_score) { + best = k; + best_score = (*candidate_scores)[m][best]; + } + } + if (best < 0) continue; + best_heads_cycle[m] = (*candidate_heads)[m][best]; + + // Reconstruct the list of candidate heads for this m. + int l = 0; + for (int k = 0; k < (*candidate_heads)[m].size(); ++k) { + int h = (*candidate_heads)[m][k]; + double score = (*candidate_scores)[m][k]; + if (!in_cycle[h]) { + (*candidate_heads)[m][l] = h; + (*candidate_scores)[m][l] = score; + ++l; + } + } + // If h is in the cycle and is not the representative node, + // it will be dropped from the list of candidate heads. + (*candidate_heads)[m][l] = representative; + (*candidate_scores)[m][l] = best_score; + (*candidate_heads)[m].resize(l+1); + (*candidate_scores)[m].resize(l+1); + } + + // 2) Update the score of each candidate parent of the cycle supernode. + vector best_modifiers_cycle(length, -1); + vector candidate_heads_representative; + vector candidate_scores_representative; + + vector best_scores_cycle(length); + // Loop through the cycle. + for (int k = 0; k < cycle.size(); ++k) { + int m = cycle[k]; + for (int l = 0; l < (*candidate_heads)[m].size(); ++l) { + // Get heads out of the cycle. + int h = (*candidate_heads)[m][l]; + if (in_cycle[h]) continue; + + double score = (*candidate_scores)[m][l] - best_scores[m]; + if (best_modifiers_cycle[h] < 0 || score > best_scores_cycle[h]) { + best_modifiers_cycle[h] = m; + best_scores_cycle[h] = score; + } + } + } + for (int h = 0; h < length; ++h) { + if (best_modifiers_cycle[h] < 0) continue; + double best_score = best_scores_cycle[h] + cycle_score; + candidate_heads_representative.push_back(h); + candidate_scores_representative.push_back(best_score); + } + + // Reconstruct the list of candidate heads for the representative node. + (*candidate_heads)[representative] = candidate_heads_representative; + (*candidate_scores)[representative] = candidate_scores_representative; + + // Save the current head of the representative node (it will be overwritten). + int head_representative = (*heads)[representative]; + + // Call itself recursively. + RunChuLiuEdmondsIteration(disabled, + candidate_heads, + candidate_scores, + heads, + value); + + // Uncontract the cycle. + int h = (*heads)[representative]; + (*heads)[representative] = head_representative; + (*heads)[best_modifiers_cycle[h]] = h; + + for (int m = 1; m < length; ++m) { + if ((*disabled)[m]) continue; + if ((*heads)[m] == representative) { + // Get the right parent from within the cycle. + (*heads)[m] = best_heads_cycle[m]; + } + } + for (int k = 0; k < cycle.size(); ++k) { + int m = cycle[k]; + (*disabled)[m] = false; + } +} + +} // namespace AD3 + +#endif // FACTOR_TREE_H_ diff --git a/src/FactorTree.h b/src/FactorTree.h new file mode 100644 index 0000000..ebab010 --- /dev/null +++ b/src/FactorTree.h @@ -0,0 +1,154 @@ +// Copyright (c) 2012 Andre Martins +// All Rights Reserved. +// +// This file is part of AD3 2.1. +// +// AD3 2.1 is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// AD3 2.1 is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with AD3 2.1. If not, see . + +#ifndef FACTOR_TREE +#define FACTOR_TREE + +#include "ad3/GenericFactor.h" + +namespace AD3 { + +class Arc { + public: + Arc(int h, int m) : h_(h), m_(m) {} + ~Arc() {} + + int head() { return h_; } + int modifier() { return m_; } + + private: + int h_; + int m_; +}; + +class FactorTree : public GenericFactor { + public: + FactorTree() {} + virtual ~FactorTree() { ClearActiveSet(); } + + void RunCLE(const vector& scores, + vector *heads, + double *value); + + // Compute the score of a given assignment. + // Note: additional_log_potentials is empty and is ignored. + void Maximize(const vector &variable_log_potentials, + const vector &additional_log_potentials, + Configuration &configuration, + double *value) { + vector* heads = static_cast*>(configuration); + RunCLE(variable_log_potentials, heads, value); + } + + // Compute the score of a given assignment. + // Note: additional_log_potentials is empty and is ignored. + void Evaluate(const vector &variable_log_potentials, + const vector &additional_log_potentials, + const Configuration configuration, + double *value) { + const vector *heads = static_cast*>(configuration); + // Heads belong to {0,1,2,...} + *value = 0.0; + for (int m = 1; m < heads->size(); ++m) { + int h = (*heads)[m]; + int index = index_arcs_[h][m]; + *value += variable_log_potentials[index]; + } + } + + // Given a configuration with a probability (weight), + // increment the vectors of variable and additional posteriors. + // Note: additional_log_potentials is empty and is ignored. + void UpdateMarginalsFromConfiguration( + const Configuration &configuration, + double weight, + vector *variable_posteriors, + vector *additional_posteriors) { + const vector *heads = static_cast*>(configuration); + for (int m = 1; m < heads->size(); ++m) { + int h = (*heads)[m]; + int index = index_arcs_[h][m]; + (*variable_posteriors)[index] += weight; + } + } + + // Count how many common values two configurations have. + int CountCommonValues(const Configuration &configuration1, + const Configuration &configuration2) { + const vector *heads1 = static_cast*>(configuration1); + const vector *heads2 = static_cast*>(configuration2); + int count = 0; + for (int i = 1; i < heads1->size(); ++i) { + if ((*heads1)[i] == (*heads2)[i]) { + ++count; + } + } + return count; + } + + // Check if two configurations are the same. + bool SameConfiguration( + const Configuration &configuration1, + const Configuration &configuration2) { + const vector *heads1 = static_cast*>(configuration1); + const vector *heads2 = static_cast*>(configuration2); + for (int i = 1; i < heads1->size(); ++i) { + if ((*heads1)[i] != (*heads2)[i]) return false; + } + return true; + } + + // Delete configuration. + void DeleteConfiguration( + Configuration configuration) { + vector *heads = static_cast*>(configuration); + delete heads; + } + + // Create configuration. + Configuration CreateConfiguration() { + vector* heads = new vector(length_); + return static_cast(heads); + } + + public: + void Initialize(int length, const vector &arcs) { + length_ = length; + index_arcs_.assign(length, vector(length, -1)); + for (int k = 0; k < arcs.size(); ++k) { + int h = arcs[k]->head(); + int m = arcs[k]->modifier(); + index_arcs_[h][m] = k; + } + } + + private: + void RunChuLiuEdmondsIteration(vector *disabled, + vector > *candidate_heads, + vector > + *candidate_scores, + vector *heads, + double *value); + protected: + int length_; // Sentence length (including root symbol). + vector > index_arcs_; +}; + +} // namespace AD3 + +#endif // FACTOR_TREE