Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ If you are an advanced user looking to process data with pre-defined splits, int

## Requirements
- Python version 3.8.10+
- Support for Linux and Mac OS. Not tested on Windows
- Support for Linux, Mac OS and Windows.

</br>

Expand Down
20 changes: 19 additions & 1 deletion generative_data_prep/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@

Entry point to the Text Processing Pipeline.
"""

import json
import logging
import logging.config
import os
import sys
from typing import Optional

from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
Expand Down Expand Up @@ -46,6 +49,21 @@
logger = logging.getLogger("generative_data_prep_logger")
logging.config.fileConfig(get_config_file_path())

# Fix Unicode encoding issues on Windows console
# Configure stdout/stderr to handle Unicode encoding errors on Windows
if sys.platform == "win32":
# Try to reconfigure streams to use UTF-8 with error replacement
if hasattr(sys.stdout, "reconfigure"):
try:
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
except (AttributeError, ValueError):
pass
if hasattr(sys.stderr, "reconfigure"):
try:
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
except (AttributeError, ValueError):
pass


def add_special_tokens_dict(tokenizer: PreTrainedTokenizerBase, special_tokens_dict: str):
"""Add the special tokens dictionary to tokenizer.
Expand Down Expand Up @@ -129,7 +147,7 @@ def get_categories(categories_path: str):
_, file_extension = os.path.splitext(categories_path)
if file_extension != ".json":
raise ValueError(f"Your --categories_path flag must point to a json file, you used {categories_path}")
with open(categories_path, "r") as categories_file:
with open(categories_path, "r", encoding="utf-8") as categories_file:
categories_list = json.load(categories_file)
if not isinstance(categories_list, list):
err_msg = (
Expand Down
20 changes: 15 additions & 5 deletions generative_data_prep/data_prep/data_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,24 @@ def data_prep_main(
dump_categories = category_to_id is not None

with Hdf5FileBuffer(output_file, max_seq_length, dump_categories) as hdf5_text_buffer:
with open(input_file, "r") as reader:
total_processed = 0
with open(input_file, "r", encoding="utf-8", errors="replace") as reader:
for i, line in enumerate(reader):
try:
hdf5_text_buffer.write(article_tokenizer(line))
total_processed = i + 1 # Track total processed (i is 0-indexed)
# Update counter every 100 articles (including the first batch)
# When i+1 is a multiple of 100, we've processed exactly that many
if (
(i != 0 and i % 100 == 0)
total_processed % 100 == 0
and num_tokenized_articles_lock is not None
and num_tokenized_articles is not None
):
with num_tokenized_articles_lock:
num_tokenized_articles.value += 100
except json.JSONDecodeError as exc:
if ignore_input_format_error:
with open(error_log_path, "a") as f:
with open(error_log_path, "a", encoding="utf-8", errors="replace") as f:
f.write(line)
if num_tokenized_articles_lock is not None and num_skipped_articles is not None:
with num_tokenized_articles_lock:
Expand All @@ -133,9 +137,15 @@ def data_prep_main(
exc.doc,
exc.pos,
) from exc
if num_tokenized_articles_lock is not None and num_tokenized_articles is not None:
# Add remaining articles that weren't counted in the batch updates
if num_tokenized_articles_lock is not None and num_tokenized_articles is not None and total_processed > 0:
with num_tokenized_articles_lock:
num_tokenized_articles.value += i % 100
# Calculate remaining articles: total processed minus what we already counted
# We count in batches of 100, so we need to add the remainder
already_counted = (total_processed // 100) * 100 # How many we've already counted
remaining = total_processed - already_counted
if remaining > 0:
num_tokenized_articles.value += remaining
hdf5_text_buffer.write(article_tokenizer(None))
article_tokenizer.metrics.dataset_type = dataset_type
return article_tokenizer.metrics
Loading
Loading