Skip to content

Commit bdfd369

Browse files
committed
add stlcg
1 parent 55fd4aa commit bdfd369

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

trajdiff/stlcg.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -417,14 +417,19 @@ def _initialize_rnn_cell(self, x):
417417
This requires padding on the signal. Currently, the default is to extend the last value.
418418
TODO: have option on this padding
419419
420-
The initial hidden state is of the form (hidden_state, count). count is needed just for the case with self.interval=[0, np.inf) and distributed=True. Since we are scanning through the sigal and outputing the min/max values incrementally, the distributed min function doesn't apply. If there are multiple min values along the signal, the gradient will be distributed equally across them. Otherwise it will only apply to the value that occurs earliest as we scan through the signal (i.e., practically, the last value in the trace as we process the signal backwards).
420+
The initial hidden state is of the form (hidden_state, count). count is needed just for the case with self.interval=[0, np.inf)
421+
and distributed=True. Since we are scanning through the sigal and outputing the min/max values incrementally,
422+
the distributed min function doesn't apply. If there are multiple min values along the signal, the gradient will be distributed
423+
equally across them. Otherwise it will only apply to the value that occurs earliest as we scan through the signal
424+
(i.e., practically, the last value in the trace as we process the signal backwards).
421425
"""
422426
raise NotImplementedError("_initialize_rnn_cell is not implemented")
423427

424428
def _rnn_cell(self, x, hc, scale=-1, agm=False, distributed=False, **kwargs):
425429
"""
426430
x: rnn input [batch_size, 1, ...]
427-
h0: input rnn hidden state. The hidden state is either a tensor, or a tuple of tensors, depending on the interval chosen. Generally, the hidden state is of size [batch_size, rnn_dim,...]
431+
h0: input rnn hidden state. The hidden state is either a tensor, or a tuple of tensors, depending on the interval chosen.
432+
Generally, the hidden state is of size [batch_size, rnn_dim,...]
428433
"""
429434
raise NotImplementedError("_initialize_rnn_cell is not implemented")
430435

@@ -487,6 +492,11 @@ def _initialize_rnn_cell(self, x):
487492
if x.is_cuda:
488493
self.M = self.M.cuda()
489494
self.b = self.b.cuda()
495+
# x = tensor([-1.4800]) with shape [1]
496+
# needs to have 3 elements in the shape to work
497+
# rnn dim is 1
498+
# print(x)
499+
# print(x.shape)
490500
h0 = (
491501
torch.ones([x.shape[0], self.rnn_dim, x.shape[2]], device=x.device)
492502
* x[:, :1, :]

0 commit comments

Comments
 (0)