Skip to content

Commit

Permalink
Move HybridGNN to device to avoid device mis-match error in `relben…
Browse files Browse the repository at this point in the history
…ch_example.py` (#10)

* fix device

* update

* update
  • Loading branch information
akihironitta authored Jul 31, 2024
1 parent 0cca3cf commit dac6427
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions examples/relbench_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,17 @@
aggr=args.aggr,
norm="layer_norm",
).to(device)
elif args.model == 'hybridgnn':
model = HybridGNN(data=data, col_stats_dict=col_stats_dict,
num_nodes=num_dst_nodes_dict["train"],
num_layers=args.num_layers, channels=args.channels,
aggr="sum", norm="layer_norm", embedding_dim=64)
elif args.model == "hybridgnn":
model = HybridGNN(
data=data,
col_stats_dict=col_stats_dict,
num_nodes=num_dst_nodes_dict["train"],
num_layers=args.num_layers,
channels=args.channels,
aggr="sum",
norm="layer_norm",
embedding_dim=64,
).to(device)
else:
raise ValueError(f"Unsupported model type {args.model}.")

Expand All @@ -142,7 +148,7 @@ def train() -> float:
steps = 0
total_steps = min(len(loader_dict["train"]), args.max_steps_per_epoch)
sparse_tensor = SparseTensor(dst_nodes_dict["train"][1], device=device)
for batch in tqdm(loader_dict["train"], total=total_steps):
for batch in tqdm(loader_dict["train"], total=total_steps, desc="Train"):
batch = batch.to(device)

# Get ground-truth
Expand Down Expand Up @@ -192,23 +198,22 @@ def train() -> float:


@torch.no_grad()
def test(loader: NeighborLoader) -> np.ndarray:
def test(loader: NeighborLoader, desc: str) -> np.ndarray:
model.eval()

pred_list: List[Tensor] = []
for batch in tqdm(loader):
for batch in tqdm(loader, desc=desc):
batch = batch.to(device)
batch_size = batch[task.src_entity_table].batch_size

if args.model == 'idgnn':
if args.model == "idgnn":
out = (model.forward(batch, task.src_entity_table,
task.dst_entity_table).detach().flatten())
scores = torch.zeros(batch_size, task.num_dst_nodes,
device=out.device)
scores[batch[task.dst_entity_table].batch,
batch[task.dst_entity_table].n_id] = torch.sigmoid(out)
elif args.model == 'hybridgnn':
# Get ground-truth
elif args.model == "hybridgnn":
out = model(batch, task.src_entity_table,
task.dst_entity_table).detach()
scores = torch.sigmoid(out)
Expand All @@ -226,21 +231,21 @@ def test(loader: NeighborLoader) -> np.ndarray:
for epoch in range(1, args.epochs + 1):
train_loss = train()
if epoch % args.eval_epochs_interval == 0:
val_pred = test(loader_dict["val"])
val_pred = test(loader_dict["val"], desc="Val")
val_metrics = task.evaluate(val_pred, task.get_table("val"))
print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, "
f"Val metrics: {val_metrics}")

if val_metrics[tune_metric] > best_val_metric:
best_val_metric = val_metrics[tune_metric]
state_dict = copy.deepcopy(model.state_dict())
state_dict = {k: v.cpu() for k, v in model.state_dict().items()}

assert state_dict is not None
model.load_state_dict(state_dict)
val_pred = test(loader_dict["val"])
val_pred = test(loader_dict["val"], desc="Best val")
val_metrics = task.evaluate(val_pred, task.get_table("val"))
print(f"Best Val metrics: {val_metrics}")
print(f"Best val metrics: {val_metrics}")

test_pred = test(loader_dict["test"])
test_pred = test(loader_dict["test"], desc="Test")
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")

0 comments on commit dac6427

Please sign in to comment.