From 537b208dac609206f2e2fb5803e6c0659a4e731f Mon Sep 17 00:00:00 2001 From: Greg Yang Date: Mon, 30 Dec 2019 15:13:42 +0000 Subject: [PATCH 1/9] Added single-batch batchnorm kernel We add `quadpy` as a dependency for the numerical integration required. --- README.md | 2 + neural_tangents/stax.py | 104 +++++++++++++++++++++++++++++ neural_tangents/tests/stax_test.py | 59 ++++++++++------ setup.py | 3 +- 4 files changed, 145 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 792b56d3..83365b66 100644 --- a/README.md +++ b/README.md @@ -401,3 +401,5 @@ Coming soon. ##### [13] [Mean Field Residual Networks: On the Edge of Chaos.](https://arxiv.org/abs/1712.08969) *NeurIPS 2017.* Greg Yang, Samuel S. Schoenholz ##### [14] [Wide Residual Networks.](https://arxiv.org/abs/1605.07146) *BMVC 2018.* Sergey Zagoruyko, Nikos Komodakis + +##### [15] [Tensor Programs I: Wide Feedforward or Recurrent Neural Networks of Any Architecture are Gaussian Processes.](https://arxiv.org/abs/1910.12478) *NeurIPS 2019.* Greg Yang. diff --git a/neural_tangents/stax.py b/neural_tangents/stax.py index 8349fa77..c58d816f 100644 --- a/neural_tangents/stax.py +++ b/neural_tangents/stax.py @@ -82,7 +82,9 @@ from neural_tangents.utils.kernel import Kernel from neural_tangents.utils import utils from neural_tangents.utils.kernel import Marginalisation as M +from jax.nn.initializers import ones, zeros import frozendict +import quadpy as qp _CONV_DIMENSION_NUMBERS = ('NHWC', 'HWIO', 'NHWC') @@ -1881,3 +1883,105 @@ def kernel_fn(kernels): setattr(kernel_fn, _COVARIANCES_REQ, {'marginal': M.OVER_PIXELS, 'cross': M.OVER_PIXELS}) return init_fn, apply_fn, kernel_fn + +@_layer +def BatchNormRelu(axis, epsilon=0., center=True, scale=True, + beta_init=zeros, gamma_init=ones): + """Layer construction function for a batch normalization layer. + + See the papers below for the derivation. + https://arxiv.org/abs/1902.08129 + https://arxiv.org/abs/1910.12478 + + The implementation here follows the reference implementation in + https://github.com/thegregyang/GP4A + """ + + assert epsilon < 1e-12 + + _beta_init = lambda rng, shape: beta_init(rng, shape) if center else () + _gamma_init = lambda rng, shape: gamma_init(rng, shape) if scale else () + axis = (axis,) if np.isscalar(axis) else axis + + init_fn, bn_apply_fn = ostax.BatchNorm( + axis, epsilon, center, scale, beta_init, gamma_init) + + def apply_fn(params, xs, **kwargs): + xs = bn_apply_fn(params, xs) + return _ab_relu(xs, 0, 1) + + # NOTE: Currently assumes var2 is None and computes single batch kernel + def kernel_fn(kernels): + batch_size = kernels.var1.shape[0] + assert kernels.var2 is None + + G = np.eye(batch_size) - np.ones((batch_size, batch_size)) / batch_size + eigvals, eigvecs = np.linalg.eigh(G @ kernels.nngp @ G) + logeigvals = np.log(eigvals[1:]) + eigvecs = eigvecs[:, 1:] + + # NOTE: Likely the same as _get_ab_relu_kernel. + def VReLU(cov, eps=1e-5): + indices = list(range(cov.shape[-1])) + d = np.sqrt(cov[..., indices, indices]) + dd = d[..., np.newaxis] * d[..., np.newaxis, :] + c = dd ** (-1) * cov + c = np.where(c > 1 - eps, 1 - eps, c) + c = np.where(c < -1 + eps, -1 + eps, c) + c = (np.sqrt(1 - c ** 2) + (np.pi - np.arccos(c)) * c) / np.pi + return np.nan_to_num(0.5 * dd * c) + + # NOTE: eps is a stability factor for the integral, not for batch norm. + def integrand(log_s, logmultifactor=0, eps=1e-10): + logUeigvals = np.logaddexp(0, np.log(2) + log_s[..., None] + logeigvals) + loginteigvals = logeigvals - logUeigvals + + loginteigvals -= 0.5 * np.sum(logUeigvals, axis=-1, keepdims=True) + loginteigvals += logmultifactor + + inteigvals = np.exp(loginteigvals) + + intvals = np.einsum( + 'ij,...j,jk->...ik', eigvecs, inteigvals, eigvecs.T, optimize=True) + return VReLU(intvals) + + # TODO(Greg): Compute the split point correctly. + intargmax = -np.log(batch_size - 1) + npos = 10 # TODO(Greg): Make these parameters, maybe? + nneg = 10 + alpha = 1 / 8. + + schemepos = qp.e1r.gauss_laguerre(npos) + schemeneg = qp.e1r.gauss_laguerre(nneg) + + schemepospoints = schemepos.points + schemenegpoints = schemeneg.points + schemeposweights = schemepos.weights + schemenegweights = schemeneg.weights + + + + integrandpos = lambda xs: np.moveaxis( + np.moveaxis( + alpha * integrand(intargmax + alpha * xs, + logmultifactor=( + intargmax + (1 + alpha) * xs)[..., np.newaxis]), + 2, 0), 3, 1) + + integrandneg = lambda xs: np.moveaxis( + np.exp(intargmax) * np.moveaxis(integrand(intargmax - xs), 2, 0), 3, 1) + + new_nngp = batch_size * ( + np.einsum('...i,i->...', + integrandpos(np.array([schemepospoints.T])), + schemeposweights) + + + np.einsum('...i,i->...', + integrandneg(np.array([schemenegpoints.T])), + schemenegweights)).squeeze(-1) + + var = np.diag(new_nngp) + return kernels._replace(var1=var, nngp=new_nngp, var2=var) + + return init_fn, apply_fn, kernel_fn + diff --git a/neural_tangents/tests/stax_test.py b/neural_tangents/tests/stax_test.py index 94570a86..c5fd3ee5 100644 --- a/neural_tangents/tests/stax_test.py +++ b/neural_tangents/tests/stax_test.py @@ -78,12 +78,14 @@ 'ATTN_PARAM' ] -LAYER_NORM = [ +NORM_AXIS = [ + (0,), (-1,), (1, 3), (1, 2, 3) ] + PARAMETERIZATIONS = ['NTK', 'STANDARD'] utils.update_test_tolerance() @@ -98,7 +100,7 @@ def _get_inputs(key, is_conv, same_inputs, input_shape, fn=np.cos): def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, - phi, strides, width, is_ntk, proj_into_2d, layer_norm, + phi, strides, width, is_ntk, proj_into_2d, norm_axis, parameterization): fc = partial( stax.Dense, W_std=W_std, b_std=b_std, parameterization=parameterization) @@ -116,17 +118,24 @@ def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, (2, 3), None, 'SAME' if padding == 'SAME' else 'CIRCULAR') if use_pooling else stax.Identity()), phi, affine) + # If norm_axis is 0, then testing batchnorm; otherwise, layernorm + if norm_axis is None: + norm = stax.Identity() + elif norm_axis == 0 or norm_axis == (0,): + norm = stax.BatchNormRelu((0,)) + else: + norm = stax.LayerNorm(axis=norm_axis) if is_res: block = stax.serial( affine, stax.FanOut(2),stax.parallel(stax.Identity(), res_unit), stax.FanInSum(), - stax.Identity() if layer_norm is None - else stax.LayerNorm(axis=layer_norm)) + norm + ) else: block = stax.serial( affine, res_unit, - stax.Identity() if layer_norm is None - else stax.LayerNorm(axis=layer_norm)) + norm + ) if proj_into_2d == 'FLAT': proj_layer = stax.Flatten() @@ -219,11 +228,11 @@ def test_exact(self, model, width, strides, padding, phi, same_inputs, ' the spatial dimensions is singleton.') W_std, b_std = 2.**0.5, 0.5**0.5 - layer_norm = None + norm_axis = None parameterization = 'ntk' self._check_agreement_with_empirical(W_std, b_std, filter_size, is_conv, - is_ntk, is_res, layer_norm, padding, + is_ntk, is_res, norm_axis, padding, phi, proj_into_2d, same_inputs, strides, use_pooling, width, parameterization) @@ -266,7 +275,7 @@ def test_parameterizations(self, model, width, same_inputs, is_ntk, strides = STRIDES[0] phi = stax.Relu() use_pooling, is_res = False, False - layer_norm = None + norm_axis = None # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: @@ -276,7 +285,7 @@ def test_parameterizations(self, model, width, same_inputs, is_ntk, raise jtu.SkipTest('FC models do not have these parameters.') self._check_agreement_with_empirical(W_std, b_std, filter_size, is_conv, - is_ntk, is_res, layer_norm, padding, + is_ntk, is_res, norm_axis, padding, phi, proj_into_2d, same_inputs, strides, use_pooling, width, parameterization) @@ -289,7 +298,7 @@ def test_parameterizations(self, model, width, same_inputs, is_ntk, model, width, 'same_inputs' if same_inputs else 'different_inputs', 'NTK' if is_ntk else 'NNGP', proj_into_2d, - 'layer_norm=%s' % str(layer_norm)), + 'norm_axis=%s' % str(norm_axis)), 'model': model, 'width': @@ -300,21 +309,21 @@ def test_parameterizations(self, model, width, same_inputs, is_ntk, is_ntk, 'proj_into_2d': proj_into_2d, - 'layer_norm': - layer_norm + 'norm_axis': + norm_axis } for model in MODELS for width in WIDTHS for same_inputs in [False, True] for is_ntk in [False, True] for proj_into_2d in PROJECTIONS[:2] - for layer_norm in LAYER_NORM)) + for norm_axis in NORM_AXIS)) def test_layernorm(self, model, width, same_inputs, is_ntk, - proj_into_2d, layer_norm): + proj_into_2d, norm_axis): is_conv = 'conv' in model # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest('Not running CNN models on CPU to save time.') - elif proj_into_2d != PROJECTIONS[0] or layer_norm != LAYER_NORM[0]: + elif proj_into_2d != PROJECTIONS[0] or norm_axis != NORM_AXIS[0]: raise jtu.SkipTest('FC models do not have these parameters.') W_std, b_std = 2.**0.5, 0.5**0.5 @@ -325,11 +334,13 @@ def test_layernorm(self, model, width, same_inputs, is_ntk, use_pooling, is_res = False, False parameterization = 'ntk' + # when testing batchnorm, use batch size 5 (instead of 2) + input_shape = INPUT_SHAPE if 0 not in norm_axis else (5,) + INPUT_SHAPE[1:] self._check_agreement_with_empirical(W_std, b_std, filter_size, is_conv, - is_ntk, is_res, layer_norm, padding, + is_ntk, is_res, norm_axis, padding, phi, proj_into_2d, same_inputs, strides, use_pooling, width, - parameterization) + parameterization, input_shape=input_shape) def test_avg_pool(self): X1 = np.ones((4, 2, 3, 2)) @@ -380,15 +391,16 @@ def test_avg_pool(self): True) def _check_agreement_with_empirical(self, W_std, b_std, filter_size, is_conv, - is_ntk, is_res, layer_norm, padding, phi, + is_ntk, is_res, norm_axis, padding, phi, proj_into_2d, same_inputs, strides, - use_pooling, width, parameterization): + use_pooling, width, parameterization, + input_shape=INPUT_SHAPE): key = random.PRNGKey(1) - x1, x2 = _get_inputs(key, is_conv, same_inputs, INPUT_SHAPE) + x1, x2 = _get_inputs(key, is_conv, same_inputs, input_shape) init_fn, apply_fn, kernel_fn = _get_net(W_std, b_std, filter_size, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, - proj_into_2d, layer_norm, + proj_into_2d, norm_axis, parameterization) x1_out_shape, params = init_fn(key, x1.shape) @@ -403,6 +415,9 @@ def _get_empirical(n_samples, get): init_fn, apply_fn, key, n_samples) return kernel_fn_empirical(x1, x2, get) + if (x2 is not None or is_ntk) and norm_axis == (0,): + # TODO(Greg): implement + return if proj_into_2d == 'ATTN_PARAM': # no analytic kernel available, just test forward/backward pass _get_empirical(1, 'ntk' if is_ntk else 'nngp') diff --git a/setup.py b/setup.py index 661080d8..e5840859 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,8 @@ INSTALL_REQUIRES = [ 'jaxlib>=0.1.36', 'jax>=0.1.53', - 'frozendict' + 'frozendict', + 'quadpy' ] From 18d1dee2fa7cd850ac7d58302ae7f3603617e9e6 Mon Sep 17 00:00:00 2001 From: Greg Yang Date: Thu, 30 Apr 2020 23:22:46 +0000 Subject: [PATCH 2/9] got cross batch working; waiting for polish --- neural_tangents/stax.py | 317 +++++++++++++++++++++++++++++++--------- 1 file changed, 245 insertions(+), 72 deletions(-) diff --git a/neural_tangents/stax.py b/neural_tangents/stax.py index 75f02312..32c75bc3 100644 --- a/neural_tangents/stax.py +++ b/neural_tangents/stax.py @@ -2919,7 +2919,7 @@ def _pool_mask(mask: np.ndarray, return mask -@_layer +@layer def BatchNormRelu(axis, epsilon=0., center=True, scale=True, beta_init=zeros, gamma_init=ones): """Layer construction function for a batch normalization layer. @@ -2932,6 +2932,9 @@ def BatchNormRelu(axis, epsilon=0., center=True, scale=True, https://github.com/thegregyang/GP4A """ + iu = ops.index_update + ix = ops.index + assert epsilon < 1e-12 _beta_init = lambda rng, shape: beta_init(rng, shape) if center else () @@ -2945,77 +2948,247 @@ def apply_fn(params, xs, **kwargs): xs = bn_apply_fn(params, xs) return _ab_relu(xs, 0, 1) - # NOTE: Currently assumes var2 is None and computes single batch kernel + # NOTE: Currently assumes cov2 is None and computes single batch kernel + @_requires(diagonal_batch=False) def kernel_fn(kernels): - batch_size = kernels.var1.shape[0] - assert kernels.var2 is None - - G = np.eye(batch_size) - np.ones((batch_size, batch_size)) / batch_size - eigvals, eigvecs = np.linalg.eigh(G @ kernels.nngp @ G) - logeigvals = np.log(eigvals[1:]) - eigvecs = eigvecs[:, 1:] - - # NOTE: Likely the same as _get_ab_relu_kernel. - def VReLU(cov, eps=1e-5): - indices = list(range(cov.shape[-1])) - d = np.sqrt(cov[..., indices, indices]) - dd = d[..., np.newaxis] * d[..., np.newaxis, :] - c = dd ** (-1) * cov - c = np.where(c > 1 - eps, 1 - eps, c) - c = np.where(c < -1 + eps, -1 + eps, c) - c = (np.sqrt(1 - c ** 2) + (np.pi - np.arccos(c)) * c) / np.pi - return np.nan_to_num(0.5 * dd * c) - - # NOTE: eps is a stability factor for the integral, not for batch norm. - def integrand(log_s, logmultifactor=0, eps=1e-10): - logUeigvals = np.logaddexp(0, np.log(2) + log_s[..., None] + logeigvals) - loginteigvals = logeigvals - logUeigvals - - loginteigvals -= 0.5 * np.sum(logUeigvals, axis=-1, keepdims=True) - loginteigvals += logmultifactor - - inteigvals = np.exp(loginteigvals) - - intvals = np.einsum( - 'ij,...j,jk->...ik', eigvecs, inteigvals, eigvecs.T, optimize=True) - return VReLU(intvals) - - # TODO(Greg): Compute the split point correctly. - intargmax = -np.log(batch_size - 1) - npos = 10 # TODO(Greg): Make these parameters, maybe? - nneg = 10 - alpha = 1 / 8. - - schemepos = qp.e1r.gauss_laguerre(npos) - schemeneg = qp.e1r.gauss_laguerre(nneg) - - schemepospoints = schemepos.points - schemenegpoints = schemeneg.points - schemeposweights = schemepos.weights - schemenegweights = schemeneg.weights - - - - integrandpos = lambda xs: np.moveaxis( - np.moveaxis( - alpha * integrand(intargmax + alpha * xs, - logmultifactor=( - intargmax + (1 + alpha) * xs)[..., np.newaxis]), - 2, 0), 3, 1) - - integrandneg = lambda xs: np.moveaxis( - np.exp(intargmax) * np.moveaxis(integrand(intargmax - xs), 2, 0), 3, 1) - - new_nngp = batch_size * ( - np.einsum('...i,i->...', - integrandpos(np.array([schemepospoints.T])), - schemeposweights) - + - np.einsum('...i,i->...', - integrandneg(np.array([schemenegpoints.T])), - schemenegweights)).squeeze(-1) - - var = np.diag(new_nngp) - return kernels._replace(var1=var, nngp=new_nngp, var2=var) + # print(kernels.cov1) + # assert kernels.cov2 is None + + # apply the below code to cov1 and cov2 + # write new cross batch code for (nngp, cov1, cov2) + + def Gmatrix(batch_size): + return np.eye(batch_size) - np.ones((batch_size, batch_size)) / batch_size + + def singlebatchker(ker): + batch_size = ker.shape[0] + G = Gmatrix(batch_size) + eigvals, eigvecs = np.linalg.eigh(G @ ker @ G) + logeigvals = np.log(eigvals[1:]) + eigvecs = eigvecs[:, 1:] + + # NOTE: Likely the same as _get_ab_relu_kernel. + def VReLU(cov, eps=1e-5): + indices = list(range(cov.shape[-1])) + d = np.sqrt(cov[..., indices, indices]) + dd = d[..., np.newaxis] * d[..., np.newaxis, :] + c = dd ** (-1) * cov + c = np.where(c > 1 - eps, 1 - eps, c) + c = np.where(c < -1 + eps, -1 + eps, c) + c = (np.sqrt(1 - c ** 2) + (np.pi - np.arccos(c)) * c) / np.pi + return np.nan_to_num(0.5 * dd * c) + + # NOTE: eps is a stability factor for the integral, not for batch norm. + def integrand(log_s, logmultifactor=0, eps=1e-10): + logUeigvals = np.logaddexp(0, np.log(2) + log_s[..., None] + logeigvals) + loginteigvals = logeigvals - logUeigvals + + loginteigvals -= 0.5 * np.sum(logUeigvals, axis=-1, keepdims=True) + loginteigvals += logmultifactor + + inteigvals = np.exp(loginteigvals) + + intvals = np.einsum( + 'ij,...j,jk->...ik', eigvecs, inteigvals, eigvecs.T, optimize=True) + return VReLU(intvals) + + # TODO(Greg): Compute the split point correctly. + intargmax = -np.log(batch_size - 1) + npos = 10 # TODO(Greg): Make these parameters, maybe? + nneg = 10 + alpha = 1 / 8. + + schemepos = qp.e1r.gauss_laguerre(npos) + schemeneg = qp.e1r.gauss_laguerre(nneg) + + schemepospoints = schemepos.points + schemenegpoints = schemeneg.points + schemeposweights = schemepos.weights + schemenegweights = schemeneg.weights + + + + integrandpos = lambda xs: np.moveaxis( + np.moveaxis( + alpha * integrand(intargmax + alpha * xs, + logmultifactor=( + intargmax + (1 + alpha) * xs)[..., np.newaxis]), + 2, 0), 3, 1) + + integrandneg = lambda xs: np.moveaxis( + np.exp(intargmax) * np.moveaxis(integrand(intargmax - xs), 2, 0), 3, 1) + + new_nngp = batch_size * ( + np.einsum('...i,i->...', + integrandpos(np.array([schemepospoints.T])), + schemeposweights) + + + np.einsum('...i,i->...', + integrandneg(np.array([schemenegpoints.T])), + schemenegweights)).squeeze(-1) + return new_nngp + + def J1(c, eps=1e-10): + c = np.clip(c, -1+eps, 1-eps) + return (np.sqrt(1-c**2) + (np.pi - np.arccos(c)) * c) / np.pi + + def VBNReLUCrossBatchIntegrand(Xi, Sigma1, Sigma2): + '''Computes the off diagonal block of the BN+ReLU kernel over 2 batches + Input: + Xi: covariance between batch1 and batch2 + Sigma1: autocovariance of batch1 + Sigma2: autocovariance of batch2 + Output: + f: integrand function in the integral for computing cross batch VBNReLU + ''' + n1 = Sigma1.shape[0] + n2 = Sigma2.shape[0] + G1 = Gmatrix(n1) + G2 = Gmatrix(n2) + Delta1, A1 = np.linalg.eigh(G1 @ Sigma1 @ G1) + Delta2, A2 = np.linalg.eigh(G2 @ Sigma2 @ G2) + # kill first 0 eigenval + Delta1 = Delta1[1:] + Delta2 = Delta2[1:] + A1 = A1[:, 1:] + A2 = A2[:, 1:] + + Xidot = A1.T @ Xi @ A2 + Omegadot = np.block([[np.diag(Delta1), Xidot], [Xidot.T, np.diag(Delta2)]]) + Omegadotinv = np.linalg.inv(Omegadot) + + def f(s, t, multfactor=1): + # import pdb + # pdb.set_trace() + # Ddot.shape = (..., n1+n2-2, n1+n2-2) + Ddot = s[..., None, None] * np.eye(n1-1+n2-1) + # Ddot[..., np.arange(n1-1, n1+n2-2), np.arange(n1-1, n1+n2-2)] = t[..., None] + Ddot = iu(Ddot, ix[..., np.arange(n1-1, n1+n2-2), np.arange(n1-1, n1+n2-2)], + t[..., None]) + + ## Compute off-diagonal block of VReLU(Pi) + Pitilde = Omegadotinv + 2 * Ddot + Pitilde = np.linalg.inv(Pitilde) + Pi11diag = np.einsum('ij,...jk,ki->...i', + A1, + Pitilde[..., :n1-1, :n1-1], + A1.T) + Pi22diag = np.einsum('ij,...jk,ki->...i', + A2, + Pitilde[..., n1-1:, n1-1:], + A2.T) + Pi12 = np.einsum('ij,...jk,kl->...il', + A1, + Pitilde[..., :n1-1, n1-1:], + A2.T) + C = J1(np.einsum('...i,...ij,...j->...ij', + Pi11diag**-0.5, + Pi12, + Pi22diag**-0.5)) + VReLUPi12 = 0.5 * np.einsum('...i,...ij,...j->...ij', + Pi11diag**0.5, + C, + Pi22diag**0.5) + + + ## Compute determinant + ind = np.arange(n1+n2-2) + # Ddot <- matrix inverse of Ddot + # Ddot[..., ind, ind] = Ddot[..., ind, ind]**-1 + Ddot = iu(Ddot, ix[..., ind, ind], Ddot[..., ind, ind]**-1) + logdet = np.linalg.slogdet(Ddot + 2 * Omegadot)[1] + return np.exp( + (np.log(multfactor) + + (-n1/2) * np.log(s) + + (-n2/2) * np.log(t) + - 1/2 * logdet)[..., None, None] + np.log(VReLUPi12)) + return f + + def VBNReLUCrossBatch(Xi, Sigma1, Sigma2, npos=10, nneg=5, + alphapos1=1/3, alphaneg1=1, + alphapos2=1/3, alphaneg2=1): + '''Compute VBNReLU for two batches. + + Inputs: + Xi: covariance between batch1 and batch2 + Sigma1: autocovariance of batch1 + Sigma2: autocovariance of batch2 + npos: number of points for integrating the big s side of the VBNReLU integral + (effective for both dimensions of integration) + nneg: number of points for integrating the small s side of the VBNReLU integral + (effective for both dimensions of integration) + alphapos1: reparametrize the large s integral by s = exp(alpha r) in the 1st dimension + alphaneg1: reparametrize the small s integral by s = exp(alpha r) in the 1st dimension + alphapos2: reparametrize the large s integral by s = exp(alpha r) in the 2nd dimension + alphaneg2: reparametrize the small s integral by s = exp(alpha r) in the 2nd dimension + By tuning the `alpha` parameters, the integrand is closer to being well-approximated by + low-degree Laguerre polynomials, which makes the quadrature more accurate in approximating the integral. + Outputs: + The (batch1, batch2) block of block matrix obtained by + applying VBNReLU^{\oplus 2} to the kernel of batch1 and batch2 + ''' + # import pdb + # pdb.set_trace() + # We will do the integration explicitly ourselves: + # We obtain sample points and weights via `quadpy`'s Gauss Laguerre quadrature + # and do the sum ourselves + schemepos = qp.e1r.gauss_laguerre(npos, alpha=0) + schemeneg = qp.e1r.gauss_laguerre(nneg, alpha=0) + dim1 = Sigma1.shape[0] + dim2 = Sigma2.shape[0] + intargmax = (-np.log(2*(dim1-1)), -np.log(2*(dim2-1))) + f = VBNReLUCrossBatchIntegrand(Xi, Sigma1, Sigma2) + # Get the points manually for each dimension + scheme1dpoints = np.concatenate([schemepos.points, -schemeneg.points]) + # Get the weights manually for each dimension + scheme1dwts = np.concatenate([schemepos.weights, schemeneg.weights]) + # Obtain the points for the whole 2d integration + scheme2dpoints = np.meshgrid(scheme1dpoints, scheme1dpoints) + # Obtain the weights for the whole 2d integration + scheme2dwts = scheme1dwts[:, None] * scheme1dwts[None, :] + + def applyalpha(x, alphapos, alphaneg): + x = iu(x, ix[x > 0], x[x > 0] * alphapos) + x = iu(x, ix[x <= 0], x[x <= 0] * alphaneg) + # xx[xx > 0] *= alphapos + # xx[xx <= 0] *= alphaneg + return x + + # def iu(x, mask, ) + def alphafactor(x, y): + a = np.zeros_like(x) + a = iu(a, ix[(x > 0) & (y > 0)], alphapos1 * alphapos2) + a = iu(a, ix[(x > 0) & (y <= 0)], alphapos1 * alphaneg2) + a = iu(a, ix[(x <= 0) & (y > 0)], alphaneg1 * alphapos2) + a = iu(a, ix[(x <= 0) & (y <= 0)], alphaneg1 * alphaneg2) + return a + + integrand = lambda inp: \ + f(np.exp(applyalpha(inp[0], alphapos1, alphaneg1) + intargmax[0]), + np.exp(applyalpha(inp[1], alphapos2, alphaneg2) + intargmax[1]), + multfactor=alphafactor(inp[0], inp[1]) + * np.pi**-1 + * np.exp(applyalpha(inp[0], alphapos1, alphaneg1) + intargmax[0] + + applyalpha(inp[1], alphapos2, alphaneg2) + intargmax[1] + + np.abs(inp[0]) + np.abs(inp[1]) + ) + ) + + return np.sqrt(dim1 * dim2) * np.einsum('ij...,ij->...', + integrand(scheme2dpoints), + scheme2dwts + ) + + + cov1 = singlebatchker(kernels.cov1) + if kernels.cov2 is None: + return kernels.replace(cov1=cov1, nngp=cov1, cov2=None) + + cov2 = singlebatchker(kernels.cov2) + + # TODO compute cross batch + new_nngp = VBNReLUCrossBatch(kernels.nngp, kernels.cov1, kernels.cov2) + + return kernels.replace(cov1=cov1, nngp=new_nngp, cov2=cov2) return init_fn, apply_fn, kernel_fn \ No newline at end of file From 466f14f62834cebd57742145504dfd896e2e8950 Mon Sep 17 00:00:00 2001 From: Greg Yang Date: Thu, 7 May 2020 08:12:55 +0000 Subject: [PATCH 3/9] fix setup requirements --- setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 9fb8cb88..9fbc15d0 100644 --- a/setup.py +++ b/setup.py @@ -26,8 +26,7 @@ INSTALL_REQUIRES = [ - 'jaxlib>=0.1.58', - 'jax>=0.1.55', + 'jax>=0.1.58', 'frozendict', 'dataclasses', 'quadpy' From fb69975a48f900168c2e4021f0fb69490884d2dc Mon Sep 17 00:00:00 2001 From: Greg Yang Date: Thu, 7 May 2020 08:16:00 +0000 Subject: [PATCH 4/9] delete pdb stuff --- examples/infinite_fcn.py | 49 +++++++++++++++------------------------- 1 file changed, 18 insertions(+), 31 deletions(-) diff --git a/examples/infinite_fcn.py b/examples/infinite_fcn.py index db90c6b9..c3220ccc 100644 --- a/examples/infinite_fcn.py +++ b/examples/infinite_fcn.py @@ -23,8 +23,7 @@ import jax.numpy as np import neural_tangents as nt from neural_tangents import stax -from examples import datasets -from examples import util +from jax import random flags.DEFINE_integer('train_size', 1000, @@ -37,43 +36,31 @@ FLAGS = flags.FLAGS +import pdb def main(unused_argv): # Build data pipelines. print('Loading data.') - x_train, y_train, x_test, y_test = \ - datasets.get_dataset('cifar10', FLAGS.train_size, FLAGS.test_size) + key = random.PRNGKey(0) + key, split = random.split(key) + x_train = random.normal(key=key, shape=[10, 30]) + x_train2 = random.normal(key=split, shape=[10, 30]) # Build the infinite network. - _, _, kernel_fn = stax.serial( - stax.Dense(1, 2., 0.05), - stax.Relu(), - stax.Dense(1, 2., 0.05) + init_fn, apply_fn, kernel_fn = stax.serial( + stax.Dense(1000, 2., 0.05), + stax.BatchNormRelu(0), + stax.Dense(1000, 2., 0.05) ) - # Optionally, compute the kernel in batches, in parallel. - kernel_fn = nt.batch(kernel_fn, - device_count=0, - batch_size=FLAGS.batch_size) - - start = time.time() - # Bayesian and infinite-time gradient descent inference with infinite network. - fx_test_nngp, fx_test_ntk = nt.predict.gp_inference(kernel_fn, - x_train, - y_train, - x_test, - get=('nngp', 'ntk'), - diag_reg=1e-3) - fx_test_nngp.block_until_ready() - fx_test_ntk.block_until_ready() - - duration = time.time() - start - print('Kernel construction and inference done in %s seconds.' % duration) - - # Print out accuracy and loss for infinite network predictions. - loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2) - util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss) - util.print_summary('NTK test', y_test, fx_test_ntk, None, loss) + mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key, 1000) + kerobj = kernel_fn(x_train, x_train2) + theory_ker = kerobj.nngp + diff = theory_ker - mc_kernel_fn(x_train, x_train2, get='nngp') + print(diff) + # print(kerobj.cov1 - kerobj.nngp) + print(np.linalg.norm(diff) / np.linalg.norm(theory_ker)) + return if __name__ == '__main__': From ecdd7ad9f86cedecec5d9c6774cde5c304e1a3eb Mon Sep 17 00:00:00 2001 From: Greg Yang Date: Thu, 7 May 2020 08:21:27 +0000 Subject: [PATCH 5/9] Revert "fix setup requirements" This reverts commit 466f14f62834cebd57742145504dfd896e2e8950. --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 9fbc15d0..9fb8cb88 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,8 @@ INSTALL_REQUIRES = [ - 'jax>=0.1.58', + 'jaxlib>=0.1.58', + 'jax>=0.1.55', 'frozendict', 'dataclasses', 'quadpy' From 6d2e68258bf617dfa9acfd18feda8e2103ee6649 Mon Sep 17 00:00:00 2001 From: Greg Yang Date: Thu, 7 May 2020 08:22:45 +0000 Subject: [PATCH 6/9] remove pdb stuff --- neural_tangents/stax.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/neural_tangents/stax.py b/neural_tangents/stax.py index 32c75bc3..fa0b9c66 100644 --- a/neural_tangents/stax.py +++ b/neural_tangents/stax.py @@ -3058,8 +3058,6 @@ def VBNReLUCrossBatchIntegrand(Xi, Sigma1, Sigma2): Omegadotinv = np.linalg.inv(Omegadot) def f(s, t, multfactor=1): - # import pdb - # pdb.set_trace() # Ddot.shape = (..., n1+n2-2, n1+n2-2) Ddot = s[..., None, None] * np.eye(n1-1+n2-1) # Ddot[..., np.arange(n1-1, n1+n2-2), np.arange(n1-1, n1+n2-2)] = t[..., None] @@ -3127,8 +3125,6 @@ def VBNReLUCrossBatch(Xi, Sigma1, Sigma2, npos=10, nneg=5, The (batch1, batch2) block of block matrix obtained by applying VBNReLU^{\oplus 2} to the kernel of batch1 and batch2 ''' - # import pdb - # pdb.set_trace() # We will do the integration explicitly ourselves: # We obtain sample points and weights via `quadpy`'s Gauss Laguerre quadrature # and do the sum ourselves From 15cd6ace8d85ee99e3de7779bcb7c2c48ce1cf7e Mon Sep 17 00:00:00 2001 From: Greg Yang Date: Thu, 14 May 2020 22:41:59 +0000 Subject: [PATCH 7/9] batchnorm fcn fully working --- neural_tangents/stax.py | 29 ++++++++------ neural_tangents/tests/stax_test.py | 63 ++++++++++++++++++++---------- setup.py | 2 +- 3 files changed, 61 insertions(+), 33 deletions(-) diff --git a/neural_tangents/stax.py b/neural_tangents/stax.py index fa0b9c66..d5a5b6e7 100644 --- a/neural_tangents/stax.py +++ b/neural_tangents/stax.py @@ -2920,8 +2920,7 @@ def _pool_mask(mask: np.ndarray, @layer -def BatchNormRelu(axis, epsilon=0., center=True, scale=True, - beta_init=zeros, gamma_init=ones): +def BatchNormRelu(axis): """Layer construction function for a batch normalization layer. See the papers below for the derivation. @@ -2931,12 +2930,15 @@ def BatchNormRelu(axis, epsilon=0., center=True, scale=True, The implementation here follows the reference implementation in https://github.com/thegregyang/GP4A """ + epsilon = 0 + center=True + scale=True + beta_init=zeros + gamma_init=ones iu = ops.index_update ix = ops.index - assert epsilon < 1e-12 - _beta_init = lambda rng, shape: beta_init(rng, shape) if center else () _gamma_init = lambda rng, shape: gamma_init(rng, shape) if scale else () axis = (axis,) if np.isscalar(axis) else axis @@ -2948,11 +2950,17 @@ def apply_fn(params, xs, **kwargs): xs = bn_apply_fn(params, xs) return _ab_relu(xs, 0, 1) - # NOTE: Currently assumes cov2 is None and computes single batch kernel @_requires(diagonal_batch=False) def kernel_fn(kernels): - # print(kernels.cov1) - # assert kernels.cov2 is None + if not kernels.is_gaussian: + raise NotImplementedError('`BatchNormRelu` is only implemented for the ' + 'case if the input layer is guaranteed to be mean' + '-zero Gaussian, i.e. having `is_gaussian` ' + 'set to `True`.') + if len(kernels.shape1) != 2: + raise NotImplementedError('`BatchNormRelu` only supports fully-connected layers.') + if kernels.ntk is not None: + raise NotImplementedError('NTK is currently not supported by `BatchNormRelu`.') # apply the below code to cov1 and cov2 # write new cross batch code for (nngp, cov1, cov2) @@ -3006,8 +3014,6 @@ def integrand(log_s, logmultifactor=0, eps=1e-10): schemeposweights = schemepos.weights schemenegweights = schemeneg.weights - - integrandpos = lambda xs: np.moveaxis( np.moveaxis( alpha * integrand(intargmax + alpha * xs, @@ -3088,7 +3094,6 @@ def f(s, t, multfactor=1): C, Pi22diag**0.5) - ## Compute determinant ind = np.arange(n1+n2-2) # Ddot <- matrix inverse of Ddot @@ -3178,13 +3183,13 @@ def alphafactor(x, y): cov1 = singlebatchker(kernels.cov1) if kernels.cov2 is None: - return kernels.replace(cov1=cov1, nngp=cov1, cov2=None) + return kernels.replace(cov1=cov1, nngp=cov1, cov2=None, is_gaussian=False) cov2 = singlebatchker(kernels.cov2) # TODO compute cross batch new_nngp = VBNReLUCrossBatch(kernels.nngp, kernels.cov1, kernels.cov2) - return kernels.replace(cov1=cov1, nngp=new_nngp, cov2=cov2) + return kernels.replace(cov1=cov1, nngp=new_nngp, cov2=cov2, is_gaussian=False) return init_fn, apply_fn, kernel_fn \ No newline at end of file diff --git a/neural_tangents/tests/stax_test.py b/neural_tangents/tests/stax_test.py index 04c2c20d..1e631b5a 100644 --- a/neural_tangents/tests/stax_test.py +++ b/neural_tangents/tests/stax_test.py @@ -85,7 +85,8 @@ 'CHW', 'NC', 'NWC', - 'NCHW' + 'NCHW', + 'N' ] POOL_TYPES = [ @@ -101,13 +102,18 @@ test_utils.update_test_tolerance() -def _get_inputs(key, is_conv, same_inputs, input_shape, fn=np.cos): +def _get_inputs(key, is_conv, same_inputs, input_shape, new_batch_size=None, fn=np.cos): key, split = random.split(key) shape = input_shape if is_conv else (input_shape[0], np.prod(input_shape[1:])) - x1 = fn(random.normal(key, shape)) batch_axis = shape.index(BATCH_SIZE) - shape = shape[:batch_axis] + (2 * BATCH_SIZE,) + shape[batch_axis + 1:] - x2 = None if same_inputs else 2 * fn(random.normal(split, shape)) + if new_batch_size is None: + shape1 = shape + shape2 = shape[:batch_axis] + (2 * BATCH_SIZE,) + shape[batch_axis + 1:] + else: + shape1 = shape[:batch_axis] + (new_batch_size,) + shape[batch_axis + 1:] + shape2 = shape[:batch_axis] + (2 * new_batch_size,) + shape[batch_axis + 1:] + x1 = fn(random.normal(key, shape1)) + x2 = None if same_inputs else 2 * fn(random.normal(split, shape2)) return x1, x2 @@ -188,10 +194,16 @@ def conv(out_chan): return stax.GeneralConv( else: pool_or_identity = stax.Identity() dropout_or_identity = dropout if use_dropout else stax.Identity() - layer_norm_or_identity = (stax.Identity() if layer_norm is None - else stax.LayerNorm(axis=layer_norm, - batch_axis=batch_axis, - channel_axis=channel_axis)) + if layer_norm is None: + norm_phi = phi + elif layer_norm == (batch_axis,): + norm_phi = stax.BatchNormRelu(batch_axis) + else: + norm_phi = stax.serial( + stax.LayerNorm(axis=layer_norm, + batch_axis=batch_axis, + channel_axis=channel_axis), + phi) res_unit = stax.serial(dropout_or_identity, affine, pool_or_identity) if is_res: block = stax.serial( @@ -200,14 +212,12 @@ def conv(out_chan): return stax.GeneralConv( stax.parallel(stax.Identity(), res_unit), stax.FanInSum(), - layer_norm_or_identity, - phi) + norm_phi) else: block = stax.serial( affine, res_unit, - layer_norm_or_identity, - phi) + norm_phi) if proj_into_2d == 'FLAT': proj_layer = stax.Flatten(batch_axis, 0) @@ -430,7 +440,7 @@ def test_parameterizations(self, model, width, same_inputs, is_ntk, for is_ntk in [False, True] for proj_into_2d in PROJECTIONS[:2] for layer_norm in LAYER_NORM)) - def test_layernorm(self, + def test_normalization(self, model, width, same_inputs, @@ -442,7 +452,9 @@ def test_layernorm(self, if is_conv: if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest('Not running CNN models on CPU to save time.') - elif proj_into_2d != PROJECTIONS[0] or layer_norm not in ('C', 'NC'): + if layer_norm == 'N': + raise jtu.SkipTest('Skipping batchnorm test for convolutional networks.') + elif proj_into_2d != PROJECTIONS[0] or layer_norm not in ('C', 'NC', 'N'): raise jtu.SkipTest('FC models do not have these parameters.') W_std, b_std = 2.**0.5, 0.5**0.5 @@ -458,8 +470,17 @@ def test_layernorm(self, net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, use_dropout) - self._check_agreement_with_empirical(net, same_inputs, is_conv, use_dropout, - is_ntk, proj_into_2d) + # # when testing batchnorm, use batch size 5 (instead of 2) + new_batch_size = None if layer_norm != 'N' else 5 + if layer_norm != 'N' or not is_ntk: + self._check_agreement_with_empirical(net, same_inputs, is_conv, use_dropout, + is_ntk, proj_into_2d, + new_batch_size=new_batch_size) + else: + with self.assertRaises(ValueError): + self._check_agreement_with_empirical(net, same_inputs, is_conv, use_dropout, + is_ntk, proj_into_2d, + new_batch_size=new_batch_size) @jtu.parameterized.named_parameters( jtu.cases_from_list({ @@ -731,12 +752,15 @@ def test_composition_conv(self, avg_pool, same_inputs): self.assertAllClose(composed_ker_out, ker_out_marg, True) def _check_agreement_with_empirical(self, net, same_inputs, is_conv, - use_dropout, is_ntk, proj_into_2d): + use_dropout, is_ntk, proj_into_2d, + new_batch_size=None): (init_fn, apply_fn, kernel_fn), input_shape, device_count = net num_samples = N_SAMPLES * 5 if use_dropout else N_SAMPLES + # num_samples *= 10 if new_batch_size is not None else 1 key = random.PRNGKey(1) - x1, x2 = _get_inputs(key, is_conv, same_inputs, input_shape) + x1, x2 = _get_inputs(key, is_conv, same_inputs, input_shape, + new_batch_size=new_batch_size) x1_out_shape, params = init_fn(key, x1.shape) if same_inputs: @@ -1753,6 +1777,5 @@ def get_attn(): empirical = empirical.reshape(exact.shape) test_utils.assert_close_matrices(self, empirical, exact, tol) - if __name__ == '__main__': jtu.absltest.main() diff --git a/setup.py b/setup.py index 9fb8cb88..4b78b9e8 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ 'jax>=0.1.55', 'frozendict', 'dataclasses', - 'quadpy' + 'quadpy==0.13.2' ] From fe9ab7ca058d4894bbad26dd64abae9d1ba602df Mon Sep 17 00:00:00 2001 From: Greg Yang Date: Mon, 10 Aug 2020 22:39:21 +0000 Subject: [PATCH 8/9] batchnorm conv fully working --- neural_tangents/stax.py | 593 +++++++++++++++++------------ neural_tangents/tests/stax_test.py | 38 +- 2 files changed, 383 insertions(+), 248 deletions(-) mode change 100644 => 100755 neural_tangents/stax.py mode change 100644 => 100755 neural_tangents/tests/stax_test.py diff --git a/neural_tangents/stax.py b/neural_tangents/stax.py old mode 100644 new mode 100755 index d5a5b6e7..741cff8b --- a/neural_tangents/stax.py +++ b/neural_tangents/stax.py @@ -67,12 +67,14 @@ import enum import functools import operator as op +from functools import reduce import string from typing import Tuple, List, Optional, Iterable, Callable, Union import warnings import frozendict from jax import lax +from jax.api import vmap from jax import linear_util as lu from jax import numpy as np from jax import ops @@ -88,6 +90,7 @@ from jax.nn.initializers import ones, zeros import quadpy as qp +from scipy.special import roots_laguerre class Padding(enum.Enum): CIRCULAR = 'CIRCULAR' @@ -2919,8 +2922,285 @@ def _pool_mask(mask: np.ndarray, return mask +def _batchnorm_relu_kernel_fn(cov1, cov2, nngp): + # apply the below code to cov1 and cov2 + # write new cross batch code for (nngp, cov1, cov2) + + # cov1 += np.eye(cov1.shape[0]) * 1e-5 + # if cov2 is not None: + # cov2 += np.eye(cov2.shape[0]) * 1e-5 + iu = ops.index_update + ix = ops.index + + def Gmatrix(batch_size): + return np.eye(batch_size) - np.ones((batch_size, batch_size)) / batch_size + + def singlebatchker(ker): + batch_size = ker.shape[0] + G = Gmatrix(batch_size) + eigvals, eigvecs = np.linalg.eigh(G @ ker @ G) + # NOTE: 0 eigvals can appear as negative, so explicitly zero out + eigvals = np.where(eigvals < 0, 0, eigvals) + # eigvals[eigvals < 0] = 0 + logeigvals = np.log(eigvals[1:]) + eigvecs = eigvecs[:, 1:] + + # NOTE: Likely the same as _get_ab_relu_kernel. + def VReLU(cov, eps=1e-5): + indices = list(range(cov.shape[-1])) + d = np.sqrt(cov[..., indices, indices]) + dd = d[..., np.newaxis] * d[..., np.newaxis, :] + c = dd ** (-1) * cov + c = np.where(c > 1 - eps, 1 - eps, c) + c = np.where(c < -1 + eps, -1 + eps, c) + c = (np.sqrt(1 - c ** 2) + (np.pi - np.arccos(c)) * c) / np.pi + return np.nan_to_num(0.5 * dd * c) + + # NOTE: eps is a stability factor for the integral, not for batch norm. + def integrand(log_s, logmultifactor=0, eps=1e-10): + logUeigvals = np.logaddexp(0, np.log(2) + log_s[..., None] + logeigvals) + loginteigvals = logeigvals - logUeigvals + + loginteigvals -= 0.5 * np.sum(logUeigvals, axis=-1, keepdims=True) + loginteigvals += logmultifactor + + inteigvals = np.exp(loginteigvals) + + intvals = np.einsum( + 'ij,...j,jk->...ik', eigvecs, inteigvals, eigvecs.T, optimize=True) + return VReLU(intvals) + + # TODO(Greg): Compute the split point correctly. + intargmax = -np.log(batch_size - 1) + npos = 10 # TODO(Greg): Make these parameters, maybe? + nneg = 10 + alpha = 1 / 8. + + schemepospoints, schemeposweights = roots_laguerre(npos) + schemenegpoints, schemenegweights = roots_laguerre(nneg) + # schemepospoints + # schemepos = qp.e1r.gauss_laguerre(npos) + # schemeneg = qp.e1r.gauss_laguerre(nneg) + + # schemepospoints = schemepos.points + # schemenegpoints = schemeneg.points + # schemeposweights = schemepos.weights + # schemenegweights = schemeneg.weights + + integrandpos = lambda xs: np.moveaxis( + np.moveaxis( + alpha * integrand(intargmax + alpha * xs, + logmultifactor=( + intargmax + (1 + alpha) * xs)[..., np.newaxis]), + 2, 0), 3, 1) + + integrandneg = lambda xs: np.moveaxis( + np.exp(intargmax) * np.moveaxis(integrand(intargmax - xs), 2, 0), 3, 1) + + new_nngp = batch_size * ( + np.einsum('...i,i->...', + integrandpos(np.array([schemepospoints.T])), + schemeposweights) + + + np.einsum('...i,i->...', + integrandneg(np.array([schemenegpoints.T])), + schemenegweights)).squeeze(-1) + return new_nngp + + def J1(c, eps=1e-10): + c = np.clip(c, -1+eps, 1-eps) + return (np.sqrt(1-c**2) + (np.pi - np.arccos(c)) * c) / np.pi + + def VBNReLUCrossBatchIntegrand(Xi, Sigma1, Sigma2): + '''Computes the off diagonal block of the BN+ReLU kernel over 2 batches + Input: + Xi: covariance between batch1 and batch2 + Sigma1: autocovariance of batch1 + Sigma2: autocovariance of batch2 + Output: + f: integrand function in the integral for computing cross batch VBNReLU + ''' + # import pdb; pdb.set_trace() + myblock = np.block([[Sigma1, Xi], [Xi.T, Sigma2]]) + # print(np.linalg.norm(myblock - myblock.T)) + eigval1, eigvec1 = np.linalg.eigh(Sigma1) + eigval2, eigvec2 = np.linalg.eigh(Sigma2) + print('eigval1\n', eigval1) + print('eigval2\n', eigval2) + eigvals, eigvecs = np.linalg.eigh(myblock) + print('block eigvals\n', eigvals) + # import sys; sys.exit() + + n1 = Sigma1.shape[0] + n2 = Sigma2.shape[0] + G1 = Gmatrix(n1) + G2 = Gmatrix(n2) + Delta1, A1 = np.linalg.eigh(G1 @ Sigma1 @ G1) + Delta2, A2 = np.linalg.eigh(G2 @ Sigma2 @ G2) + # NOTE: 0 eigvals can appear as negative, so explicitly zero out + Delta1 = np.where(Delta1 < 0, 0, Delta1) + Delta2 = np.where(Delta2 < 0, 0, Delta2) + # kill first 0 eigenval + Delta1 = Delta1[1:] + Delta2 = Delta2[1:] + A1 = A1[:, 1:] + A2 = A2[:, 1:] + + Xidot = A1.T @ Xi @ A2 + Omegadot = np.block([[np.diag(Delta1), Xidot], [Xidot.T, np.diag(Delta2)]]) + Omegadotinv = np.linalg.inv(Omegadot) + # import pdb; pdb.set_trace() + + def f(s, t, multfactor=1): + # Ddot.shape = (..., n1+n2-2, n1+n2-2) + Ddot = s[..., None, None] * np.eye(n1-1+n2-1) + # Ddot[..., np.arange(n1-1, n1+n2-2), np.arange(n1-1, n1+n2-2)] = t[..., None] + Ddot = iu(Ddot, ix[..., np.arange(n1-1, n1+n2-2), np.arange(n1-1, n1+n2-2)], + t[..., None]) + + ## Compute off-diagonal block of VReLU(Pi) + Pitilde = Omegadotinv + 2 * Ddot + Pitilde = np.linalg.inv(Pitilde) + Pi11diag = np.einsum('ij,...jk,ki->...i', + A1, + Pitilde[..., :n1-1, :n1-1], + A1.T) + Pi22diag = np.einsum('ij,...jk,ki->...i', + A2, + Pitilde[..., n1-1:, n1-1:], + A2.T) + Pi12 = np.einsum('ij,...jk,kl->...il', + A1, + Pitilde[..., :n1-1, n1-1:], + A2.T) + C = J1(np.einsum('...i,...ij,...j->...ij', + Pi11diag**-0.5, + Pi12, + Pi22diag**-0.5)) + VReLUPi12 = 0.5 * np.einsum('...i,...ij,...j->...ij', + Pi11diag**0.5, + C, + Pi22diag**0.5) + print('Cnorm', np.linalg.norm(C)) + print('Pitildenorm', np.linalg.norm(Pitilde)) + print('Omegadotinv', np.linalg.norm(Omegadotinv)) + # import pdb; pdb.set_trace() + ## Compute determinant + ind = np.arange(n1+n2-2) + # Ddot <- matrix inverse of Ddot + # Ddot[..., ind, ind] = Ddot[..., ind, ind]**-1 + Ddot = iu(Ddot, ix[..., ind, ind], Ddot[..., ind, ind]**-1) + logdet = np.linalg.slogdet(Ddot + 2 * Omegadot)[1] + # print('logdet', np.linalg.norm(logdet)) + return np.exp( + (np.log(multfactor) + + (-n1/2) * np.log(s) + + (-n2/2) * np.log(t) + - 1/2 * logdet)[..., None, None] + np.log(VReLUPi12)) + return f + + def VBNReLUCrossBatch(Xi, Sigma1, Sigma2, npos=10, nneg=5, + alphapos1=1/3, alphaneg1=1, + alphapos2=1/3, alphaneg2=1): + '''Compute VBNReLU for two batches. + + Inputs: + Xi: covariance between batch1 and batch2 + Sigma1: autocovariance of batch1 + Sigma2: autocovariance of batch2 + npos: number of points for integrating the big s side of the VBNReLU integral + (effective for both dimensions of integration) + nneg: number of points for integrating the small s side of the VBNReLU integral + (effective for both dimensions of integration) + alphapos1: reparametrize the large s integral by s = exp(alpha r) in the 1st dimension + alphaneg1: reparametrize the small s integral by s = exp(alpha r) in the 1st dimension + alphapos2: reparametrize the large s integral by s = exp(alpha r) in the 2nd dimension + alphaneg2: reparametrize the small s integral by s = exp(alpha r) in the 2nd dimension + By tuning the `alpha` parameters, the integrand is closer to being well-approximated by + low-degree Laguerre polynomials, which makes the quadrature more accurate in approximating the integral. + Outputs: + The (batch1, batch2) block of block matrix obtained by + applying VBNReLU^{\oplus 2} to the kernel of batch1 and batch2 + ''' + # We will do the integration explicitly ourselves: + # We obtain sample points and weights via `quadpy`'s Gauss Laguerre quadrature + # and do the sum ourselves + print('B') + blockeig(Sigma1, Sigma2, Xi) + # schemepos = qp.e1r.gauss_laguerre(npos, alpha=0) + # schemeneg = qp.e1r.gauss_laguerre(nneg, alpha=0) + schemepospoints, schemeposweights = roots_laguerre(npos) + schemenegpoints, schemenegweights = roots_laguerre(nneg) + dim1 = Sigma1.shape[0] + dim2 = Sigma2.shape[0] + intargmax = (-np.log(2*(dim1-1)), -np.log(2*(dim2-1))) + f = VBNReLUCrossBatchIntegrand(Xi, Sigma1, Sigma2) + # Get the points manually for each dimension + # scheme1dpoints = np.concatenate([schemepos.points, -schemeneg.points]) + scheme1dpoints = np.concatenate([schemepospoints, -schemenegpoints]) + # Get the weights manually for each dimension + # scheme1dwts = np.concatenate([schemepos.weights, schemeneg.weights]) + scheme1dwts = np.concatenate([schemeposweights, schemenegweights]) + # Obtain the points for the whole 2d integration + scheme2dpoints = np.meshgrid(scheme1dpoints, scheme1dpoints) + # Obtain the weights for the whole 2d integration + scheme2dwts = scheme1dwts[:, None] * scheme1dwts[None, :] + + def applyalpha(x, alphapos, alphaneg): + x = iu(x, ix[x > 0], x[x > 0] * alphapos) + x = iu(x, ix[x <= 0], x[x <= 0] * alphaneg) + # xx[xx > 0] *= alphapos + # xx[xx <= 0] *= alphaneg + return x + + # def iu(x, mask, ) + def alphafactor(x, y): + a = np.zeros_like(x) + a = iu(a, ix[(x > 0) & (y > 0)], alphapos1 * alphapos2) + a = iu(a, ix[(x > 0) & (y <= 0)], alphapos1 * alphaneg2) + a = iu(a, ix[(x <= 0) & (y > 0)], alphaneg1 * alphapos2) + a = iu(a, ix[(x <= 0) & (y <= 0)], alphaneg1 * alphaneg2) + return a + + integrand = lambda inp: \ + f(np.exp(applyalpha(inp[0], alphapos1, alphaneg1) + intargmax[0]), + np.exp(applyalpha(inp[1], alphapos2, alphaneg2) + intargmax[1]), + multfactor=alphafactor(inp[0], inp[1]) + * np.pi**-1 + * np.exp(applyalpha(inp[0], alphapos1, alphaneg1) + intargmax[0] + + applyalpha(inp[1], alphapos2, alphaneg2) + intargmax[1] + + np.abs(inp[0]) + np.abs(inp[1]) + ) + ) + + return np.sqrt(dim1 * dim2) * np.einsum('ij...,ij->...', + integrand(scheme2dpoints), + scheme2dwts + ) + + + new_cov1 = singlebatchker(cov1) + if cov2 is None: + return cov1, None, cov1 + + new_cov2 = singlebatchker(cov2) + + print('C') + blockeig(cov1, cov2, nngp) + print('new C') + blockeig(new_cov1, new_cov2, nngp) + new_nngp = VBNReLUCrossBatch(nngp, cov1, cov2) + + return new_cov1, new_cov2, new_nngp + +def blockeig(cov1, cov2, nngp): + myblock = np.block([[cov1, nngp], + [nngp.T, cov2]]) + eigvals, _ = np.linalg.eigh(myblock) + print(eigvals) + @layer -def BatchNormRelu(axis): +def BatchNormRelu(axis, channel_axis=-1): """Layer construction function for a batch normalization layer. See the papers below for the derivation. @@ -2929,27 +3209,57 @@ def BatchNormRelu(axis): The implementation here follows the reference implementation in https://github.com/thegregyang/GP4A + + Args: + :axis: integer or a tuple, specifies dimensions over which to normalize. + :channel_axis: integer, channel axis. Defaults to `-1`, the trailing axis. + For `kernel_fn`, channel size is considered to be infinite. """ - epsilon = 0 + epsilon = 1e-8 center=True scale=True beta_init=zeros gamma_init=ones - iu = ops.index_update - ix = ops.index - _beta_init = lambda rng, shape: beta_init(rng, shape) if center else () _gamma_init = lambda rng, shape: gamma_init(rng, shape) if scale else () axis = (axis,) if np.isscalar(axis) else axis init_fn, bn_apply_fn = ostax.BatchNorm( axis, epsilon, center, scale, beta_init, gamma_init) + if channel_axis >= 0: + axis = tuple(a if a < channel_axis else a-1 for a in axis) def apply_fn(params, xs, **kwargs): xs = bn_apply_fn(params, xs) return _ab_relu(xs, 0, 1) + def rotate_greg(cov1, axis, flatten=False): + cov1 = utils.unzip_axes(cov1) + ndim = len(cov1.shape) // 2 + naxes = len(axis) + unnorm_size1 = reduce(op.mul, (cov1.shape[i] for i in range(ndim) if i not in axis), 1) + unnorm_size2 = reduce(op.mul, (cov1.shape[i+ndim] for i in range(ndim) if i not in axis), 1) + norm_size1 = reduce(op.mul, (cov1.shape[i] for i in axis), 1) + norm_size2 = reduce(op.mul, (cov1.shape[i+ndim] for i in axis), 1) + + source_axes = list(axis) + list(np.array(axis) + ndim) + _negidx = np.array(list(range(-2*naxes, 0))) + dest_axes = list(2*ndim + _negidx) + cov1 = np.moveaxis(cov1, source_axes, dest_axes) + old_shape = cov1.shape + if not flatten: + return cov1 + cov1 = cov1.reshape(unnorm_size1, unnorm_size2, norm_size1, norm_size2) + + def unrotate(cov): + assert cov.shape == (unnorm_size1, unnorm_size2, norm_size1, norm_size2), str((unnorm_size1, unnorm_size2, norm_size1, norm_size2)) + cov = cov.reshape(*old_shape) + cov = np.moveaxis(cov, dest_axes, source_axes) + cov = utils.zip_axes(cov) + return cov + return cov1, unrotate + @_requires(diagonal_batch=False) def kernel_fn(kernels): if not kernels.is_gaussian: @@ -2957,239 +3267,54 @@ def kernel_fn(kernels): 'case if the input layer is guaranteed to be mean' '-zero Gaussian, i.e. having `is_gaussian` ' 'set to `True`.') - if len(kernels.shape1) != 2: - raise NotImplementedError('`BatchNormRelu` only supports fully-connected layers.') if kernels.ntk is not None: raise NotImplementedError('NTK is currently not supported by `BatchNormRelu`.') - # apply the below code to cov1 and cov2 - # write new cross batch code for (nngp, cov1, cov2) - - def Gmatrix(batch_size): - return np.eye(batch_size) - np.ones((batch_size, batch_size)) / batch_size - - def singlebatchker(ker): - batch_size = ker.shape[0] - G = Gmatrix(batch_size) - eigvals, eigvecs = np.linalg.eigh(G @ ker @ G) - logeigvals = np.log(eigvals[1:]) - eigvecs = eigvecs[:, 1:] - - # NOTE: Likely the same as _get_ab_relu_kernel. - def VReLU(cov, eps=1e-5): - indices = list(range(cov.shape[-1])) - d = np.sqrt(cov[..., indices, indices]) - dd = d[..., np.newaxis] * d[..., np.newaxis, :] - c = dd ** (-1) * cov - c = np.where(c > 1 - eps, 1 - eps, c) - c = np.where(c < -1 + eps, -1 + eps, c) - c = (np.sqrt(1 - c ** 2) + (np.pi - np.arccos(c)) * c) / np.pi - return np.nan_to_num(0.5 * dd * c) - - # NOTE: eps is a stability factor for the integral, not for batch norm. - def integrand(log_s, logmultifactor=0, eps=1e-10): - logUeigvals = np.logaddexp(0, np.log(2) + log_s[..., None] + logeigvals) - loginteigvals = logeigvals - logUeigvals - - loginteigvals -= 0.5 * np.sum(logUeigvals, axis=-1, keepdims=True) - loginteigvals += logmultifactor - - inteigvals = np.exp(loginteigvals) - - intvals = np.einsum( - 'ij,...j,jk->...ik', eigvecs, inteigvals, eigvecs.T, optimize=True) - return VReLU(intvals) - - # TODO(Greg): Compute the split point correctly. - intargmax = -np.log(batch_size - 1) - npos = 10 # TODO(Greg): Make these parameters, maybe? - nneg = 10 - alpha = 1 / 8. - - schemepos = qp.e1r.gauss_laguerre(npos) - schemeneg = qp.e1r.gauss_laguerre(nneg) - - schemepospoints = schemepos.points - schemenegpoints = schemeneg.points - schemeposweights = schemepos.weights - schemenegweights = schemeneg.weights - - integrandpos = lambda xs: np.moveaxis( - np.moveaxis( - alpha * integrand(intargmax + alpha * xs, - logmultifactor=( - intargmax + (1 + alpha) * xs)[..., np.newaxis]), - 2, 0), 3, 1) - - integrandneg = lambda xs: np.moveaxis( - np.exp(intargmax) * np.moveaxis(integrand(intargmax - xs), 2, 0), 3, 1) - - new_nngp = batch_size * ( - np.einsum('...i,i->...', - integrandpos(np.array([schemepospoints.T])), - schemeposweights) - + - np.einsum('...i,i->...', - integrandneg(np.array([schemenegpoints.T])), - schemenegweights)).squeeze(-1) - return new_nngp - - def J1(c, eps=1e-10): - c = np.clip(c, -1+eps, 1-eps) - return (np.sqrt(1-c**2) + (np.pi - np.arccos(c)) * c) / np.pi - - def VBNReLUCrossBatchIntegrand(Xi, Sigma1, Sigma2): - '''Computes the off diagonal block of the BN+ReLU kernel over 2 batches - Input: - Xi: covariance between batch1 and batch2 - Sigma1: autocovariance of batch1 - Sigma2: autocovariance of batch2 - Output: - f: integrand function in the integral for computing cross batch VBNReLU - ''' - n1 = Sigma1.shape[0] - n2 = Sigma2.shape[0] - G1 = Gmatrix(n1) - G2 = Gmatrix(n2) - Delta1, A1 = np.linalg.eigh(G1 @ Sigma1 @ G1) - Delta2, A2 = np.linalg.eigh(G2 @ Sigma2 @ G2) - # kill first 0 eigenval - Delta1 = Delta1[1:] - Delta2 = Delta2[1:] - A1 = A1[:, 1:] - A2 = A2[:, 1:] - - Xidot = A1.T @ Xi @ A2 - Omegadot = np.block([[np.diag(Delta1), Xidot], [Xidot.T, np.diag(Delta2)]]) - Omegadotinv = np.linalg.inv(Omegadot) - - def f(s, t, multfactor=1): - # Ddot.shape = (..., n1+n2-2, n1+n2-2) - Ddot = s[..., None, None] * np.eye(n1-1+n2-1) - # Ddot[..., np.arange(n1-1, n1+n2-2), np.arange(n1-1, n1+n2-2)] = t[..., None] - Ddot = iu(Ddot, ix[..., np.arange(n1-1, n1+n2-2), np.arange(n1-1, n1+n2-2)], - t[..., None]) - - ## Compute off-diagonal block of VReLU(Pi) - Pitilde = Omegadotinv + 2 * Ddot - Pitilde = np.linalg.inv(Pitilde) - Pi11diag = np.einsum('ij,...jk,ki->...i', - A1, - Pitilde[..., :n1-1, :n1-1], - A1.T) - Pi22diag = np.einsum('ij,...jk,ki->...i', - A2, - Pitilde[..., n1-1:, n1-1:], - A2.T) - Pi12 = np.einsum('ij,...jk,kl->...il', - A1, - Pitilde[..., :n1-1, n1-1:], - A2.T) - C = J1(np.einsum('...i,...ij,...j->...ij', - Pi11diag**-0.5, - Pi12, - Pi22diag**-0.5)) - VReLUPi12 = 0.5 * np.einsum('...i,...ij,...j->...ij', - Pi11diag**0.5, - C, - Pi22diag**0.5) - - ## Compute determinant - ind = np.arange(n1+n2-2) - # Ddot <- matrix inverse of Ddot - # Ddot[..., ind, ind] = Ddot[..., ind, ind]**-1 - Ddot = iu(Ddot, ix[..., ind, ind], Ddot[..., ind, ind]**-1) - logdet = np.linalg.slogdet(Ddot + 2 * Omegadot)[1] - return np.exp( - (np.log(multfactor) - + (-n1/2) * np.log(s) - + (-n2/2) * np.log(t) - - 1/2 * logdet)[..., None, None] + np.log(VReLUPi12)) - return f - - def VBNReLUCrossBatch(Xi, Sigma1, Sigma2, npos=10, nneg=5, - alphapos1=1/3, alphaneg1=1, - alphapos2=1/3, alphaneg2=1): - '''Compute VBNReLU for two batches. - - Inputs: - Xi: covariance between batch1 and batch2 - Sigma1: autocovariance of batch1 - Sigma2: autocovariance of batch2 - npos: number of points for integrating the big s side of the VBNReLU integral - (effective for both dimensions of integration) - nneg: number of points for integrating the small s side of the VBNReLU integral - (effective for both dimensions of integration) - alphapos1: reparametrize the large s integral by s = exp(alpha r) in the 1st dimension - alphaneg1: reparametrize the small s integral by s = exp(alpha r) in the 1st dimension - alphapos2: reparametrize the large s integral by s = exp(alpha r) in the 2nd dimension - alphaneg2: reparametrize the small s integral by s = exp(alpha r) in the 2nd dimension - By tuning the `alpha` parameters, the integrand is closer to being well-approximated by - low-degree Laguerre polynomials, which makes the quadrature more accurate in approximating the integral. - Outputs: - The (batch1, batch2) block of block matrix obtained by - applying VBNReLU^{\oplus 2} to the kernel of batch1 and batch2 - ''' - # We will do the integration explicitly ourselves: - # We obtain sample points and weights via `quadpy`'s Gauss Laguerre quadrature - # and do the sum ourselves - schemepos = qp.e1r.gauss_laguerre(npos, alpha=0) - schemeneg = qp.e1r.gauss_laguerre(nneg, alpha=0) - dim1 = Sigma1.shape[0] - dim2 = Sigma2.shape[0] - intargmax = (-np.log(2*(dim1-1)), -np.log(2*(dim2-1))) - f = VBNReLUCrossBatchIntegrand(Xi, Sigma1, Sigma2) - # Get the points manually for each dimension - scheme1dpoints = np.concatenate([schemepos.points, -schemeneg.points]) - # Get the weights manually for each dimension - scheme1dwts = np.concatenate([schemepos.weights, schemeneg.weights]) - # Obtain the points for the whole 2d integration - scheme2dpoints = np.meshgrid(scheme1dpoints, scheme1dpoints) - # Obtain the weights for the whole 2d integration - scheme2dwts = scheme1dwts[:, None] * scheme1dwts[None, :] - - def applyalpha(x, alphapos, alphaneg): - x = iu(x, ix[x > 0], x[x > 0] * alphapos) - x = iu(x, ix[x <= 0], x[x <= 0] * alphaneg) - # xx[xx > 0] *= alphapos - # xx[xx <= 0] *= alphaneg - return x - - # def iu(x, mask, ) - def alphafactor(x, y): - a = np.zeros_like(x) - a = iu(a, ix[(x > 0) & (y > 0)], alphapos1 * alphapos2) - a = iu(a, ix[(x > 0) & (y <= 0)], alphapos1 * alphaneg2) - a = iu(a, ix[(x <= 0) & (y > 0)], alphaneg1 * alphapos2) - a = iu(a, ix[(x <= 0) & (y <= 0)], alphaneg1 * alphaneg2) - return a - - integrand = lambda inp: \ - f(np.exp(applyalpha(inp[0], alphapos1, alphaneg1) + intargmax[0]), - np.exp(applyalpha(inp[1], alphapos2, alphaneg2) + intargmax[1]), - multfactor=alphafactor(inp[0], inp[1]) - * np.pi**-1 - * np.exp(applyalpha(inp[0], alphapos1, alphaneg1) + intargmax[0] - + applyalpha(inp[1], alphapos2, alphaneg2) + intargmax[1] - + np.abs(inp[0]) + np.abs(inp[1]) - ) - ) - - return np.sqrt(dim1 * dim2) * np.einsum('ij...,ij->...', - integrand(scheme2dpoints), - scheme2dwts - ) - - - cov1 = singlebatchker(kernels.cov1) - if kernels.cov2 is None: - return kernels.replace(cov1=cov1, nngp=cov1, cov2=None, is_gaussian=False) - - cov2 = singlebatchker(kernels.cov2) + cov1, cov2, nngp = kernels.cov1, kernels.cov2, kernels.nngp + # myblock = np.block([[cov1, nngp], [nngp.T, cov2]]) + + cov1_flatten, cov1_unrotate = rotate_greg(cov1, axis, flatten=True) + # import pdb; pdb.set_trace() + ll = list(range(cov1_flatten.shape[0])) + cov1diag = cov1_flatten[ll, ll] + + ##### compute kernel ###### + # loop over pairs of non-normalized coordinates + # extract blocks + # cov1.shape = (batch, spatial, batch, spatial) + def fn_sam(cov1, cov2, nngp): + return lax.cond(np.allclose(cov1, cov2), + (cov1, None, cov1), lambda x: _batchnorm_relu_kernel_fn(*x)[2], + (cov1, cov2, nngp), lambda x: _batchnorm_relu_kernel_fn(*x)[2]) + _vmapped_bnrelu_self = vmap(vmap(fn_sam, (None, 0, 0)), (0, None, 0)) + _vmapped_bnrelu_other = vmap(vmap(lambda *x: _batchnorm_relu_kernel_fn(*x)[2], (None, 0, 0)), (0, None, 0)) + # import pdb; pdb.set_trace() + cov1 = _vmapped_bnrelu_self(cov1diag, cov1diag, cov1_flatten) - # TODO compute cross batch - new_nngp = VBNReLUCrossBatch(kernels.nngp, kernels.cov1, kernels.cov2) + # TODO(Greg): bypass to single batch case if no spatial dimension and cov2 is None + if cov2 is None: + cov1 = cov1_unrotate(cov1) + nngp = cov1 + else: + cov2_flatten, cov2_unrotate = rotate_greg(cov2, axis, flatten=True) + # import pdb; pdb.set_trace() + nngp_flatten, nngp_unrotate = rotate_greg(nngp, axis, flatten=True) + cov2diag = cov2_flatten[ll, ll] + print('A') + blockeig(cov1diag[0], cov2diag[0], nngp_flatten[0, 0]) + cov2 = _vmapped_bnrelu_self(cov2diag, cov2diag, cov2_flatten) + # print('cov1', cov1) + # print('cov2', cov2) + # import pdb; pdb.set_trace() + # xxx = _batchnorm_relu_kernel_fn(cov1diag[0], cov2diag[0], nngp_flatten[0, 0]) + # print(xxx) + nngp = _vmapped_bnrelu_other(cov1diag, cov2diag, nngp_flatten) + # import pdb; pdb.set_trace() + cov1 = cov1_unrotate(cov1) + cov2 = cov2_unrotate(cov2) + nngp = nngp_unrotate(nngp) + + return kernels.replace(cov1=cov1, nngp=nngp, cov2=cov2, is_gaussian=False) - return kernels.replace(cov1=cov1, nngp=new_nngp, cov2=cov2, is_gaussian=False) return init_fn, apply_fn, kernel_fn \ No newline at end of file diff --git a/neural_tangents/tests/stax_test.py b/neural_tangents/tests/stax_test.py old mode 100644 new mode 100755 index 1e631b5a..00ad3eb8 --- a/neural_tangents/tests/stax_test.py +++ b/neural_tangents/tests/stax_test.py @@ -26,11 +26,15 @@ from jax.config import config as jax_config from jax.lib import xla_bridge import jax.numpy as np +import numpy as onp import jax.random as random from neural_tangents import stax from neural_tangents.utils import monte_carlo from neural_tangents.utils import test_utils +from absl.testing import absltest +from jax import disable_jit +disable_jit() jax_config.parse_flags_with_absl() @@ -42,7 +46,7 @@ BATCH_SIZE = 2 -INPUT_SHAPE = (BATCH_SIZE, 7, 6, 3) +INPUT_SHAPE = (BATCH_SIZE, 7, 6, 16) WIDTHS = [2**11] @@ -86,7 +90,8 @@ 'NC', 'NWC', 'NCHW', - 'N' + 'N', + 'NHW' ] POOL_TYPES = [ @@ -114,6 +119,7 @@ def _get_inputs(key, is_conv, same_inputs, input_shape, new_batch_size=None, fn= shape2 = shape[:batch_axis] + (2 * new_batch_size,) + shape[batch_axis + 1:] x1 = fn(random.normal(key, shape1)) x2 = None if same_inputs else 2 * fn(random.normal(split, shape2)) + print(shape1, shape2) return x1, x2 @@ -173,7 +179,7 @@ def conv(out_chan): return stax.GeneralConv( ) affine = conv(width) if is_conv else fc(width) - rate = np.onp.random.uniform(0.5, 0.9) + rate = onp.random.uniform(0.5, 0.9) dropout = stax.Dropout(rate, mode='train') if pool_type == 'AVG': @@ -196,8 +202,8 @@ def conv(out_chan): return stax.GeneralConv( dropout_or_identity = dropout if use_dropout else stax.Identity() if layer_norm is None: norm_phi = phi - elif layer_norm == (batch_axis,): - norm_phi = stax.BatchNormRelu(batch_axis) + elif channel_axis not in layer_norm: + norm_phi = stax.BatchNormRelu(layer_norm, channel_axis=channel_axis) else: norm_phi = stax.serial( stax.LayerNorm(axis=layer_norm, @@ -450,10 +456,11 @@ def test_normalization(self, is_conv = 'conv' in model # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: - if xla_bridge.get_backend().platform == 'cpu': - raise jtu.SkipTest('Not running CNN models on CPU to save time.') - if layer_norm == 'N': - raise jtu.SkipTest('Skipping batchnorm test for convolutional networks.') + pass + # if xla_bridge.get_backend().platform == 'cpu': + # raise jtu.SkipTest('Not running CNN models on CPU to save time.') + # if layer_norm == 'N': + # raise jtu.SkipTest('Skipping batchnorm test for convolutional networks.') elif proj_into_2d != PROJECTIONS[0] or layer_norm not in ('C', 'NC', 'N'): raise jtu.SkipTest('FC models do not have these parameters.') @@ -670,11 +677,11 @@ def test_sparse_inputs(self, act, kernel): samples = N_SAMPLES if xla_bridge.get_backend().platform == 'gpu': - jtu._default_tolerance[np.onp.dtype(np.onp.float64)] = 5e-4 + jtu._default_tolerance[onp.dtype(onp.float64)] = 5e-4 samples = 100 * N_SAMPLES else: - jtu._default_tolerance[np.onp.dtype(np.onp.float32)] = 5e-2 - jtu._default_tolerance[np.onp.dtype(np.onp.float64)] = 5e-3 + jtu._default_tolerance[onp.dtype(onp.float32)] = 5e-2 + jtu._default_tolerance[onp.dtype(onp.float64)] = 5e-3 # a batch of dense inputs x_dense = random.normal(key, (input_count, input_size)) @@ -756,12 +763,13 @@ def _check_agreement_with_empirical(self, net, same_inputs, is_conv, new_batch_size=None): (init_fn, apply_fn, kernel_fn), input_shape, device_count = net + # print(use_dropout) num_samples = N_SAMPLES * 5 if use_dropout else N_SAMPLES # num_samples *= 10 if new_batch_size is not None else 1 key = random.PRNGKey(1) x1, x2 = _get_inputs(key, is_conv, same_inputs, input_shape, new_batch_size=new_batch_size) - + # print(np.linalg.norm(x1), np.linalg.norm(x2)) x1_out_shape, params = init_fn(key, x1.shape) if same_inputs: assert (x2 is None) @@ -796,7 +804,9 @@ def _get_empirical(n_samples, get): empirical = np.reshape(_get_empirical(num_samples, 'ntk'), exact.shape) else: exact, shape1, shape2 = kernel_fn(x1, x2, ('nngp', 'shape1', 'shape2')) + print('getting empirical') empirical = _get_empirical(num_samples, 'nngp') + print(empirical) test_utils.assert_close_matrices(self, exact, empirical, rtol) self.assertEqual(shape1, x1_out_shape) self.assertEqual(shape2, x2_out_shape) @@ -1778,4 +1788,4 @@ def get_attn(): test_utils.assert_close_matrices(self, empirical, exact, tol) if __name__ == '__main__': - jtu.absltest.main() + absltest.main() From 3dc85ccfb897c62b7398f6f5a97e6d5e718e2813 Mon Sep 17 00:00:00 2001 From: Greg Yang Date: Mon, 10 Aug 2020 22:40:38 +0000 Subject: [PATCH 9/9] save --- examples/infinite_fcn.py | 55 +++++++++++++++++----------------------- setup.py | 2 +- 2 files changed, 24 insertions(+), 33 deletions(-) mode change 100644 => 100755 examples/infinite_fcn.py mode change 100644 => 100755 setup.py diff --git a/examples/infinite_fcn.py b/examples/infinite_fcn.py old mode 100644 new mode 100755 index db90c6b9..e7d2dc9c --- a/examples/infinite_fcn.py +++ b/examples/infinite_fcn.py @@ -23,8 +23,7 @@ import jax.numpy as np import neural_tangents as nt from neural_tangents import stax -from examples import datasets -from examples import util +from jax import random flags.DEFINE_integer('train_size', 1000, @@ -37,43 +36,35 @@ FLAGS = flags.FLAGS +import pdb +from jax.experimental import callback +from functools import partial def main(unused_argv): # Build data pipelines. print('Loading data.') - x_train, y_train, x_test, y_test = \ - datasets.get_dataset('cifar10', FLAGS.train_size, FLAGS.test_size) + key = random.PRNGKey(0) + key, split = random.split(key) + x_train = random.normal(key=key, shape=[2, 3, 4, 5]) + x_train2 = random.normal(key=split, shape=[1, 3, 4, 5]) # Build the infinite network. - _, _, kernel_fn = stax.serial( - stax.Dense(1, 2., 0.05), - stax.Relu(), - stax.Dense(1, 2., 0.05) + init_fn, apply_fn, kernel_fn = stax.serial( + stax.Conv(256, (3, 3), padding='SAME'), + stax.BatchNormRelu((0, 1, 2)), + stax.GlobalAvgPool(), + stax.Dense(256, 2., 0.05) ) - - # Optionally, compute the kernel in batches, in parallel. - kernel_fn = nt.batch(kernel_fn, - device_count=0, - batch_size=FLAGS.batch_size) - - start = time.time() - # Bayesian and infinite-time gradient descent inference with infinite network. - fx_test_nngp, fx_test_ntk = nt.predict.gp_inference(kernel_fn, - x_train, - y_train, - x_test, - get=('nngp', 'ntk'), - diag_reg=1e-3) - fx_test_nngp.block_until_ready() - fx_test_ntk.block_until_ready() - - duration = time.time() - start - print('Kernel construction and inference done in %s seconds.' % duration) - - # Print out accuracy and loss for infinite network predictions. - loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2) - util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss) - util.print_summary('NTK test', y_test, fx_test_ntk, None, loss) + # kernel_fn = callback.find_by_value(partial(kernel_fn, get='nngp'), np.nan) + kerobj = kernel_fn(x_train, x_train2, get='nngp') + theory_ker = kerobj + mc_kernel_fn = nt.monte_carlo_kernel_fn(init_fn, apply_fn, key, 10000) + diff = theory_ker - mc_kernel_fn(x_train, x_train2, get='nngp') + print(diff) + # print(kerobj.cov1 - kerobj.nngp) + print(np.linalg.norm(diff) / np.linalg.norm(theory_ker)) + # 0.0032839081 + return if __name__ == '__main__': diff --git a/setup.py b/setup.py old mode 100644 new mode 100755 index 4b78b9e8..c3377eec --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ INSTALL_REQUIRES = [ - 'jaxlib>=0.1.58', + 'jaxlib>=0.1.47', 'jax>=0.1.55', 'frozendict', 'dataclasses',