Skip to content

Commit

Permalink
remove changes in other scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiwen Yuan committed Nov 17, 2024
1 parent 1845d1b commit 0b97aaa
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
4 changes: 2 additions & 2 deletions examples/contextgnn_sample_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--eval_epochs_interval", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=512)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--channels", type=int, default=128)
parser.add_argument("--aggr", type=str, default="sum")
parser.add_argument("--num_layers", type=int, default=4)
parser.add_argument("--num_neighbors", type=int, default=128)
parser.add_argument("--temporal_strategy", type=str, default="last")
parser.add_argument("--max_steps_per_epoch", type=int, default=2000)
parser.add_argument("--max_steps_per_epoch", type=int, default=200)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--cache_dir", type=str,
Expand Down
3 changes: 1 addition & 2 deletions examples/relbench_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,7 @@ def train() -> float:
loss = F.binary_cross_entropy_with_logits(out, target)
numel = out.numel()
elif args.model in ['contextgnn', 'shallowrhsgnn']:
logits = model(batch, task.src_entity_table, task.dst_entity_table,
dst_index)
logits = model(batch, task.src_entity_table, task.dst_entity_table)
edge_label_index = torch.stack([src_batch, dst_index], dim=0)
loss = sparse_cross_entropy(logits, edge_label_index)
numel = len(batch[task.dst_entity_table].batch)
Expand Down

0 comments on commit 0b97aaa

Please sign in to comment.