Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Candidate entities for scoring in pipeline #40

Merged
merged 3 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions besskge/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
class AllScoresPipeline(torch.nn.Module):
"""
Pipeline to compute scores of (h, r, ?) / (?, r, t) queries against all entities
in the KG, and related prediction metrics.
It supports filtering out the scores of specific completions that appear in a given
set of triples.
in the KG (or a given subset of entities), and related prediction metrics.
It supports filtering out, for each query, the scores of specific completions that
appear in a given set of triples.

To be used in combination with a batch sampler based on a
"h_shard"/"t_shard"-partitioned triple set.
Expand All @@ -38,6 +38,7 @@ def __init__(
score_fn: BaseScoreFunction,
evaluation: Optional[Evaluation] = None,
filter_triples: Optional[List[Union[torch.Tensor, NDArray[np.int32]]]] = None,
candidate_ents: Optional[Union[torch.Tensor, NDArray[np.int32]]] = None,
return_scores: bool = False,
return_topk: bool = False,
k: int = 10,
Expand All @@ -62,6 +63,12 @@ def __init__(
The set of all triples whose scores need to be filtered.
The triples passed here must have GLOBAL IDs for head/tail
entities. Default: None.
:param candidate_ents:
If specified, score queries only against a given set of entities.
This array needs to contain the global IDs of the
candidate entities to be used for completion. All other entities
will then be ignored when scoring queries.
Default: None (i.e. score queries against all entities).
:param return_scores:
If True, store and return scores of all queries' completions
(with filters applied, if specified).
Expand Down Expand Up @@ -165,6 +172,13 @@ def __init__(
],
dim=0,
)
self.candidate_mask: Optional[torch.Tensor] = None
if candidate_ents is not None:
self.candidate_mask = torch.from_numpy(
np.setdiff1d(
np.arange(self.bess_module.sharding.n_entity), candidate_ents
)
)

def forward(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -231,6 +245,10 @@ def forward(self) -> Dict[str, Any]:
batch_scores_filt = batch_scores[triple_mask.flatten()][
:, np.unique(np.concatenate(batch_idx), return_index=True)[1]
][:, : self.bess_module.sharding.n_entity]
if self.candidate_mask is not None:
# Filter scores for entities that are not in
# the given set of canidates
batch_scores_filt[:, self.candidate_mask] = -torch.inf
if ground_truth is not None:
# Scores of positive triples
true_scores = batch_scores_filt[
Expand Down
43 changes: 35 additions & 8 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,19 @@
test_triples_r = np.random.randint(n_relation_type, size=n_test_triple)
triples = {"test": np.stack([test_triples_h, test_triples_r, test_triples_t], axis=1)}

compl_candidates = np.arange(0, n_entity - 1, step=5)


@pytest.mark.parametrize("corruption_scheme", ["h", "t"])
@pytest.mark.parametrize(
"filter_scores, extra_only", [(True, True), (True, False), (False, False)]
)
@pytest.mark.parametrize("filter_candidates", [True, False])
def test_all_scores_pipeline(
corruption_scheme: str, filter_scores: bool, extra_only: bool
corruption_scheme: str,
filter_scores: bool,
extra_only: bool,
filter_candidates: bool,
) -> None:
ds = KGDataset(
n_entity=n_entity,
Expand Down Expand Up @@ -104,6 +110,7 @@ def test_all_scores_pipeline(
score_fn,
evaluation,
filter_triples=triples_to_filter, # type: ignore
candidate_ents=compl_candidates if filter_candidates else None,
return_scores=True,
return_topk=True,
k=10,
Expand Down Expand Up @@ -136,6 +143,14 @@ def test_all_scores_pipeline(
triple_reordered[:, 1],
unsharded_entity_table[triple_reordered[:, 2]],
).flatten()
if filter_candidates:
# positive score -inf if ground truth not in candidate list
pos_scores[
torch.from_numpy(
~np.in1d(triple_reordered[:, ground_truth_col], compl_candidates)
)
] = -torch.inf

# mask positive scores to compute metrics
cpu_scores[
torch.arange(cpu_scores.shape[0]), triple_reordered[:, ground_truth_col]
Expand All @@ -161,19 +176,31 @@ def test_all_scores_pipeline(
assert torch.all(
tr_filter[1::2, 1] == triple_reordered[:, ground_truth_col] + 1
)
if filter_candidates:
cand_mask = np.setdiff1d(np.arange(cpu_scores.shape[-1]), compl_candidates)
cpu_scores[:, cand_mask] = -torch.inf

cpu_ranks = evaluation.ranks_from_scores(pos_scores, cpu_scores)
# we allow for a off-by-one rank difference on at most 1% of triples,
# due to rounding differences in CPU vs IPU score computations
assert torch.all(torch.abs(cpu_ranks - out["ranks"]) <= 1)
assert (cpu_ranks != out["ranks"]).sum() < n_test_triple / 100

# restore positive scores
cpu_scores[
torch.arange(cpu_scores.shape[0]), triple_reordered[:, ground_truth_col]
] = pos_scores

cpu_preds = torch.topk(cpu_scores, k=pipeline.k, dim=-1).indices
assert torch.all(cpu_preds == out["topk_global_id"])

if filter_candidates:
# check that all predictions are in set of candidates
assert np.all(np.in1d(out["topk_global_id"], compl_candidates))
assert np.all(np.in1d(cpu_preds, compl_candidates))

cpu_scores = cpu_scores[:, compl_candidates]
out["scores"] = out["scores"][:, compl_candidates]

assert_close(cpu_scores, out["scores"], atol=1e-3, rtol=1e-4)
assert torch.all(
torch.topk(cpu_scores, k=pipeline.k, dim=-1).indices == out["topk_global_id"]
)

# we allow for a off-by-one rank difference on at most 1% of triples,
# due to rounding differences in CPU vs IPU score computations
assert torch.all(torch.abs(cpu_ranks - out["ranks"]) <= 1)
assert (cpu_ranks != out["ranks"]).sum() < n_test_triple / 100
Loading