-
Notifications
You must be signed in to change notification settings - Fork 292
Add model-based HTML extraction stage #1768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,104 +1,214 @@ | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from typing import Any | ||
|
|
||
| from bs4 import BeautifulSoup | ||
| from loguru import logger | ||
|
|
||
| from nemo_curator.stages.resources import Resources | ||
| from nemo_curator.stages.text.download import DocumentExtractor | ||
| from nemo_curator.stages.text.download.html_extractors import HTMLExtractorAlgorithm | ||
| from nemo_curator.stages.text.download.html_extractors.justext import JusTextExtractor | ||
| from nemo_curator.stages.text.download.html_extractors.model_based import ( | ||
| CANDIDATE_ATTRIBUTES_FIELD, | ||
| CANDIDATE_HTML_FIELD, | ||
| CANDIDATE_INDEX_FIELD, | ||
| CANDIDATE_TAG_NAME_FIELD, | ||
| CANDIDATE_TEXT_FIELD, | ||
| HTML_FIELD, | ||
| MODEL_INPUT_FIELD, | ||
| PLACEHOLDER_CANDIDATE_INDEX, | ||
| ModelBasedHTMLExtractionStage, | ||
| extract_candidate_elements, | ||
| serialize_html_element, | ||
| ) | ||
| from nemo_curator.stages.text.download.html_extractors.resiliparse import ResiliparseExtractor | ||
| from nemo_curator.stages.text.download.html_extractors.trafilatura import TrafilaturaExtractor | ||
| from nemo_curator.stages.text.download.html_extractors.utils import get_stop_list_dict | ||
| from nemo_curator.stages.text.download.utils import decode_html, lang_detect | ||
|
|
||
|
|
||
| class CommonCrawlHTMLExtractor(DocumentExtractor): | ||
| def __init__( | ||
| self, | ||
| algorithm: HTMLExtractorAlgorithm | str | None = None, | ||
| algorithm_kwargs: dict | None = None, | ||
| stop_lists: dict[str, frozenset[str]] | None = None, | ||
| ): | ||
| super().__init__() | ||
| algorithm_kwargs = algorithm_kwargs or {} | ||
| if algorithm is None: | ||
| logger.warning("No algorithm provided, using justext with default parameters") | ||
| algorithm = JusTextExtractor() | ||
| elif isinstance(algorithm, str): | ||
| if algorithm == "justext": | ||
| algorithm = JusTextExtractor(**algorithm_kwargs) | ||
| elif algorithm == "resiliparse": | ||
| algorithm = ResiliparseExtractor(**algorithm_kwargs) | ||
| def __init__( | ||
| self, | ||
| algorithm: HTMLExtractorAlgorithm | str | None = None, | ||
| algorithm_kwargs: dict | None = None, | ||
| stop_lists: dict[str, frozenset[str]] | None = None, | ||
| ): | ||
| super().__init__() | ||
| algorithm_kwargs = algorithm_kwargs or {} | ||
| if algorithm is None: | ||
| logger.warning("No algorithm provided, using justext with default parameters") | ||
| algorithm = JusTextExtractor() | ||
| elif isinstance(algorithm, str): | ||
| if algorithm == "justext": | ||
| algorithm = JusTextExtractor(**algorithm_kwargs) | ||
| elif algorithm == "resiliparse": | ||
| algorithm = ResiliparseExtractor(**algorithm_kwargs) | ||
| elif algorithm == "trafilatura": | ||
| algorithm = TrafilaturaExtractor(**algorithm_kwargs) | ||
| elif algorithm in {"model", "model_based"}: | ||
| msg = ( | ||
| "Model-based HTML extraction is only supported through " | ||
| "CommonCrawlDownloadExtractStage with html_extraction='model' or 'model_based'." | ||
| ) | ||
| raise ValueError(msg) | ||
| else: | ||
| msg = f"Invalid algorithm: {algorithm}" | ||
| raise ValueError(msg) | ||
| elif isinstance(algorithm, HTMLExtractorAlgorithm): | ||
| if algorithm_kwargs: | ||
| logger.warning("Algorithm kwargs provided are ignored when an HTMLExtractorAlgorithm is provided") | ||
| else: | ||
| msg = f"Invalid algorithm: {algorithm}" | ||
| raise ValueError(msg) | ||
|
|
||
| if stop_lists is not None: | ||
| self._stop_lists = stop_lists | ||
| elif isinstance(algorithm, HTMLExtractorAlgorithm): | ||
| if algorithm_kwargs: | ||
| logger.warning("Algorithm kwargs provided are ignored when an HTMLExtractorAlgorithm is provided") | ||
| else: | ||
| msg = f"Invalid algorithm: {algorithm}" | ||
| raise ValueError(msg) | ||
| if stop_lists is not None: | ||
| self._stop_lists = stop_lists | ||
| else: | ||
| self._stop_lists = get_stop_list_dict() | ||
|
|
||
| self.algorithm = algorithm | ||
| self.resources = getattr(self.algorithm, "resources", Resources(cpus=1.0)) | ||
|
|
||
| def extract(self, record: dict[str, Any]) -> dict[str, Any] | None: | ||
| """Extract text from HTML content in the record. | ||
|
|
||
| Takes a record dict containing "content" field with HTML and returns | ||
| a new dict with only the output columns: url, warc_id, source_id, language, text. | ||
| """ | ||
| # Extract the HTML content from the record | ||
| html_content = record.get("content") | ||
| if not html_content: | ||
| return None | ||
|
|
||
| # Content from WARC records is bytes, even though type annotation suggests str | ||
| html = decode_html(html_content) | ||
|
|
||
| if html is not None: | ||
| # Language detection and HTML extraction | ||
| lang = lang_detect(html) | ||
|
|
||
| text = None | ||
| if lang in self._stop_lists: | ||
| text = self.algorithm.extract_text(html, self._stop_lists[lang], lang) | ||
|
|
||
| if text is not None: | ||
| if len(text) > 0: | ||
| text = "\n\n".join(text) | ||
| return { | ||
| "url": record["url"], | ||
| "warc_id": record["warc_id"], | ||
| "source_id": record["source_id"], | ||
| "language": lang, | ||
| "text": text, | ||
| } | ||
| else: | ||
| return None | ||
| return None | ||
|
|
||
| def input_columns(self) -> list[str]: | ||
| return ["url", "warc_id", "source_id", "content"] | ||
|
|
||
| def output_columns(self) -> list[str]: | ||
| return ["url", "warc_id", "source_id", "language", "text"] | ||
|
|
||
| def setup_on_node(self, *args, **kwargs) -> None: | ||
| setup_on_node = getattr(self.algorithm, "setup_on_node", None) | ||
| if callable(setup_on_node): | ||
| setup_on_node(*args, **kwargs) | ||
|
|
||
| def setup(self, *args, **kwargs) -> None: | ||
| setup = getattr(self.algorithm, "setup", None) | ||
| if callable(setup): | ||
| setup(*args, **kwargs) | ||
|
|
||
| def extract(self, record: dict[str, Any]) -> dict[str, Any] | None: | ||
| """Extract text from HTML content in the record. | ||
| def teardown(self) -> None: | ||
| teardown = getattr(self.algorithm, "teardown", None) | ||
| if callable(teardown): | ||
| teardown() | ||
|
|
||
| Takes a record dict containing "content" field with HTML and returns | ||
| a new dict with only the output columns: url, warc_id, source_id, language, text. | ||
| """ | ||
| # Extract the HTML content from the record | ||
| def ray_stage_spec(self) -> dict[str, Any]: | ||
| ray_stage_spec = getattr(self.algorithm, "ray_stage_spec", None) | ||
| if callable(ray_stage_spec): | ||
| return ray_stage_spec() | ||
| return {} | ||
|
|
||
|
|
||
| class CommonCrawlModelBasedCandidateExtractor(DocumentExtractor): | ||
| def __init__( | ||
| self, | ||
| algorithm: ModelBasedHTMLExtractionStage, | ||
| stop_lists: dict[str, frozenset[str]] | None = None, | ||
| ): | ||
| super().__init__() | ||
| self.algorithm = algorithm | ||
| self._stop_lists = stop_lists or get_stop_list_dict() | ||
| self.resources = Resources(cpus=1.0) | ||
|
|
||
| def extract(self, record: dict[str, Any]) -> list[dict[str, Any]] | None: | ||
| html_content = record.get("content") | ||
| if not html_content: | ||
| return None | ||
|
|
||
| # Content from WARC records is bytes, even though type annotation suggests str | ||
| html = decode_html(html_content) | ||
| if html is None: | ||
| return None | ||
|
|
||
| language = lang_detect(html) | ||
| if language not in self._stop_lists: | ||
| return None | ||
|
|
||
| if html is not None: | ||
| # Language detection and HTML extraction | ||
| lang = lang_detect(html) | ||
|
|
||
| text = None | ||
| if lang in self._stop_lists: | ||
| text = self.algorithm.extract_text(html, self._stop_lists[lang], lang) | ||
|
|
||
| if text is not None: | ||
| if len(text) > 0: | ||
| text = "\n\n".join(text) | ||
| return { | ||
| "url": record["url"], | ||
| "warc_id": record["warc_id"], | ||
| "source_id": record["source_id"], | ||
| "language": lang, | ||
| "text": text, | ||
| } | ||
| else: | ||
| return None | ||
| return None | ||
| elements = extract_candidate_elements(BeautifulSoup(html, "lxml")) | ||
| base_record = { | ||
| "url": record["url"], | ||
| "warc_id": record["warc_id"], | ||
| "source_id": record["source_id"], | ||
| "language": language, | ||
| HTML_FIELD: html, | ||
| } | ||
|
|
||
| if not elements: | ||
| return [ | ||
| { | ||
| **base_record, | ||
| CANDIDATE_INDEX_FIELD: PLACEHOLDER_CANDIDATE_INDEX, | ||
| CANDIDATE_TAG_NAME_FIELD: None, | ||
| CANDIDATE_TEXT_FIELD: None, | ||
| CANDIDATE_HTML_FIELD: None, | ||
| CANDIDATE_ATTRIBUTES_FIELD: {}, | ||
| MODEL_INPUT_FIELD: "", | ||
| } | ||
| ] | ||
|
|
||
| return [{**base_record, **serialize_html_element(element)} for element in elements] | ||
|
|
||
| def input_columns(self) -> list[str]: | ||
| return ["url", "warc_id", "source_id", "content"] | ||
|
|
||
| def output_columns(self) -> list[str]: | ||
| return ["url", "warc_id", "source_id", "language", "text"] | ||
| return [ | ||
| "url", | ||
| "warc_id", | ||
| "source_id", | ||
| "language", | ||
| HTML_FIELD, | ||
| CANDIDATE_INDEX_FIELD, | ||
| CANDIDATE_TAG_NAME_FIELD, | ||
| CANDIDATE_TEXT_FIELD, | ||
| CANDIDATE_HTML_FIELD, | ||
| CANDIDATE_ATTRIBUTES_FIELD, | ||
| MODEL_INPUT_FIELD, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| ] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
algorithm="model"entry point is unreachable fromCommonCrawlDownloadExtractStageThe error message tells users to call
CommonCrawlDownloadExtractStage(html_extraction='model'), butstage.pywas not updated in this PR — it still passeshtml_extractiondirectly toCommonCrawlHTMLExtractor.__init__, where it hits this sameValueError. Any call toCommonCrawlDownloadExtractStage(html_extraction="model")will fail at construction with a self-referential error message. Thestage.pyfile needs to detecthtml_extraction in {"model", "model_based"}and compose the newCommonCrawlModelBasedCandidateExtractor→TokenizerStage→ModelBasedHTMLInferenceStage→AssembleModelBasedHTMLExtractionStagepipeline instead of delegating toCommonCrawlHTMLExtractor.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 I think the PR is not usable as is.
fyi the PR #2075 is going to take over this work I think, thank you!