|
| 1 | +import argparse |
| 2 | +import logging |
| 3 | +import re |
| 4 | +from collections import defaultdict |
| 5 | +from dataclasses import dataclass |
| 6 | +from math import exp |
| 7 | +from pathlib import Path |
| 8 | +from typing import List, Optional, Tuple |
| 9 | + |
| 10 | +from machine.scripture import VerseRef |
| 11 | +from openpyxl import load_workbook |
| 12 | +from scipy.stats import linregress |
| 13 | + |
| 14 | +from silnlp.nmt.config import get_mt_exp_dir |
| 15 | + |
| 16 | +LOGGER = logging.getLogger(__package__ + ".quality_estimation") |
| 17 | + |
| 18 | + |
| 19 | +@dataclass |
| 20 | +class VerseScore: |
| 21 | + vref: VerseRef |
| 22 | + confidence: float |
| 23 | + projected_chrf3: Optional[float] = None |
| 24 | + |
| 25 | + |
| 26 | +def estimate_quality(diff_predictions_file: Path, confidence_files: List[Path]) -> None: |
| 27 | + verse_scores: List[VerseScore] = project_chrf3(diff_predictions_file, confidence_files) |
| 28 | + compute_usable_proportions(verse_scores, confidence_files[0].parent) |
| 29 | + |
| 30 | + |
| 31 | +def project_chrf3(diff_predictions_file: Path, confidence_files: List[Path]) -> List[VerseScore]: |
| 32 | + chrf3_scores, confidence_scores = extract_diff_predictions(diff_predictions_file) |
| 33 | + if len(chrf3_scores) != len(confidence_scores): |
| 34 | + raise ValueError( |
| 35 | + f"The number of chrF3 scores ({len(chrf3_scores)}) and confidence scores ({len(confidence_scores)}) " |
| 36 | + f"in {diff_predictions_file} do not match." |
| 37 | + ) |
| 38 | + slope, intercept = linregress(confidence_scores, chrf3_scores)[:2] |
| 39 | + verse_scores: List[VerseScore] = [] |
| 40 | + for confidence_file in confidence_files: |
| 41 | + file_scores = extract_confidences(confidence_file) |
| 42 | + verse_scores += file_scores |
| 43 | + with open(confidence_file.with_suffix(".projected_chrf3.tsv"), "w", encoding="utf-8") as output_file: |
| 44 | + output_file.write("VRef\tConfidence\tProjected chrF3\n") |
| 45 | + for verse_score in verse_scores: |
| 46 | + projected_chrf3 = slope * verse_score.confidence + intercept |
| 47 | + verse_score.projected_chrf3 = projected_chrf3 |
| 48 | + output_file.write(f"{verse_score.vref}\t{verse_score.confidence}\t{projected_chrf3:.2f}\n") |
| 49 | + return verse_scores |
| 50 | + |
| 51 | + |
| 52 | +def extract_diff_predictions(diff_predictions_file_path) -> Tuple[List[float], List[float]]: |
| 53 | + chrf3_scores = extract_diff_predictions_column(diff_predictions_file_path, "chrf3") |
| 54 | + confidence_scores = extract_diff_predictions_column(diff_predictions_file_path, "confidence") |
| 55 | + return chrf3_scores, confidence_scores |
| 56 | + |
| 57 | + |
| 58 | +def extract_diff_predictions_column(file_path: Path, target_header: str) -> List[float]: |
| 59 | + wb = load_workbook(file_path) |
| 60 | + ws = wb.active |
| 61 | + |
| 62 | + header_row_idx = None |
| 63 | + col_idx = None |
| 64 | + |
| 65 | + for row in ws.iter_rows(): |
| 66 | + for cell in row: |
| 67 | + if cell.value == "Score Summary": |
| 68 | + break |
| 69 | + if str(cell.value).lower() == target_header.lower(): |
| 70 | + header_row_idx = cell.row |
| 71 | + col_idx = cell.column |
| 72 | + break |
| 73 | + if header_row_idx: |
| 74 | + break |
| 75 | + |
| 76 | + if not header_row_idx: |
| 77 | + raise ValueError(f"Header '{target_header}' not found.") |
| 78 | + |
| 79 | + data = [] |
| 80 | + for row in ws.iter_rows(min_row=header_row_idx + 1, min_col=col_idx, max_col=col_idx): |
| 81 | + cell_value = row[0].value |
| 82 | + if cell_value is not None: |
| 83 | + data.append(float(cell_value)) |
| 84 | + |
| 85 | + return data |
| 86 | + |
| 87 | + |
| 88 | +def extract_confidences(input_file_path: Path) -> List[VerseScore]: |
| 89 | + current_book = "" |
| 90 | + current_chapter = 0 |
| 91 | + current_verse = 0 |
| 92 | + is_at_verse_reference = False |
| 93 | + |
| 94 | + vref_confidences: List[VerseScore] = [] |
| 95 | + with open(input_file_path, "r", encoding="utf-8") as f: |
| 96 | + for line in f: |
| 97 | + line = line.rstrip("\n") |
| 98 | + if line.lower().startswith("vref") or line.lower().startswith("sequence score"): |
| 99 | + continue |
| 100 | + |
| 101 | + match = re.match(r"^([0-9A-Z][A-Z]{2}) (\d+):(\d+)(/.*)?", line) |
| 102 | + if match: |
| 103 | + current_book = match.group(1) |
| 104 | + current_chapter = int(match.group(2)) |
| 105 | + current_verse = int(match.group(3)) |
| 106 | + extra = match.group(4) |
| 107 | + |
| 108 | + is_at_verse_reference = current_verse != 0 and not extra |
| 109 | + elif is_at_verse_reference: |
| 110 | + cols = line.split("\t") |
| 111 | + if cols: |
| 112 | + vref_confidences += [ |
| 113 | + VerseScore( |
| 114 | + VerseRef.from_string(f"{current_book} {current_chapter}:{current_verse}"), float(cols[0]) |
| 115 | + ) |
| 116 | + ] |
| 117 | + return vref_confidences |
| 118 | + |
| 119 | + |
| 120 | +@dataclass |
| 121 | +class UsabilityParameters: |
| 122 | + count: float |
| 123 | + mean: float |
| 124 | + variance: float |
| 125 | + |
| 126 | + |
| 127 | +def compute_usable_proportions(verse_scores: List[VerseScore], output_dir: Path) -> None: |
| 128 | + usable_params, unusable_params = parse_parameters(output_dir / "usability_parameters.tsv") |
| 129 | + |
| 130 | + book_totals = defaultdict(float) |
| 131 | + book_counts = defaultdict(int) |
| 132 | + chapter_totals = defaultdict(lambda: defaultdict(float)) |
| 133 | + chapter_counts = defaultdict(lambda: defaultdict(int)) |
| 134 | + |
| 135 | + for verse_score in verse_scores: |
| 136 | + vref = verse_score.vref |
| 137 | + if vref.verse_num == 0: |
| 138 | + continue |
| 139 | + if verse_score.projected_chrf3 is None: |
| 140 | + LOGGER.warning(f"{vref} does not have a projected chrf3. Skipping.") |
| 141 | + continue |
| 142 | + |
| 143 | + prob = calculate_usable_prob(verse_score.projected_chrf3, usable_params, unusable_params) |
| 144 | + book_totals[vref.book] += prob |
| 145 | + book_counts[vref.book] += 1 |
| 146 | + chapter_totals[vref.book][vref.chapter_num] += prob |
| 147 | + chapter_counts[vref.book][vref.chapter_num] += 1 |
| 148 | + |
| 149 | + with open(output_dir / "usability_books.tsv", "w", encoding="utf-8", newline="\n") as book_file: |
| 150 | + book_file.write("Book\tUsability\n") |
| 151 | + for book in sorted(book_totals): |
| 152 | + avg_prob = book_totals[book] / book_counts[book] |
| 153 | + book_file.write(f"{book}\t{avg_prob:.6f}\n") |
| 154 | + |
| 155 | + with open(output_dir / "usability_chapters.tsv", "w", encoding="utf-8", newline="\n") as chapter_file: |
| 156 | + chapter_file.write("Book\tChapter\tUsability\n") |
| 157 | + for book in sorted(chapter_totals): |
| 158 | + for chapter in sorted(chapter_totals[book]): |
| 159 | + avg_prob = chapter_totals[book][chapter] / chapter_counts[book][chapter] |
| 160 | + chapter_file.write(f"{book}\t{chapter}\t{avg_prob:.6f}\n") |
| 161 | + |
| 162 | + |
| 163 | +def parse_parameters(parameter_file: Path) -> Tuple[UsabilityParameters, UsabilityParameters]: |
| 164 | + params = { |
| 165 | + "usable": UsabilityParameters(263, 51.4, 95.19), |
| 166 | + "unusable": UsabilityParameters(97, 45.85, 99.91), |
| 167 | + } |
| 168 | + if parameter_file.exists(): |
| 169 | + with open(parameter_file, "r", encoding="utf-8") as f: |
| 170 | + for line_num, line in enumerate(f, start=1): |
| 171 | + parts = line.strip().split("\t") |
| 172 | + if len(parts) != 4: |
| 173 | + raise ValueError( |
| 174 | + f"Malformed line {line_num} in {parameter_file}: expected 4 tab-separated columns, " |
| 175 | + f"got {len(parts)}. Line content: {line.strip()}" |
| 176 | + ) |
| 177 | + label, count, mean, variance = parts |
| 178 | + params[label] = UsabilityParameters(float(count), float(mean), float(variance)) |
| 179 | + else: |
| 180 | + LOGGER.warning(f"{parameter_file} does not exist. Using default parameters.") |
| 181 | + |
| 182 | + return params["usable"], params["unusable"] |
| 183 | + |
| 184 | + |
| 185 | +def calculate_usable_prob( |
| 186 | + chrf3: float, |
| 187 | + usable: UsabilityParameters, |
| 188 | + unusable: UsabilityParameters, |
| 189 | +) -> float: |
| 190 | + usable_weight = exp(-((chrf3 - usable.mean) ** 2) / (2 * usable.variance)) * usable.count |
| 191 | + unusable_weight = exp(-((chrf3 - unusable.mean) ** 2) / (2 * unusable.variance)) * unusable.count |
| 192 | + |
| 193 | + return usable_weight / (usable_weight + unusable_weight) |
| 194 | + |
| 195 | + |
| 196 | +def main() -> None: |
| 197 | + parser = argparse.ArgumentParser(description="Estimate the quality of drafts created by an NMT model.") |
| 198 | + parser.add_argument( |
| 199 | + "diff_predictions", help="The diff predictions path relative to MT/experiments to determine line of best fit." |
| 200 | + ) |
| 201 | + parser.add_argument( |
| 202 | + "confidence_files", |
| 203 | + nargs="*", |
| 204 | + help="Relative paths for the confidence files to process (relative to MT/experiments or --confidence-dir " |
| 205 | + + "if specified) e.g. 'project_folder/exp_folder/infer/5000/source/631JN.SFM.confidences.tsv' or " |
| 206 | + + "'631JN.SFM.confidences.tsv --confidence-dir project_folder/exp_folder/infer/5000/source'.", |
| 207 | + ) |
| 208 | + parser.add_argument( |
| 209 | + "--confidence-dir", |
| 210 | + type=Path, |
| 211 | + default=None, |
| 212 | + help="Folder (relative to experiment MT/experiments) containing confidence files e.g. 'infer/5000/source/'.", |
| 213 | + ) |
| 214 | + parser.add_argument( |
| 215 | + "--books", |
| 216 | + nargs="+", |
| 217 | + metavar="book_ids", |
| 218 | + help="Provide book ids (e.g. 1JN LUK) to select confidence files rather than providing file paths with " |
| 219 | + + "the confidence_files positional argument.", |
| 220 | + ) |
| 221 | + parser.add_argument( |
| 222 | + "--draft-index", type=int, default=None, help="If using --books with multiple drafts, specify the draft index." |
| 223 | + ) |
| 224 | + args = parser.parse_args() |
| 225 | + |
| 226 | + using_files = bool(args.confidence_files) |
| 227 | + using_books = bool(args.books) |
| 228 | + |
| 229 | + if using_files and using_books: |
| 230 | + raise ValueError("Specify either confidence_files or --books, not both.") |
| 231 | + if not using_files and not using_books: |
| 232 | + raise ValueError( |
| 233 | + "You must specify either confidence_files or --books to indicate which confidence files to use." |
| 234 | + ) |
| 235 | + |
| 236 | + confidence_dir = get_mt_exp_dir(args.confidence_dir or Path()) |
| 237 | + |
| 238 | + if using_files: |
| 239 | + if len(args.confidence_files) == 0: |
| 240 | + raise ValueError("Please provide at least one confidence file for the confidence_files argument.") |
| 241 | + confidence_files = [ |
| 242 | + confidence_dir / confidence_file if confidence_dir else confidence_file |
| 243 | + for confidence_file in args.confidence_files |
| 244 | + ] |
| 245 | + |
| 246 | + elif using_books: |
| 247 | + if len(args.books) == 0: |
| 248 | + raise ValueError("Please provide at least one book for the --books argument.") |
| 249 | + if args.draft_index is not None: |
| 250 | + if not isinstance(args.draft_index, int) or args.draft_index < 0: |
| 251 | + raise ValueError("Draft index must be a non-negative integer.") |
| 252 | + draft_suffix = "." + str(args.draft_index) |
| 253 | + else: |
| 254 | + draft_suffix = "" |
| 255 | + confidence_files = [] |
| 256 | + for book_id in args.books: |
| 257 | + confidence_files.extend(confidence_dir.glob(f"[0-9]*{book_id}{draft_suffix}.*.confidences.tsv")) |
| 258 | + |
| 259 | + estimate_quality(get_mt_exp_dir(args.diff_predictions), confidence_files) |
| 260 | + |
| 261 | + |
| 262 | +if __name__ == "__main__": |
| 263 | + main() |
0 commit comments