|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT license. |
| 3 | +import torch |
| 4 | +import warnings |
| 5 | +try: |
| 6 | + pytorch_version_one_and_above = int(torch.__version__[0]) > 0 |
| 7 | +except TypeError: |
| 8 | + pytorch_version_one_and_above = True |
| 9 | + |
| 10 | +def norm(x): |
| 11 | + """Compute RMS norm.""" |
| 12 | + if torch.is_tensor(x): |
| 13 | + return x.norm() / (x.numel()**0.5) |
| 14 | + else: |
| 15 | + return torch.sqrt(sum(x_.norm()**2 for x_ in x) / sum(x_.numel() for x_ in x)) |
| 16 | + |
| 17 | +def flatten(iterable): |
| 18 | + out = [] |
| 19 | + for i in iterable: |
| 20 | + if hasattr(i,'__iter__') and not isinstance(i, torch.Tensor): |
| 21 | + out.extend(flatten(i)) |
| 22 | + else: |
| 23 | + out.append(i) |
| 24 | + return out |
| 25 | + |
| 26 | + |
| 27 | +def delete_local_computation_graph( inputs): |
| 28 | + for i in inputs: |
| 29 | + #i.set_() |
| 30 | + del i |
| 31 | + #torch.cuda.empty_cache() |
| 32 | + return |
| 33 | + |
| 34 | +def _possibly_nonzero(x): |
| 35 | + return isinstance(x, torch.Tensor) or x != 0 |
| 36 | + |
| 37 | +def _scaled_dot_product(scale, xs, ys): |
| 38 | + """Calculate a scaled, vector inner product between lists of Tensors.""" |
| 39 | + # Using _possibly_nonzero lets us avoid wasted computation. |
| 40 | + return sum([(scale * x) * y for x, y in zip(xs, ys) if _possibly_nonzero(x) or _possibly_nonzero(y)]) |
| 41 | + |
| 42 | +def _convert_to_tensor(a, dtype=None, device=None): |
| 43 | + if not isinstance(a, torch.Tensor): |
| 44 | + a = torch.tensor(a) |
| 45 | + if dtype is not None: |
| 46 | + a = a.type(dtype) |
| 47 | + if device is not None: |
| 48 | + a = a.to(device) |
| 49 | + return a |
| 50 | + |
| 51 | +def _dot_product(xs, ys): |
| 52 | + """Calculate the vector inner product between two lists of Tensors.""" |
| 53 | + return sum([x * y for x, y in zip(xs, ys)]) |
| 54 | + |
| 55 | +def _interp_fit(y0, y1, y_mid, f0, f1, dt): |
| 56 | + """Fit coefficients for 4th order polynomial interpolation. |
| 57 | + Args: |
| 58 | + y0: function value at the start of the interval. |
| 59 | + y1: function value at the end of the interval. |
| 60 | + y_mid: function value at the mid-point of the interval. |
| 61 | + f0: derivative value at the start of the interval. |
| 62 | + f1: derivative value at the end of the interval. |
| 63 | + dt: width of the interval. |
| 64 | + Returns: |
| 65 | + List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial |
| 66 | + `p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x` |
| 67 | + between 0 (start of interval) and 1 (end of interval). |
| 68 | + """ |
| 69 | + a = tuple( |
| 70 | + _dot_product([-2 * dt, 2 * dt, -8, -8, 16], [f0_, f1_, y0_, y1_, y_mid_]) |
| 71 | + for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid) |
| 72 | + ) |
| 73 | + b = tuple( |
| 74 | + _dot_product([5 * dt, -3 * dt, 18, 14, -32], [f0_, f1_, y0_, y1_, y_mid_]) |
| 75 | + for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid) |
| 76 | + ) |
| 77 | + c = tuple( |
| 78 | + _dot_product([-4 * dt, dt, -11, -5, 16], [f0_, f1_, y0_, y1_, y_mid_]) |
| 79 | + for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid) |
| 80 | + ) |
| 81 | + d = tuple(dt * f0_ for f0_ in f0) |
| 82 | + e = y0 |
| 83 | + return [a, b, c, d, e] |
| 84 | + |
| 85 | + |
| 86 | +def _interp_evaluate(coefficients, t0, t1, t): |
| 87 | + """Evaluate polynomial interpolation at the given time point. |
| 88 | + Args: |
| 89 | + coefficients: list of Tensor coefficients as created by `interp_fit`. |
| 90 | + t0: scalar float64 Tensor giving the start of the interval. |
| 91 | + t1: scalar float64 Tensor giving the end of the interval. |
| 92 | + t: scalar float64 Tensor giving the desired interpolation point. |
| 93 | + Returns: |
| 94 | + Polynomial interpolation of the coefficients at time `t`. |
| 95 | + """ |
| 96 | + |
| 97 | + dtype = coefficients[0][0].dtype |
| 98 | + device = coefficients[0][0].device |
| 99 | + |
| 100 | + t0 = _convert_to_tensor(t0, dtype=dtype, device=device) |
| 101 | + t1 = _convert_to_tensor(t1, dtype=dtype, device=device) |
| 102 | + t = _convert_to_tensor(t, dtype=dtype, device=device) |
| 103 | + |
| 104 | + assert (t0 <= t) & (t <= t1), 'invalid interpolation, fails `t0 <= t <= t1`: {}, {}, {}'.format(t0, t, t1) |
| 105 | + x = ((t - t0) / (t1 - t0)).type(dtype).to(device) |
| 106 | + |
| 107 | + xs = [torch.tensor(1).type(dtype).to(device), x] |
| 108 | + for _ in range(2, len(coefficients)): |
| 109 | + xs.append(xs[-1] * x) |
| 110 | + |
| 111 | + return tuple(_dot_product(coefficients_, reversed(xs)) for coefficients_ in zip(*coefficients)) |
| 112 | + |
| 113 | + |
| 114 | +# ---------------------------------------------------------------------------------------------------- |
| 115 | +# cubic hermite spline |
| 116 | +import matplotlib.pylab as P |
| 117 | +import torch as T |
| 118 | + |
| 119 | +def h_poly_helper(tt): |
| 120 | + A = T.tensor([ |
| 121 | + [1, 0, -3, 2], |
| 122 | + [0, 1, -2, 1], |
| 123 | + [0, 0, 3, -2], |
| 124 | + [0, 0, -1, 1] |
| 125 | + ], dtype=tt[-1].dtype) |
| 126 | + return [ |
| 127 | + sum( A[i, j]*tt[j] for j in range(4) ) |
| 128 | + for i in range(4) ] |
| 129 | + |
| 130 | +def h_poly(t): |
| 131 | + tt = [ None for _ in range(4) ] |
| 132 | + tt[0] = 1 |
| 133 | + for i in range(1, 4): |
| 134 | + tt[i] = tt[i-1]*t |
| 135 | + return h_poly_helper(tt) |
| 136 | + |
| 137 | +def H_poly(t): |
| 138 | + tt = [ None for _ in range(4) ] |
| 139 | + tt[0] = t |
| 140 | + for i in range(1, 4): |
| 141 | + tt[i] = tt[i-1]*t*i/(i+1) |
| 142 | + return h_poly_helper(tt) |
| 143 | + |
| 144 | +def interp_cubic_hermite_spline(x, y, xs): |
| 145 | + """ |
| 146 | + :param x: tensor |
| 147 | + :param y: tensor |
| 148 | + :param xs: tensor |
| 149 | + :return: |
| 150 | + """ |
| 151 | + if isinstance(xs, T.Tensor): |
| 152 | + xs_np = xs.data.cpu().numpy() |
| 153 | + xs_np = float(xs_np) |
| 154 | + else: |
| 155 | + xs_np = float(xs) |
| 156 | + xs = T.tensor(xs_np).to(y.device) |
| 157 | + |
| 158 | + x_tmp = (x[1:] - x[:-1]) |
| 159 | + if x_tmp == 0: |
| 160 | + return y[0].unsqueeze(0) |
| 161 | + |
| 162 | + if y.dim() > 1: |
| 163 | + x_tmp = x_tmp.view([-1]+[1]*(y.dim()-1)) |
| 164 | + m = (y[1:] - y[:-1])/ x_tmp |
| 165 | + m = T.cat([m[[0]], (m[1:] + m[:-1])/2, m[[-1]]]) |
| 166 | + |
| 167 | + I = P.searchsorted(x[1:].data.cpu().numpy(), xs_np) |
| 168 | + if isinstance(I, P.int64): |
| 169 | + I = P.array([I]) |
| 170 | + I[I== (x.shape[0]-1)] = I[I== (x.shape[0]-1)] - 2 |
| 171 | + dx = (x[I+1]-x[I]) |
| 172 | + hh = h_poly((xs.expand_as(x[I])-x[I])/dx) |
| 173 | + |
| 174 | + if y.dim() > 1: |
| 175 | + hh = [tmp.view([-1]+[1]*(y.dim()-1)) for tmp in hh] |
| 176 | + dx = dx.view([-1]+[1]*(y.dim()-1)) |
| 177 | + return hh[0]*y[I] + hh[1]*m[I]*dx + hh[2]*y[I+1] + hh[3]*m[I+1]*dx |
| 178 | + |
| 179 | +def integ(x, y, xs): |
| 180 | + x_tmp = (x[1:] - x[:-1]) |
| 181 | + if y.dim() > 1: |
| 182 | + x_tmp = x_tmp.view([-1] + [1] * (y.dim() - 1)) |
| 183 | + m = (y[1:] - y[:-1])/ x_tmp |
| 184 | + m = T.cat([m[[0]], (m[1:] + m[:-1])/2, m[[-1]]]) |
| 185 | + I = P.searchsorted(x[1:], xs) |
| 186 | + I[I == (x.shape[0] - 1)] = I[I == (x.shape[0] - 1)] - 2 |
| 187 | + Y = T.zeros_like(y) |
| 188 | + Y[1:] = x_tmp*( |
| 189 | + (y[:-1]+y[1:])/2 + (m[:-1] - m[1:])*x_tmp/12 |
| 190 | + ) |
| 191 | + Y = Y.cumsum(0) |
| 192 | + dx = (x[I+1]-x[I]) |
| 193 | + hh = H_poly((xs-x[I])/dx) |
| 194 | + if y.dim() > 1: |
| 195 | + hh = [tmp.view([-1]+[1]*(y.dim()-1)) for tmp in hh] |
| 196 | + dx = dx.view([-1]+[1]*(y.dim()-1)) |
| 197 | + return Y[I] + dx*( |
| 198 | + hh[0]*y[I] + hh[1]*m[I]*dx + hh[2]*y[I+1] + hh[3]*m[I+1]*dx |
| 199 | + ) |
| 200 | + |
| 201 | +def _is_iterable(inputs): |
| 202 | + try: |
| 203 | + iter(inputs) |
| 204 | + return True |
| 205 | + except TypeError: |
| 206 | + return False |
| 207 | + |
0 commit comments