Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiwen Yuan committed Oct 21, 2024
1 parent 5856967 commit 909692f
Showing 1 changed file with 41 additions and 27 deletions.
68 changes: 41 additions & 27 deletions examples/ijcai_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
choices=["contextgnn", "idgnn", "shallowrhsgnn"],
)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--eval_epochs_interval", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=512)
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--channels", type=int, default=128)
parser.add_argument("--aggr", type=str, default="sum")
parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument("--num_layers", type=int, default=6)
parser.add_argument("--num_neighbors", type=int, default=128)
parser.add_argument("--temporal_strategy", type=str, default="last")
parser.add_argument("--max_steps_per_epoch", type=int, default=10)
parser.add_argument("--max_steps_per_epoch", type=int, default=2000)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--eval_k", type=int, default=10)
parser.add_argument("--seed", type=int, default=42)
Expand All @@ -49,6 +49,7 @@
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
torch.set_num_threads(1)
torch.cuda.empty_cache()
seed_everything(args.seed)

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
Expand All @@ -61,22 +62,38 @@


def calculate_hit_rate(pred, target):
"""
Calculate the Hit Rate (HR) given predicted and target values.
Args:
pred (np.ndarray): 2D numpy array of shape (num_users, num_preds), predicted values.
target (np.ndarray): 2D numpy array of shape (num_users, any_value), target values.
Returns:
float: Hit Rate (HR) as a ratio of users with hits to the total number of users.
"""
# Check if any of the predictions for each user match any of their target values
hits = np.any(np.isin(pred, target), axis=1)

# Calculate the hit rate as the ratio of users with at least one hit
hit_rate = np.mean(hits)

hits = 0
total = 0
for i in range(len(target)):
if target[i] is not None:
total += 1
if target[i] in pred[i]:
hits += 1

return hits/total

def calculate_hit_rate_on_sparse_target(pred, target):
crow_indices = dst_nodes_dict['val'].crow_indices()
col_indices = dst_nodes_dict['val'].col_indices()
values = dst_nodes_dict['val'].values()

# Iterate through each row and check if predictions match ground truth
hits = 0
num_rows = val_pred.shape[0]

for i in range(num_rows):
# Get the ground truth indices for this row
row_start = crow_indices[i].item()
row_end = crow_indices[i + 1].item()
true_indices = col_indices[row_start:row_end].tolist()

# Check if any of the predicted values match the true indices
pred_indices = val_pred[i]
if any(pred in true_indices for pred in pred_indices):
hits += 1

# Callculate hit rate
hit_rate = hits / num_rows
return hit_rate


Expand Down Expand Up @@ -306,13 +323,12 @@ def test(loader: NeighborLoader, desc: str) -> np.ndarray:
state_dict = None
best_val_metric = 0
tune_metric = 'hr'
val_metrics = dict()
for epoch in range(1, args.epochs + 1):
train_loss = train()
if epoch % args.eval_epochs_interval == 0:
val_pred = test(loader_dict["val"], desc="Val")
import pdb
pdb.set_trace()
val_metrics = calculate_hit_rate(val_pred, dst_nodes_dict['val'].to_dense().numpy())
val_metrics[tune_metric] = calculate_hit_rate_on_sparse_target(val_pred, dst_nodes_dict['val'])
print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, "
f"Val metrics: {val_metrics}")

Expand All @@ -323,14 +339,12 @@ def test(loader: NeighborLoader, desc: str) -> np.ndarray:
assert state_dict is not None
model.load_state_dict(state_dict)
val_pred = test(loader_dict["val"], desc="Best val")
val_metrics = calculate_hit_rate(val_pred, dst_nodes_dict['val'].to_dense().numpy())
val_metrics = calculate_hit_rate_on_sparse_target(val_pred, dst_nodes_dict['val'])
print(f"Best val metrics: {val_metrics}")

with open(osp.join(path, 'tst_int'), 'rb') as fs:
mat = pickle.load(fs)
import pdb
pdb.set_trace()

test_pred = test(loader_dict["test"], desc="Test")
test_metrics = task.evaluate(test_pred)
test_metrics = calculate_hit_rate(test_pred, mat)
print(f"Best test metrics: {test_metrics}")

0 comments on commit 909692f

Please sign in to comment.