Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions nemo_curator/stages/text/deduplication/removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class TextDuplicatesRemovalStage(ProcessingStage[DocumentBatch, DocumentBatch]):
id_field: Field to use for deduplication within the input dataframe. Defaults to CURATOR_DEDUP_ID_STR.
duplicate_id_field: Field to use for deduplication within the removal dataframe. Defaults to "id".
read_kwargs: Additional arguments for reading parquet files
drop_id_field: Whether to drop the deduplication ID field from the output batch.
"""

ids_to_remove_path: str
Expand All @@ -51,6 +52,7 @@ class TextDuplicatesRemovalStage(ProcessingStage[DocumentBatch, DocumentBatch]):

# Optional parameters
read_kwargs: dict[str, Any] | None = None
drop_id_field: bool = False

def __post_init__(self):
"""Initialize parent class after dataclass initialization."""
Expand Down Expand Up @@ -84,6 +86,8 @@ def process(self, task: DocumentBatch) -> DocumentBatch:
time_to_remove_t0 = time.perf_counter()
removal_ids = set(removal_df[self.duplicate_id_field].tolist())
df = df[~df[self.id_field].isin(removal_ids)]
if self.drop_id_field:
df = df.drop(columns=[self.id_field])
removal_ids_time = time.perf_counter() - time_to_remove_t0
self._log_metrics(
{
Expand Down
5 changes: 5 additions & 0 deletions nemo_curator/stages/text/deduplication/removal_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,17 @@ class TextDuplicatesRemovalWorkflow(WorkflowBase):
output_kwargs: dict[str, Any] | None = None
output_fields: list[str] | None = None
output_mode: Literal["ignore", "overwrite", "append", "error"] | None = None
drop_id_field: bool = False

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Conflicting drop_id_field + output_fields not validated

When a caller sets drop_id_field=True and also includes id_field in output_fields, the removal stage will have already dropped that column by the time the writer stage tries to select it, producing a KeyError at runtime. The semantic workflow avoids this with an explicit self.output_fields is None guard, but the base TextDuplicatesRemovalWorkflow has no equivalent check. Adding a __post_init__ guard (after the id_generator warning) like if self.drop_id_field and self.output_fields and self.id_field in self.output_fields: raise ValueError(...) would surface this misconfiguration early with a clear message.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 I think this is decent feedback. WDYT @nightcityblade ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call — I added an early __post_init__ validation that raises a clear ValueError when drop_id_field=True conflicts with output_fields containing the id field, plus a focused unit test for that configuration.

Pushed: nightcityblade/Curator@2dbe71d
Validation: uv run ruff check nemo_curator/stages/text/deduplication/removal_workflow.py tests/stages/text/deduplication/test_removal_workflow.py and python3 -m py_compile ... passed locally. Targeted pytest collection is blocked on this macOS host by Curator’s import-time Linux-only guard.


def __post_init__(self):
"""Initialize parent class after dataclass initialization."""
if self.id_generator_path is None and self.id_field == CURATOR_DEDUP_ID_STR:
logger.warning(
f"Using {CURATOR_DEDUP_ID_STR} as id_field for removal stage, even though we are not using id generator."
)
if self.drop_id_field and self.output_fields and self.id_field in self.output_fields:
msg = f"Cannot drop id_field {self.id_field!r} when it is included in output_fields."
raise ValueError(msg)

def _generate_stages(self, initial_tasks: list[FileGroupTask] | None = None) -> list[ProcessingStage]:
stages = []
Expand Down Expand Up @@ -120,6 +124,7 @@ def _generate_stages(self, initial_tasks: list[FileGroupTask] | None = None) ->
id_field=self.id_field,
duplicate_id_field=self.duplicate_id_field,
read_kwargs=self.duplicate_id_read_kwargs,
drop_id_field=self.drop_id_field,
)
)

Expand Down
1 change: 1 addition & 0 deletions nemo_curator/stages/text/deduplication/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def _run_duplicate_removal(self, executor: BaseExecutor) -> WorkflowRunResult |
output_kwargs=self.write_kwargs,
output_fields=self.output_fields,
output_mode="ignore",
drop_id_field=self.use_id_generator and self.output_fields is None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I am on the fence about whether to silently drop the IDs generated by use_id_generator in this case (since output_fields default is none). I think it is fine but the information in the tutorial https://github.com/NVIDIA-NeMo/Curator/blob/main/tutorials/text/deduplication/semantic/semantic_e2e.ipynb might need to be updated. Can you check?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checked and updated the tutorial in 3ff725a.

The semantic dedup notebook previously said _curator_dedup_id would appear in the final deduplicated output when use_id_generator=True. That is now outdated with this PR's default drop_id_field behavior, so I revised the note to explain that the generated ID is dropped by default and can be preserved only by including it explicitly in output_fields.

Validation: the notebook still parses as JSON after the edit.

)

return workflow.run(executor=executor)
Expand Down
35 changes: 35 additions & 0 deletions tests/stages/text/deduplication/test_removal_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,39 @@ def test_initial_tasks_partitioning(self, test_config: "TestTextDuplicateRemoval
assert workflow_output.get_metadata("num_duplicates_removed") == expected_removed



def test_removal_stage_can_drop_id_field(tmp_path: Path):
ids_to_remove_path = tmp_path / "ids_to_remove.parquet"
pd.DataFrame({"id": [1]}).to_parquet(ids_to_remove_path, index=False)
task = DocumentBatch(
task_id="task",
dataset_name="dataset",
data=pd.DataFrame({CURATOR_DEDUP_ID_STR: [1, 2], "text": ["drop", "keep"]}),
)

stage = TextDuplicatesRemovalStage(
ids_to_remove_path=str(ids_to_remove_path),
id_field=CURATOR_DEDUP_ID_STR,
drop_id_field=True,
)

result = stage.process(task).to_pandas()

assert result.to_dict(orient="list") == {"text": ["keep"]}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I misread this test at first and was confused why we were checking that this text was kept. Can you add another assertion that explicitly checks that CURATOR_DEDUP_ID_STR is not in the result?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added — the test now explicitly asserts that CURATOR_DEDUP_ID_STR is absent from the result columns, in addition to checking the remaining row contents.

Pushed in nightcityblade/Curator@2dbe71d.

assert CURATOR_DEDUP_ID_STR not in result.columns


class TestTextDuplicatesRemovalWorkflowGenerateStages:
def test_drop_id_field_conflicts_with_output_fields(self):
with pytest.raises(ValueError, match="Cannot drop id_field"):
TextDuplicatesRemovalWorkflow(
input_path="input_path",
ids_to_remove_path="ids_to_remove_path",
output_path="output_path",
output_fields=["text", CURATOR_DEDUP_ID_STR],
drop_id_field=True,
)

def test_invalid_filetypes(self):
read_invalid_file_type_workflow = TextDuplicatesRemovalWorkflow(
input_path="input_path",
Expand Down Expand Up @@ -340,6 +372,7 @@ def test_reader_stage(self, input_filetype: str, id_generator_path: str | None):
assert stages[2].id_field == CURATOR_DEDUP_ID_STR
assert stages[2].duplicate_id_field == "id"
assert stages[2].read_kwargs == {}
assert not stages[2].drop_id_field

# test for writer stage (stages[3]) - default output_filetype is parquet
assert isinstance(stages[3], ParquetWriter)
Expand All @@ -352,13 +385,15 @@ def test_writer_stage(self, output_filetype: str):
output_path="output_path",
output_filetype=output_filetype,
id_generator_path=None,
drop_id_field=True,
)
stages = workflow._generate_stages(initial_tasks=None)
assert len(stages) == 4
assert isinstance(stages[0], FilePartitioningStage)
# reader stage
assert isinstance(stages[1], ParquetReaderStage) # Default input_filetype is parquet
assert isinstance(stages[2], TextDuplicatesRemovalStage)
assert stages[2].drop_id_field
expected_write_stage = ParquetWriter if output_filetype == "parquet" else JsonlWriter
assert isinstance(stages[3], expected_write_stage)

Expand Down