Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/bioid_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def get_results_tables(
total.loc[:, 'entity_type'] = 'Total'
stats = res_df.groupby('entity_type', as_index=False).sum()
stats = stats[stats['entity_type'] != 'unknown']
stats = stats.append(total, ignore_index=True)
stats = pd.concat([stats, total], ignore_index=True)
stats.loc[:, stats.columns[1:]] = stats[stats.columns[1:]].astype(int)
if match == 'strict':
score_cols = ['top_correct', 'exists_correct']
Expand Down
214 changes: 195 additions & 19 deletions benchmarks/bioid_ner_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import json
import pathlib
import logging
import pickle
from datetime import datetime
from collections import defaultdict, Counter
import xml.etree.ElementTree as ET
from textwrap import dedent
from typing import List, Dict

import pystow
import pandas as pd
from tqdm import tqdm
Expand All @@ -16,7 +16,7 @@
from gilda.ner import annotate

#from benchmarks.bioid_evaluation import fplx_members
from benchmarks.bioid_evaluation import BioIDBenchmarker
from bioid_evaluation import BioIDBenchmarker

logging.getLogger('gilda.grounder').setLevel('WARNING')
logger = logging.getLogger('bioid_ner_benchmark')
Expand Down Expand Up @@ -51,6 +51,7 @@ def __init__(self):
self.counts_table = None
self.precision_recall = None
self.false_positives_counter = Counter()
self.result = None

def process_xml_files(self):
"""Extract relevant information from XML files."""
Expand All @@ -60,7 +61,8 @@ def process_xml_files(self):
for filename in os.listdir(DATA_DIR):
if filename.endswith('.xml'):
filepath = os.path.join(DATA_DIR, filename)
tree = ET.parse(filepath)
with open(filepath, 'r', encoding='utf-8') as file:
tree = ET.parse(file)
root = tree.getroot()
for document in root.findall('.//document'):
doc_id_full = document.find('.//id').text.strip()
Expand Down Expand Up @@ -145,11 +147,11 @@ def annotate_entities_with_gilda(self):
doc_id = item['doc_id']
figure = item['figure']
text = item['text']

# Get the full text for the paper-level disambiguation
full_text = self._get_plaintext(doc_id)

gilda_annotations = annotate(text, context_text=full_text)
gilda_annotations = annotate(text, context_text=full_text,
organisms=self._get_organism_priority(doc_id))

for annotation in gilda_annotations:
total_gilda_annotations += 1
Expand Down Expand Up @@ -199,32 +201,29 @@ def evaluate_gilda_performance(self):
if match_found:
break

if not match_found:
if not match_found and matching_refs != []:
metrics['all_matches']['fp'] += 1
self.false_positives_counter[annotation.text] += 1
if annotation.matches: # Check if there are any matches
metrics['top_match']['fp'] += 1

# False negative calculation using ref dict
# The number entries of annotion in reference with no annotion in grounding
for key, refs in tqdm(ref_dict.items(),
desc="Calculating False Negatives"):
doc_id, figure = key[0], key[1]
gilda_annotations = self.gilda_annotations_map.get((doc_id, figure),
[])
for original_curies, synonyms in refs:
match_found = any(
ann.text == key[2] and
ann.start == key[3] and
ann.end == key[4] and
any(f"{match.term.db}:{match.term.id}" in original_curies or
f"{match.term.db}:{match.term.id}" in synonyms
for match in ann.matches)
for ann in gilda_annotations
match_found = any(
ann.text == key[2] and
ann.start == key[3] and
ann.end == key[4]
for ann in gilda_annotations
)

if not match_found:
metrics['all_matches']['fn'] += 1
metrics['top_match']['fn'] += 1
if not match_found:
metrics['all_matches']['fn'] += 1
metrics['top_match']['fn'] += 1

results = {}
for match_type, counts in metrics.items():
Expand Down Expand Up @@ -292,13 +291,174 @@ def get_tables(self):
self.precision_recall,
self.false_positives_counter)

def check_match(self, row):
"""
Check if 'obj' or 'obj_synonyms' match any grounding identifiers.

Parameters:
- row (pd.Series): A DataFrame row with 'obj', 'obj_synonyms', and 'groundings'.

Returns:
- bool: True if a match is found, else False.
"""
obj=row['obj']
obj_synonyms = row['obj_synonyms']
groundings = row['groundings']
if obj_synonyms is None or groundings is None:
return False
for elem in obj_synonyms:
for tup in groundings:
if elem == tup[0]:
return True
for elem in obj:
for tup in groundings:
if elem == tup[0]:
return True
return False

def generate_result_table(self):
"""
Generates a results DataFrame by aligning Gilda annotation matches with
reference annotations.
"""
ref_dict = defaultdict()

for _, row in self.annotations_df.iterrows():
key = (str(row['don_article']), row['figure'], row['text'],
row['first left'], row['last right'])
ref_dict[key] = (row['obj'], row['obj_synonyms'])

text_list, obj_synonyms_list, don_articles_list = [], [], []
groundings_list, entity_type_list, obj_list = [], [], []
figure_list = []
all_annotation = {}

for (doc_id, figure), annotations in (
tqdm(self.gilda_annotations_map.items(),
desc="Getting result")):
for annotation in annotations:
key = (doc_id, figure, annotation.text, annotation.start,
annotation.end)
all_annotation[key] = annotation

matching_refs = ref_dict.get(key, None)

groundings = []
if matching_refs:
obj = matching_refs[0]
obj_synonyms = matching_refs[1]
else:
obj, obj_synonyms = None, None

text = annotation.text
for scored_match in annotation.matches:
curies = []
curie = f"{scored_match.term.db}:{scored_match.term.id}"
score = scored_match.score
groundings.append((curie, score))
curies.append(curie)

if obj:
entity_type = self._get_entity_type(obj)
else:
entity_type = None

obj_list.append(obj)
text_list.append(text)
obj_synonyms_list.append(obj_synonyms)
don_articles_list.append(doc_id)
figure_list.append(figure)
entity_type_list.append(entity_type)
groundings_list.append(groundings)

for key, refs in tqdm(ref_dict.items(),
desc="Things in reference but not in grounding"):
doc_id, figure = key[0], key[1]
text, start, end = key[2], key[3], key[4]

if not all_annotation.get((doc_id, figure, text, start, end)):
obj_list.append(refs[0]) # ([i[0] for i in refs])
entity_type = self._get_entity_type(refs[0])
entity_type_list.append(entity_type)
text_list.append(key[2])
figure_list.append(key[1])
obj_synonyms_list.append(refs[1]) # ([i[1] for i in refs])
don_articles_list.append(key[0])
groundings_list.append(None)


data = {
'text': text_list,
'obj': obj_list,
'obj_synonyms': obj_synonyms_list,
'don_article': don_articles_list,
'figure': figure_list,
'entity_type': entity_type_list,
'groundings': groundings_list,
}
self.result = pd.DataFrame(data)
self.result['match'] = self.result.apply(self.check_match, axis=1)
self.result = self.result.sort_values(by=['don_article', 'figure'])
def get_entity_result(self):
"""
Compute precision and recall for entity recognition.

- True Positives (TP): Matched objects with groundings.
- False Negatives (FN): Objects without groundings.
- False Positives (FP): Objects with groundings but no match.

Returns:
- pd.DataFrame: TP, FP, FN counts, precision, and recall per entity type.
"""
df = self.result
#True Positives
df_tp = df[(df['obj'].notna()) & (df['groundings'].notna()) & (df['match'] == True)]
true_positive_counts = df_tp.groupby('entity_type').size().reset_index(name='true_positive_count')
#False Negatives
df_fn = df[(df['obj'].notna()) & (df['groundings'].isna())]
false_negative_counts = df_fn.groupby('entity_type').size().reset_index(name='false_negative_count')
#False Positives
df_fp = df[(df['obj'].notna()) & (df['groundings'].notna()) & (df['match'] == False)]
false_positive_counts = df_fp.groupby('entity_type').size().reset_index(name='false_positive_count')
#Merge
merged_df = pd.merge(true_positive_counts, false_negative_counts, on='entity_type', how='outer').fillna(0)
merged_df = pd.merge(merged_df, false_positive_counts, on='entity_type', how='outer').fillna(0)
#Recall
merged_df['recall'] = merged_df['true_positive_count'] / (
merged_df['true_positive_count'] + merged_df['false_negative_count'])
#Precision
merged_df['precision'] = merged_df['true_positive_count'] / (
merged_df['true_positive_count'] + merged_df['false_positive_count'])

total_tp = merged_df['true_positive_count'].sum()
total_fn = merged_df['false_negative_count'].sum()
total_fp = merged_df['false_positive_count'].sum()
total_recall = total_tp / (total_tp + total_fn)
total_precision = total_tp / (total_tp + total_fp)

total_row = pd.DataFrame({
'entity_type': ['Total'],
'true_positive_count': [total_tp],
'false_negative_count': [total_fn],
'false_positive_count': [total_fp],
'recall': [total_recall],
'precision': [total_precision]
})
final_df = pd.concat([merged_df, total_row], ignore_index=True)
final_df = final_df[
['entity_type', 'true_positive_count', 'false_positive_count', 'false_negative_count', 'precision',
'recall']]
return final_df


def main(results: str = RESULTS_DIR):
results_path = os.path.expandvars(os.path.expanduser(results))
os.makedirs(results_path, exist_ok=True)

benchmarker = BioIDNERBenchmarker()
benchmarker.annotate_entities_with_gilda()
df = pd.DataFrame(list(benchmarker.gilda_annotations_map.items()), columns=['Key', 'Value'])
benchmarker.generate_result_table()
benchmarker.evaluate_gilda_performance()
counts, precision_recall, false_positives_counter = benchmarker.get_tables()

Expand All @@ -311,6 +471,7 @@ def main(results: str = RESULTS_DIR):

outname = f'benchmark_{time}'
result_stub = pathlib.Path(results_path).joinpath(outname)
entity_result = benchmarker.get_entity_result()

caption0 = dedent(f"""\
# Gilda NER Benchmarking
Expand Down Expand Up @@ -339,6 +500,16 @@ def main(results: str = RESULTS_DIR):
table2 = precision_recall.to_markdown(index=False)

caption3 = dedent("""\
## Table 3

Precision and recall values for Gilda performance by entity type. Values
are given both for the case where Gilda is considered correct only if the
top grounding matches and the case where Gilda is considered correct if
any of its groundings match.
""")
table_by_entity = entity_result.to_markdown(index=False)

caption4 = dedent("""\
## 50 Most Common False Positive Words

A list of 50 most common false positive annotations created by Gilda.
Expand All @@ -351,7 +522,8 @@ def main(results: str = RESULTS_DIR):
caption0,
caption1, table1,
caption2, table2,
caption3, false_positives_list
caption3, table_by_entity,
caption4, false_positives_list
])

md_path = result_stub.with_suffix(".md")
Expand All @@ -361,6 +533,10 @@ def main(results: str = RESULTS_DIR):
counts.to_csv(result_stub.with_suffix(".counts.csv"), index=False)
precision_recall.to_csv(result_stub.with_suffix(".precision_recall.csv"),
index=False)
benchmarker.result.to_csv(
result_stub.with_suffix(".ner_result.tsv"),
sep='\t', index=False)

print(f'Results saved to {results_path}')


Expand Down
2 changes: 1 addition & 1 deletion gilda/grounder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .term import Term, get_identifiers_curie, get_identifiers_url
from .process import normalize, replace_dashes, replace_greek_uni, \
replace_greek_latin, replace_greek_spelled_out, depluralize, \
replace_roman_arabic
replace_roman_arabic, strip_greek_letters
from .scorer import Match, generate_match, score
from .resources import get_gilda_models, get_grounding_terms

Expand Down
Loading