Skip to content

Commit f4ff3b4

Browse files
authored
Add quality estimation module (#793)
* first draft of quality estimation module * keep vref consistent * add options for handling multiple files at a time * fix typing issues, help string spaces * improve error handling * improve error handling * undo change * refactor to minimize loop computation; improve file handling; minor bug fixes
1 parent f0e1e8c commit f4ff3b4

File tree

1 file changed

+263
-0
lines changed

1 file changed

+263
-0
lines changed

silnlp/nmt/quality_estimation.py

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
import argparse
2+
import logging
3+
import re
4+
from collections import defaultdict
5+
from dataclasses import dataclass
6+
from math import exp
7+
from pathlib import Path
8+
from typing import List, Optional, Tuple
9+
10+
from machine.scripture import VerseRef
11+
from openpyxl import load_workbook
12+
from scipy.stats import linregress
13+
14+
from silnlp.nmt.config import get_mt_exp_dir
15+
16+
LOGGER = logging.getLogger(__package__ + ".quality_estimation")
17+
18+
19+
@dataclass
20+
class VerseScore:
21+
vref: VerseRef
22+
confidence: float
23+
projected_chrf3: Optional[float] = None
24+
25+
26+
def estimate_quality(diff_predictions_file: Path, confidence_files: List[Path]) -> None:
27+
verse_scores: List[VerseScore] = project_chrf3(diff_predictions_file, confidence_files)
28+
compute_usable_proportions(verse_scores, confidence_files[0].parent)
29+
30+
31+
def project_chrf3(diff_predictions_file: Path, confidence_files: List[Path]) -> List[VerseScore]:
32+
chrf3_scores, confidence_scores = extract_diff_predictions(diff_predictions_file)
33+
if len(chrf3_scores) != len(confidence_scores):
34+
raise ValueError(
35+
f"The number of chrF3 scores ({len(chrf3_scores)}) and confidence scores ({len(confidence_scores)}) "
36+
f"in {diff_predictions_file} do not match."
37+
)
38+
slope, intercept = linregress(confidence_scores, chrf3_scores)[:2]
39+
verse_scores: List[VerseScore] = []
40+
for confidence_file in confidence_files:
41+
file_scores = extract_confidences(confidence_file)
42+
verse_scores += file_scores
43+
with open(confidence_file.with_suffix(".projected_chrf3.tsv"), "w", encoding="utf-8") as output_file:
44+
output_file.write("VRef\tConfidence\tProjected chrF3\n")
45+
for verse_score in verse_scores:
46+
projected_chrf3 = slope * verse_score.confidence + intercept
47+
verse_score.projected_chrf3 = projected_chrf3
48+
output_file.write(f"{verse_score.vref}\t{verse_score.confidence}\t{projected_chrf3:.2f}\n")
49+
return verse_scores
50+
51+
52+
def extract_diff_predictions(diff_predictions_file_path) -> Tuple[List[float], List[float]]:
53+
chrf3_scores = extract_diff_predictions_column(diff_predictions_file_path, "chrf3")
54+
confidence_scores = extract_diff_predictions_column(diff_predictions_file_path, "confidence")
55+
return chrf3_scores, confidence_scores
56+
57+
58+
def extract_diff_predictions_column(file_path: Path, target_header: str) -> List[float]:
59+
wb = load_workbook(file_path)
60+
ws = wb.active
61+
62+
header_row_idx = None
63+
col_idx = None
64+
65+
for row in ws.iter_rows():
66+
for cell in row:
67+
if cell.value == "Score Summary":
68+
break
69+
if str(cell.value).lower() == target_header.lower():
70+
header_row_idx = cell.row
71+
col_idx = cell.column
72+
break
73+
if header_row_idx:
74+
break
75+
76+
if not header_row_idx:
77+
raise ValueError(f"Header '{target_header}' not found.")
78+
79+
data = []
80+
for row in ws.iter_rows(min_row=header_row_idx + 1, min_col=col_idx, max_col=col_idx):
81+
cell_value = row[0].value
82+
if cell_value is not None:
83+
data.append(float(cell_value))
84+
85+
return data
86+
87+
88+
def extract_confidences(input_file_path: Path) -> List[VerseScore]:
89+
current_book = ""
90+
current_chapter = 0
91+
current_verse = 0
92+
is_at_verse_reference = False
93+
94+
vref_confidences: List[VerseScore] = []
95+
with open(input_file_path, "r", encoding="utf-8") as f:
96+
for line in f:
97+
line = line.rstrip("\n")
98+
if line.lower().startswith("vref") or line.lower().startswith("sequence score"):
99+
continue
100+
101+
match = re.match(r"^([0-9A-Z][A-Z]{2}) (\d+):(\d+)(/.*)?", line)
102+
if match:
103+
current_book = match.group(1)
104+
current_chapter = int(match.group(2))
105+
current_verse = int(match.group(3))
106+
extra = match.group(4)
107+
108+
is_at_verse_reference = current_verse != 0 and not extra
109+
elif is_at_verse_reference:
110+
cols = line.split("\t")
111+
if cols:
112+
vref_confidences += [
113+
VerseScore(
114+
VerseRef.from_string(f"{current_book} {current_chapter}:{current_verse}"), float(cols[0])
115+
)
116+
]
117+
return vref_confidences
118+
119+
120+
@dataclass
121+
class UsabilityParameters:
122+
count: float
123+
mean: float
124+
variance: float
125+
126+
127+
def compute_usable_proportions(verse_scores: List[VerseScore], output_dir: Path) -> None:
128+
usable_params, unusable_params = parse_parameters(output_dir / "usability_parameters.tsv")
129+
130+
book_totals = defaultdict(float)
131+
book_counts = defaultdict(int)
132+
chapter_totals = defaultdict(lambda: defaultdict(float))
133+
chapter_counts = defaultdict(lambda: defaultdict(int))
134+
135+
for verse_score in verse_scores:
136+
vref = verse_score.vref
137+
if vref.verse_num == 0:
138+
continue
139+
if verse_score.projected_chrf3 is None:
140+
LOGGER.warning(f"{vref} does not have a projected chrf3. Skipping.")
141+
continue
142+
143+
prob = calculate_usable_prob(verse_score.projected_chrf3, usable_params, unusable_params)
144+
book_totals[vref.book] += prob
145+
book_counts[vref.book] += 1
146+
chapter_totals[vref.book][vref.chapter_num] += prob
147+
chapter_counts[vref.book][vref.chapter_num] += 1
148+
149+
with open(output_dir / "usability_books.tsv", "w", encoding="utf-8", newline="\n") as book_file:
150+
book_file.write("Book\tUsability\n")
151+
for book in sorted(book_totals):
152+
avg_prob = book_totals[book] / book_counts[book]
153+
book_file.write(f"{book}\t{avg_prob:.6f}\n")
154+
155+
with open(output_dir / "usability_chapters.tsv", "w", encoding="utf-8", newline="\n") as chapter_file:
156+
chapter_file.write("Book\tChapter\tUsability\n")
157+
for book in sorted(chapter_totals):
158+
for chapter in sorted(chapter_totals[book]):
159+
avg_prob = chapter_totals[book][chapter] / chapter_counts[book][chapter]
160+
chapter_file.write(f"{book}\t{chapter}\t{avg_prob:.6f}\n")
161+
162+
163+
def parse_parameters(parameter_file: Path) -> Tuple[UsabilityParameters, UsabilityParameters]:
164+
params = {
165+
"usable": UsabilityParameters(263, 51.4, 95.19),
166+
"unusable": UsabilityParameters(97, 45.85, 99.91),
167+
}
168+
if parameter_file.exists():
169+
with open(parameter_file, "r", encoding="utf-8") as f:
170+
for line_num, line in enumerate(f, start=1):
171+
parts = line.strip().split("\t")
172+
if len(parts) != 4:
173+
raise ValueError(
174+
f"Malformed line {line_num} in {parameter_file}: expected 4 tab-separated columns, "
175+
f"got {len(parts)}. Line content: {line.strip()}"
176+
)
177+
label, count, mean, variance = parts
178+
params[label] = UsabilityParameters(float(count), float(mean), float(variance))
179+
else:
180+
LOGGER.warning(f"{parameter_file} does not exist. Using default parameters.")
181+
182+
return params["usable"], params["unusable"]
183+
184+
185+
def calculate_usable_prob(
186+
chrf3: float,
187+
usable: UsabilityParameters,
188+
unusable: UsabilityParameters,
189+
) -> float:
190+
usable_weight = exp(-((chrf3 - usable.mean) ** 2) / (2 * usable.variance)) * usable.count
191+
unusable_weight = exp(-((chrf3 - unusable.mean) ** 2) / (2 * unusable.variance)) * unusable.count
192+
193+
return usable_weight / (usable_weight + unusable_weight)
194+
195+
196+
def main() -> None:
197+
parser = argparse.ArgumentParser(description="Estimate the quality of drafts created by an NMT model.")
198+
parser.add_argument(
199+
"diff_predictions", help="The diff predictions path relative to MT/experiments to determine line of best fit."
200+
)
201+
parser.add_argument(
202+
"confidence_files",
203+
nargs="*",
204+
help="Relative paths for the confidence files to process (relative to MT/experiments or --confidence-dir "
205+
+ "if specified) e.g. 'project_folder/exp_folder/infer/5000/source/631JN.SFM.confidences.tsv' or "
206+
+ "'631JN.SFM.confidences.tsv --confidence-dir project_folder/exp_folder/infer/5000/source'.",
207+
)
208+
parser.add_argument(
209+
"--confidence-dir",
210+
type=Path,
211+
default=None,
212+
help="Folder (relative to experiment MT/experiments) containing confidence files e.g. 'infer/5000/source/'.",
213+
)
214+
parser.add_argument(
215+
"--books",
216+
nargs="+",
217+
metavar="book_ids",
218+
help="Provide book ids (e.g. 1JN LUK) to select confidence files rather than providing file paths with "
219+
+ "the confidence_files positional argument.",
220+
)
221+
parser.add_argument(
222+
"--draft-index", type=int, default=None, help="If using --books with multiple drafts, specify the draft index."
223+
)
224+
args = parser.parse_args()
225+
226+
using_files = bool(args.confidence_files)
227+
using_books = bool(args.books)
228+
229+
if using_files and using_books:
230+
raise ValueError("Specify either confidence_files or --books, not both.")
231+
if not using_files and not using_books:
232+
raise ValueError(
233+
"You must specify either confidence_files or --books to indicate which confidence files to use."
234+
)
235+
236+
confidence_dir = get_mt_exp_dir(args.confidence_dir or Path())
237+
238+
if using_files:
239+
if len(args.confidence_files) == 0:
240+
raise ValueError("Please provide at least one confidence file for the confidence_files argument.")
241+
confidence_files = [
242+
confidence_dir / confidence_file if confidence_dir else confidence_file
243+
for confidence_file in args.confidence_files
244+
]
245+
246+
elif using_books:
247+
if len(args.books) == 0:
248+
raise ValueError("Please provide at least one book for the --books argument.")
249+
if args.draft_index is not None:
250+
if not isinstance(args.draft_index, int) or args.draft_index < 0:
251+
raise ValueError("Draft index must be a non-negative integer.")
252+
draft_suffix = "." + str(args.draft_index)
253+
else:
254+
draft_suffix = ""
255+
confidence_files = []
256+
for book_id in args.books:
257+
confidence_files.extend(confidence_dir.glob(f"[0-9]*{book_id}{draft_suffix}.*.confidences.tsv"))
258+
259+
estimate_quality(get_mt_exp_dir(args.diff_predictions), confidence_files)
260+
261+
262+
if __name__ == "__main__":
263+
main()

0 commit comments

Comments
 (0)