From 183e18e590fb5ab43bb0f053260cbf58a5a3a72e Mon Sep 17 00:00:00 2001 From: Ben King Date: Thu, 28 Aug 2025 22:21:04 +0000 Subject: [PATCH 1/8] Working implementation of quotation denormalization via postprocess.py --- silnlp/common/paratext.py | 56 ++++- silnlp/common/postprocesser.py | 332 ++++++++++++++++++++++++++++-- silnlp/common/translator.py | 36 ++-- silnlp/nmt/config.py | 281 ++----------------------- silnlp/nmt/corpora.py | 266 ++++++++++++++++++++++++ silnlp/nmt/experiment.py | 44 ++-- silnlp/nmt/hugging_face_config.py | 3 +- silnlp/nmt/postprocess.py | 171 +++++++++------ silnlp/nmt/translate.py | 2 + 9 files changed, 808 insertions(+), 383 deletions(-) create mode 100644 silnlp/nmt/corpora.py diff --git a/silnlp/common/paratext.py b/silnlp/common/paratext.py index 978c1298..ec232d36 100644 --- a/silnlp/common/paratext.py +++ b/silnlp/common/paratext.py @@ -2,7 +2,7 @@ import os from contextlib import ExitStack from pathlib import Path -from typing import Dict, List, Optional, Set, TextIO, Tuple +from typing import Dict, Iterable, List, Optional, Set, TextIO, Tuple from xml.sax.saxutils import escape import regex as re @@ -15,11 +15,22 @@ Text, TextCorpus, TextRow, + UsfmFileText, UsfmFileTextCorpus, + UsfmParserHandler, create_versification_ref_corpus, extract_scripture_corpus, + parse_usfm, +) +from machine.scripture import ( + BOOK_NUMBERS, + ORIGINAL_VERSIFICATION, + VerseRef, + VersificationType, + book_id_to_number, + book_number_to_id, + get_books, ) -from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef, VersificationType, book_id_to_number, get_books from machine.tokenization import WhitespaceTokenizer from .corpus import get_terms_glosses_path, get_terms_metadata_path, get_terms_vrefs_path, load_corpus @@ -416,6 +427,15 @@ def get_book_path(project: str, book: str) -> Path: return SIL_NLP_ENV.pt_projects_dir / project / book_file_name +def get_book_path_by_book_number(project: str, book_number: int) -> Path: + project_dir = get_project_dir(project) + settings = FileParatextProjectSettingsParser(project_dir).parse() + book_id = book_number_to_id(book_number) + book_file_name = settings.get_book_file_name(book_id) + + return SIL_NLP_ENV.pt_projects_dir / project / book_file_name + + def get_last_verse(project_dir: str, book: str, chapter: int) -> int: last_verse = "0" book_path = get_book_path(project_dir, book) @@ -571,3 +591,35 @@ def check_versification(project_dir: str) -> Tuple[bool, List[VersificationType] matching = True return (matching, detected_versification) + + +def read_usfm(project_dir: str, book_number: int) -> str: + project_settings = FileParatextProjectSettingsParser(get_project_dir(project_dir)).parse() + book_path: Path = get_book_path_by_book_number(project_dir, book_number) + + if not book_path.exists(): + raise FileNotFoundError(f"USFM file for book number {book_number} not found in project {project_dir}") + + usfm_text_file = UsfmFileText( + project_settings.stylesheet, + project_settings.encoding, + book_number_to_id(book_number), + book_path, + project_settings.versification, + include_all_text=True, + project=project_settings.name, + ) + # This is not a public method, but I don't think any method exists in machine.py + # to read raw USFM using the project settings + return usfm_text_file._read_usfm() + + +# This is a placeholder until the ParatextProjectQuoteConventionDetector is released in machine.py +def parse_project(project_dir: str, selected_books: Iterable[int], usfm_parser_handler: UsfmParserHandler) -> None: + project_settings = FileParatextProjectSettingsParser(get_project_dir(project_dir)).parse() + for book_number in selected_books: + try: + usfm = read_usfm(project_dir, book_number) + except FileNotFoundError: + continue + parse_usfm(usfm, usfm_parser_handler, project_settings.stylesheet, project_settings.versification) diff --git a/silnlp/common/postprocesser.py b/silnlp/common/postprocesser.py index 2a50f33e..a1eb61f1 100644 --- a/silnlp/common/postprocesser.py +++ b/silnlp/common/postprocesser.py @@ -1,28 +1,248 @@ +import logging from pathlib import Path from tempfile import TemporaryDirectory -from typing import List, Optional +from typing import Dict, List, Optional, Tuple from machine.corpora import ( PlaceMarkersAlignmentInfo, PlaceMarkersUsfmUpdateBlockHandler, + QuotationMarkDenormalizationFirstPass, + QuotationMarkDenormalizationUsfmUpdateBlockHandler, + QuotationMarkUpdateSettings, + QuotationMarkUpdateStrategy, ScriptureRef, UpdateUsfmMarkerBehavior, + UpdateUsfmParserHandler, UpdateUsfmRow, + UpdateUsfmTextBehavior, UsfmUpdateBlockHandler, + parse_usfm, ) +from machine.punctuation_analysis import ( + STANDARD_QUOTE_CONVENTIONS, + QuoteConvention, + QuoteConventionAnalysis, + QuoteConventionDetector, +) +from machine.scripture import get_chapters from machine.tokenization import LatinWordTokenizer from machine.translation import WordAlignmentMatrix +from silnlp.common.paratext import parse_project +from silnlp.nmt.corpora import CorpusPair + from ..alignment.eflomal import to_word_alignment_matrix from ..alignment.utils import compute_alignment_scores from .corpus import load_corpus, write_corpus +LOGGER = logging.getLogger(__package__ + ".translate") + POSTPROCESS_DEFAULTS = { "paragraph_behavior": "end", # Possible values: end, place, strip "include_style_markers": False, "include_embeds": False, + "denormalize_quotation_marks": False, + "source_quote_convention": "standard_english", + "target_quote_convention": "standard_english", +} +POSTPROCESS_SUFFIX_CHARS = { + "paragraph_behavior": {"place": "p", "strip": "x"}, + "include_style_markers": "s", + "include_embeds": "e", + "denormalize_quotation_marks": "q", } -POSTPROCESS_SUFFIX_CHARS = [{"place": "p", "strip": "x"}, "s", "e"] + + +class PlaceMarkersPostprocessor: + _BEHAVIOR_DESCRIPTION_MAP = { + UpdateUsfmMarkerBehavior.PRESERVE: " have positions preserved.", + UpdateUsfmMarkerBehavior.STRIP: " were removed.", + } + + def __init__( + self, + paragraph_behavior: UpdateUsfmMarkerBehavior, + embed_behavior: UpdateUsfmMarkerBehavior, + style_behavior: UpdateUsfmMarkerBehavior, + ): + self._paragraph_behavior = paragraph_behavior + self._embed_behavior = embed_behavior + self._style_behavior = style_behavior + self._update_block_handlers = [PlaceMarkersUsfmUpdateBlockHandler()] + + def _create_remark(self) -> str: + behavior_map: Dict[UpdateUsfmMarkerBehavior, List[str]] = { + UpdateUsfmMarkerBehavior.PRESERVE: [], + UpdateUsfmMarkerBehavior.STRIP: [], + } + behavior_map[self._paragraph_behavior].append("paragraph markers") + behavior_map[self._embed_behavior].append("embed markers") + behavior_map[self._style_behavior].append("style markers") + + remark_sentences = [ + self._create_remark_sentence_for_behavior(behavior, items) + for behavior, items in behavior_map.items() + if len(items) > 0 + ] + return " ".join(remark_sentences) + + def _create_remark_sentence_for_behavior(self, behavior: UpdateUsfmMarkerBehavior, items: List[str]) -> str: + return self._format_group_of_items_for_remark(items) + self._BEHAVIOR_DESCRIPTION_MAP[behavior] + + def _format_group_of_items_for_remark(self, items: List[str]) -> str: + if len(items) == 1: + return items[0].capitalize() + elif len(items) == 2: + return f"{items[0].capitalize()} and {items[1]}" + return f"{items[0].capitalize()}, {', '.join(items[1:-1])}, and {items[-1]}" + + def postprocess_usfm( + self, + usfm: str, + rows: List[UpdateUsfmRow], + remarks: List[str] = [], + ) -> str: + handler = UpdateUsfmParserHandler( + rows=rows, + text_behavior=UpdateUsfmTextBehavior.STRIP_EXISTING, + paragraph_behavior=self._paragraph_behavior, + embed_behavior=self._embed_behavior, + style_behavior=self._style_behavior, + update_block_handlers=self._update_block_handlers, + remarks=remarks + [self._create_remark()], + ) + parse_usfm(usfm, handler) + return handler.get_usfm() + + +class UnknownQuoteConventionException(Exception): + def __init__(self, convention_name: str): + super().__init__( + f'"{convention_name}" is not a known quote convention. Skipping quotation mark denormalization.' + ) + self.convention_name = convention_name + + +class NoDetectedQuoteConventionException(Exception): + def __init__(self, project_name: str): + super().__init__( + f'Could not detect quote convention for project "{project_name}". Skipping quotation mark denormalization.' + ) + self.project_name = project_name + + +class DenormalizeQuotationMarksPostprocessor: + _REMARK_SENTENCE = ( + "Quotation marks in the following chapters have been automatically denormalized after translation: " + ) + + def __init__( + self, + source_quote_convention_name: str | None, + target_quote_convention_name: str | None, + source_project_name: str | None = None, + target_project_name: str | None = None, + selected_training_books: Dict[int, List[int]] = {}, + ): + self._source_quote_convention = self._get_source_quote_convention( + source_quote_convention_name, source_project_name, selected_training_books + ) + self._target_quote_convention = self._get_target_quote_convention( + target_quote_convention_name, target_project_name, selected_training_books + ) + + def _get_source_quote_convention( + self, convention_name: str | None, project_name: str | None, selected_training_books: Dict[int, List[int]] = {} + ) -> QuoteConvention: + if convention_name is None or convention_name == "detect": + if project_name is None: + raise ValueError( + "The experiment's translate_config.yml must exist and specify a source project name, since an explicit source quote convention name was not provided." + ) + if selected_training_books is None: + raise ValueError( + "The experiment's config.yml must exist and specify selected training books, since an explicit source quote convention name was not provided." + ) + return self._detect_quote_convention(project_name, selected_training_books) + return self._get_named_quote_convention(convention_name) + + def _get_target_quote_convention( + self, convention_name: str | None, project_name: str | None, selected_training_books: Dict[int, List[int]] = {} + ) -> QuoteConvention: + if convention_name is None or convention_name == "detect": + if project_name is None: + raise ValueError( + "The experiment's config.yml must exist and specify a target project name, since an explicit target quote convention name was not provided." + ) + if selected_training_books is None: + raise ValueError( + "The experiment's config.yml must exist and specify selected training books, since an explicit target quote convention name was not provided." + ) + return self._detect_quote_convention(project_name, selected_training_books) + return self._get_named_quote_convention(convention_name) + + def _get_named_quote_convention(self, convention_name: str) -> QuoteConvention: + convention = STANDARD_QUOTE_CONVENTIONS.get_quote_convention_by_name(convention_name) + + if convention is None: + raise UnknownQuoteConventionException(convention_name) + return convention + + def _detect_quote_convention( + self, project_name: str, selected_training_books: Dict[int, List[int]] = {} + ) -> QuoteConvention: + quote_convention_detector = QuoteConventionDetector() + + parse_project(project_name, selected_training_books.keys(), quote_convention_detector) + + quote_convention_analysis: QuoteConventionAnalysis | None = quote_convention_detector.detect_quote_convention() + if quote_convention_analysis is None: + raise NoDetectedQuoteConventionException(project_name) + return quote_convention_analysis.best_quote_convention + + def _create_update_block_handlers( + self, chapter_strategies: List[QuotationMarkUpdateStrategy] + ) -> List[UsfmUpdateBlockHandler]: + return [ + QuotationMarkDenormalizationUsfmUpdateBlockHandler( + self._source_quote_convention, + self._target_quote_convention, + QuotationMarkUpdateSettings(chapter_strategies=chapter_strategies), + ) + ] + + def _get_best_chapter_strategies(self, usfm: str) -> List[QuotationMarkUpdateStrategy]: + quotation_mark_update_first_pass = QuotationMarkDenormalizationFirstPass( + self._source_quote_convention, self._target_quote_convention + ) + + parse_usfm(usfm, quotation_mark_update_first_pass) + return quotation_mark_update_first_pass.find_best_chapter_strategies() + + def _create_remark(self, best_chapter_strategies: List[QuotationMarkUpdateStrategy]) -> str: + return ( + self._REMARK_SENTENCE + + ", ".join( + [ + str(chapter_num) + for chapter_num, strategy in enumerate(best_chapter_strategies, 1) + if strategy != QuotationMarkUpdateStrategy.SKIP + ] + ) + + "." + ) + + def postprocess_usfm( + self, + usfm: str, + ) -> str: + best_chapter_strategies = self._get_best_chapter_strategies(usfm) + handler = UpdateUsfmParserHandler( + update_block_handlers=self._create_update_block_handlers(best_chapter_strategies), + remarks=[self._create_remark(best_chapter_strategies)], + ) + parse_usfm(usfm, handler) + return handler.get_usfm() class PostprocessConfig: @@ -42,9 +262,6 @@ def __init__(self, config: dict = {}) -> None: self.update_block_handlers: List[UsfmUpdateBlockHandler] = [] self.rows: List[UpdateUsfmRow] = [] - if self._config["paragraph_behavior"] == "place" or self._config["include_style_markers"]: - self.update_block_handlers.append(PlaceMarkersUsfmUpdateBlockHandler()) - def _get_usfm_marker_behavior(self, preserve: bool) -> UpdateUsfmMarkerBehavior: return UpdateUsfmMarkerBehavior.PRESERVE if preserve else UpdateUsfmMarkerBehavior.STRIP @@ -60,12 +277,12 @@ def get_embed_behavior(self) -> UpdateUsfmMarkerBehavior: # NOTE: Each postprocessing configuration needs to have a unique suffix so files don't overwrite each other def get_postprocess_suffix(self) -> str: suffix = "_" - for (option, default), char in zip(POSTPROCESS_DEFAULTS.items(), POSTPROCESS_SUFFIX_CHARS): - if self._config[option] != default: + for option, default in POSTPROCESS_DEFAULTS.items(): + if option in POSTPROCESS_SUFFIX_CHARS and self._config[option] != default: if isinstance(default, str): - suffix += char[self._config[option]] + suffix += POSTPROCESS_SUFFIX_CHARS[option][self._config[option]] else: - suffix += char + suffix += POSTPROCESS_SUFFIX_CHARS[option] return suffix if len(suffix) > 1 else "" @@ -82,6 +299,79 @@ def get_postprocess_remark(self) -> Optional[str]: def is_base_config(self) -> bool: return self._config == POSTPROCESS_DEFAULTS + def is_marker_placement_required(self) -> bool: + return self._config["paragraph_behavior"] == "place" or self._config["include_style_markers"] + + def is_quotation_mark_denormalization_required(self) -> bool: + return self._config["denormalize_quotation_marks"] + + def is_quote_convention_detection_required(self) -> bool: + return self.is_quotation_mark_denormalization_required() and ( + self._config["source_quote_convention"] is None + or self._config["source_quote_convention"] == "detect" + or self._config["target_quote_convention"] is None + or self._config["target_quote_convention"] == "detect" + ) + + def create_place_markers_postprocessor(self) -> PlaceMarkersPostprocessor: + return PlaceMarkersPostprocessor( + paragraph_behavior=self.get_paragraph_behavior(), + embed_behavior=self.get_embed_behavior(), + style_behavior=self.get_style_behavior(), + ) + + def create_denormalize_quotation_marks_postprocessor( + self, + training_corpus_pairs: List[CorpusPair], + ) -> DenormalizeQuotationMarksPostprocessor: + _, training_target_project_name, selected_training_books = self._get_experiment_training_info( + training_corpus_pairs, + ) + translation_source_project_name = self._config.get("src_project") + + return DenormalizeQuotationMarksPostprocessor( + self._config["source_quote_convention"], + self._config["target_quote_convention"], + translation_source_project_name, + training_target_project_name, + selected_training_books, + ) + + def _get_experiment_training_info( + self, + training_corpus_pairs: List[CorpusPair], + ) -> Tuple[Optional[str], Optional[str], Dict[int, List[int]]]: + # Target project info is only needed for quote convention detection + if self.is_quote_convention_detection_required(): + if len(training_corpus_pairs) > 1: + LOGGER.warning( + "The experiment has multiple corpus pairs. Quotation mark denormalization is unlikely to work correctly in this scenario." + ) + if len(training_corpus_pairs) > 0 and len(training_corpus_pairs[0].src_files) > 1: + LOGGER.warning( + "The experiment has multiple source projects. Quotation mark denormalization is unlikely to work correctly in this scenario." + ) + if len(training_corpus_pairs) > 0 and len(training_corpus_pairs[0].trg_files) > 1: + LOGGER.warning( + "The experiment has multiple target projects. Quotation mark denormalization is unlikely to work correctly in this scenario." + ) + + source_project_name = ( + training_corpus_pairs[0].src_files[0].project + if len(training_corpus_pairs) > 0 and len(training_corpus_pairs[0].src_files) > 0 + else None + ) + target_project_name = ( + training_corpus_pairs[0].trg_files[0].project + if len(training_corpus_pairs) > 0 and len(training_corpus_pairs[0].trg_files) > 0 + else None + ) + selected_training_books = training_corpus_pairs[0].corpus_books if len(training_corpus_pairs) > 0 else {} + + return source_project_name, target_project_name, selected_training_books + + return None, None, {} + def __getitem__(self, key): return self._config[key] @@ -118,17 +408,19 @@ def _construct_place_markers_metadata( translation_tokens = list(tokenizer.tokenize(t)) for config in pm_configs: - config.rows[i].metadata["alignment_info"] = PlaceMarkersAlignmentInfo( - source_tokens=source_tokens, - translation_tokens=translation_tokens, - alignment=alignment, - paragraph_behavior=( - UpdateUsfmMarkerBehavior.PRESERVE - if config["paragraph_behavior"] == "place" - else UpdateUsfmMarkerBehavior.STRIP - ), - style_behavior=config.get_style_behavior(), - ) + row_metadata = config.rows[i].metadata + if row_metadata is not None: + row_metadata["alignment_info"] = PlaceMarkersAlignmentInfo( + source_tokens=source_tokens, + translation_tokens=translation_tokens, + alignment=alignment, + paragraph_behavior=( + UpdateUsfmMarkerBehavior.PRESERVE + if config["paragraph_behavior"] == "place" + else UpdateUsfmMarkerBehavior.STRIP + ), + style_behavior=config.get_style_behavior(), + ) def _get_alignment_matrices( self, src_sents: List[str], trg_sents: List[str], aligner: str = "eflomal" diff --git a/silnlp/common/translator.py b/silnlp/common/translator.py index bccf4f43..542b4f25 100644 --- a/silnlp/common/translator.py +++ b/silnlp/common/translator.py @@ -6,6 +6,7 @@ from itertools import groupby from math import exp from pathlib import Path +from pydoc import text from typing import DefaultDict, Iterable, List, Optional import docx @@ -24,9 +25,11 @@ from machine.scripture import VerseRef, is_book_id_valid from scipy.stats import gmean +from silnlp.nmt.corpora import CorpusPair + from .corpus import load_corpus, write_corpus from .paratext import get_book_path, get_iso, get_project_dir -from .postprocesser import PostprocessHandler +from .postprocesser import NoDetectedQuoteConventionException, PostprocessHandler, UnknownQuoteConventionException from .usfm_utils import PARAGRAPH_TYPE_EMBEDS LOGGER = logging.getLogger(__package__ + ".translate") @@ -196,6 +199,7 @@ def translate_book( trg_project: Optional[str] = None, postprocess_handler: PostprocessHandler = PostprocessHandler(), experiment_ckpt_str: str = "", + training_corpus_pairs: List[CorpusPair] = [], ) -> None: book_path = get_book_path(src_project, book) if not book_path.is_file(): @@ -214,6 +218,7 @@ def translate_book( trg_project, postprocess_handler, experiment_ckpt_str, + training_corpus_pairs, ) def translate_usfm( @@ -228,6 +233,7 @@ def translate_usfm( trg_project: Optional[str] = None, postprocess_handler: PostprocessHandler = PostprocessHandler(), experiment_ckpt_str: str = "", + training_corpus_pairs: List[CorpusPair] = [], ) -> None: # Create UsfmFileText object for source src_from_project = False @@ -294,6 +300,7 @@ def translate_usfm( postprocess_handler.construct_rows(vrefs, sentences, translated_draft) for config in postprocess_handler.configs: + # Compile draft remarks draft_src_str = f"project {src_file_text.project}" if src_from_project else f"file {src_file_path.name}" draft_remark = f"This draft of {vrefs[0].book} was machine translated on {date.today()} from {draft_src_str} using model {experiment_ckpt_str}. It should be reviewed and edited carefully." @@ -304,17 +311,12 @@ def translate_usfm( # If the target project is not the same as the translated file's original project, # no verses outside of the ones translated will be overwritten if trg_project is not None or src_from_project: - dest_updater = FileParatextProjectTextUpdater( - get_project_dir(trg_project if trg_project is not None else src_file_path.parent.name) - ) + project_dir = get_project_dir(trg_project if trg_project is not None else src_file_path.parent.name) + dest_updater = FileParatextProjectTextUpdater(project_dir) usfm_out = dest_updater.update_usfm( book_id=src_file_text.id, rows=config.rows, text_behavior=text_behavior, - paragraph_behavior=config.get_paragraph_behavior(), - embed_behavior=config.get_embed_behavior(), - style_behavior=config.get_style_behavior(), - update_block_handlers=config.update_block_handlers, remarks=remarks, ) @@ -329,15 +331,25 @@ def translate_usfm( rows=config.rows, id_text=vrefs[0].book, text_behavior=text_behavior, - paragraph_behavior=config.get_paragraph_behavior(), - embed_behavior=config.get_embed_behavior(), - style_behavior=config.get_style_behavior(), - update_block_handlers=config.update_block_handlers, remarks=remarks, ) parse_usfm(usfm, handler) usfm_out = handler.get_usfm() + # Post-process the USFM output + if config.is_marker_placement_required(): + place_markers_postprocessor = config.create_place_markers_postprocessor() + usfm_out = place_markers_postprocessor.postprocess_usfm(usfm_out, config.rows) + + if config.is_quotation_mark_denormalization_required(): + try: + quotation_denormalization_postprocessor = ( + config.create_denormalize_quotation_marks_postprocessor(training_corpus_pairs) + ) + usfm_out = quotation_denormalization_postprocessor.postprocess_usfm(usfm_out) + except (UnknownQuoteConventionException, NoDetectedQuoteConventionException) as e: + raise e + # Construct output file name write to file trg_draft_file_path = trg_file_path.with_stem(trg_file_path.stem + config.get_postprocess_suffix()) if produce_multiple_translations: diff --git a/silnlp/nmt/config.py b/silnlp/nmt/config.py index 99c26210..1b2b00d6 100644 --- a/silnlp/nmt/config.py +++ b/silnlp/nmt/config.py @@ -4,32 +4,40 @@ import re from abc import ABC, abstractmethod from contextlib import ExitStack -from dataclasses import dataclass, field from decimal import ROUND_HALF_UP, Decimal -from enum import Enum, Flag, auto +from enum import Enum, auto from pathlib import Path from statistics import mean, median, stdev from typing import Any, Dict, Iterable, List, Optional, Set, TextIO, Tuple, Union, cast import pandas as pd -from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef, get_books, get_chapters +from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef, get_books from machine.tokenization import LatinWordTokenizer from tqdm import tqdm +from silnlp.common.translator import TranslationGroup +from silnlp.nmt.corpora import ( + BASIC_DATA_PROJECT, + CorpusPair, + DataFile, + DataFileMapping, + IsoPairInfo, + get_data_file_pairs, + get_parallel_corpus_size, + get_terms_glosses_file_paths, + parse_corpus_pairs, +) + from ..alignment.config import get_aligner_name from ..alignment.utils import add_alignment_scores from ..common.corpus import ( Term, exclude_chapters, filter_parallel_corpus, - get_mt_corpus_path, get_scripture_parallel_corpus, get_terms, get_terms_corpus, get_terms_data_frame, - get_terms_glosses_path, - get_terms_list, - get_terms_renderings_path, include_chapters, load_corpus, split_corpus, @@ -37,15 +45,12 @@ write_corpus, ) from ..common.environment import SIL_NLP_ENV -from ..common.translator import TranslationGroup -from ..common.utils import NoiseMethod, Side, create_noise_methods, get_mt_exp_dir, is_set, set_seed -from .augment import AugmentMethod, create_augment_methods +from ..common.utils import NoiseMethod, Side, get_mt_exp_dir, set_seed +from .augment import AugmentMethod from .tokenizer import Tokenizer LOGGER = logging.getLogger(__package__ + ".config") -BASIC_DATA_PROJECT = "BASIC" - ALIGNMENT_SCORES_FILE = re.compile(r"([a-z]{2,3}-.+)_([a-z]{2,3}-.+)") @@ -56,258 +61,6 @@ class CheckpointType(Enum): OTHER = auto() -class DataFileType(Flag): - NONE = 0 - TRAIN = auto() - TEST = auto() - VAL = auto() - DICT = auto() - - -class DataFileMapping(Enum): - ONE_TO_ONE = auto() - MIXED_SRC = auto() - MANY_TO_MANY = auto() - - -@dataclass -class DataFile: - path: Path - iso: str = field(init=False) - project: str = field(init=False) - include_test: bool = True - - def __post_init__(self): - file_name = self.path.stem - parts = file_name.split("-") - if len(parts) < 2: - raise RuntimeError(f"The filename {file_name} needs to be of the format -") - self.iso = parts[0] - self.project = ( - parts[1] if str(self.path.parent).startswith(str(SIL_NLP_ENV.mt_scripture_dir)) else BASIC_DATA_PROJECT - ) - - @property - def is_scripture(self) -> bool: - return self.project != BASIC_DATA_PROJECT - - -@dataclass -class CorpusPair: - src_files: List[DataFile] - trg_files: List[DataFile] - type: DataFileType - src_noise: List[NoiseMethod] - augmentations: List[AugmentMethod] - tags: List[str] - size: Union[float, int] - test_size: Optional[Union[float, int]] - val_size: Optional[Union[float, int]] - disjoint_test: bool - disjoint_val: bool - score_threshold: float - corpus_books: Dict[int, List[int]] - test_books: Dict[int, List[int]] - use_test_set_from: str - src_terms_files: List[DataFile] - trg_terms_files: List[DataFile] - is_lexical_data: bool - mapping: DataFileMapping - - @property - def is_train(self) -> bool: - return is_set(self.type, DataFileType.TRAIN) - - @property - def is_test(self) -> bool: - return is_set(self.type, DataFileType.TEST) - - @property - def is_val(self) -> bool: - return is_set(self.type, DataFileType.VAL) - - @property - def is_dictionary(self) -> bool: - return is_set(self.type, DataFileType.DICT) - - @property - def is_scripture(self) -> bool: - return self.src_files[0].is_scripture - - -@dataclass -class IsoPairInfo: - test_projects: Set[str] = field(default_factory=set) - val_projects: Set[str] = field(default_factory=set) - has_basic_test_data: bool = False - - @property - def has_multiple_test_projects(self) -> bool: - return len(self.test_projects) > 1 - - @property - def has_test_data(self) -> bool: - return len(self.test_projects) > 0 or self.has_basic_test_data - - -def parse_corpus_pairs(corpus_pairs: List[dict]) -> List[CorpusPair]: - pairs: List[CorpusPair] = [] - for pair in corpus_pairs: - if "type" not in pair: - pair["type"] = ["train", "test", "val"] - type_strs: Union[str, List[str]] = pair["type"] - if isinstance(type_strs, str): - type_strs = type_strs.split(",") - type = DataFileType.NONE - for type_str in type_strs: - type_str = type_str.strip().lower() - if type_str == "train": - type |= DataFileType.TRAIN - elif type_str == "test": - type |= DataFileType.TEST - elif type_str == "val" or type_str == "validation": - type |= DataFileType.VAL - elif type_str == "dict" or type_str == "dictionary": - type |= DataFileType.DICT - - src: Union[str, List[Union[dict, str]]] = pair["src"] - src_files = [] - if isinstance(src, str): - src = src.split(",") - for file in src: - if isinstance(file, str): - src_files.append(DataFile(get_mt_corpus_path(file.strip()))) - else: - src_files.append(DataFile(get_mt_corpus_path(file.pop("name")))) - for k, v in file.items(): - setattr(src_files[-1], k, v) - trg: Union[str, List[Union[dict, str]]] = pair["trg"] - trg_files = [] - if isinstance(trg, str): - trg = trg.split(",") - for file in trg: - if isinstance(file, str): - trg_files.append(DataFile(get_mt_corpus_path(file.strip()))) - else: - trg_files.append(DataFile(get_mt_corpus_path(file.pop("name")))) - for k, v in file.items(): - setattr(trg_files[-1], k, v) - is_scripture = src_files[0].is_scripture - if not all(df.is_scripture == is_scripture for df in (src_files + trg_files)): - raise RuntimeError("All corpora in a corpus pair must contain the same type of data.") - - tags: Union[str, List[str]] = pair.get("tags", []) - if isinstance(tags, str): - tags = [tag.strip() for tag in tags.split(",")] - - src_noise = create_noise_methods(pair.get("src_noise", [])) - augmentations = create_augment_methods(pair.get("augment", [])) - - if "size" not in pair: - pair["size"] = 1.0 - size: Union[float, int] = pair["size"] - if "test_size" not in pair and is_set(type, DataFileType.TRAIN | DataFileType.TEST): - pair["test_size"] = 0 if "test_books" in pair else 250 - test_size: Optional[Union[float, int]] = pair.get("test_size") - if "val_size" not in pair and is_set(type, DataFileType.TRAIN | DataFileType.VAL): - pair["val_size"] = 250 - val_size: Optional[Union[float, int]] = pair.get("val_size") - - if "disjoint_test" not in pair: - pair["disjoint_test"] = True - disjoint_test: bool = pair["disjoint_test"] - if "disjoint_val" not in pair: - pair["disjoint_val"] = True - disjoint_val: bool = pair["disjoint_val"] - score_threshold: float = pair.get("score_threshold", 0.0) - corpus_books = get_chapters(pair.get("corpus_books", [])) - test_books = get_chapters(pair.get("test_books", [])) - use_test_set_from: str = pair.get("use_test_set_from", "") - - src_terms_files = get_terms_files(src_files) if is_set(type, DataFileType.TRAIN) else [] - trg_terms_files = get_terms_files(trg_files) if is_set(type, DataFileType.TRAIN) else [] - - if "lexical" not in pair: - pair["lexical"] = is_set(type, DataFileType.DICT) - is_lexical_data: bool = pair["lexical"] - - if "mapping" not in pair: - pair["mapping"] = DataFileMapping.ONE_TO_ONE.name.lower() - mapping = DataFileMapping[pair["mapping"].upper()] - if not is_scripture and mapping != DataFileMapping.ONE_TO_ONE: - raise RuntimeError("Basic corpus pairs only support one-to-one mapping.") - if mapping == DataFileMapping.ONE_TO_ONE and len(src_files) != len(trg_files): - raise RuntimeError( - "A corpus pair with one-to-one mapping must contain the same number of source and target corpora." - ) - - pairs.append( - CorpusPair( - src_files, - trg_files, - type, - src_noise, - augmentations, - tags, - size, - test_size, - val_size, - disjoint_test, - disjoint_val, - score_threshold, - corpus_books, - test_books, - use_test_set_from, - src_terms_files, - trg_terms_files, - is_lexical_data, - mapping, - ) - ) - return pairs - - -def get_terms_files(files: List[DataFile]) -> List[DataFile]: - terms_files: List[DataFile] = [] - for file in files: - terms_path = get_terms_renderings_path(file.iso, file.project) - if terms_path is None: - continue - terms_files.append(DataFile(terms_path)) - return terms_files - - -def get_terms_glosses_file_paths(terms_files: List[DataFile]) -> Set[Path]: - glosses_file_paths: Set[Path] = set() - for terms_file in terms_files: - list_name = get_terms_list(terms_file.path) - glosses_path = get_terms_glosses_path(list_name, iso=terms_file.iso) - if glosses_path.is_file(): - glosses_file_paths.add(glosses_path) - return glosses_file_paths - - -def get_parallel_corpus_size(src_file_path: Path, trg_file_path: Path) -> int: - count = 0 - with src_file_path.open("r", encoding="utf-8") as src_file, trg_file_path.open("r", encoding="utf-8") as trg_file: - for src_line, trg_line in zip(src_file, trg_file): - src_line = src_line.strip() - trg_line = trg_line.strip() - if len(src_line) > 0 and len(trg_line) > 0: - count += 1 - return count - - -def get_data_file_pairs(corpus_pair: CorpusPair) -> Iterable[Tuple[DataFile, DataFile]]: - if corpus_pair.mapping == DataFileMapping.ONE_TO_ONE: - for file_pair in zip(corpus_pair.src_files, corpus_pair.trg_files): - yield file_pair - else: - for src_file in corpus_pair.src_files: - for trg_file in corpus_pair.trg_files: - yield (src_file, trg_file) - - class NMTModel(ABC): @abstractmethod def train(self) -> None: ... diff --git a/silnlp/nmt/corpora.py b/silnlp/nmt/corpora.py new file mode 100644 index 00000000..752f9816 --- /dev/null +++ b/silnlp/nmt/corpora.py @@ -0,0 +1,266 @@ +from dataclasses import dataclass, field +from enum import Enum, Flag, auto +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union + +from machine.scripture import get_chapters + +from ..common.corpus import get_mt_corpus_path, get_terms_glosses_path, get_terms_list, get_terms_renderings_path +from ..common.environment import SIL_NLP_ENV +from ..common.utils import NoiseMethod, create_noise_methods, is_set +from .augment import AugmentMethod, create_augment_methods + + +class DataFileType(Flag): + NONE = 0 + TRAIN = auto() + TEST = auto() + VAL = auto() + DICT = auto() + + +class DataFileMapping(Enum): + ONE_TO_ONE = auto() + MIXED_SRC = auto() + MANY_TO_MANY = auto() + + +BASIC_DATA_PROJECT = "BASIC" + + +@dataclass +class DataFile: + path: Path + iso: str = field(init=False) + project: str = field(init=False) + include_test: bool = True + + def __post_init__(self): + file_name = self.path.stem + parts = file_name.split("-") + if len(parts) < 2: + raise RuntimeError(f"The filename {file_name} needs to be of the format -") + self.iso = parts[0] + self.project = ( + parts[1] if str(self.path.parent).startswith(str(SIL_NLP_ENV.mt_scripture_dir)) else BASIC_DATA_PROJECT + ) + + @property + def is_scripture(self) -> bool: + return self.project != BASIC_DATA_PROJECT + + +@dataclass +class CorpusPair: + src_files: List[DataFile] + trg_files: List[DataFile] + type: DataFileType + src_noise: List[NoiseMethod] + augmentations: List[AugmentMethod] + tags: List[str] + size: Union[float, int] + test_size: Optional[Union[float, int]] + val_size: Optional[Union[float, int]] + disjoint_test: bool + disjoint_val: bool + score_threshold: float + corpus_books: Dict[int, List[int]] + test_books: Dict[int, List[int]] + use_test_set_from: str + src_terms_files: List[DataFile] + trg_terms_files: List[DataFile] + is_lexical_data: bool + mapping: DataFileMapping + + @property + def is_train(self) -> bool: + return is_set(self.type, DataFileType.TRAIN) + + @property + def is_test(self) -> bool: + return is_set(self.type, DataFileType.TEST) + + @property + def is_val(self) -> bool: + return is_set(self.type, DataFileType.VAL) + + @property + def is_dictionary(self) -> bool: + return is_set(self.type, DataFileType.DICT) + + @property + def is_scripture(self) -> bool: + return self.src_files[0].is_scripture + + +@dataclass +class IsoPairInfo: + test_projects: Set[str] = field(default_factory=set) + val_projects: Set[str] = field(default_factory=set) + has_basic_test_data: bool = False + + @property + def has_multiple_test_projects(self) -> bool: + return len(self.test_projects) > 1 + + @property + def has_test_data(self) -> bool: + return len(self.test_projects) > 0 or self.has_basic_test_data + + +def parse_corpus_pairs(corpus_pairs: List[dict]) -> List[CorpusPair]: + pairs: List[CorpusPair] = [] + for pair in corpus_pairs: + if "type" not in pair: + pair["type"] = ["train", "test", "val"] + type_strs: Union[str, List[str]] = pair["type"] + if isinstance(type_strs, str): + type_strs = type_strs.split(",") + type = DataFileType.NONE + for type_str in type_strs: + type_str = type_str.strip().lower() + if type_str == "train": + type |= DataFileType.TRAIN + elif type_str == "test": + type |= DataFileType.TEST + elif type_str == "val" or type_str == "validation": + type |= DataFileType.VAL + elif type_str == "dict" or type_str == "dictionary": + type |= DataFileType.DICT + + src: Union[str, List[Union[dict, str]]] = pair["src"] + src_files = [] + if isinstance(src, str): + src = src.split(",") + for file in src: + if isinstance(file, str): + src_files.append(DataFile(get_mt_corpus_path(file.strip()))) + else: + src_files.append(DataFile(get_mt_corpus_path(file.pop("name")))) + for k, v in file.items(): + setattr(src_files[-1], k, v) + trg: Union[str, List[Union[dict, str]]] = pair["trg"] + trg_files = [] + if isinstance(trg, str): + trg = trg.split(",") + for file in trg: + if isinstance(file, str): + trg_files.append(DataFile(get_mt_corpus_path(file.strip()))) + else: + trg_files.append(DataFile(get_mt_corpus_path(file.pop("name")))) + for k, v in file.items(): + setattr(trg_files[-1], k, v) + is_scripture = src_files[0].is_scripture + if not all(df.is_scripture == is_scripture for df in (src_files + trg_files)): + raise RuntimeError("All corpora in a corpus pair must contain the same type of data.") + + tags: Union[str, List[str]] = pair.get("tags", []) + if isinstance(tags, str): + tags = [tag.strip() for tag in tags.split(",")] + + src_noise = create_noise_methods(pair.get("src_noise", [])) + augmentations = create_augment_methods(pair.get("augment", [])) + + if "size" not in pair: + pair["size"] = 1.0 + size: Union[float, int] = pair["size"] + if "test_size" not in pair and is_set(type, DataFileType.TRAIN | DataFileType.TEST): + pair["test_size"] = 0 if "test_books" in pair else 250 + test_size: Optional[Union[float, int]] = pair.get("test_size") + if "val_size" not in pair and is_set(type, DataFileType.TRAIN | DataFileType.VAL): + pair["val_size"] = 250 + val_size: Optional[Union[float, int]] = pair.get("val_size") + + if "disjoint_test" not in pair: + pair["disjoint_test"] = True + disjoint_test: bool = pair["disjoint_test"] + if "disjoint_val" not in pair: + pair["disjoint_val"] = True + disjoint_val: bool = pair["disjoint_val"] + score_threshold: float = pair.get("score_threshold", 0.0) + corpus_books = get_chapters(pair.get("corpus_books", [])) + test_books = get_chapters(pair.get("test_books", [])) + use_test_set_from: str = pair.get("use_test_set_from", "") + + src_terms_files = get_terms_files(src_files) if is_set(type, DataFileType.TRAIN) else [] + trg_terms_files = get_terms_files(trg_files) if is_set(type, DataFileType.TRAIN) else [] + + if "lexical" not in pair: + pair["lexical"] = is_set(type, DataFileType.DICT) + is_lexical_data: bool = pair["lexical"] + + if "mapping" not in pair: + pair["mapping"] = DataFileMapping.ONE_TO_ONE.name.lower() + mapping = DataFileMapping[pair["mapping"].upper()] + if not is_scripture and mapping != DataFileMapping.ONE_TO_ONE: + raise RuntimeError("Basic corpus pairs only support one-to-one mapping.") + if mapping == DataFileMapping.ONE_TO_ONE and len(src_files) != len(trg_files): + raise RuntimeError( + "A corpus pair with one-to-one mapping must contain the same number of source and target corpora." + ) + + pairs.append( + CorpusPair( + src_files, + trg_files, + type, + src_noise, + augmentations, + tags, + size, + test_size, + val_size, + disjoint_test, + disjoint_val, + score_threshold, + corpus_books, + test_books, + use_test_set_from, + src_terms_files, + trg_terms_files, + is_lexical_data, + mapping, + ) + ) + return pairs + + +def get_terms_files(files: List[DataFile]) -> List[DataFile]: + terms_files: List[DataFile] = [] + for file in files: + terms_path = get_terms_renderings_path(file.iso, file.project) + if terms_path is None: + continue + terms_files.append(DataFile(terms_path)) + return terms_files + + +def get_terms_glosses_file_paths(terms_files: List[DataFile]) -> Set[Path]: + glosses_file_paths: Set[Path] = set() + for terms_file in terms_files: + list_name = get_terms_list(terms_file.path) + glosses_path = get_terms_glosses_path(list_name, iso=terms_file.iso) + if glosses_path.is_file(): + glosses_file_paths.add(glosses_path) + return glosses_file_paths + + +def get_parallel_corpus_size(src_file_path: Path, trg_file_path: Path) -> int: + count = 0 + with src_file_path.open("r", encoding="utf-8") as src_file, trg_file_path.open("r", encoding="utf-8") as trg_file: + for src_line, trg_line in zip(src_file, trg_file): + src_line = src_line.strip() + trg_line = trg_line.strip() + if len(src_line) > 0 and len(trg_line) > 0: + count += 1 + return count + + +def get_data_file_pairs(corpus_pair: CorpusPair) -> Iterable[Tuple[DataFile, DataFile]]: + if corpus_pair.mapping == DataFileMapping.ONE_TO_ONE: + for file_pair in zip(corpus_pair.src_files, corpus_pair.trg_files): + yield file_pair + else: + for src_file in corpus_pair.src_files: + for trg_file in corpus_pair.trg_files: + yield (src_file, trg_file) diff --git a/silnlp/nmt/experiment.py b/silnlp/nmt/experiment.py index 15adfa33..466d292a 100644 --- a/silnlp/nmt/experiment.py +++ b/silnlp/nmt/experiment.py @@ -90,44 +90,44 @@ def translate(self): postprocess_configs = translate_configs.get("postprocess", []) postprocess_handler = PostprocessHandler([PostprocessConfig(pc) for pc in postprocess_configs]) - for config in translate_configs.get("translate", []): + for translate_config in translate_configs.get("translate", []): translator = TranslationTask( - name=self.name, checkpoint=config.get("checkpoint", "last"), commit=self.commit + name=self.name, checkpoint=translate_config.get("checkpoint", "last"), commit=self.commit ) # Backwards compatibility if not postprocess_configs: - postprocess_handler = PostprocessHandler([PostprocessConfig(config)]) + postprocess_handler = PostprocessHandler([PostprocessConfig(translate_config)]) - if len(config.get("books", [])) > 0: - if isinstance(config["books"], list): - config["books"] = ";".join(config["books"]) + if len(translate_config.get("books", [])) > 0: + if isinstance(translate_config["books"], list): + translate_config["books"] = ";".join(translate_config["books"]) translator.translate_books( - config["books"], - config.get("src_project"), - config.get("trg_project"), - config.get("trg_iso"), + translate_config["books"], + translate_config.get("src_project"), + translate_config.get("trg_project"), + translate_config.get("trg_iso"), self.produce_multiple_translations, self.save_confidences, postprocess_handler, ) - elif config.get("src_prefix"): + elif translate_config.get("src_prefix"): translator.translate_text_files( - config.get("src_prefix"), - config.get("trg_prefix"), - config.get("start_seq"), - config.get("end_seq"), - config.get("src_iso"), - config.get("trg_iso"), + translate_config.get("src_prefix"), + translate_config.get("trg_prefix"), + translate_config.get("start_seq"), + translate_config.get("end_seq"), + translate_config.get("src_iso"), + translate_config.get("trg_iso"), self.produce_multiple_translations, self.save_confidences, ) - elif config.get("src"): + elif translate_config.get("src"): translator.translate_files( - config.get("src"), - config.get("trg"), - config.get("src_iso"), - config.get("trg_iso"), + translate_config.get("src"), + translate_config.get("trg"), + translate_config.get("src_iso"), + translate_config.get("trg_iso"), self.produce_multiple_translations, self.save_confidences, postprocess_handler, diff --git a/silnlp/nmt/hugging_face_config.py b/silnlp/nmt/hugging_face_config.py index 3fe35e03..105a9dc5 100644 --- a/silnlp/nmt/hugging_face_config.py +++ b/silnlp/nmt/hugging_face_config.py @@ -76,7 +76,8 @@ from ..common.environment import SIL_NLP_ENV from ..common.translator import DraftGroup, TranslationGroup, generate_confidence_files from ..common.utils import NoiseMethod, ReplaceRandomToken, Side, create_noise_methods, get_mt_exp_dir, merge_dict -from .config import CheckpointType, Config, DataFile, NMTModel +from .config import CheckpointType, Config, NMTModel +from .corpora import DataFile from .tokenizer import NullTokenizer, Tokenizer if is_safetensors_available(): diff --git a/silnlp/nmt/postprocess.py b/silnlp/nmt/postprocess.py index 8ed35506..1cf34a49 100644 --- a/silnlp/nmt/postprocess.py +++ b/silnlp/nmt/postprocess.py @@ -2,47 +2,56 @@ import logging import re from pathlib import Path -from typing import List, Optional, Tuple +from typing import List, Optional import yaml -from machine.corpora import ( - FileParatextProjectSettingsParser, - ScriptureRef, - UpdateUsfmParserHandler, - UpdateUsfmTextBehavior, - UsfmFileText, - UsfmStylesheet, - UsfmTextType, - parse_usfm, -) +from attr import dataclass +from machine.corpora import FileParatextProjectSettingsParser, ScriptureRef, UsfmFileText, UsfmStylesheet, UsfmTextType from machine.scripture import book_number_to_id, get_chapters from transformers.trainer_utils import get_last_checkpoint from ..common.paratext import book_file_name_digits, get_book_path, get_project_dir -from ..common.postprocesser import PostprocessConfig, PostprocessHandler +from ..common.postprocesser import ( + NoDetectedQuoteConventionException, + PostprocessConfig, + PostprocessHandler, + UnknownQuoteConventionException, +) from ..common.usfm_utils import PARAGRAPH_TYPE_EMBEDS from ..common.utils import get_git_revision_hash from .clearml_connection import SILClearML from .config import Config from .config_utils import load_config +from .corpora import CorpusPair from .hugging_face_config import get_best_checkpoint LOGGER = logging.getLogger(__package__ + ".postprocess") +@dataclass +class Sentence: + text: str + ref: ScriptureRef + + +@dataclass +class DraftSentences: + sentences: List[Sentence] + remarks: List[str] + + # Takes the path to a USFM file and the relevant info to parse it # and returns the text of all non-embed sentences and their respective references, # along with any remarks (\rem) that were inserted at the beginning of the file def get_sentences( book_path: Path, stylesheet: UsfmStylesheet, encoding: str, book: str, chapters: List[int] = [] -) -> Tuple[List[str], List[ScriptureRef], List[str]]: - sents = [] - refs = [] - draft_remarks = [] +) -> DraftSentences: + draft_sentences = DraftSentences([], []) + for sent in UsfmFileText(stylesheet, encoding, book, book_path, include_all_text=True): marker = sent.ref.path[-1].name if len(sent.ref.path) > 0 else "" - if marker == "rem" and len(refs) == 0: - draft_remarks.append(sent.text) + if marker == "rem" and len(draft_sentences.sentences) == 0: + draft_sentences.remarks.append(sent.text) continue if ( marker in PARAGRAPH_TYPE_EMBEDS @@ -51,20 +60,24 @@ def get_sentences( ): continue - sents.append(re.sub(" +", " ", sent.text.strip())) - refs.append(sent.ref) + draft_sentences.sentences.append(Sentence(re.sub(" +", " ", sent.text.strip()), sent.ref)) + + return draft_sentences - return sents, refs, draft_remarks + +@dataclass +class DraftMetadata: + source_path: Path + draft_path: Path + postprocess_config: PostprocessConfig # Get the paths of all drafts that would be produced by an experiment's translate config and that exist -def get_draft_paths_from_exp(config: Config) -> Tuple[List[Path], List[Path], List[PostprocessConfig]]: - with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file: - translate_requests = yaml.safe_load(file).get("translate", []) +def get_draft_paths_from_exp(config: Config) -> List[DraftMetadata]: + with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as translate_config_file: + translate_requests = yaml.safe_load(translate_config_file).get("translate", []) - src_paths = [] - draft_paths = [] - postprocess_configs = [] + draft_metadata_list = [] for translate_request in translate_requests: src_project = translate_request.get("src_project", next(iter(config.src_projects))) @@ -88,18 +101,26 @@ def get_draft_paths_from_exp(config: Config) -> Tuple[List[Path], List[Path], Li config.exp_dir / "infer" / step_str / src_project / f"{book_file_name_digits(book_num)}{book}.SFM" ) if draft_path.exists(): - src_paths.append(src_path) - draft_paths.append(draft_path) - postprocess_configs.append(postprocess_config) + draft_metadata_list.append( + DraftMetadata( + source_path=src_path, + draft_path=draft_path, + postprocess_config=postprocess_config, + ) + ) elif draft_path.with_suffix(f".{1}{draft_path.suffix}").exists(): # multiple drafts for i in range(1, config.infer.get("num_drafts", 1) + 1): - src_paths.append(src_path) - draft_paths.append(draft_path.with_suffix(f".{i}{draft_path.suffix}")) - postprocess_configs.append(postprocess_config) + draft_metadata_list.append( + DraftMetadata( + source_path=src_path, + draft_path=draft_path.with_suffix(f".{i}{draft_path.suffix}"), + postprocess_config=postprocess_config, + ) + ) else: LOGGER.warning(f"Draft not found: {draft_path}") - return src_paths, draft_paths, postprocess_configs + return draft_metadata_list def postprocess_draft( @@ -108,6 +129,7 @@ def postprocess_draft( postprocess_handler: PostprocessHandler, book: Optional[str] = None, out_dir: Optional[Path] = None, + training_corpus_pairs: List[CorpusPair] = [], ) -> None: if str(src_path).startswith(str(get_project_dir(""))): settings = FileParatextProjectSettingsParser(src_path.parent).parse() @@ -118,58 +140,83 @@ def postprocess_draft( stylesheet = UsfmStylesheet("usfm.sty") encoding = "utf-8-sig" - src_sents, src_refs, _ = get_sentences(src_path, stylesheet, encoding, book) - draft_sents, draft_refs, draft_remarks = get_sentences(draft_path, stylesheet, encoding, book) + src_sentences = get_sentences(src_path, stylesheet, encoding, book) + draft_sentences = get_sentences(draft_path, stylesheet, encoding, book) # Verify reference parity - if len(src_refs) != len(draft_refs): + if len(src_sentences.sentences) != len(draft_sentences.sentences): LOGGER.warning(f"Can't process {src_path} and {draft_path}: Unequal number of verses/references") return - for src_ref, draft_ref in zip(src_refs, draft_refs): - if src_ref.to_relaxed() != draft_ref.to_relaxed(): + for src_sentence, draft_sentence in zip(src_sentences.sentences, draft_sentences.sentences): + if src_sentence.ref.to_relaxed() != draft_sentence.ref.to_relaxed(): LOGGER.warning( f"Can't process {src_path} and {draft_path}: Mismatched ref, {src_ref} != {draft_ref}. Files must have the exact same USFM structure" ) return - postprocess_handler.construct_rows(src_refs, src_sents, draft_sents) + if any(config.is_marker_placement_required() for config in postprocess_handler.configs): + postprocess_handler.construct_rows( + [s.ref for s in src_sentences.sentences], + [s.text for s in src_sentences.sentences], + [s.text for s in draft_sentences.sentences], + ) with src_path.open(encoding=encoding) as f: - usfm = f.read() + source_usfm = f.read() for config in postprocess_handler.configs: - handler = UpdateUsfmParserHandler( - rows=config.rows, - id_text=book, - text_behavior=UpdateUsfmTextBehavior.STRIP_EXISTING, - paragraph_behavior=config.get_paragraph_behavior(), - embed_behavior=config.get_embed_behavior(), - style_behavior=config.get_style_behavior(), - update_block_handlers=config.update_block_handlers, - remarks=(draft_remarks + [config.get_postprocess_remark()]), - ) - parse_usfm(usfm, handler) - usfm_out = handler.get_usfm() + if config.is_marker_placement_required(): + place_markers_postprocessor = config.create_place_markers_postprocessor() + target_usfm = place_markers_postprocessor.postprocess_usfm( + source_usfm, config.rows, draft_sentences.remarks + ) + else: + with draft_path.open(encoding=encoding) as f: + target_usfm = f.read() + + if config.is_quotation_mark_denormalization_required(): + try: + quotation_denormalization_postprocessor = config.create_denormalize_quotation_marks_postprocessor( + training_corpus_pairs + ) + target_usfm = quotation_denormalization_postprocessor.postprocess_usfm(target_usfm) + except (UnknownQuoteConventionException, NoDetectedQuoteConventionException) as e: + raise e if not out_dir: out_dir = draft_path.parent out_path = out_dir / f"{draft_path.stem}{config.get_postprocess_suffix()}{draft_path.suffix}" - with out_path.open("w", encoding="utf-8" if encoding == "utf-8-sig" else encoding) as f: - f.write(usfm_out) + with out_path.open( + "w", encoding="utf-8" if encoding == "utf-8-sig" or encoding == "utf_8_sig" else encoding + ) as f: + f.write(target_usfm) def postprocess_experiment(config: Config, out_dir: Optional[Path] = None) -> None: - src_paths, draft_paths, legacy_pcs = get_draft_paths_from_exp(config) + draft_metadata_list = get_draft_paths_from_exp(config) with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file: - postprocess_configs = yaml.safe_load(file).get("postprocess", []) + translate_config = yaml.safe_load(file) + postprocess_configs = [PostprocessConfig(pc) for pc in translate_config.get("postprocess", [])] - postprocess_handler = PostprocessHandler([PostprocessConfig(pc) for pc in postprocess_configs], include_base=False) + postprocess_handler = PostprocessHandler(postprocess_configs, include_base=False) - for src_path, draft_path, legacy_pc in zip(src_paths, draft_paths, legacy_pcs): + for draft_metadata in draft_metadata_list: if postprocess_configs: - postprocess_draft(src_path, draft_path, postprocess_handler, out_dir=out_dir) - elif not legacy_pc.is_base_config(): - postprocess_draft(src_path, draft_path, PostprocessHandler([legacy_pc], False), out_dir=out_dir) + postprocess_draft( + draft_metadata.source_path, + draft_metadata.draft_path, + postprocess_handler, + out_dir=out_dir, + training_corpus_pairs=config.corpus_pairs, + ) + elif not draft_metadata.postprocess_config.is_base_config(): + postprocess_draft( + draft_metadata.source_path, + draft_metadata.draft_path, + PostprocessHandler([draft_metadata.postprocess_config], False), + out_dir=out_dir, + training_corpus_pairs=config.corpus_pairs, + ) def main() -> None: diff --git a/silnlp/nmt/translate.py b/silnlp/nmt/translate.py index f752058f..7589a823 100644 --- a/silnlp/nmt/translate.py +++ b/silnlp/nmt/translate.py @@ -115,6 +115,7 @@ def translate_books( trg_project, postprocess_handler, experiment_ckpt_str, + config.corpus_pairs, ) except Exception as e: translation_failed.append(book) @@ -257,6 +258,7 @@ def translate_files( save_confidences, postprocess_handler=postprocess_handler, experiment_ckpt_str=experiment_ckpt_str, + training_corpus_pairs=config.corpus_pairs, ) def _init_translation_task(self, experiment_suffix: str) -> Tuple[Translator, Config, str]: From e509062b8ed10ce46cef5934f3c1710cbc570628 Mon Sep 17 00:00:00 2001 From: Ben King Date: Tue, 2 Sep 2025 13:28:26 +0000 Subject: [PATCH 2/8] Working implementation of quotation denormalization for translate.py --- silnlp/common/postprocesser.py | 14 +++++--- silnlp/common/translator.py | 59 +++++++++++++++++++++++----------- silnlp/nmt/translate.py | 19 +++++++++++ 3 files changed, 69 insertions(+), 23 deletions(-) diff --git a/silnlp/common/postprocesser.py b/silnlp/common/postprocesser.py index a1eb61f1..2f7d8a01 100644 --- a/silnlp/common/postprocesser.py +++ b/silnlp/common/postprocesser.py @@ -1,7 +1,7 @@ import logging from pathlib import Path from tempfile import TemporaryDirectory -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Sequence, Tuple from machine.corpora import ( PlaceMarkersAlignmentInfo, @@ -42,8 +42,8 @@ "include_style_markers": False, "include_embeds": False, "denormalize_quotation_marks": False, - "source_quote_convention": "standard_english", - "target_quote_convention": "standard_english", + "source_quote_convention": "detect", + "target_quote_convention": "detect", } POSTPROCESS_SUFFIX_CHARS = { "paragraph_behavior": {"place": "p", "strip": "x"}, @@ -70,6 +70,9 @@ def __init__( self._style_behavior = style_behavior self._update_block_handlers = [PlaceMarkersUsfmUpdateBlockHandler()] + def get_update_block_handlers(self) -> Sequence[UsfmUpdateBlockHandler]: + return self._update_block_handlers + def _create_remark(self) -> str: behavior_map: Dict[UpdateUsfmMarkerBehavior, List[str]] = { UpdateUsfmMarkerBehavior.PRESERVE: [], @@ -157,7 +160,7 @@ def _get_source_quote_convention( if convention_name is None or convention_name == "detect": if project_name is None: raise ValueError( - "The experiment's translate_config.yml must exist and specify a source project name, since an explicit source quote convention name was not provided." + "The source project name must be explicitly provided or be present in translate_config.yml, since an explicit source quote convention name was not provided." ) if selected_training_books is None: raise ValueError( @@ -259,6 +262,9 @@ def __init__(self, config: dict = {}) -> None: if config.get("include_inline_elements"): self._config["include_embeds"] = True + if config.get("src_project"): + self._config["src_project"] = config.get("src_project") + self.update_block_handlers: List[UsfmUpdateBlockHandler] = [] self.rows: List[UpdateUsfmRow] = [] diff --git a/silnlp/common/translator.py b/silnlp/common/translator.py index 542b4f25..3dc2071b 100644 --- a/silnlp/common/translator.py +++ b/silnlp/common/translator.py @@ -6,7 +6,6 @@ from itertools import groupby from math import exp from pathlib import Path -from pydoc import text from typing import DefaultDict, Iterable, List, Optional import docx @@ -297,7 +296,8 @@ def translate_usfm( draft_set: DraftGroup = DraftGroup(translations) for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1): - postprocess_handler.construct_rows(vrefs, sentences, translated_draft) + if any([config.is_marker_placement_required() for config in postprocess_handler.configs]): + postprocess_handler.construct_rows(vrefs, sentences, translated_draft) for config in postprocess_handler.configs: @@ -313,12 +313,25 @@ def translate_usfm( if trg_project is not None or src_from_project: project_dir = get_project_dir(trg_project if trg_project is not None else src_file_path.parent.name) dest_updater = FileParatextProjectTextUpdater(project_dir) - usfm_out = dest_updater.update_usfm( - book_id=src_file_text.id, - rows=config.rows, - text_behavior=text_behavior, - remarks=remarks, - ) + if config.is_marker_placement_required(): + place_markers_postprocessor = config.create_place_markers_postprocessor() + usfm_out = dest_updater.update_usfm( + book_id=src_file_text.id, + rows=config.rows, + text_behavior=text_behavior, + paragraph_behavior=config.get_paragraph_behavior(), + embed_behavior=config.get_embed_behavior(), + style_behavior=config.get_style_behavior(), + update_block_handlers=place_markers_postprocessor.get_update_block_handlers(), + remarks=remarks, + ) + else: + usfm_out = dest_updater.update_usfm( + book_id=src_file_text.id, + rows=config.rows, + text_behavior=text_behavior, + remarks=remarks, + ) if usfm_out is None: raise FileNotFoundError( @@ -327,20 +340,28 @@ def translate_usfm( else: # Slightly more manual version for updating an individual file with open(src_file_path, encoding="utf-8-sig") as f: usfm = f.read() - handler = UpdateUsfmParserHandler( - rows=config.rows, - id_text=vrefs[0].book, - text_behavior=text_behavior, - remarks=remarks, - ) + if config.is_marker_placement_required(): + place_markers_postprocessor = config.create_place_markers_postprocessor() + handler = UpdateUsfmParserHandler( + rows=config.rows, + id_text=vrefs[0].book, + text_behavior=text_behavior, + paragraph_behavior=config.get_paragraph_behavior(), + embed_behavior=config.get_embed_behavior(), + style_behavior=config.get_style_behavior(), + update_block_handlers=place_markers_postprocessor.get_update_block_handlers(), + remarks=remarks, + ) + else: + handler = UpdateUsfmParserHandler( + rows=config.rows, + id_text=vrefs[0].book, + text_behavior=text_behavior, + remarks=remarks, + ) parse_usfm(usfm, handler) usfm_out = handler.get_usfm() - # Post-process the USFM output - if config.is_marker_placement_required(): - place_markers_postprocessor = config.create_place_markers_postprocessor() - usfm_out = place_markers_postprocessor.postprocess_usfm(usfm_out, config.rows) - if config.is_quotation_mark_denormalization_required(): try: quotation_denormalization_postprocessor = ( diff --git a/silnlp/nmt/translate.py b/silnlp/nmt/translate.py index 7589a823..d91a1914 100644 --- a/silnlp/nmt/translate.py +++ b/silnlp/nmt/translate.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Iterable, Optional, Tuple, Union +from flake8 import LOG from machine.scripture import VerseRef, book_number_to_id, get_chapters from ..common.environment import SIL_NLP_ENV @@ -377,6 +378,24 @@ def main() -> None: action="store_true", help="For files in USFM format, attempt to place paragraph markers in translated verses based on the source project's markers", ) + parser.add_argument( + "--denormalize-quotation-marks", + default=False, + action="store_true", + help="For files in USFM format, attempt to change the draft's quotation marks to match the target project's quote convention", + ) + parser.add_argument( + "--source-quote-convention", + default="detect", + type=str, + help="The quote convention for the source project. If not specified, it will be detected automatically.", + ) + parser.add_argument( + "--target-quote-convention", + default="detect", + type=str, + help="The quote convention for the target project. If not specified, it will be detected automatically.", + ) parser.add_argument( "--clearml-queue", default=None, From 8aeb278871d1687a27beb94aa8760e336c593e6b Mon Sep 17 00:00:00 2001 From: Ben King Date: Tue, 2 Sep 2025 20:22:18 +0000 Subject: [PATCH 3/8] Working implementation of quotation denormalization for experiment.py --- silnlp/common/postprocesser.py | 5 +-- silnlp/common/translator.py | 14 ++++++-- silnlp/nmt/experiment.py | 3 +- silnlp/nmt/postprocess.py | 59 ++++++++++++++-------------------- silnlp/nmt/translate.py | 5 +++ 5 files changed, 43 insertions(+), 43 deletions(-) diff --git a/silnlp/common/postprocesser.py b/silnlp/common/postprocesser.py index 2f7d8a01..9cc7243d 100644 --- a/silnlp/common/postprocesser.py +++ b/silnlp/common/postprocesser.py @@ -24,7 +24,6 @@ QuoteConventionAnalysis, QuoteConventionDetector, ) -from machine.scripture import get_chapters from machine.tokenization import LatinWordTokenizer from machine.translation import WordAlignmentMatrix @@ -327,13 +326,11 @@ def create_place_markers_postprocessor(self) -> PlaceMarkersPostprocessor: ) def create_denormalize_quotation_marks_postprocessor( - self, - training_corpus_pairs: List[CorpusPair], + self, training_corpus_pairs: List[CorpusPair], translation_source_project_name: Optional[str] ) -> DenormalizeQuotationMarksPostprocessor: _, training_target_project_name, selected_training_books = self._get_experiment_training_info( training_corpus_pairs, ) - translation_source_project_name = self._config.get("src_project") return DenormalizeQuotationMarksPostprocessor( self._config["source_quote_convention"], diff --git a/silnlp/common/translator.py b/silnlp/common/translator.py index 3dc2071b..2cb3c1e7 100644 --- a/silnlp/common/translator.py +++ b/silnlp/common/translator.py @@ -218,6 +218,7 @@ def translate_book( postprocess_handler, experiment_ckpt_str, training_corpus_pairs, + src_project, ) def translate_usfm( @@ -233,6 +234,7 @@ def translate_usfm( postprocess_handler: PostprocessHandler = PostprocessHandler(), experiment_ckpt_str: str = "", training_corpus_pairs: List[CorpusPair] = [], + src_project: Optional[str] = None, ) -> None: # Create UsfmFileText object for source src_from_project = False @@ -365,7 +367,7 @@ def translate_usfm( if config.is_quotation_mark_denormalization_required(): try: quotation_denormalization_postprocessor = ( - config.create_denormalize_quotation_marks_postprocessor(training_corpus_pairs) + config.create_denormalize_quotation_marks_postprocessor(training_corpus_pairs, src_project) ) usfm_out = quotation_denormalization_postprocessor.postprocess_usfm(usfm_out) except (UnknownQuoteConventionException, NoDetectedQuoteConventionException) as e: @@ -375,8 +377,16 @@ def translate_usfm( trg_draft_file_path = trg_file_path.with_stem(trg_file_path.stem + config.get_postprocess_suffix()) if produce_multiple_translations: trg_draft_file_path = trg_draft_file_path.with_suffix(f".{draft_index}{trg_file_path.suffix}") + with trg_draft_file_path.open( - "w", encoding=src_settings.encoding if src_from_project else "utf-8" + "w", + encoding=( + "utf-8" + if not src_from_project + or src_from_project + and (src_settings.encoding == "utf-8-sig" or src_settings.encoding == "utf_8_sig") + else src_settings.encoding + ), ) as f: f.write(usfm_out) diff --git a/silnlp/nmt/experiment.py b/silnlp/nmt/experiment.py index 466d292a..13fe324c 100644 --- a/silnlp/nmt/experiment.py +++ b/silnlp/nmt/experiment.py @@ -95,9 +95,8 @@ def translate(self): name=self.name, checkpoint=translate_config.get("checkpoint", "last"), commit=self.commit ) - # Backwards compatibility if not postprocess_configs: - postprocess_handler = PostprocessHandler([PostprocessConfig(translate_config)]) + postprocess_handler = PostprocessHandler([]) if len(translate_config.get("books", [])) > 0: if isinstance(translate_config["books"], list): diff --git a/silnlp/nmt/postprocess.py b/silnlp/nmt/postprocess.py index 1cf34a49..b60e64d0 100644 --- a/silnlp/nmt/postprocess.py +++ b/silnlp/nmt/postprocess.py @@ -2,7 +2,7 @@ import logging import re from pathlib import Path -from typing import List, Optional +from typing import Dict, List, Optional import yaml from attr import dataclass @@ -69,7 +69,7 @@ def get_sentences( class DraftMetadata: source_path: Path draft_path: Path - postprocess_config: PostprocessConfig + source_project: str # Get the paths of all drafts that would be produced by an experiment's translate config and that exist @@ -89,9 +89,6 @@ def get_draft_paths_from_exp(config: Config) -> List[DraftMetadata]: else: step_str = str(ckpt) - # Backwards compatibility - postprocess_config = PostprocessConfig(translate_request) - book_nums = get_chapters(translate_request.get("books", [])).keys() for book_num in book_nums: book = book_number_to_id(book_num) @@ -102,11 +99,7 @@ def get_draft_paths_from_exp(config: Config) -> List[DraftMetadata]: ) if draft_path.exists(): draft_metadata_list.append( - DraftMetadata( - source_path=src_path, - draft_path=draft_path, - postprocess_config=postprocess_config, - ) + DraftMetadata(source_path=src_path, draft_path=draft_path, source_project=src_project) ) elif draft_path.with_suffix(f".{1}{draft_path.suffix}").exists(): # multiple drafts for i in range(1, config.infer.get("num_drafts", 1) + 1): @@ -114,7 +107,7 @@ def get_draft_paths_from_exp(config: Config) -> List[DraftMetadata]: DraftMetadata( source_path=src_path, draft_path=draft_path.with_suffix(f".{i}{draft_path.suffix}"), - postprocess_config=postprocess_config, + source_project=src_project, ) ) else: @@ -124,33 +117,34 @@ def get_draft_paths_from_exp(config: Config) -> List[DraftMetadata]: def postprocess_draft( - src_path: Path, - draft_path: Path, + draft_metadata: DraftMetadata, postprocess_handler: PostprocessHandler, book: Optional[str] = None, out_dir: Optional[Path] = None, training_corpus_pairs: List[CorpusPair] = [], ) -> None: - if str(src_path).startswith(str(get_project_dir(""))): - settings = FileParatextProjectSettingsParser(src_path.parent).parse() + if str(draft_metadata.source_path).startswith(str(get_project_dir(""))): + settings = FileParatextProjectSettingsParser(draft_metadata.source_path.parent).parse() stylesheet = settings.stylesheet encoding = settings.encoding - book = settings.get_book_id(src_path.name) + book = settings.get_book_id(draft_metadata.source_path.name) else: stylesheet = UsfmStylesheet("usfm.sty") encoding = "utf-8-sig" - src_sentences = get_sentences(src_path, stylesheet, encoding, book) - draft_sentences = get_sentences(draft_path, stylesheet, encoding, book) + src_sentences = get_sentences(draft_metadata.source_path, stylesheet, encoding, book) + draft_sentences = get_sentences(draft_metadata.draft_path, stylesheet, encoding, book) # Verify reference parity if len(src_sentences.sentences) != len(draft_sentences.sentences): - LOGGER.warning(f"Can't process {src_path} and {draft_path}: Unequal number of verses/references") + LOGGER.warning( + f"Can't process {draft_metadata.source_path} and {draft_path}: Unequal number of verses/references" + ) return for src_sentence, draft_sentence in zip(src_sentences.sentences, draft_sentences.sentences): if src_sentence.ref.to_relaxed() != draft_sentence.ref.to_relaxed(): LOGGER.warning( - f"Can't process {src_path} and {draft_path}: Mismatched ref, {src_ref} != {draft_ref}. Files must have the exact same USFM structure" + f"Can't process {draft_metadata.source_path} and {draft_path}: Mismatched ref, {src_ref} != {draft_ref}. Files must have the exact same USFM structure" ) return @@ -161,7 +155,7 @@ def postprocess_draft( [s.text for s in draft_sentences.sentences], ) - with src_path.open(encoding=encoding) as f: + with draft_metadata.source_path.open(encoding=encoding) as f: source_usfm = f.read() for config in postprocess_handler.configs: @@ -171,21 +165,25 @@ def postprocess_draft( source_usfm, config.rows, draft_sentences.remarks ) else: - with draft_path.open(encoding=encoding) as f: + with draft_metadata.draft_path.open(encoding=encoding) as f: target_usfm = f.read() if config.is_quotation_mark_denormalization_required(): try: quotation_denormalization_postprocessor = config.create_denormalize_quotation_marks_postprocessor( - training_corpus_pairs + training_corpus_pairs, + draft_metadata.source_project, ) target_usfm = quotation_denormalization_postprocessor.postprocess_usfm(target_usfm) except (UnknownQuoteConventionException, NoDetectedQuoteConventionException) as e: raise e if not out_dir: - out_dir = draft_path.parent - out_path = out_dir / f"{draft_path.stem}{config.get_postprocess_suffix()}{draft_path.suffix}" + out_dir = draft_metadata.draft_path.parent + out_path = ( + out_dir + / f"{draft_metadata.draft_path.stem}{config.get_postprocess_suffix()}{draft_metadata.draft_path.suffix}" + ) with out_path.open( "w", encoding="utf-8" if encoding == "utf-8-sig" or encoding == "utf_8_sig" else encoding ) as f: @@ -203,20 +201,11 @@ def postprocess_experiment(config: Config, out_dir: Optional[Path] = None) -> No for draft_metadata in draft_metadata_list: if postprocess_configs: postprocess_draft( - draft_metadata.source_path, - draft_metadata.draft_path, + draft_metadata, postprocess_handler, out_dir=out_dir, training_corpus_pairs=config.corpus_pairs, ) - elif not draft_metadata.postprocess_config.is_base_config(): - postprocess_draft( - draft_metadata.source_path, - draft_metadata.draft_path, - PostprocessHandler([draft_metadata.postprocess_config], False), - out_dir=out_dir, - training_corpus_pairs=config.corpus_pairs, - ) def main() -> None: diff --git a/silnlp/nmt/translate.py b/silnlp/nmt/translate.py index d91a1914..ecb6b65c 100644 --- a/silnlp/nmt/translate.py +++ b/silnlp/nmt/translate.py @@ -260,6 +260,11 @@ def translate_files( postprocess_handler=postprocess_handler, experiment_ckpt_str=experiment_ckpt_str, training_corpus_pairs=config.corpus_pairs, + src_project=( + config.corpus_pairs[0].src_files[0].project + if config.corpus_pairs and config.corpus_pairs[0].src_files + else None + ), ) def _init_translation_task(self, experiment_suffix: str) -> Tuple[Translator, Config, str]: From 16b518309a01740cf8c7c4ac3fcbf291d8f4b5ab Mon Sep 17 00:00:00 2001 From: Ben King Date: Wed, 3 Sep 2025 18:30:29 +0000 Subject: [PATCH 4/8] Working implementation of quotation denormalization for postprocess_draft.py --- silnlp/common/postprocess_draft.py | 100 ++++++++++++----------------- silnlp/common/translator.py | 69 +++++++++----------- silnlp/nmt/postprocess.py | 8 ++- 3 files changed, 77 insertions(+), 100 deletions(-) diff --git a/silnlp/common/postprocess_draft.py b/silnlp/common/postprocess_draft.py index 7fe37893..d9390693 100644 --- a/silnlp/common/postprocess_draft.py +++ b/silnlp/common/postprocess_draft.py @@ -5,7 +5,6 @@ from ..nmt.clearml_connection import SILClearML from ..nmt.config_utils import load_config from ..nmt.postprocess import get_draft_paths_from_exp, postprocess_draft, postprocess_experiment -from .paratext import get_project_dir from .postprocesser import PostprocessConfig, PostprocessHandler from .utils import get_mt_exp_dir @@ -19,28 +18,11 @@ def main() -> None: ) parser.add_argument( "--experiment", + required=True, default=None, help="Name of an experiment directory in MT/experiments. \ If this option is used, the experiment's translate config will be used to find source and draft files.", ) - parser.add_argument( - "--source", - default=None, - help="Path of the source USFM file. \ - If in a Paratext project, the project settings will be used when reading the files.", - ) - parser.add_argument( - "--draft", - default=None, - help="Path of the draft USFM file that postprocessing will be applied to. \ - Must have the exact same USFM structure as 'source', which it will if it is a draft from that source.", - ) - parser.add_argument( - "--book", - default=None, - help="3-letter book id of book being evaluated, e.g. MAT. \ - Only necessary if the source file is not in a Paratext project directory.", - ) parser.add_argument( "--output-folder", default=None, @@ -63,6 +45,30 @@ def main() -> None: action="store_true", help="Carry over embeds from the source project to the output without translating them", ) + parser.add_argument( + "--denormalize-quotation-marks", + default=False, + action="store_true", + help="For files in USFM format, attempt to change the draft's quotation marks to match the target project's quote convention", + ) + parser.add_argument( + "--source-quote-convention", + default="detect", + type=str, + help="The quote convention for the source project. If not specified, it will be detected automatically.", + ) + parser.add_argument( + "--target-quote-convention", + default="detect", + type=str, + help="The quote convention for the target project. If not specified, it will be detected automatically.", + ) + parser.add_argument( + "--source-project", + default="", + help="The name of the Paratext project used as the source. When the source quote convention is set to 'detect' or not specified," + + " this project will be used to detect the source quote convention.", + ) parser.add_argument( "--clearml-queue", default=None, @@ -72,52 +78,30 @@ def main() -> None: ) args = parser.parse_args() - experiment = args.experiment.replace("\\", "/") if args.experiment else None + experiment = args.experiment.replace("\\", "/") args.output_folder = Path(args.output_folder.replace("\\", "/")) if args.output_folder else None postprocess_config = PostprocessConfig(vars(args)) - if args.experiment and (args.source or args.draft or args.book): - LOGGER.info("--experiment option used. --source, --draft, and --book will be ignored.") - if not (args.experiment or (args.source and args.draft)): - raise ValueError("Not enough options used. Please use --experiment OR --source and --draft.") + if not get_mt_exp_dir(experiment).exists(): + raise ValueError(f"Experiment {experiment} not found.") - if experiment: - if not get_mt_exp_dir(experiment).exists(): - raise ValueError(f"Experiment {experiment} not found.") - - if args.clearml_queue is not None: - if "cpu" not in args.clearml_queue: - raise ValueError("Running this script on a GPU queue will not speed it up. Please only use CPU queues.") - clearml = SILClearML(experiment, args.clearml_queue) - config = clearml.config - else: - config = load_config(experiment) - - if not (config.exp_dir / "translate_config.yml").exists(): - raise ValueError("Experiment translate_config.yml not found.") - - if not postprocess_config.is_base_config(): - src_paths, draft_paths, _ = get_draft_paths_from_exp(config) - else: - LOGGER.info("No postprocessing options used. Applying postprocessing requests from translate config.") - postprocess_experiment(config, args.output_folder) - exit() - elif args.clearml_queue is not None: - raise ValueError("Must use --experiment option to use ClearML.") + if args.clearml_queue is not None: + if "cpu" not in args.clearml_queue: + raise ValueError("Running this script on a GPU queue will not speed it up. Please only use CPU queues.") + clearml = SILClearML(experiment, args.clearml_queue) + config = clearml.config else: - src_paths = [Path(args.source.replace("\\", "/"))] - draft_paths = [Path(args.draft.replace("\\", "/"))] - if not str(src_paths[0]).startswith(str(get_project_dir(""))) and args.book is None: - raise ValueError( - "--book argument must be passed if the source file is not in a Paratext project directory." - ) + config = load_config(experiment) - if postprocess_config.is_base_config(): - raise ValueError("Please use at least one postprocessing option.") - postprocess_handler = PostprocessHandler([postprocess_config], include_base=False) + if not (config.exp_dir / "translate_config.yml").exists(): + raise ValueError("Experiment translate_config.yml not found.") - for src_path, draft_path in zip(src_paths, draft_paths): - postprocess_draft(src_path, draft_path, postprocess_handler, args.book, args.output_folder) + if postprocess_config.is_base_config(): + LOGGER.info("No postprocessing options used. Applying postprocessing requests from translate config.") + postprocess_experiment(config, out_dir=args.output_folder) + else: + postprocess_handler = PostprocessHandler([postprocess_config], include_base=False) + postprocess_experiment(config, postprocess_handler=postprocess_handler, out_dir=args.output_folder) if __name__ == "__main__": diff --git a/silnlp/common/translator.py b/silnlp/common/translator.py index 2cb3c1e7..a3f4a0eb 100644 --- a/silnlp/common/translator.py +++ b/silnlp/common/translator.py @@ -298,8 +298,7 @@ def translate_usfm( draft_set: DraftGroup = DraftGroup(translations) for draft_index, translated_draft in enumerate(draft_set.get_drafts(), 1): - if any([config.is_marker_placement_required() for config in postprocess_handler.configs]): - postprocess_handler.construct_rows(vrefs, sentences, translated_draft) + postprocess_handler.construct_rows(vrefs, sentences, translated_draft) for config in postprocess_handler.configs: @@ -315,25 +314,20 @@ def translate_usfm( if trg_project is not None or src_from_project: project_dir = get_project_dir(trg_project if trg_project is not None else src_file_path.parent.name) dest_updater = FileParatextProjectTextUpdater(project_dir) - if config.is_marker_placement_required(): - place_markers_postprocessor = config.create_place_markers_postprocessor() - usfm_out = dest_updater.update_usfm( - book_id=src_file_text.id, - rows=config.rows, - text_behavior=text_behavior, - paragraph_behavior=config.get_paragraph_behavior(), - embed_behavior=config.get_embed_behavior(), - style_behavior=config.get_style_behavior(), - update_block_handlers=place_markers_postprocessor.get_update_block_handlers(), - remarks=remarks, - ) - else: - usfm_out = dest_updater.update_usfm( - book_id=src_file_text.id, - rows=config.rows, - text_behavior=text_behavior, - remarks=remarks, - ) + usfm_out = dest_updater.update_usfm( + book_id=src_file_text.id, + rows=config.rows, + text_behavior=text_behavior, + paragraph_behavior=config.get_paragraph_behavior(), + embed_behavior=config.get_embed_behavior(), + style_behavior=config.get_style_behavior(), + update_block_handlers=( + config.create_place_markers_postprocessor().get_update_block_handlers() + if config.is_marker_placement_required() + else None + ), + remarks=remarks, + ) if usfm_out is None: raise FileNotFoundError( @@ -342,25 +336,20 @@ def translate_usfm( else: # Slightly more manual version for updating an individual file with open(src_file_path, encoding="utf-8-sig") as f: usfm = f.read() - if config.is_marker_placement_required(): - place_markers_postprocessor = config.create_place_markers_postprocessor() - handler = UpdateUsfmParserHandler( - rows=config.rows, - id_text=vrefs[0].book, - text_behavior=text_behavior, - paragraph_behavior=config.get_paragraph_behavior(), - embed_behavior=config.get_embed_behavior(), - style_behavior=config.get_style_behavior(), - update_block_handlers=place_markers_postprocessor.get_update_block_handlers(), - remarks=remarks, - ) - else: - handler = UpdateUsfmParserHandler( - rows=config.rows, - id_text=vrefs[0].book, - text_behavior=text_behavior, - remarks=remarks, - ) + handler = UpdateUsfmParserHandler( + rows=config.rows, + id_text=vrefs[0].book, + text_behavior=text_behavior, + paragraph_behavior=config.get_paragraph_behavior(), + embed_behavior=config.get_embed_behavior(), + style_behavior=config.get_style_behavior(), + update_block_handlers=( + config.create_place_markers_postprocessor().get_update_block_handlers() + if config.is_marker_placement_required() + else None + ), + remarks=remarks, + ) parse_usfm(usfm, handler) usfm_out = handler.get_usfm() diff --git a/silnlp/nmt/postprocess.py b/silnlp/nmt/postprocess.py index b60e64d0..832bf6e4 100644 --- a/silnlp/nmt/postprocess.py +++ b/silnlp/nmt/postprocess.py @@ -190,13 +190,17 @@ def postprocess_draft( f.write(target_usfm) -def postprocess_experiment(config: Config, out_dir: Optional[Path] = None) -> None: +def postprocess_experiment( + config: Config, postprocess_handler: Optional[PostprocessHandler] = None, out_dir: Optional[Path] = None +) -> None: draft_metadata_list = get_draft_paths_from_exp(config) + with (config.exp_dir / "translate_config.yml").open("r", encoding="utf-8") as file: translate_config = yaml.safe_load(file) postprocess_configs = [PostprocessConfig(pc) for pc in translate_config.get("postprocess", [])] - postprocess_handler = PostprocessHandler(postprocess_configs, include_base=False) + if postprocess_handler is None: + postprocess_handler = PostprocessHandler(postprocess_configs, include_base=False) for draft_metadata in draft_metadata_list: if postprocess_configs: From d649a2f748fb54978d6f1bae37721f84dd088c40 Mon Sep 17 00:00:00 2001 From: Ben King Date: Fri, 5 Sep 2025 19:35:02 +0000 Subject: [PATCH 5/8] Handle quote convention detection errors more gracefully --- silnlp/common/postprocesser.py | 11 +++++------ silnlp/common/translator.py | 3 ++- silnlp/nmt/postprocess.py | 3 ++- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/silnlp/common/postprocesser.py b/silnlp/common/postprocesser.py index 9cc7243d..6e67c286 100644 --- a/silnlp/common/postprocesser.py +++ b/silnlp/common/postprocesser.py @@ -119,17 +119,13 @@ def postprocess_usfm( class UnknownQuoteConventionException(Exception): def __init__(self, convention_name: str): - super().__init__( - f'"{convention_name}" is not a known quote convention. Skipping quotation mark denormalization.' - ) + super().__init__(f'"{convention_name}" is not a known quote convention.') self.convention_name = convention_name class NoDetectedQuoteConventionException(Exception): def __init__(self, project_name: str): - super().__init__( - f'Could not detect quote convention for project "{project_name}". Skipping quotation mark denormalization.' - ) + super().__init__(f'Could not detect quote convention for project "{project_name}".') self.project_name = project_name @@ -200,6 +196,9 @@ def _detect_quote_convention( quote_convention_analysis: QuoteConventionAnalysis | None = quote_convention_detector.detect_quote_convention() if quote_convention_analysis is None: raise NoDetectedQuoteConventionException(project_name) + LOGGER.info( + f'Detected quote convention for project "{project_name}" is "{quote_convention_analysis.best_quote_convention.name}" with score {quote_convention_analysis.best_quote_convention_score:.2f}.' + ) return quote_convention_analysis.best_quote_convention def _create_update_block_handlers( diff --git a/silnlp/common/translator.py b/silnlp/common/translator.py index a3f4a0eb..ebbf9a1b 100644 --- a/silnlp/common/translator.py +++ b/silnlp/common/translator.py @@ -360,7 +360,8 @@ def translate_usfm( ) usfm_out = quotation_denormalization_postprocessor.postprocess_usfm(usfm_out) except (UnknownQuoteConventionException, NoDetectedQuoteConventionException) as e: - raise e + LOGGER.warning(str(e) + " Skipping quotation mark denormalization.") + continue # Construct output file name write to file trg_draft_file_path = trg_file_path.with_stem(trg_file_path.stem + config.get_postprocess_suffix()) diff --git a/silnlp/nmt/postprocess.py b/silnlp/nmt/postprocess.py index 832bf6e4..8010b895 100644 --- a/silnlp/nmt/postprocess.py +++ b/silnlp/nmt/postprocess.py @@ -176,7 +176,8 @@ def postprocess_draft( ) target_usfm = quotation_denormalization_postprocessor.postprocess_usfm(target_usfm) except (UnknownQuoteConventionException, NoDetectedQuoteConventionException) as e: - raise e + LOGGER.warning(str(e) + " Skipping quotation mark denormalization.") + continue if not out_dir: out_dir = draft_metadata.draft_path.parent From bc3e5c16a1076cccd570593f941f4a7100e4badb Mon Sep 17 00:00:00 2001 From: Ben King Date: Mon, 8 Sep 2025 14:18:54 +0000 Subject: [PATCH 6/8] Addressing Eli's review comments --- silnlp/nmt/postprocess.py | 6 +++--- silnlp/nmt/translate.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/silnlp/nmt/postprocess.py b/silnlp/nmt/postprocess.py index 8010b895..d93cb1c8 100644 --- a/silnlp/nmt/postprocess.py +++ b/silnlp/nmt/postprocess.py @@ -2,7 +2,7 @@ import logging import re from pathlib import Path -from typing import Dict, List, Optional +from typing import List, Optional import yaml from attr import dataclass @@ -138,13 +138,13 @@ def postprocess_draft( # Verify reference parity if len(src_sentences.sentences) != len(draft_sentences.sentences): LOGGER.warning( - f"Can't process {draft_metadata.source_path} and {draft_path}: Unequal number of verses/references" + f"Can't process {draft_metadata.source_path} and {draft_metadata.draft_path}: Unequal number of verses/references" ) return for src_sentence, draft_sentence in zip(src_sentences.sentences, draft_sentences.sentences): if src_sentence.ref.to_relaxed() != draft_sentence.ref.to_relaxed(): LOGGER.warning( - f"Can't process {draft_metadata.source_path} and {draft_path}: Mismatched ref, {src_ref} != {draft_ref}. Files must have the exact same USFM structure" + f"Can't process {draft_metadata.source_path} and {draft_metadata.draft_path}: Mismatched ref, {src_sentence.ref} != {draft_sentence.ref}. Files must have the exact same USFM structure" ) return diff --git a/silnlp/nmt/translate.py b/silnlp/nmt/translate.py index 5cd99bdc..78069f79 100644 --- a/silnlp/nmt/translate.py +++ b/silnlp/nmt/translate.py @@ -6,7 +6,6 @@ from pathlib import Path from typing import Iterable, Optional, Tuple, Union -from flake8 import LOG from machine.scripture import VerseRef, book_number_to_id, get_chapters from ..common.environment import SIL_NLP_ENV From 3fb4d51049b0d4aec1788393c2013e00492a6e84 Mon Sep 17 00:00:00 2001 From: Ben King Date: Mon, 15 Sep 2025 17:30:30 +0000 Subject: [PATCH 7/8] Use FileParatextProjectQuoteConventionDetector from Machine.py --- poetry.lock | 8 +-- pyproject.toml | 2 +- silnlp/common/paratext.py | 56 +----------------- silnlp/common/postprocess_draft.py | 2 +- silnlp/common/postprocesser.py | 91 +++++++++++++++--------------- silnlp/nmt/postprocess.py | 5 +- 6 files changed, 56 insertions(+), 108 deletions(-) diff --git a/poetry.lock b/poetry.lock index a654be5e..094171e9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5159,13 +5159,13 @@ type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.12 [[package]] name = "sil-machine" -version = "1.7.4" +version = "1.8.2" description = "A natural language processing library that is focused on providing tools for resource-poor languages." optional = false python-versions = "<3.13,>=3.9" files = [ - {file = "sil_machine-1.7.4-py3-none-any.whl", hash = "sha256:88363b4160af1bb24b1fd523839002975f9c0e2635c3ca73363b55ec3f2d9023"}, - {file = "sil_machine-1.7.4.tar.gz", hash = "sha256:16ba024ae7f3fe5c0e140e3f95dad56f29ac9300741f338f0a297edddc28fde4"}, + {file = "sil_machine-1.8.2-py3-none-any.whl", hash = "sha256:3996d47aa3100b544cac0a033e743b7d5c8b5f812c8ac755d6e59b99e5e7b4dd"}, + {file = "sil_machine-1.8.2.tar.gz", hash = "sha256:8592c0653e8c13c4dbccc71e6cb1289c33ad89d528900540efb38b0618998d02"}, ] [package.dependencies] @@ -6177,4 +6177,4 @@ eflomal = ["eflomal"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.11" -content-hash = "07fb940926cfc8ba2957349240f76a80123adaeab5c3711205da614be3410dc6" +content-hash = "36aca28ca88e6b6e6f8370df8b53e8ca307c3cdd8efb84afc92d318f93276106" diff --git a/pyproject.toml b/pyproject.toml index abd64c72..da448588 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,7 @@ tqdm = "^4.62.2" sacrebleu = "^2.3.1" ctranslate2 = "^3.5.1" libclang = "14.0.6" -sil-machine = {extras = ["thot"], version = "1.7.4"} +sil-machine = {extras = ["thot"], version = "1.8.2"} datasets = "^2.7.1" torch = {version = "^2.4", source = "torch"} sacremoses = "^0.0.53" diff --git a/silnlp/common/paratext.py b/silnlp/common/paratext.py index ec232d36..978c1298 100644 --- a/silnlp/common/paratext.py +++ b/silnlp/common/paratext.py @@ -2,7 +2,7 @@ import os from contextlib import ExitStack from pathlib import Path -from typing import Dict, Iterable, List, Optional, Set, TextIO, Tuple +from typing import Dict, List, Optional, Set, TextIO, Tuple from xml.sax.saxutils import escape import regex as re @@ -15,22 +15,11 @@ Text, TextCorpus, TextRow, - UsfmFileText, UsfmFileTextCorpus, - UsfmParserHandler, create_versification_ref_corpus, extract_scripture_corpus, - parse_usfm, -) -from machine.scripture import ( - BOOK_NUMBERS, - ORIGINAL_VERSIFICATION, - VerseRef, - VersificationType, - book_id_to_number, - book_number_to_id, - get_books, ) +from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef, VersificationType, book_id_to_number, get_books from machine.tokenization import WhitespaceTokenizer from .corpus import get_terms_glosses_path, get_terms_metadata_path, get_terms_vrefs_path, load_corpus @@ -427,15 +416,6 @@ def get_book_path(project: str, book: str) -> Path: return SIL_NLP_ENV.pt_projects_dir / project / book_file_name -def get_book_path_by_book_number(project: str, book_number: int) -> Path: - project_dir = get_project_dir(project) - settings = FileParatextProjectSettingsParser(project_dir).parse() - book_id = book_number_to_id(book_number) - book_file_name = settings.get_book_file_name(book_id) - - return SIL_NLP_ENV.pt_projects_dir / project / book_file_name - - def get_last_verse(project_dir: str, book: str, chapter: int) -> int: last_verse = "0" book_path = get_book_path(project_dir, book) @@ -591,35 +571,3 @@ def check_versification(project_dir: str) -> Tuple[bool, List[VersificationType] matching = True return (matching, detected_versification) - - -def read_usfm(project_dir: str, book_number: int) -> str: - project_settings = FileParatextProjectSettingsParser(get_project_dir(project_dir)).parse() - book_path: Path = get_book_path_by_book_number(project_dir, book_number) - - if not book_path.exists(): - raise FileNotFoundError(f"USFM file for book number {book_number} not found in project {project_dir}") - - usfm_text_file = UsfmFileText( - project_settings.stylesheet, - project_settings.encoding, - book_number_to_id(book_number), - book_path, - project_settings.versification, - include_all_text=True, - project=project_settings.name, - ) - # This is not a public method, but I don't think any method exists in machine.py - # to read raw USFM using the project settings - return usfm_text_file._read_usfm() - - -# This is a placeholder until the ParatextProjectQuoteConventionDetector is released in machine.py -def parse_project(project_dir: str, selected_books: Iterable[int], usfm_parser_handler: UsfmParserHandler) -> None: - project_settings = FileParatextProjectSettingsParser(get_project_dir(project_dir)).parse() - for book_number in selected_books: - try: - usfm = read_usfm(project_dir, book_number) - except FileNotFoundError: - continue - parse_usfm(usfm, usfm_parser_handler, project_settings.stylesheet, project_settings.versification) diff --git a/silnlp/common/postprocess_draft.py b/silnlp/common/postprocess_draft.py index d9390693..84354dc6 100644 --- a/silnlp/common/postprocess_draft.py +++ b/silnlp/common/postprocess_draft.py @@ -4,7 +4,7 @@ from ..nmt.clearml_connection import SILClearML from ..nmt.config_utils import load_config -from ..nmt.postprocess import get_draft_paths_from_exp, postprocess_draft, postprocess_experiment +from ..nmt.postprocess import postprocess_experiment from .postprocesser import PostprocessConfig, PostprocessHandler from .utils import get_mt_exp_dir diff --git a/silnlp/common/postprocesser.py b/silnlp/common/postprocesser.py index 6e67c286..211b7f4c 100644 --- a/silnlp/common/postprocesser.py +++ b/silnlp/common/postprocesser.py @@ -1,9 +1,10 @@ import logging from pathlib import Path from tempfile import TemporaryDirectory -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Dict, List, Optional, Sequence from machine.corpora import ( + FileParatextProjectQuoteConventionDetector, PlaceMarkersAlignmentInfo, PlaceMarkersUsfmUpdateBlockHandler, QuotationMarkDenormalizationFirstPass, @@ -18,16 +19,11 @@ UsfmUpdateBlockHandler, parse_usfm, ) -from machine.punctuation_analysis import ( - STANDARD_QUOTE_CONVENTIONS, - QuoteConvention, - QuoteConventionAnalysis, - QuoteConventionDetector, -) +from machine.punctuation_analysis import STANDARD_QUOTE_CONVENTIONS, QuoteConvention, QuoteConventionDetector from machine.tokenization import LatinWordTokenizer from machine.translation import WordAlignmentMatrix -from silnlp.common.paratext import parse_project +from silnlp.common.paratext import get_project_dir from silnlp.nmt.corpora import CorpusPair from ..alignment.eflomal import to_word_alignment_matrix @@ -102,8 +98,11 @@ def postprocess_usfm( self, usfm: str, rows: List[UpdateUsfmRow], - remarks: List[str] = [], + remarks: Optional[List[str]] = None, ) -> str: + if remarks is None: + remarks = [] + handler = UpdateUsfmParserHandler( rows=rows, text_behavior=UpdateUsfmTextBehavior.STRIP_EXISTING, @@ -130,9 +129,11 @@ def __init__(self, project_name: str): class DenormalizeQuotationMarksPostprocessor: + _NO_CHAPTERS_REMARK_SENTENCE = "Quotation marks were not denormalized in any chapters due to errors." _REMARK_SENTENCE = ( "Quotation marks in the following chapters have been automatically denormalized after translation: " ) + _project_convention_cache: Dict[str, QuoteConvention] = {} def __init__( self, @@ -140,43 +141,30 @@ def __init__( target_quote_convention_name: str | None, source_project_name: str | None = None, target_project_name: str | None = None, - selected_training_books: Dict[int, List[int]] = {}, ): self._source_quote_convention = self._get_source_quote_convention( - source_quote_convention_name, source_project_name, selected_training_books + source_quote_convention_name, source_project_name ) self._target_quote_convention = self._get_target_quote_convention( - target_quote_convention_name, target_project_name, selected_training_books + target_quote_convention_name, target_project_name ) - def _get_source_quote_convention( - self, convention_name: str | None, project_name: str | None, selected_training_books: Dict[int, List[int]] = {} - ) -> QuoteConvention: + def _get_source_quote_convention(self, convention_name: str | None, project_name: str | None) -> QuoteConvention: if convention_name is None or convention_name == "detect": if project_name is None: raise ValueError( "The source project name must be explicitly provided or be present in translate_config.yml, since an explicit source quote convention name was not provided." ) - if selected_training_books is None: - raise ValueError( - "The experiment's config.yml must exist and specify selected training books, since an explicit source quote convention name was not provided." - ) - return self._detect_quote_convention(project_name, selected_training_books) + return self._detect_quote_convention(project_name) return self._get_named_quote_convention(convention_name) - def _get_target_quote_convention( - self, convention_name: str | None, project_name: str | None, selected_training_books: Dict[int, List[int]] = {} - ) -> QuoteConvention: + def _get_target_quote_convention(self, convention_name: str | None, project_name: str | None) -> QuoteConvention: if convention_name is None or convention_name == "detect": if project_name is None: raise ValueError( "The experiment's config.yml must exist and specify a target project name, since an explicit target quote convention name was not provided." ) - if selected_training_books is None: - raise ValueError( - "The experiment's config.yml must exist and specify selected training books, since an explicit target quote convention name was not provided." - ) - return self._detect_quote_convention(project_name, selected_training_books) + return self._detect_quote_convention(project_name) return self._get_named_quote_convention(convention_name) def _get_named_quote_convention(self, convention_name: str) -> QuoteConvention: @@ -186,19 +174,24 @@ def _get_named_quote_convention(self, convention_name: str) -> QuoteConvention: raise UnknownQuoteConventionException(convention_name) return convention - def _detect_quote_convention( - self, project_name: str, selected_training_books: Dict[int, List[int]] = {} - ) -> QuoteConvention: + def _detect_quote_convention(self, project_name: str) -> QuoteConvention: + if project_name in self._project_convention_cache: + return self._project_convention_cache[project_name] + quote_convention_detector = QuoteConventionDetector() - parse_project(project_name, selected_training_books.keys(), quote_convention_detector) + quote_convention_detector = FileParatextProjectQuoteConventionDetector(get_project_dir(project_name)) + quote_convention_analysis = quote_convention_detector.get_quote_convention_analysis() - quote_convention_analysis: QuoteConventionAnalysis | None = quote_convention_detector.detect_quote_convention() if quote_convention_analysis is None: raise NoDetectedQuoteConventionException(project_name) LOGGER.info( - f'Detected quote convention for project "{project_name}" is "{quote_convention_analysis.best_quote_convention.name}" with score {quote_convention_analysis.best_quote_convention_score:.2f}.' + f'Detected quote convention for project "{project_name}" is ' + + '"{quote_convention_analysis.best_quote_convention.name}" with score ' + + "{quote_convention_analysis.best_quote_convention_score:.2f}." ) + self._project_convention_cache[project_name] = quote_convention_analysis.best_quote_convention + return quote_convention_analysis.best_quote_convention def _create_update_block_handlers( @@ -221,6 +214,14 @@ def _get_best_chapter_strategies(self, usfm: str) -> List[QuotationMarkUpdateStr return quotation_mark_update_first_pass.find_best_chapter_strategies() def _create_remark(self, best_chapter_strategies: List[QuotationMarkUpdateStrategy]) -> str: + processed_chapters: List[str] = [ + str(chapter_num) + for chapter_num, strategy in enumerate(best_chapter_strategies, 1) + if strategy != QuotationMarkUpdateStrategy.SKIP + ] + + if len(processed_chapters) == 0: + return self._NO_CHAPTERS_REMARK_SENTENCE return ( self._REMARK_SENTENCE + ", ".join( @@ -327,7 +328,7 @@ def create_place_markers_postprocessor(self) -> PlaceMarkersPostprocessor: def create_denormalize_quotation_marks_postprocessor( self, training_corpus_pairs: List[CorpusPair], translation_source_project_name: Optional[str] ) -> DenormalizeQuotationMarksPostprocessor: - _, training_target_project_name, selected_training_books = self._get_experiment_training_info( + training_target_project_name = self._get_training_target_project_name( training_corpus_pairs, ) @@ -336,13 +337,12 @@ def create_denormalize_quotation_marks_postprocessor( self._config["target_quote_convention"], translation_source_project_name, training_target_project_name, - selected_training_books, ) - def _get_experiment_training_info( + def _get_training_target_project_name( self, training_corpus_pairs: List[CorpusPair], - ) -> Tuple[Optional[str], Optional[str], Dict[int, List[int]]]: + ) -> Optional[str]: # Target project info is only needed for quote convention detection if self.is_quote_convention_detection_required(): if len(training_corpus_pairs) > 1: @@ -358,28 +358,25 @@ def _get_experiment_training_info( "The experiment has multiple target projects. Quotation mark denormalization is unlikely to work correctly in this scenario." ) - source_project_name = ( - training_corpus_pairs[0].src_files[0].project - if len(training_corpus_pairs) > 0 and len(training_corpus_pairs[0].src_files) > 0 - else None - ) target_project_name = ( training_corpus_pairs[0].trg_files[0].project if len(training_corpus_pairs) > 0 and len(training_corpus_pairs[0].trg_files) > 0 else None ) - selected_training_books = training_corpus_pairs[0].corpus_books if len(training_corpus_pairs) > 0 else {} - return source_project_name, target_project_name, selected_training_books + return target_project_name - return None, None, {} + return None def __getitem__(self, key): return self._config[key] class PostprocessHandler: - def __init__(self, configs: List[PostprocessConfig] = [], include_base: bool = True) -> None: + def __init__(self, configs: Optional[List[PostprocessConfig]] = None, include_base: bool = True) -> None: + if configs is None: + configs = [] + self.configs = ([PostprocessConfig()] if include_base else []) + configs # NOTE: Row metadata may need to be created/recreated at different times diff --git a/silnlp/nmt/postprocess.py b/silnlp/nmt/postprocess.py index d93cb1c8..e41c5520 100644 --- a/silnlp/nmt/postprocess.py +++ b/silnlp/nmt/postprocess.py @@ -121,8 +121,11 @@ def postprocess_draft( postprocess_handler: PostprocessHandler, book: Optional[str] = None, out_dir: Optional[Path] = None, - training_corpus_pairs: List[CorpusPair] = [], + training_corpus_pairs: Optional[List[CorpusPair]] = None, ) -> None: + if training_corpus_pairs is None: + training_corpus_pairs = [] + if str(draft_metadata.source_path).startswith(str(get_project_dir(""))): settings = FileParatextProjectSettingsParser(draft_metadata.source_path.parent).parse() stylesheet = settings.stylesheet From 276f182dd9c759b13f1fb51ca7f85ee3cb72d5f4 Mon Sep 17 00:00:00 2001 From: Ben King Date: Tue, 16 Sep 2025 13:19:48 +0000 Subject: [PATCH 8/8] Use relative imports --- silnlp/common/postprocesser.py | 16 +++------------- silnlp/nmt/config.py | 25 ++++++++++++------------- 2 files changed, 15 insertions(+), 26 deletions(-) diff --git a/silnlp/common/postprocesser.py b/silnlp/common/postprocesser.py index 211b7f4c..f11187b9 100644 --- a/silnlp/common/postprocesser.py +++ b/silnlp/common/postprocesser.py @@ -187,8 +187,8 @@ def _detect_quote_convention(self, project_name: str) -> QuoteConvention: raise NoDetectedQuoteConventionException(project_name) LOGGER.info( f'Detected quote convention for project "{project_name}" is ' - + '"{quote_convention_analysis.best_quote_convention.name}" with score ' - + "{quote_convention_analysis.best_quote_convention_score:.2f}." + + f'"{quote_convention_analysis.best_quote_convention.name}" with score ' + + f"{quote_convention_analysis.best_quote_convention_score:.2f}." ) self._project_convention_cache[project_name] = quote_convention_analysis.best_quote_convention @@ -222,17 +222,7 @@ def _create_remark(self, best_chapter_strategies: List[QuotationMarkUpdateStrate if len(processed_chapters) == 0: return self._NO_CHAPTERS_REMARK_SENTENCE - return ( - self._REMARK_SENTENCE - + ", ".join( - [ - str(chapter_num) - for chapter_num, strategy in enumerate(best_chapter_strategies, 1) - if strategy != QuotationMarkUpdateStrategy.SKIP - ] - ) - + "." - ) + return self._REMARK_SENTENCE + ", ".join(processed_chapters) + "." def postprocess_usfm( self, diff --git a/silnlp/nmt/config.py b/silnlp/nmt/config.py index 1b2b00d6..5f81694f 100644 --- a/silnlp/nmt/config.py +++ b/silnlp/nmt/config.py @@ -15,19 +15,6 @@ from machine.tokenization import LatinWordTokenizer from tqdm import tqdm -from silnlp.common.translator import TranslationGroup -from silnlp.nmt.corpora import ( - BASIC_DATA_PROJECT, - CorpusPair, - DataFile, - DataFileMapping, - IsoPairInfo, - get_data_file_pairs, - get_parallel_corpus_size, - get_terms_glosses_file_paths, - parse_corpus_pairs, -) - from ..alignment.config import get_aligner_name from ..alignment.utils import add_alignment_scores from ..common.corpus import ( @@ -45,8 +32,20 @@ write_corpus, ) from ..common.environment import SIL_NLP_ENV +from ..common.translator import TranslationGroup from ..common.utils import NoiseMethod, Side, get_mt_exp_dir, set_seed from .augment import AugmentMethod +from .corpora import ( + BASIC_DATA_PROJECT, + CorpusPair, + DataFile, + DataFileMapping, + IsoPairInfo, + get_data_file_pairs, + get_parallel_corpus_size, + get_terms_glosses_file_paths, + parse_corpus_pairs, +) from .tokenizer import Tokenizer LOGGER = logging.getLogger(__package__ + ".config")