Skip to content

Commit a59affb

Browse files
authored
Merge pull request #20 from Meteor-Stars/main
Add a new TSG model MNTSG
2 parents 6a6217e + c6d338d commit a59affb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+13754
-0
lines changed

MNTSG/README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# MNTSG
2+
3+
We propose MN-TSG, a framework that explores MOE (Mixture of Experts)-NCDE and integrates it with existing TSG models for irregular or continuous TSG tasks. The key designs of MOE-NCDE are the dynamic functions with mixture of experts and the decoupled design to better optimize the MOE dynamics. Further, we employ the existing TSG model to learn the joint distribution of the mixture of experts and the time series. In this way, the model can not only generate new samples but also produce suitable experts for them to enable MOE-NCDE for refined continuous TSG tasks.
4+
5+
6+
7+
## Environment
8+
Install the environment from the yaml file given here: environment_mntsg.yml
9+
10+
```bash
11+
conda env create -f environment.yml --force --no-deps
12+
```
13+
14+
## Data
15+
Stocks and Energy data are located in /datasets. Sine, MuJoCo, polynomial datasets are generated and the scripts are included in datasets folder.
16+
utils_data.py contains functions to load the data both in regular and irregular setups. Specifically, the irregular data is pre-processed by the TimeDataset_irregular class,
17+
and it might take a while. Once the data pre-processing is done, it is saved in the /datasets folder.
18+
19+
20+
## Reproducing the paper results
21+
By setting the time series length and missing values within the script, the results in the paper can be reproduced:
22+
23+
24+
**Step1:** The MOE NeuralCDE can be trained by `run_irregular_moencde.py`.
25+
26+
---
27+
28+
**Step2:** The training of the joint distribution of MOE expert weights and time series samples can be implemented through script `run_mntsg_diffsuion.py`.
29+
30+
---
31+
32+
**Step3:** By using the initially trained MOE-NeuralCDE, each newly generated sample, and the corresponding generated MOE weights, fine-grained and refined continuous time series generation can be achieved through script `run_irregular_moencde_continues.py`, thereby enhancing the accuracy of downstream tasks with richer temporal information.

MNTSG/TorchDiffEqPack/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .odesolver import odesolve
2+
from .odesolver_mem import odesolve_endtime, odesolve_adjoint, odesolve_adjoint_sym12
3+
__version__ = '0.1.0'

MNTSG/TorchDiffEqPack/misc.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
import torch
4+
import warnings
5+
try:
6+
pytorch_version_one_and_above = int(torch.__version__[0]) > 0
7+
except TypeError:
8+
pytorch_version_one_and_above = True
9+
10+
def norm(x):
11+
"""Compute RMS norm."""
12+
if torch.is_tensor(x):
13+
return x.norm() / (x.numel()**0.5)
14+
else:
15+
return torch.sqrt(sum(x_.norm()**2 for x_ in x) / sum(x_.numel() for x_ in x))
16+
17+
def flatten(iterable):
18+
out = []
19+
for i in iterable:
20+
if hasattr(i,'__iter__') and not isinstance(i, torch.Tensor):
21+
out.extend(flatten(i))
22+
else:
23+
out.append(i)
24+
return out
25+
26+
27+
def delete_local_computation_graph( inputs):
28+
for i in inputs:
29+
#i.set_()
30+
del i
31+
#torch.cuda.empty_cache()
32+
return
33+
34+
def _possibly_nonzero(x):
35+
return isinstance(x, torch.Tensor) or x != 0
36+
37+
def _scaled_dot_product(scale, xs, ys):
38+
"""Calculate a scaled, vector inner product between lists of Tensors."""
39+
# Using _possibly_nonzero lets us avoid wasted computation.
40+
return sum([(scale * x) * y for x, y in zip(xs, ys) if _possibly_nonzero(x) or _possibly_nonzero(y)])
41+
42+
def _convert_to_tensor(a, dtype=None, device=None):
43+
if not isinstance(a, torch.Tensor):
44+
a = torch.tensor(a)
45+
if dtype is not None:
46+
a = a.type(dtype)
47+
if device is not None:
48+
a = a.to(device)
49+
return a
50+
51+
def _dot_product(xs, ys):
52+
"""Calculate the vector inner product between two lists of Tensors."""
53+
return sum([x * y for x, y in zip(xs, ys)])
54+
55+
def _interp_fit(y0, y1, y_mid, f0, f1, dt):
56+
"""Fit coefficients for 4th order polynomial interpolation.
57+
Args:
58+
y0: function value at the start of the interval.
59+
y1: function value at the end of the interval.
60+
y_mid: function value at the mid-point of the interval.
61+
f0: derivative value at the start of the interval.
62+
f1: derivative value at the end of the interval.
63+
dt: width of the interval.
64+
Returns:
65+
List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial
66+
`p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x`
67+
between 0 (start of interval) and 1 (end of interval).
68+
"""
69+
a = tuple(
70+
_dot_product([-2 * dt, 2 * dt, -8, -8, 16], [f0_, f1_, y0_, y1_, y_mid_])
71+
for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid)
72+
)
73+
b = tuple(
74+
_dot_product([5 * dt, -3 * dt, 18, 14, -32], [f0_, f1_, y0_, y1_, y_mid_])
75+
for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid)
76+
)
77+
c = tuple(
78+
_dot_product([-4 * dt, dt, -11, -5, 16], [f0_, f1_, y0_, y1_, y_mid_])
79+
for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid)
80+
)
81+
d = tuple(dt * f0_ for f0_ in f0)
82+
e = y0
83+
return [a, b, c, d, e]
84+
85+
86+
def _interp_evaluate(coefficients, t0, t1, t):
87+
"""Evaluate polynomial interpolation at the given time point.
88+
Args:
89+
coefficients: list of Tensor coefficients as created by `interp_fit`.
90+
t0: scalar float64 Tensor giving the start of the interval.
91+
t1: scalar float64 Tensor giving the end of the interval.
92+
t: scalar float64 Tensor giving the desired interpolation point.
93+
Returns:
94+
Polynomial interpolation of the coefficients at time `t`.
95+
"""
96+
97+
dtype = coefficients[0][0].dtype
98+
device = coefficients[0][0].device
99+
100+
t0 = _convert_to_tensor(t0, dtype=dtype, device=device)
101+
t1 = _convert_to_tensor(t1, dtype=dtype, device=device)
102+
t = _convert_to_tensor(t, dtype=dtype, device=device)
103+
104+
assert (t0 <= t) & (t <= t1), 'invalid interpolation, fails `t0 <= t <= t1`: {}, {}, {}'.format(t0, t, t1)
105+
x = ((t - t0) / (t1 - t0)).type(dtype).to(device)
106+
107+
xs = [torch.tensor(1).type(dtype).to(device), x]
108+
for _ in range(2, len(coefficients)):
109+
xs.append(xs[-1] * x)
110+
111+
return tuple(_dot_product(coefficients_, reversed(xs)) for coefficients_ in zip(*coefficients))
112+
113+
114+
# ----------------------------------------------------------------------------------------------------
115+
# cubic hermite spline
116+
import matplotlib.pylab as P
117+
import torch as T
118+
119+
def h_poly_helper(tt):
120+
A = T.tensor([
121+
[1, 0, -3, 2],
122+
[0, 1, -2, 1],
123+
[0, 0, 3, -2],
124+
[0, 0, -1, 1]
125+
], dtype=tt[-1].dtype)
126+
return [
127+
sum( A[i, j]*tt[j] for j in range(4) )
128+
for i in range(4) ]
129+
130+
def h_poly(t):
131+
tt = [ None for _ in range(4) ]
132+
tt[0] = 1
133+
for i in range(1, 4):
134+
tt[i] = tt[i-1]*t
135+
return h_poly_helper(tt)
136+
137+
def H_poly(t):
138+
tt = [ None for _ in range(4) ]
139+
tt[0] = t
140+
for i in range(1, 4):
141+
tt[i] = tt[i-1]*t*i/(i+1)
142+
return h_poly_helper(tt)
143+
144+
def interp_cubic_hermite_spline(x, y, xs):
145+
"""
146+
:param x: tensor
147+
:param y: tensor
148+
:param xs: tensor
149+
:return:
150+
"""
151+
if isinstance(xs, T.Tensor):
152+
xs_np = xs.data.cpu().numpy()
153+
xs_np = float(xs_np)
154+
else:
155+
xs_np = float(xs)
156+
xs = T.tensor(xs_np).to(y.device)
157+
158+
x_tmp = (x[1:] - x[:-1])
159+
if x_tmp == 0:
160+
return y[0].unsqueeze(0)
161+
162+
if y.dim() > 1:
163+
x_tmp = x_tmp.view([-1]+[1]*(y.dim()-1))
164+
m = (y[1:] - y[:-1])/ x_tmp
165+
m = T.cat([m[[0]], (m[1:] + m[:-1])/2, m[[-1]]])
166+
167+
I = P.searchsorted(x[1:].data.cpu().numpy(), xs_np)
168+
if isinstance(I, P.int64):
169+
I = P.array([I])
170+
I[I== (x.shape[0]-1)] = I[I== (x.shape[0]-1)] - 2
171+
dx = (x[I+1]-x[I])
172+
hh = h_poly((xs.expand_as(x[I])-x[I])/dx)
173+
174+
if y.dim() > 1:
175+
hh = [tmp.view([-1]+[1]*(y.dim()-1)) for tmp in hh]
176+
dx = dx.view([-1]+[1]*(y.dim()-1))
177+
return hh[0]*y[I] + hh[1]*m[I]*dx + hh[2]*y[I+1] + hh[3]*m[I+1]*dx
178+
179+
def integ(x, y, xs):
180+
x_tmp = (x[1:] - x[:-1])
181+
if y.dim() > 1:
182+
x_tmp = x_tmp.view([-1] + [1] * (y.dim() - 1))
183+
m = (y[1:] - y[:-1])/ x_tmp
184+
m = T.cat([m[[0]], (m[1:] + m[:-1])/2, m[[-1]]])
185+
I = P.searchsorted(x[1:], xs)
186+
I[I == (x.shape[0] - 1)] = I[I == (x.shape[0] - 1)] - 2
187+
Y = T.zeros_like(y)
188+
Y[1:] = x_tmp*(
189+
(y[:-1]+y[1:])/2 + (m[:-1] - m[1:])*x_tmp/12
190+
)
191+
Y = Y.cumsum(0)
192+
dx = (x[I+1]-x[I])
193+
hh = H_poly((xs-x[I])/dx)
194+
if y.dim() > 1:
195+
hh = [tmp.view([-1]+[1]*(y.dim()-1)) for tmp in hh]
196+
dx = dx.view([-1]+[1]*(y.dim()-1))
197+
return Y[I] + dx*(
198+
hh[0]*y[I] + hh[1]*m[I]*dx + hh[2]*y[I+1] + hh[3]*m[I+1]*dx
199+
)
200+
201+
def _is_iterable(inputs):
202+
try:
203+
iter(inputs)
204+
return True
205+
except TypeError:
206+
return False
207+
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .ode_solver import odesolve

0 commit comments

Comments
 (0)