From 5e92a132cc3f50743dde3737f588ee0dcf70f043 Mon Sep 17 00:00:00 2001 From: Zecheng Zhang Date: Mon, 30 Sep 2024 00:19:03 +0000 Subject: [PATCH] Update --- examples/ngcf.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/ngcf.py b/examples/ngcf.py index 7f70f64..ba87139 100644 --- a/examples/ngcf.py +++ b/examples/ngcf.py @@ -94,8 +94,7 @@ def reset_parameters(self): def sparse_dropout(self, row: Tensor, col: Tensor, value: Tensor, rate: float, nnz: int) -> SparseTensor: - rand = 1 - rate - rand += torch.rand(nnz) + rand: Tensor = (1 - rate) + torch.rand(nnz) assert isinstance(rand, Tensor) dropout_mask = torch.floor(rand).type(torch.bool) adj = SparseTensor( @@ -131,8 +130,8 @@ def get_embedding(self, norm_adj: SparseTensor, ego_emb = F.dropout(ego_emb) norm_emb = F.normalize(ego_emb, p=2, dim=1) all_embs += [norm_emb] - all_embs: Tensor = torch.cat(all_embs, 1) - return all_embs + res_embs: Tensor = torch.cat(all_embs, 1) + return res_embs def recommendation_loss( self,