Skip to content

Commit

Permalink
fix code
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiwen Yuan committed Oct 21, 2024
1 parent 835b0a6 commit 599e431
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions examples/ijcai_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os.path as osp
import pickle
import warnings
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -60,8 +60,8 @@
data = HeteroData()


def calculate_hit_rate(pred, target):
r"""This function calculates hit_rate
def calculate_hit_rate(pred: torch.Tensor, target: List[Optional[int]]):
r"""Calculates hit rate when pred is a tensor and target is a list
Args:
pred (torch.Tensor): Prediction tensor of size (num_entity,
num_target_predicitons_per_entity).
Expand All @@ -81,12 +81,18 @@ def calculate_hit_rate(pred, target):
return hits / total


def calculate_hit_rate_on_sparse_target(pred, target):
def calculate_hit_rate_on_sparse_target(pred: torch.Tensor,
target: torch.sparse.Tensor):
r"""Calculates hit rate when pred is a tensor and target is a sparse
tensor
Args:
pred (torch.Tensor): Prediction tensor of size (num_entity,
num_target_predicitons_per_entity).
target (torch.sparse.Tensor): Target sparse tensor.
"""
crow_indices = target.crow_indices()
col_indices = target.col_indices()
values = target.values()
import pdb
pdb.set_trace()
# Iterate through each row and check if predictions match ground truth
hits = 0
num_rows = val_pred.shape[0]
Expand All @@ -95,11 +101,13 @@ def calculate_hit_rate_on_sparse_target(pred, target):
# Get the ground truth indices for this row
row_start = crow_indices[i].item()
row_end = crow_indices[i + 1].item()
true_indices = col_indices[row_start:row_end].tolist()
dst_indices = col_indices[row_start:row_end]
bool_indices = values[row_start:row_end]
true_indices = dst_indices[bool_indices].tolist()

# Check if any of the predicted values match the true indices
pred_indices = pred[i]
if any(pred in true_indices for pred in pred_indices):
if torch.isin(pred_indices, true_indices).any():
hits += 1

# Callculate hit rate
Expand Down

0 comments on commit 599e431

Please sign in to comment.