diff --git a/autograd/core.py b/autograd/core.py index 77a041ada..089008d68 100644 --- a/autograd/core.py +++ b/autograd/core.py @@ -90,6 +90,41 @@ def translate_vjp(vjpfun, fun, argnum): else: raise Exception("Bad VJP '{}' for '{}'".format(vjpfun, fun.__name__)) +def vjp_numeric(fun, argnum=0, step=1e-6, mode='centered'): + """ Evaluatest the vector-jacobian product numerically, using a step size + `step` to evaluate the jacobian. """ + + def vjpfun(ans, *args, **kwargs): + arg = args[argnum] + arg_vs = vspace(arg) + shape = arg_vs.shape + num_p = arg_vs.size + fn_vs = vspace(ans) + + def vjp(v): + vjp_num = arg_vs.zeros() + for ip in range(int(num_p)): + if mode == 'forward': + args_for = list(args) + args_for[argnum] = arg_vs.add(arg, arg_vs.scalar_mul(arg_vs.one_ind(ip), step)) + fn_for = fun(*args_for, **kwargs) + neg_ans = fn_vs.scalar_mul(ans, -1.0) + dfn_dp = fn_vs.scalar_mul(fn_vs.add(fn_for, neg_ans), 1.0/step) + elif mode == 'centered': + args_for = list(args) + args_for[argnum] = arg_vs.add(arg, arg_vs.scalar_mul(arg_vs.one_ind(ip), step/2)) + fn_for = fun(*args_for, **kwargs) + args_back = list(args) + args_back[argnum] = arg_vs.add(arg, arg_vs.scalar_mul(arg_vs.one_ind(ip), -step/2)) + fn_back = fun(*args_back, **kwargs) + neg_fn_back = fn_vs.scalar_mul(fn_back, -1.0) + dfn_dp = fn_vs.scalar_mul(fn_vs.add(fn_for, neg_fn_back), 1.0/step) + + vjp_num[arg_vs.one_ind(ip)==1.] = arg_vs.inner_prod(v, dfn_dp) + return vjp_num + return vjp + return vjpfun + # -------------------- forward mode -------------------- def make_jvp(fun, x): diff --git a/autograd/extend.py b/autograd/extend.py index 16879cd80..16280b8a5 100644 --- a/autograd/extend.py +++ b/autograd/extend.py @@ -2,4 +2,4 @@ from .tracer import Box, primitive, register_notrace, notrace_primitive from .core import (SparseObject, VSpace, vspace, VJPNode, JVPNode, defvjp_argnums, defvjp_argnum, defvjp, - defjvp_argnums, defjvp_argnum, defjvp, def_linear) + defjvp_argnums, defjvp_argnum, defjvp, def_linear, vjp_numeric) diff --git a/autograd/numpy/numpy_vspaces.py b/autograd/numpy/numpy_vspaces.py index 8eda1b2a2..327b6de1f 100644 --- a/autograd/numpy/numpy_vspaces.py +++ b/autograd/numpy/numpy_vspaces.py @@ -13,6 +13,10 @@ def size(self): return np.prod(self.shape) def ndim(self): return len(self.shape) def zeros(self): return np.zeros(self.shape, dtype=self.dtype) def ones(self): return np.ones( self.shape, dtype=self.dtype) + def one_ind(self, ind): + out = np.zeros(self.shape, dtype=self.dtype) + out[np.unravel_index(ind, shape=self.shape)] = np.array([1.]).astype(dtype=self.dtype) + return out def standard_basis(self): for idxs in np.ndindex(*self.shape):