Skip to content

Commit 557ea72

Browse files
committed
return hash tp label dict from predict_from_rows and remove statement information
1 parent cb68f8c commit 557ea72

1 file changed

Lines changed: 11 additions & 11 deletions

File tree

indra/statements/classifier.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,6 @@ def build_input_from_rows(self, rows):
234234

235235
df = df.groupby(["subject", "object", "type"], as_index=False).agg({
236236
"hash": list,
237-
"statement": list,
238237
"source_count": self.merge_source_count_dicts,
239238
"rel_evidence": "sum",
240239
"in_signor": "max",
@@ -297,35 +296,36 @@ def predict_from_rows(self, rows):
297296
"""
298297
A list of relation records. Each record is a dict with keys:
299298
``subject`` (str), ``object`` (str), ``type`` (str),
300-
``hash`` (int), ``statement`` (str), and
301-
``source_count`` (dict[str, int]).
299+
``hash`` (int), and ``source_count`` (dict[str, int]).
302300
e.g. rows = [
303301
{
304302
"subject": "MAP2K1",
305303
"object": "MAPK1",
306304
"type": "Phosphorylation",
307305
"hash": 123,
308-
"statement": "MAP2K1 phosphorylates MAPK1.",
309306
"source_count": {"reach": 3, "sparser": 1},
310307
}, ...]
311308
"""
312309

313310
df_input = self.build_input_from_rows(rows)
314311

315312
if df_input.empty:
316-
return df_input
313+
return {}
317314

318315
X = df_input[self.feature_cols].copy()
319316

320317
df_input["pred_prob"] = self.model.predict_proba(X)[:, 1]
321318
df_input["pred_label"] = self.model.predict(X)
322319

323-
first_cols = ["subject", "object", "type", "pred_prob", "pred_label"]
324-
other_cols = [col for col in df_input.columns if col not in first_cols]
325-
326-
df_input = df_input[first_cols + other_cols]
327-
328-
return df_input.sort_values("pred_prob", ascending=False)
320+
hash_to_label = {}
321+
for _, row in df_input.iterrows():
322+
label = int(row["pred_label"])
323+
hashes = row["hash"]
324+
if not isinstance(hashes, list):
325+
hashes = [hashes]
326+
for stmt_hash in hashes:
327+
hash_to_label[stmt_hash] = label
328+
return hash_to_label
329329

330330
def predict_from_hashes(self, hashes):
331331
"""

0 commit comments

Comments
 (0)