Skip to content

Commit

Permalink
Update FlashST.py
Browse files Browse the repository at this point in the history
  • Loading branch information
LZH-YS1998 authored Sep 2, 2024
1 parent fadc836 commit 39d4b8e
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions model/FlashST.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,13 @@ def __init__(self, args):
self.pretrain_model = PromptNet(args)


def forward(self, source, label, select_dataset, batch_seen=None, nadj=None, lpls=None, useGNN=False, DSU=True):
def forward(self, source, label, select_dataset, batch_seen=None, nadj=None, lpls=None, useGNN=False):
if self.mode == 'ori':
return self.forward_ori(source, label, select_dataset, batch_seen)
else:
return self.forward_pretrain(source, label, select_dataset, batch_seen, nadj, lpls, useGNN, DSU)
return self.forward_pretrain(source, label, select_dataset, batch_seen, nadj, lpls, useGNN)

def forward_pretrain(self, source, label, select_dataset, batch_seen=None, nadj=None, lpls=None, useGNN=False, DSU=True):
def forward_pretrain(self, source, label, select_dataset, batch_seen=None, nadj=None, lpls=None, useGNN=False):
x_prompt_return = self.pretrain_model(source[..., :self.input_base_dim], source, None, nadj, lpls, useGNN)
if self.model == 'DMSTGCN':
x_predic = self.predictor(x_prompt_return, source[:, 0, 0, 1], select_dataset) # MTGNN
Expand All @@ -164,4 +164,4 @@ def forward_ori(self, source, label=None, select_dataset=None, batch_seen=None):
x_predic = self.predictor(source[..., :self.input_base_dim], source[:, 0, 0, 1], select_dataset) # MTGNN
else:
x_predic = self.predictor(source[..., :self.input_base_dim], select_dataset)
return x_predic, None
return x_predic, None

0 comments on commit 39d4b8e

Please sign in to comment.