diff --git a/README.md b/README.md index 13c2a60..b705cbc 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ If you are an advanced user looking to process data with pre-defined splits, int ## Requirements - Python version 3.8.10+ -- Support for Linux and Mac OS. Not tested on Windows +- Support for Linux, Mac OS and Windows.
diff --git a/generative_data_prep/__main__.py b/generative_data_prep/__main__.py index 85811f2..fabe79e 100644 --- a/generative_data_prep/__main__.py +++ b/generative_data_prep/__main__.py @@ -15,9 +15,12 @@ Entry point to the Text Processing Pipeline. """ + import json import logging +import logging.config import os +import sys from typing import Optional from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase @@ -46,6 +49,21 @@ logger = logging.getLogger("generative_data_prep_logger") logging.config.fileConfig(get_config_file_path()) +# Fix Unicode encoding issues on Windows console +# Configure stdout/stderr to handle Unicode encoding errors on Windows +if sys.platform == "win32": + # Try to reconfigure streams to use UTF-8 with error replacement + if hasattr(sys.stdout, "reconfigure"): + try: + sys.stdout.reconfigure(encoding="utf-8", errors="replace") + except (AttributeError, ValueError): + pass + if hasattr(sys.stderr, "reconfigure"): + try: + sys.stderr.reconfigure(encoding="utf-8", errors="replace") + except (AttributeError, ValueError): + pass + def add_special_tokens_dict(tokenizer: PreTrainedTokenizerBase, special_tokens_dict: str): """Add the special tokens dictionary to tokenizer. @@ -129,7 +147,7 @@ def get_categories(categories_path: str): _, file_extension = os.path.splitext(categories_path) if file_extension != ".json": raise ValueError(f"Your --categories_path flag must point to a json file, you used {categories_path}") - with open(categories_path, "r") as categories_file: + with open(categories_path, "r", encoding="utf-8") as categories_file: categories_list = json.load(categories_file) if not isinstance(categories_list, list): err_msg = ( diff --git a/generative_data_prep/data_prep/data_prep.py b/generative_data_prep/data_prep/data_prep.py index 7d970d6..34a056d 100644 --- a/generative_data_prep/data_prep/data_prep.py +++ b/generative_data_prep/data_prep/data_prep.py @@ -106,12 +106,16 @@ def data_prep_main( dump_categories = category_to_id is not None with Hdf5FileBuffer(output_file, max_seq_length, dump_categories) as hdf5_text_buffer: - with open(input_file, "r") as reader: + total_processed = 0 + with open(input_file, "r", encoding="utf-8", errors="replace") as reader: for i, line in enumerate(reader): try: hdf5_text_buffer.write(article_tokenizer(line)) + total_processed = i + 1 # Track total processed (i is 0-indexed) + # Update counter every 100 articles (including the first batch) + # When i+1 is a multiple of 100, we've processed exactly that many if ( - (i != 0 and i % 100 == 0) + total_processed % 100 == 0 and num_tokenized_articles_lock is not None and num_tokenized_articles is not None ): @@ -119,7 +123,7 @@ def data_prep_main( num_tokenized_articles.value += 100 except json.JSONDecodeError as exc: if ignore_input_format_error: - with open(error_log_path, "a") as f: + with open(error_log_path, "a", encoding="utf-8", errors="replace") as f: f.write(line) if num_tokenized_articles_lock is not None and num_skipped_articles is not None: with num_tokenized_articles_lock: @@ -133,9 +137,15 @@ def data_prep_main( exc.doc, exc.pos, ) from exc - if num_tokenized_articles_lock is not None and num_tokenized_articles is not None: + # Add remaining articles that weren't counted in the batch updates + if num_tokenized_articles_lock is not None and num_tokenized_articles is not None and total_processed > 0: with num_tokenized_articles_lock: - num_tokenized_articles.value += i % 100 + # Calculate remaining articles: total processed minus what we already counted + # We count in batches of 100, so we need to add the remainder + already_counted = (total_processed // 100) * 100 # How many we've already counted + remaining = total_processed - already_counted + if remaining > 0: + num_tokenized_articles.value += remaining hdf5_text_buffer.write(article_tokenizer(None)) article_tokenizer.metrics.dataset_type = dataset_type return article_tokenizer.metrics diff --git a/generative_data_prep/data_prep/pipeline.py b/generative_data_prep/data_prep/pipeline.py index bd8d573..de7d742 100644 --- a/generative_data_prep/data_prep/pipeline.py +++ b/generative_data_prep/data_prep/pipeline.py @@ -12,7 +12,6 @@ See the License for the specific language governing permissions and limitations under the License. - Data preparation pipeline for converting a jsonl file to tokenized hdf5 files consumable by SambaSuite. """ @@ -23,10 +22,10 @@ import os import random import shutil +import sys import time import uuid from pathlib import Path -from sys import platform from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np @@ -35,6 +34,13 @@ from alive_progress import alive_bar from transformers import PretrainedConfig, PreTrainedTokenizerBase +# Set multiprocessing start method for Windows compatibility +if sys.platform == "win32": + try: + multiprocessing.set_start_method("spawn", force=True) + except RuntimeError: + # Start method already set, ignore + pass from generative_data_prep.data_prep import data_prep_main from generative_data_prep.processors.metrics import Metrics from generative_data_prep.utils import ( @@ -42,7 +48,6 @@ PackingConfig, balance_hdf5_files, create_sha256, - execute_and_return_stdout, get_num_training_splits, large_file_shuffle, log_sep_str, @@ -91,11 +96,11 @@ def combine_input_dir_files(input_path: str) -> Tuple[str, List[Path]]: output_file = input_path_obj / f"combined_output_{uuid.uuid4().hex[:8]}{ext}" # Open the output file and concatenate all input files - with open(output_file, "w") as f_out: + with open(output_file, "w", encoding="utf-8", errors="replace") as f_out: for input_file in input_files: if "combined_output_" not in str(input_file): verify_input_file(str(input_file)) - with open(input_file, "r") as f_in: + with open(input_file, "r", encoding="utf-8", errors="replace") as f_in: if input_file.stat().st_size == 0: continue # Skip empty files @@ -104,16 +109,33 @@ def combine_input_dir_files(input_path: str) -> Tuple[str, List[Path]]: return str(output_file), input_files -def split_file_linux(num_splits: int, input_file_path: str, split_dir: str) -> None: - """Split the [input_file_path] into num_splits and places it in [split_dir]. +def split_file_round_robin(num_splits: int, input_file_path: str, split_dir: str) -> None: + """Split the [input_file_path] into num_splits and places it in [split_dir] using round-robin distribution. + + This is a cross-platform replacement for the Linux 'split -d -n r/' command. Args: num_splits (int): number of output file splits input_file_path (str): input jsonl file path split_dir (str): The directory to place all the outputted splits """ - split_command = f"split -d -n r/{num_splits} {input_file_path} {split_dir}/" - execute_and_return_stdout(split_command) + # Create file handles for all split files + split_files = [] + num_digits = len(str(num_splits)) + for i in range(num_splits): + out_file_path = os.path.join(split_dir, str(i).zfill(max(2, num_digits))) + split_files.append(open(out_file_path, "w", encoding="utf-8", errors="replace")) + + try: + # Read input file and distribute lines in round-robin fashion + with open(input_file_path, "r", encoding="utf-8", errors="replace") as infile: + for line_num, line in enumerate(infile): + split_index = line_num % num_splits + split_files[split_index].write(line) + finally: + # Close all file handles + for f in split_files: + f.close() def check_RAM(input_file_size_in_bytes: int): @@ -161,11 +183,11 @@ def rename_files( num_digits = len(str(num_splits)) for i in range(num_splits): if i < train_count: - new_name = f"train_{i+1}_of_{train_count}{file_ext}" + new_name = f"train_{i + 1}_of_{train_count}{file_ext}" elif i < train_count + test_count: - new_name = f"test_{i-train_count+1}_of_{test_count}{file_ext}" + new_name = f"test_{i - train_count + 1}_of_{test_count}{file_ext}" else: - new_name = f"dev_{i-train_count-test_count+1}_of_{dev_count}{file_ext}" + new_name = f"dev_{i - train_count - test_count + 1}_of_{dev_count}{file_ext}" new_file_path = os.path.join(split_dir, new_name) @@ -189,8 +211,40 @@ def rename_files( return files_to_tokenize +def count_exact_total_num_articles(files_to_tokenize, split_dir): + """Counts the exact total number of articles by counting all non-empty lines in all files. + + Args: + files_to_tokenize: List of files to tokenize. + split_dir: Directory where the split files are located. + + Returns: + Exact count of the total number of articles to tokenize + """ + if not files_to_tokenize: + return 0 + + total_lines = 0 + LOGGER.info(f"Counting articles in {len(files_to_tokenize)} files to get exact total...") + + for file_name in files_to_tokenize: + file_path = os.path.join(split_dir, file_name) + lines_in_file = 0 + with open(file_path, "r", encoding="utf-8", errors="replace") as file: + for line in file: + # Skip empty lines to match actual processing behavior + if line.strip(): + lines_in_file += 1 + total_lines += lines_in_file + + LOGGER.info(f"Exact total articles counted: {total_lines}") + return total_lines + + def estimate_total_num_articles(files_to_tokenize, split_dir): - """Estimates the total number of articles based on number of artiles in first split times number of splits. + """Estimates the total number of articles based on number of articles in sample files times number of splits. + + DEPRECATED: Use count_exact_total_num_articles for exact count instead. Args: files_to_tokenize: List of files to tokenize. @@ -199,12 +253,31 @@ def estimate_total_num_articles(files_to_tokenize, split_dir): Returns: Estimate of the total number of articles needed to tokenize """ - lines_per_file = 0 - with open(os.path.join(split_dir, files_to_tokenize[0]), "r") as file: - for _ in file: - lines_per_file += 1 + if not files_to_tokenize: + return 0 + + # Sample up to 5 files to get a better average estimate + sample_size = min(5, len(files_to_tokenize)) + total_lines = 0 + files_sampled = 0 - return lines_per_file * len(files_to_tokenize) + for i in range(sample_size): + file_path = os.path.join(split_dir, files_to_tokenize[i]) + lines_in_file = 0 + with open(file_path, "r", encoding="utf-8", errors="replace") as file: + for line in file: + # Skip empty lines to match actual processing behavior + if line.strip(): + lines_in_file += 1 + total_lines += lines_in_file + files_sampled += 1 + + if files_sampled == 0: + return 0 + + # Calculate average lines per file and multiply by total files + avg_lines_per_file = total_lines / files_sampled + return int(avg_lines_per_file * len(files_to_tokenize)) def get_split_counts( @@ -345,7 +418,8 @@ def multiprocess_data_prep( # noqa: C901 ) train_hdf5_files = list(filter(lambda file_name: "train" in file_name, sub_output_file_paths)) dev_hdf5_files = list(filter(lambda file_name: "dev" in file_name, sub_output_file_paths)) - total_num_articles = estimate_total_num_articles(files_to_tokenize, split_dir) + # Count exact total to guarantee 100% accuracy + total_num_articles = count_exact_total_num_articles(files_to_tokenize, split_dir) # create manager for shared variables to keep track of tokenization progress manager = multiprocessing.Manager() num_tokenized_articles_lock = manager.Lock() @@ -353,7 +427,13 @@ def multiprocess_data_prep( # noqa: C901 num_skipped_articles = manager.Value(int, 0) prev_num_tokenized_articles = 0 prev_num_skipped_articles = 0 + # Track how much we've actually updated the progress bar to prevent exceeding total + bar_update_tracker = 0 # Submit multiprocessing workers + # On Windows, reduce workers to avoid pickling issues with large tokenizers + if sys.platform == "win32" and num_workers > 4: + LOGGER.warning(f"Reducing workers from {num_workers} to 4 on Windows to avoid multiprocessing issues.") + num_workers = 4 executor = concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) futures = [] for input_file_path, output_file_path in zip(sub_input_file_paths, sub_output_file_paths): @@ -401,7 +481,8 @@ def multiprocess_data_prep( # noqa: C901 tokenization_start_time = time.time() finished_futures = set() # Loop while processes are running, update progress bar. - with alive_bar(total_num_articles) as bar: + # Use manual mode to have better control over the progress bar + with alive_bar(total_num_articles, manual=True, title="Tokenizing articles") as bar: while True: for i, future in enumerate(futures): if future.done() and future not in finished_futures: @@ -445,15 +526,56 @@ def multiprocess_data_prep( # noqa: C901 if all(future.done() for future in futures): if len(finished_futures) != len(futures): raise ValueError("All futures done, but finished futures set does not equal all futures list.") + # Final update to ensure progress bar reflects all processed articles + with num_tokenized_articles_lock: + num_new_tokenized_articles = num_tokenized_articles.value - prev_num_tokenized_articles + if num_new_tokenized_articles > 0: + # Use our tracker to ensure we never exceed total + remaining_until_total = max(0, total_num_articles - bar_update_tracker) + if remaining_until_total > 0: + # Cap update to not exceed total + max_update = min(num_new_tokenized_articles, remaining_until_total) + if max_update > 0: + bar_update_tracker += max_update + # Set bar to exact position (as fraction of total, capped at 1.0) + bar_position = ( + min(1.0, bar_update_tracker / total_num_articles) if total_num_articles > 0 else 0.0 + ) + bar(bar_position) + # Ensure progress bar reaches exactly 100% (1.0 in manual mode) + # Use tracker to set final position + bar_update_tracker = total_num_articles + bar(1.0) # Set to 100% completion break # Update the progress bar with how every many new articles were tokenized with num_tokenized_articles_lock: num_new_tokenized_articles = num_tokenized_articles.value - prev_num_tokenized_articles - bar(num_new_tokenized_articles) - perc_complete = round((bar.current / total_num_articles) * 100, 2) + if num_new_tokenized_articles > 0: + # Use our tracker to ensure we never exceed total + remaining_until_total = max(0, total_num_articles - bar_update_tracker) + # Only update if there's room and we have new articles + if remaining_until_total > 0: + # Cap update to not exceed total + max_update = min(num_new_tokenized_articles, remaining_until_total) + if max_update > 0: + bar_update_tracker += max_update + # Set bar to exact position (as fraction of total, capped at 1.0) + bar_position = ( + min(1.0, bar_update_tracker / total_num_articles) if total_num_articles > 0 else 0.0 + ) + bar(bar_position) + # Calculate percentage based on our tracker (more accurate than bar.current in manual mode) + if total_num_articles > 0: + # Use tracker to calculate accurate percentage + actual_current = min(bar_update_tracker, total_num_articles) + perc_complete = min(100.0, round((actual_current / total_num_articles) * 100, 2)) + else: + perc_complete = 0.0 elapsed_time_str = f"--- elapsed time: {time.time() - tokenization_start_time}" LOGGER.debug( - f"{total_num_articles}, {perc_complete}% complete => Time remaining: {bar.eta} {elapsed_time_str}" + f"Counter: {num_tokenized_articles.value}, Progress tracker: " + f"{bar_update_tracker}/{total_num_articles}, {perc_complete}% complete => " + f"Time remaining: {bar.eta} {elapsed_time_str}" ) prev_num_tokenized_articles = num_tokenized_articles.value @@ -464,9 +586,83 @@ def multiprocess_data_prep( # noqa: C901 prev_num_skipped_articles = num_skipped_articles.value time.sleep(5) + # Log final article count and validate 100% completion + log_sep_str() + total_actual_articles = train_metrics.articles + dev_metrics.articles + LOGGER.info( + f"Total articles processed (from metrics): {total_actual_articles} " + f"(Train: {train_metrics.articles}, Dev: {dev_metrics.articles})" + ) + LOGGER.info(f"Total articles counted in input files: {total_num_articles}") + if ignore_input_format_error: - LOGGER.info(f"Total processed lines: {num_tokenized_articles.value}") - LOGGER.info(f"Total skipped lines: {num_skipped_articles.value}") + LOGGER.info(f"Progress counter value: {num_tokenized_articles.value}") + LOGGER.info(f"Total skipped lines (format errors): {num_skipped_articles.value}") + + # Validate 100% completion + if total_num_articles > 0: + counter_articles = num_tokenized_articles.value + metrics_articles = total_actual_articles + skipped_articles = num_skipped_articles.value if ignore_input_format_error else 0 + + # Calculate expected articles (total - skipped due to format errors) + # Note: Articles dropped during processing (prompt-only, packing drops) + # are still counted in metrics.articles because metrics.articles is + # incremented before processing/dropping + expected_articles = total_num_articles - skipped_articles + + # Compare metrics with expected count + metrics_diff = abs(metrics_articles - expected_articles) + metrics_diff_percent = (metrics_diff / total_num_articles) * 100 if total_num_articles > 0 else 0 + + log_sep_str() + if metrics_diff == 0: + LOGGER.info( + f"[SUCCESS] 100% DATA UTILIZATION: All {total_num_articles} " + f"articles from input files were processed!" + ) + if skipped_articles > 0: + LOGGER.info( + f" Note: {skipped_articles} articles were skipped due to " f"JSON format errors (expected)" + ) + LOGGER.info(f" All {metrics_articles} processed articles are included in " f"the output dataset.") + elif metrics_diff_percent <= 0.1: # Less than 0.1% difference + LOGGER.warning( + f"Near-complete data utilization: {metrics_articles}/" + f"{expected_articles} articles processed " + f"({metrics_diff_percent:.3f}% difference). This is likely due " + f"to rounding or minor counting differences." + ) + LOGGER.info(f" {metrics_articles} articles are included in the output " f"dataset.") + else: + LOGGER.error( + f"[WARNING] INCOMPLETE DATA UTILIZATION: Only " + f"{metrics_articles}/{expected_articles} articles processed " + f"({metrics_diff_percent:.2f}% difference, " + f"{expected_articles - metrics_articles} articles missing)." + ) + LOGGER.error( + f" This means {expected_articles - metrics_articles} articles " + f"from your input files were not processed. " + f"Please check for errors in processing or data format issues." + ) + + # Compare counter with metrics to identify counting issues + if abs(counter_articles - metrics_articles) > 10: + LOGGER.warning( + f"Counter discrepancy detected: Progress counter shows " + f"{counter_articles} articles, but metrics show " + f"{metrics_articles} articles were actually processed. " + f"Difference: {abs(metrics_articles - counter_articles)} articles. " + f"The metrics count ({metrics_articles}) is the accurate one." + ) + else: + LOGGER.info( + f"[OK] Progress counter matches metrics: {counter_articles} " + f"articles counted, {metrics_articles} articles processed." + ) + + log_sep_str() if dataset_metadata_json is not None: dataset_metadata_json["max_batch_size_train"] = max_batch_size_train @@ -580,7 +776,7 @@ def pipeline_main( # noqa: C901 ) num_splits_greater_lines = False - with open(input_file_path, "r") as input_file: + with open(input_file_path, "r", encoding="utf-8", errors="replace") as input_file: for i, line in enumerate(input_file): if i > num_splits: num_splits_greater_lines = True @@ -611,7 +807,7 @@ def pipeline_main( # noqa: C901 if category_to_id is not None: category_to_id_output_file_path = os.path.join(output_dir, "category_to_id.json") verify_output_file(category_to_id_output_file_path, overwrite_output_path) - with open(category_to_id_output_file_path, "w") as f: + with open(category_to_id_output_file_path, "w", encoding="utf-8") as f: json.dump(category_to_id, f) test_dir = os.path.join(output_dir, "test_files") @@ -622,66 +818,31 @@ def pipeline_main( # noqa: C901 # ========================================================= # Case 1: large file shuffle specified. REQUIRES: linux OS if shuffle == "large_file": - err_msg = "You specified --shuffle=large_file, but this is only supported on linux operating systems, " - err_msg += f"your operating system is {platform}. Please change the flag to --shuffle=on_RAM or --shuffle=False" - if "linux" not in platform.lower(): - raise OSError(err_msg) split_dir = large_file_shuffle(input_file_path, output_dir, False, num_splits) - # Case 2: Shuffling on RAM with linux OS - elif shuffle == "on_RAM" and "linux" in platform.lower(): + # Case 2: Shuffling on RAM (cross-platform) + elif shuffle == "on_RAM": check_RAM(input_file_size_in_bytes) log_sep_str() LOGGER.info("Shuffling input file, please be patient.") - file_ext = os.path.splitext(input_file_path)[1] - shuffle_file_path = os.path.join(output_dir, f"tmp_shuf{file_ext}") - shuffle_command = f"shuf {input_file_path} > {shuffle_file_path}" - try: - out = execute_and_return_stdout(shuffle_command) - err_msg = f"Shuffle command killed, with print stdout:{out.stdout} stderr:{out.stderr}" - if "killed" in out.stdout or "killed" in out.stderr: - raise MemoryError(err_msg) - except Exception as e: - err_msg = f"Failed with exception {e}, shuffling on RAM is not possible," - err_msg += " try specifying argument --shuffle=large_file" - raise RuntimeError(err_msg) - split_file_linux(num_splits, shuffle_file_path, split_dir) - os.remove(shuffle_file_path) - - # Case 3: shuffle on RAM without linux OS - elif shuffle == "on_RAM" and "linux" not in platform.lower(): - check_RAM(input_file_size_in_bytes) - lines = open(input_file_path).readlines() + # Read all lines into memory + with open(input_file_path, "r", encoding="utf-8", errors="replace") as f: + lines = f.readlines() + # Shuffle the lines random.shuffle(lines) + # Split into chunks splits = np.array_split(lines, num_splits) num_digits = len(str(num_splits)) for i, split in enumerate(splits): out_file_path = os.path.join(split_dir, str(i).zfill(max(2, num_digits))) - with open(out_file_path, "w") as out_file: + with open(out_file_path, "w", encoding="utf-8", errors="replace") as out_file: out_file.writelines(split) - # Case 4: Do not shuffle, split file without linux OS - elif shuffle == "False" and "linux" not in platform.lower(): - log_sep_str() - LOGGER.warning("WARNING: you did not specify the --shuffle flag, so no shuffling was done!") - out_files = [] - num_digits = len(str(num_splits)) - for i in range(num_splits): - out_file_path = os.path.join(split_dir, str(i).zfill(max(2, num_digits))) - out_files.append(out_file_path) - with open(out_file_path, "w") as _: - pass - - with open(input_file_path, "r") as input_file: - for i, line in enumerate(input_file): - with open(out_files[i % len(out_files)], "a") as out_f: - out_f.write(line) - - # Case 5: Do not shuffle, split file with linux OS - elif shuffle == "False" and "linux" in platform.lower(): + # Case 3: Do not shuffle, split file (cross-platform) + elif shuffle == "False": log_sep_str() LOGGER.warning("WARNING: you did not specify the --shuffle flag, so no shuffling was done!") - split_file_linux(num_splits, input_file_path, split_dir) + split_file_round_robin(num_splits, input_file_path, split_dir) # rename files to include the corresponding names of 'test', 'dev' and 'train' files_to_tokenize = rename_files( @@ -742,9 +903,9 @@ def pipeline_main( # noqa: C901 for file_name in os.listdir(json_error_log_dir): file_names.append(os.path.join(json_error_log_dir, file_name)) if file_names: - with open(os.path.join(output_dir, "json_load_failed_lines.log"), "w") as outfile: + with open(os.path.join(output_dir, "json_load_failed_lines.log"), "w", encoding="utf-8") as outfile: for file_name in file_names: - with open(file_name) as reader: + with open(file_name, "r", encoding="utf-8") as reader: for line in reader: outfile.write(line) shutil.rmtree(json_error_log_dir) @@ -755,7 +916,7 @@ def pipeline_main( # noqa: C901 update_dataset_metadata(train_metrics, dataset_metadata_json) update_dataset_metadata(dev_metrics, dataset_metadata_json) metadata_file_path = os.path.join(output_dir, "metadata.yaml") - with open(metadata_file_path, "w") as file: + with open(metadata_file_path, "w", encoding="utf-8") as file: yaml.dump(dataset_metadata_json, file, default_flow_style=False) # Create sha256 of all the files within the directory diff --git a/generative_data_prep/utils/add_metadata_to_dataset.py b/generative_data_prep/utils/add_metadata_to_dataset.py index 4ecb445..02f5571 100644 --- a/generative_data_prep/utils/add_metadata_to_dataset.py +++ b/generative_data_prep/utils/add_metadata_to_dataset.py @@ -59,7 +59,7 @@ def save_metadata(metadata_path, metadata): metadata_path (str): Path to the metadata YAML file. metadata (dict): Metadata dictionary to save. """ - with open(metadata_path, "w") as f: + with open(metadata_path, "w", encoding="utf-8") as f: yaml.safe_dump(metadata, f, default_flow_style=False) @@ -72,7 +72,7 @@ def add_seq_metadata_dataset(dataset_path): metadata_path = os.path.join(dataset_path, "metadata.yaml") metadata = {} if os.path.exists(metadata_path): - with open(metadata_path, "r") as f: + with open(metadata_path, "r", encoding="utf-8") as f: metadata = yaml.safe_load(f) or {} train_sequences = 0 diff --git a/generative_data_prep/utils/large_file_shuffle.py b/generative_data_prep/utils/large_file_shuffle.py index ffb3a22..ca5f27a 100644 --- a/generative_data_prep/utils/large_file_shuffle.py +++ b/generative_data_prep/utils/large_file_shuffle.py @@ -12,6 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import logging import os import random @@ -25,6 +26,55 @@ LOGGER = logging.getLogger("generative_data_prep_logger") +def _split_file_round_robin(input_file_path: str, split_dir: str, num_splits: int): + """Split a file into multiple files using round-robin distribution. + + This is a cross-platform replacement for the Linux 'split -d -n r/' command. + Each line from the input file is distributed to output files in round-robin fashion. + + Args: + input_file_path (str): Path to the input file to split + split_dir (str): Directory where split files will be created + num_splits (int): Number of split files to create + """ + # Create file handles for all split files + split_files = [] + for i in range(num_splits): + split_file_path = os.path.join(split_dir, f"x{i:02d}") + split_files.append(open(split_file_path, "w", encoding="utf-8", errors="replace")) + + try: + # Read input file and distribute lines in round-robin fashion + with open(input_file_path, "r", encoding="utf-8", errors="replace") as infile: + for line_num, line in enumerate(infile): + split_index = line_num % num_splits + split_files[split_index].write(line) + finally: + # Close all file handles + for f in split_files: + f.close() + + +def _shuffle_file(file_path: str): + """Shuffle the lines of a file in-place. + + This is a cross-platform replacement for the Linux 'shuf' command. + + Args: + file_path (str): Path to the file to shuffle + """ + # Read all lines + with open(file_path, "r", encoding="utf-8", errors="replace") as f: + lines = f.readlines() + + # Shuffle the lines + random.shuffle(lines) + + # Write back to file + with open(file_path, "w", encoding="utf-8", errors="replace") as f: + f.writelines(lines) + + def large_file_shuffle( input_file_path: str, output_dir: str, @@ -86,8 +136,7 @@ def large_file_shuffle( prev_time = time.time() LOGGER.info("splitting file") - split_command = f"split -d -n r/{num_splits} {input_file_path} {split_dir}/" - os.system(split_command) # nosec + _split_file_round_robin(input_file_path, split_dir, num_splits) LOGGER.info(f"splitting took {time.time() - prev_time} seconds (used round robin splitting).") prev_time = time.time() @@ -95,19 +144,19 @@ def large_file_shuffle( file_list = list(os.listdir(split_dir)) for file in tqdm(file_list): curr_file_path = os.path.join(split_dir, file) - shuf_command = f"shuf {curr_file_path} --output={curr_file_path}" - os.system(shuf_command) # nosec + _shuffle_file(curr_file_path) if concat_splits: random_split_list = list(range(num_splits)) random.shuffle(random_split_list) prev_time = time.time() LOGGER.info("Concatenating shuffled splits.") - for rand_ind in tqdm(random_split_list): - curr_file_path = os.path.join(split_dir, file_list[rand_ind]) - concat_command = f"cat {curr_file_path} >> {output_path}" - os.system(concat_command) # nosec - os.remove(curr_file_path) + with open(output_path, "wb") as outfile: + for rand_ind in tqdm(random_split_list): + curr_file_path = os.path.join(split_dir, file_list[rand_ind]) + with open(curr_file_path, "rb") as infile: + shutil.copyfileobj(infile, outfile) + os.remove(curr_file_path) LOGGER.info(f"Finished concatenating files. Took {time.time() - prev_time} seconds.") shutil.rmtree(split_dir) diff --git a/generative_data_prep/utils/logger.py b/generative_data_prep/utils/logger.py index 716948c..65971cf 100644 --- a/generative_data_prep/utils/logger.py +++ b/generative_data_prep/utils/logger.py @@ -15,6 +15,7 @@ This class creates a common logger. """ + import argparse import datetime import importlib.metadata @@ -89,7 +90,31 @@ def log_input_args(args): def log_metrics(metrics): """Log the metrics table.""" if not metrics.is_empty: - LOGGER.info(f"{get_header('')}\n{metrics}\n{get_header('')}") + metrics_str = f"{get_header('')}\n{metrics}\n{get_header('')}" + # Replace Unicode box-drawing characters with ASCII equivalents for Windows compatibility + if sys.platform == "win32": + replacements = { + "╒": "+", + "═": "=", + "╤": "+", + "╕": "+", + "├": "+", + "─": "-", + "┼": "+", + "┤": "+", + "╘": "+", + "╧": "+", + "╛": "+", + "│": "|", + } + for old, new in replacements.items(): + metrics_str = metrics_str.replace(old, new) + try: + LOGGER.info(metrics_str) + except UnicodeEncodeError: + # Fallback: encode as ASCII with replacement + safe_str = metrics_str.encode("ascii", errors="replace").decode("ascii") + LOGGER.info(safe_str) def get_header(header_name: str): diff --git a/generative_data_prep/utils/utils.py b/generative_data_prep/utils/utils.py index 7ce6329..265cc57 100644 --- a/generative_data_prep/utils/utils.py +++ b/generative_data_prep/utils/utils.py @@ -80,7 +80,7 @@ def validate_sha256(output_dir: str): """ files_to_hash = _get_walk_files_to_hash(output_dir, "sha256") sha_info_file = os.path.join(output_dir, "sha256", "files_metadata.json") - with open(sha_info_file, "r") as output_file: + with open(sha_info_file, "r", encoding="utf-8", errors="replace") as output_file: file_info_dict = json.load(output_file) for file, hash_file_name in files_to_hash: if "logs" not in hash_file_name: @@ -128,7 +128,7 @@ def create_sha256(output_dir: str): "size": os.path.getsize(file), "modified_time": os.path.getmtime(file), } - with open(output_file_hash, "w") as output_file: + with open(output_file_hash, "w", encoding="utf-8") as output_file: json.dump(file_info_dict, output_file)