Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions pipelinerl/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,18 @@ def preprocess_dataset(
entry["step_index"] = entry["metadata"]["step_index"]
if not isinstance(tokenizer.eos_token_id, int):
raise ValueError(f"Tokenizer {tokenizer} does not have an eos_token_id")
dataset = populate_rl_data(dataset=dataset, eos_token_id=tokenizer.eos_token_id, config=rl_config)
try:
dataset = populate_rl_data(dataset=dataset, eos_token_id=tokenizer.eos_token_id, config=rl_config)
except Exception as e:
logger.error(f"Error in populate_rl_data: {e}")
logger.error(f"Data: {data}")
logger.error(f"Dataset: {dataset}")
logger.error(f"Tokenizer: {tokenizer}")
logger.error(f"Tokenizer eos_token_id: {tokenizer.eos_token_id}")
logger.error(f"RL config: {rl_config}")
logger.error(f"LLM: {llm}")
logger.error(f"Seq length: {seq_length}")
raise e
Copy link
Collaborator

@rafapi rafapi Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these logger calls are heavier than they look, you could keep this lighter and simpler with something like:

    except Exception:
        logger.exception(
            "Error in populate_rl_data",
            extra={
                "Seq length": seq_length,
                "eos_token_id": int(eos_id),
                ...
            },
        )
        raise  <-- this (no e) already keeps the stack trace

return dataset


Expand Down Expand Up @@ -533,7 +544,7 @@ def run_preprocessing_loop(
while len(buffer) > 0:
if len(processed_entries_queue) == processed_entries_queue.maxlen:
if not pop_old_data:
break
break
else:
processed_entries_queue_popped_data += 1
if processed_entries_queue_popped_data % 100 == 0 and last_time_notice != processed_entries_queue_popped_data // 100:
Expand Down Expand Up @@ -573,6 +584,10 @@ def run_preprocessing_loop(
sample_length = len(entry["input_ids"])

if current_length + sample_length > cfg.finetune.seq_length:
if len(current_batch) == 0:
raise ValueError(
f"sample_length is {sample_length}, but cfg.finetune.seq_length is {cfg.finetune.seq_length}"
)
time_to_write = True
break # Current micro batch is full

Expand Down