@@ -234,7 +234,6 @@ def build_input_from_rows(self, rows):
234234
235235 df = df .groupby (["subject" , "object" , "type" ], as_index = False ).agg ({
236236 "hash" : list ,
237- "statement" : list ,
238237 "source_count" : self .merge_source_count_dicts ,
239238 "rel_evidence" : "sum" ,
240239 "in_signor" : "max" ,
@@ -297,35 +296,36 @@ def predict_from_rows(self, rows):
297296 """
298297 A list of relation records. Each record is a dict with keys:
299298 ``subject`` (str), ``object`` (str), ``type`` (str),
300- ``hash`` (int), ``statement`` (str), and
301- ``source_count`` (dict[str, int]).
299+ ``hash`` (int), and ``source_count`` (dict[str, int]).
302300 e.g. rows = [
303301 {
304302 "subject": "MAP2K1",
305303 "object": "MAPK1",
306304 "type": "Phosphorylation",
307305 "hash": 123,
308- "statement": "MAP2K1 phosphorylates MAPK1.",
309306 "source_count": {"reach": 3, "sparser": 1},
310307 }, ...]
311308 """
312309
313310 df_input = self .build_input_from_rows (rows )
314311
315312 if df_input .empty :
316- return df_input
313+ return {}
317314
318315 X = df_input [self .feature_cols ].copy ()
319316
320317 df_input ["pred_prob" ] = self .model .predict_proba (X )[:, 1 ]
321318 df_input ["pred_label" ] = self .model .predict (X )
322319
323- first_cols = ["subject" , "object" , "type" , "pred_prob" , "pred_label" ]
324- other_cols = [col for col in df_input .columns if col not in first_cols ]
325-
326- df_input = df_input [first_cols + other_cols ]
327-
328- return df_input .sort_values ("pred_prob" , ascending = False )
320+ hash_to_label = {}
321+ for _ , row in df_input .iterrows ():
322+ label = int (row ["pred_label" ])
323+ hashes = row ["hash" ]
324+ if not isinstance (hashes , list ):
325+ hashes = [hashes ]
326+ for stmt_hash in hashes :
327+ hash_to_label [stmt_hash ] = label
328+ return hash_to_label
329329
330330 def predict_from_hashes (self , hashes ):
331331 """
0 commit comments