diff --git a/pyproject.toml b/pyproject.toml index e29e4c2f6..452d549d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "requests", "rich", "textual", + "pillow>=10.0.0", "pydantic>=2.11.9", "prime-sandboxes>=0.1.0", ] diff --git a/tests/test_environment.py b/tests/test_environment.py index c7dac0442..aec0304ee 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -255,6 +255,8 @@ def apply_template(conversation, tokenize=False, add_generation_prompt=True): ( prompt_ids, prompt_mask, + prompt_image_grid, + prompt_pixel_value, completion_ids, completion_mask, completion_logprobs, @@ -300,6 +302,8 @@ def test_process_completion_format(self, mock_openai_client, sample_dataset): ( prompt_ids, prompt_mask, + prompt_image_grid, + prompt_pixel_value, completion_ids, completion_mask, completion_logprobs, diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 86a39d26f..5bede75b6 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -4,8 +4,7 @@ from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from copy import deepcopy -from typing import TYPE_CHECKING, Literal - +from typing import TYPE_CHECKING, Literal, Union from datasets import Dataset from openai import AsyncOpenAI, BadRequestError, OpenAI @@ -27,6 +26,7 @@ SamplingArgs, State, ) +from verifiers.utils.processor_utils import encode_text_with_processor, encode_chat_with_processor from verifiers.utils.message_utils import ( cleanup_messages, get_overlong_prompt_dummy_response, @@ -35,10 +35,9 @@ if TYPE_CHECKING: from transformers.tokenization_utils_base import ( # type: ignore - PreTrainedTokenizerBase, + PreTrainedTokenizerBase, ProcessorMixin ) - - + class Environment(ABC): """ Base class for all environments. @@ -69,7 +68,6 @@ def __init__( self.logger.warning( "The parser and rubric parser are different. This may cause unexpected behavior." ) - if self.message_type == "chat": if dataset is not None: self.dataset = self.format_dataset( @@ -228,6 +226,7 @@ async def get_model_response( ): sampling_args.pop("max_completion_tokens") clean_sampling_args = {k: v for k, v in sampling_args.items() if v is not None} + try: if message_type == "chat": assert isinstance(prompt, list) @@ -444,6 +443,7 @@ async def a_generate( reward=[], metrics={}, ) + n = len(results.prompt) # Resolve concurrency knobs @@ -593,6 +593,7 @@ def generate( ) -> GenerateOutputs: if isinstance(client, OpenAI): client = AsyncOpenAI(api_key=client.api_key, base_url=client.base_url) + coro = self.a_generate( inputs, client, @@ -819,9 +820,9 @@ def process_chat_format_vllm( prompt: list[ChatMessage], completion: list[ChatMessage], state: State, - processing_class: "PreTrainedTokenizerBase", + processing_class: Union["PreTrainedTokenizerBase", "ProcessorMixin"], mask_env_responses: bool = False, - ) -> tuple[list[int], list[int], list[int], list[int], list[float]]: + ) -> tuple[list[int], list[int], list[int], list[int], list[int], list[int], list[float]]: """ Process chat format conversations using incremental prefixes. """ @@ -836,10 +837,13 @@ def process_chat_format_vllm( zipped.append((turn, None)) assert len(responses) == responses_idx, "Responses not fully consumed" assert len(zipped) == len(completion), "Length mismatch" - prompt_ids: list[int] = processing_class.apply_chat_template( - conversation=prompt, # type: ignore + + prompt_ids, prompt_image_grid, prompt_pixel_value = encode_chat_with_processor( + conversation=prompt, + processing_class=processing_class, add_generation_prompt=True, ) + messages_consumed = [m for m in prompt] prompt_mask: list[int] = [0] * len(prompt_ids) completion_ids: list[int] = [] @@ -900,13 +904,15 @@ def deserialize_tool_call(tool_call) -> dict: while j < len(zipped) and zipped[j][0]["role"] != "assistant": consecutive_messages.append(zipped[j][0]) j += 1 - token_prefix: list[int] = processing_class.apply_chat_template( - conversation=messages_consumed # type: ignore + token_prefix, token_prefix_image_grid, token_prefix_pixel_values = encode_chat_with_processor( + conversation=messages_consumed, # type: ignore + processing_class=processing_class, + add_generation_prompt=False, ) - token_prefix_with_turn: list[int] = ( - processing_class.apply_chat_template( - conversation=messages_consumed + consecutive_messages, # type: ignore - ) + token_prefix_with_turn, token_prefix_with_turn_image_grid,token_prefix_with_turn_pixel_values = encode_chat_with_processor( + conversation=messages_consumed + consecutive_messages, # type: ignore + processing_class=processing_class, + add_generation_prompt=False, ) assert token_prefix_with_turn[: len(token_prefix)] == token_prefix, ( f"Token prefix mismatch. Token prefix: {token_prefix}, token prefix with turn: {token_prefix_with_turn}" @@ -916,6 +922,7 @@ def deserialize_tool_call(tool_call) -> dict: completion_turn_mask = [0] * len(completion_turn_ids) else: completion_turn_mask = [1] * len(completion_turn_ids) + completion_turn_logprobs = [0.0] * len(completion_turn_ids) completion_ids.extend(completion_turn_ids) completion_mask.extend(completion_turn_mask) @@ -925,6 +932,8 @@ def deserialize_tool_call(tool_call) -> dict: return ( prompt_ids, prompt_mask, + prompt_image_grid, + prompt_pixel_value, completion_ids, completion_mask, completion_logprobs, @@ -935,9 +944,9 @@ def process_completion_format_vllm( prompt: str, completion: str, state: State, - processing_class: "PreTrainedTokenizerBase", + processing_class: Union["PreTrainedTokenizerBase", "ProcessorMixin"], mask_env_responses: bool = False, - ) -> tuple[list[int], list[int], list[int], list[int], list[float]]: + ) -> tuple[list[int], list[int], list[int], list[int], list[int], list[int], list[float]]: """ Process completion format conversations using incremental prefixes. """ @@ -958,12 +967,16 @@ def process_completion_format_vllm( idx = response_start_idx + len(response_text) assert idx == len(completion), "Completion not fully consumed" - prompt_ids: list[int] = processing_class.encode(prompt) + prompt_ids, prompt_image_grid, prompt_pixel_value = encode_text_with_processor( + text=prompt, + processing_class=processing_class, + ) rollout_consumed = prompt prompt_mask: list[int] = [0] * len(prompt_ids) completion_ids: list[int] = [] completion_mask: list[int] = [] completion_logprobs: list[float] = [] + i = 0 while i < len(zipped): text, response = zipped[i] @@ -1000,6 +1013,8 @@ def process_completion_format_vllm( return ( prompt_ids, prompt_mask, + prompt_image_grid, + prompt_pixel_value, completion_ids, completion_mask, completion_logprobs, @@ -1011,7 +1026,7 @@ def process_env_results_vllm( completions: list[Messages], states: list[State], rewards: list[float], - processing_class: "PreTrainedTokenizerBase", + processing_class: Union["PreTrainedTokenizerBase", "ProcessorMixin"], max_seq_len: int = -1, mask_env_responses: bool = False, mask_truncated_completions: bool = False, @@ -1024,6 +1039,8 @@ def process_env_results_vllm( all_prompt_ids = [] all_prompt_masks = [] + all_prompt_image_grid = [] + all_prompt_pixel_value = [] all_completion_ids = [] all_completion_masks = [] all_completion_logprobs = [] @@ -1037,6 +1054,8 @@ def process_env_results_vllm( ( prompt_ids, prompt_mask, + prompt_image_grid, + prompt_pixel_value, completion_ids, completion_mask, completion_logprobs, @@ -1048,6 +1067,8 @@ def process_env_results_vllm( ( prompt_ids, prompt_mask, + prompt_image_grid, + prompt_pixel_value, completion_ids, completion_mask, completion_logprobs, @@ -1080,16 +1101,22 @@ def process_env_results_vllm( ) all_prompt_ids.append(prompt_ids) all_prompt_masks.append(prompt_mask) + all_prompt_image_grid.append(prompt_image_grid) + all_prompt_pixel_value.append(prompt_pixel_value) all_completion_ids.append(completion_ids) all_completion_masks.append(completion_mask) all_completion_logprobs.append(completion_logprobs) + if zero_truncated_completions and is_truncated: all_rewards.append(0) else: all_rewards.append(reward) + return ProcessedOutputs( prompt_ids=all_prompt_ids, prompt_mask=all_prompt_masks, + image_grid_thw = all_prompt_image_grid, + pixel_values= all_prompt_pixel_value, completion_ids=all_completion_ids, completion_mask=all_completion_masks, completion_logprobs=all_completion_logprobs, diff --git a/verifiers/trainers/async_batch_generator.py b/verifiers/trainers/async_batch_generator.py index b7c44a96d..b5debe8c1 100644 --- a/verifiers/trainers/async_batch_generator.py +++ b/verifiers/trainers/async_batch_generator.py @@ -4,14 +4,13 @@ import threading import time from collections import deque -from typing import Any +from typing import Any, Optional from pydantic import BaseModel, Field from verifiers import GenerateOutputs from verifiers.types import ProcessedOutputs - class BatchRequest(BaseModel): """Request for batch generation""" @@ -38,6 +37,7 @@ class BatchResult(BaseModel): default_factory=list ) # Store completions for logging prompts: list[Any] = Field(default_factory=list) # Store prompts for logging + answers : Optional[list[Any]] class AsyncBatchGenerator: @@ -264,6 +264,7 @@ async def _generate_batch_async(self, request: BatchRequest) -> BatchResult: """ # Call environment generation self.is_generating = True + env_results = await self.env.a_generate( request.env_inputs, client=self.client, @@ -272,6 +273,7 @@ async def _generate_batch_async(self, request: BatchRequest) -> BatchResult: score_rollouts=True, max_concurrent=request.max_concurrent, ) + self.is_generating = False # Extract all reward-related keys @@ -281,6 +283,7 @@ async def _generate_batch_async(self, request: BatchRequest) -> BatchResult: for k in env_results.metrics: all_reward_dict[k] = env_results.metrics[k] + # Process results processed_results = self.env.process_env_results_vllm( prompts=env_results.prompt, @@ -300,6 +303,7 @@ async def _generate_batch_async(self, request: BatchRequest) -> BatchResult: all_reward_dict=all_reward_dict, completions=env_results.completion, prompts=env_results.prompt, + answers=request.env_inputs.get("answer") ) async def _evaluate_async(self, num_samples: int = -1) -> GenerateOutputs: diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index c8887e45e..aa4cd50bf 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -4,7 +4,8 @@ import time from collections import defaultdict, deque from contextlib import nullcontext -from typing import Any, Dict, List, Optional, Sized, Tuple, Union +from typing import Any, Dict, List, Optional, Sized, Union +import inspect import datasets import numpy as np @@ -17,7 +18,6 @@ ) from peft import PeftConfig, get_peft_model # type: ignore[unresolved-import] from torch.utils.data import DataLoader, Sampler # type: ignore[unresolved-import] -from transformers import AutoModelForCausalLM # type: ignore[unresolved-import] from transformers.integrations.deepspeed import ( # type: ignore[unresolved-import] is_deepspeed_zero3_enabled, ) @@ -31,6 +31,7 @@ from transformers.trainer_callback import ( # type: ignore[unresolved-import] TrainerCallback, ) +from transformers import ProcessorMixin, AutoModelForCausalLM # type: ignore[unresolved-import] from transformers.trainer_utils import seed_worker # type: ignore[unresolved-import] from trl.models import ( # type: ignore[unresolved-import] create_reference_model, @@ -44,13 +45,12 @@ pad, selective_log_softmax, ) - from verifiers import Environment from verifiers.trainers.async_batch_generator import AsyncBatchGenerator, BatchRequest from verifiers.trainers.async_dataloader_wrapper import AsyncDataLoaderWrapper from verifiers.trainers.grpo_config import GRPOConfig -from verifiers.utils.logging_utils import print_prompt_completions_sample - +from verifiers.utils.logging_utils import print_prompt_completions_sample, serialize_for_wandb, extract_images +from verifiers.utils.image_utils import pil_to_base64_url class RepeatSampler(Sampler): """ @@ -210,33 +210,30 @@ def split_tensor_dict( ] -def shuffle_tensor_dict( - tensor_dict: dict[str, Optional[torch.Tensor]], -) -> dict[str, Optional[torch.Tensor]]: +def shuffle_dict_with_lists( + data_dict: Dict[str, Optional[Union[torch.Tensor, List]]], +) -> Dict[str, Optional[Union[torch.Tensor, List]]]: """ - Shuffles a dictionary of tensors along the first dimension in unison. - - Example: - >>> x = torch.arange(6).reshape(3, 2) - >>> y = torch.arange(3).reshape(3, 1) - >>> tensor_dict = {"x": x, "y": y} - >>> shuffle_tensor_dict(tensor_dict) - {'x': tensor([[2, 3], - [0, 1], - [4, 5]]), - 'y': tensor([[1], - [0], - [2]])} + Shuffles a dictionary of tensors and/or lists along the first dimension in unison since pixel values can't be a schufflable tensor at the moment """ - first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None) - batch_size = first_tensor.shape[0] + first_item = next(item for item in data_dict.values() if item is not None) + batch_size = len(first_item) + permutation = torch.randperm(batch_size) - return { - key: tensor[permutation] if tensor is not None else None - for key, tensor in tensor_dict.items() - } - + + shuffled_dict = {} + for key, value in data_dict.items(): + if value is None: + shuffled_dict[key] = None + elif isinstance(value, torch.Tensor): + shuffled_dict[key] = value[permutation] + elif isinstance(value, list): + shuffled_dict[key] = [value[i] for i in permutation] + else: + shuffled_dict[key] = value + return shuffled_dict + def nanmin(tensor: torch.Tensor) -> torch.Tensor: """ Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors. @@ -266,14 +263,13 @@ def nanmax(tensor: torch.Tensor) -> torch.Tensor: return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) return torch.max(tensor[~torch.isnan(tensor)]) - class GRPOTrainer(Trainer): def __init__( self, model: PreTrainedModel, env: Environment, args: GRPOConfig, - processing_class: PreTrainedTokenizerBase, + processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None, callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[ Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR] @@ -283,6 +279,12 @@ def __init__( ): self.logger = logging.getLogger(__name__) + self.model_kwarg_keys = ( + inspect.signature(model.forward).parameters.keys() + if not hasattr(model, "get_base_model") + else inspect.signature(model.get_base_model().forward).parameters.keys() + ) + # Models if peft_config is not None: model = get_peft_model(model, peft_config) # type: ignore @@ -299,10 +301,25 @@ def __init__( # Suppress irrelevant warning model.warnings_issued["estimate_tokens"] = True + + # Handle pad token for processors or tokenizers + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + elif isinstance(processing_class, PreTrainedTokenizerBase): + tokenizer = processing_class + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - # Tokenizer pad token - if processing_class.pad_token is None: # type: ignore - processing_class.pad_token = processing_class.eos_token # type: ignore + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + self.pad_token = tokenizer.pad_token + self.pad_token_id = tokenizer.pad_token_id + self.eos_token_id = tokenizer.eos_token_id + self.image_token = getattr(processing_class, "image_token", None) + self.image_token_id = getattr(processing_class, "image_token_id", None) + self.vision_start_token_id = getattr(model.config, "vision_start_token_id", None) + self.vision_end_token_id = getattr(model.config, "vision_end_token_id", None) # Training arguments self.per_device_train_batch_size = args.per_device_train_batch_size @@ -421,24 +438,37 @@ def __init__( def filter_by_prompt_length(example, processing_class): prompt = example["prompt"] - # Tokenize prompt to check length if isinstance(prompt, list): - # Chat format prompt_text = processing_class.apply_chat_template( prompt, tokenize=False, add_generation_prompt=True ) else: - # Completion format prompt_text = prompt - prompt_ids = processing_class.encode(prompt_text) # type: ignore + + if isinstance(processing_class, PreTrainedTokenizerBase): + prompt_ids = processing_class.encode(prompt_text) + elif isinstance(processing_class, ProcessorMixin): + kwargs = {} + if "image" in example: + kwargs["images"] = [example["image"]] + + inputs = processing_class( + text=prompt_text, + return_tensors="pt", + add_special_tokens=False, + **kwargs, + ) + prompt_ids = inputs["input_ids"][0].tolist() + else: + raise ValueError(f"Unsupported processing class: {type(processing_class)}") return len(prompt_ids) <= max_length original_size = len(train_dataset) - train_dataset = train_dataset.filter( - filter_by_prompt_length, - num_proc=self.max_data_workers, - fn_kwargs={"processing_class": processing_class}, - ) + #train_dataset = train_dataset.filter( + # filter_by_prompt_length, + # num_proc=self.max_data_workers, + # fn_kwargs={"processing_class": processing_class}, + #) filtered_size = len(train_dataset) if filtered_size < original_size: self.logger.info( @@ -536,10 +566,12 @@ def data_collator(features): * args.per_device_train_batch_size * args.gradient_accumulation_steps ) - self._textual_logs = { + self._logs = { + "image": deque(maxlen=maxlen), "prompt": deque(maxlen=maxlen), "completion": deque(maxlen=maxlen), "rewards": defaultdict(lambda: deque(maxlen=maxlen)), + "answers": deque(maxlen=maxlen), } # OpenAI client for Environment generation (using vLLM server) @@ -725,44 +757,77 @@ def _inner_training_loop(self, *args, **kwargs): self.async_generator.stop() self._async_started = False + def _get_last_hidden_state( - self, unwrapped_model, input_ids, attention_mask, logits_to_keep=None + self, + unwrapped_model, + input_ids, + attention_mask, + logits_to_keep, + pixel_values=None, + image_grid_thw=None, ): if is_peft_model(unwrapped_model): unwrapped_model = unwrapped_model.base_model.model - last_hidden_state = unwrapped_model.model( - input_ids=input_ids, attention_mask=attention_mask - ).last_hidden_state + + # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} + + if pixel_values is not None: + model_inputs["pixel_values"] = pixel_values + if image_grid_thw is not None: + model_inputs["image_grid_thw"] = image_grid_thw + + last_hidden_state = unwrapped_model.model(**model_inputs).last_hidden_state + # Exclude the last value: it corresponds to the next token pred last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H) - if logits_to_keep is not None: - last_hidden_state = last_hidden_state[ - :, -logits_to_keep:, : - ] # (B, logits_to_keep, H) + # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. + last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H) return last_hidden_state # Get the per-token log probabilities for the completions for the model and the reference model def _get_per_token_logps( - self, model, input_ids, attention_mask, logits_to_keep, batch_size=None + self, + model, + input_ids, + attention_mask, + logits_to_keep, + batch_size=None, + pixel_values=None, + image_grid_thw=None, ) -> torch.Tensor: batch_size = batch_size or input_ids.size( 0 ) # Chunk inputs into smaller batches to reduce memory peak all_logps = [] + for i in range(0, input_ids.size(0), batch_size): input_ids_batch = input_ids[i : i + batch_size] attention_mask_batch = attention_mask[i : i + batch_size] - logits = model( - input_ids=input_ids_batch, - attention_mask=attention_mask_batch, - logits_to_keep=logits_to_keep + 1, - ).logits - logits = logits[ - :, :-1, : - ] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + + # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) + model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} + + if image_grid_thw is not None and pixel_values is not None: + model_inputs["pixel_values"] = pixel_values[i : i + batch_size] + model_inputs["image_grid_thw"]= image_grid_thw[i : i + batch_size] + model_inputs["pixel_values"] = torch.cat(model_inputs["pixel_values"], dim=0) + model_inputs["image_grid_thw"] = model_inputs["image_grid_thw"].reshape(-1, *model_inputs["image_grid_thw"].shape[2:]) + elif pixel_values is not None: + model_inputs["pixel_values"] = pixel_values[i : i + batch_size] + + # Only add logits_to_keep if the model supports it + if "logits_to_keep" in self.model_kwarg_keys: + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + model_inputs["logits_to_keep"] = logits_to_keep + 1 + + logits = model(**model_inputs).logits + + # Exclude the last value: it corresponds to the next token pred + logits = logits[:, :-1, :] # (B, L-1, H) input_ids_batch = input_ids_batch[:, -logits_to_keep:] - # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. - # See https://github.com/huggingface/trl/issues/2770 - logits = logits[:, -logits_to_keep:] + # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. + logits = logits[:, -logits_to_keep:, :] # (B, logits_to_keep, H) # Divide logits by sampling temperature. # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details logits = logits / self.temperature @@ -770,6 +835,7 @@ def _get_per_token_logps( logits, input_ids_batch ) # compute logprobs for the input tokens all_logps.append(logps) + return torch.cat(all_logps, dim=0) def _move_model_to_vllm(self): @@ -921,41 +987,96 @@ def _ids_to_tensors( ids = torch.stack(ids, dim=0) mask = torch.stack(mask, dim=0) return {"ids": ids, "mask": mask} - - def _gather_batch_data( - self, batch_offset: int = 0 - ) -> Tuple[List[Any], List[Any], List[Any], List[Any]]: + + def _gather_batch_data(self, batch_offset: int = 0): """ - Gather batch data from all processes. - - Args: - batch_offset: 0 for current batch, >0 for future batches (peek ahead) - - Returns: - Tuple of (all_prompts, all_answers, all_tasks) + Gather batch data from all processes and convert PIL images (single or multiple) + in prompts to base64 image_url. + + Handles: + - 'image': single PIL.Image + - 'images': list of PIL.Image objects (matched in order to placeholders) + - placeholders {"type": "image"} + - existing image_url (skipped) + - direct PIL.Image objects inside content """ + batches = self._async_dataloader.peek_ahead(batch_offset) - - if batch_offset == 0: - batch = batches[0] if batches else None - else: - batch = batches[batch_offset - 1] if batches else None - + + if not batches: + return [], [], [], [] + + batch = ( + batches[0] + if batch_offset == 0 + else batches[batch_offset - 1] if batch_offset - 1 < len(batches) else None + ) + if batch is None: return [], [], [], [] - + if isinstance(batch, dict): batch = [batch] - - # Gather batch data from all processes - prompts = [x["prompt"] for x in batch] - answers = [x["answer"] for x in batch] + + prompts = [] + + for x in batch: + prompt = x.get("prompt", []) + single_image = x.get("image", None) + multiple_images = x.get("images", []) + img_index = 0 # track which image in x["images"] we're up to + + for message in prompt: + content = message.get("content", []) + if not isinstance(content, list): + continue + + for i, c in enumerate(content): + if hasattr(c, "save"): + img_url = pil_to_base64_url(c) + content[i] = { + "type": "image_url", + "image_url": {"url": img_url}, + } + continue + + if not isinstance(c, dict): + continue + + ctype = c.get("type") + + if ctype == "image_url": #already an image_url -> skip + continue + + if ctype == "image": # placeholder and we have an image list + # Prefer multiple_images if available + if multiple_images and img_index < len(multiple_images): + pil_img = multiple_images[img_index] + img_index += 1 + elif single_image is not None: + pil_img = single_image + else: + pil_img = None + + if pil_img is not None: + img_url = pil_to_base64_url(pil_img) + c.clear() + c.update({ + "type": "image_url", + "image_url": {"url": img_url}, + }) + + prompts.append(prompt) + + answers = [x.get("answer") for x in batch] tasks = [x.get("task", "default") for x in batch] infos = [x.get("info", {}) for x in batch] + all_prompts = gather_object(prompts) all_answers = gather_object(answers) all_tasks = gather_object(tasks) all_infos = gather_object(infos) + return all_prompts, all_answers, all_tasks, all_infos def _prepare_inputs( # type: ignore @@ -973,7 +1094,6 @@ def _prepare_inputs( # type: ignore self.accelerator.wait_for_everyone() # inputs = list of dicts for all gradient accumulation steps generate_every = self.gradient_accumulation_steps * self.num_iterations - # Check if we need to generate new completions if self._step % generate_every == 0 or self._buffered_inputs is None: # Update weights to vLLM if needed @@ -1020,6 +1140,7 @@ def _prepare_inputs( # type: ignore ) break batch_offset = batch_id - batch_id_to_retrieve + all_prompts, all_answers, all_tasks, all_infos = ( self._gather_batch_data(batch_offset) ) @@ -1028,17 +1149,18 @@ def _prepare_inputs( # type: ignore f"No prompts for batch {batch_id}, stopping batch generation" ) break - + + env_inputs = { + "prompt": all_prompts, + "answer": all_answers, + "task": all_tasks, + "info": all_infos, + } # Submit batch (main process only) if self.accelerator.is_main_process: request = BatchRequest( batch_id=batch_id, - env_inputs={ - "prompt": all_prompts, - "answer": all_answers, - "task": all_tasks, - "info": all_infos, - }, + env_inputs=env_inputs, processing_class=self.processing_class, mask_env_responses=self.mask_env_responses, max_seq_len=self.max_seq_len or -1, @@ -1065,6 +1187,7 @@ def _prepare_inputs( # type: ignore # Now retrieve the batch we need for this step if self.accelerator.is_main_process: # Get batch result + batch_result = self.async_generator.get_batch(batch_id_to_retrieve) processed_results = batch_result.processed_results @@ -1078,6 +1201,9 @@ def _prepare_inputs( # type: ignore "all_reward_dict": batch_result.all_reward_dict, "completions": batch_result.completions, "prompts": batch_result.prompts, + "pixel_values":processed_results.pixel_values, + "image_grid_thw":processed_results.image_grid_thw, + "answers": batch_result.answers, } else: broadcast_data = None @@ -1107,6 +1233,13 @@ def _prepare_inputs( # type: ignore # Now create tensors only for this process's slice input_ids_list = [] attention_mask_list = [] + pixel_values_list = [] + image_grid_list = [] + + has_images = any( + broadcast_data["pixel_values"][i] is not None + for i in range(process_slice.start, process_slice.stop) + ) for i in range(process_slice.start, process_slice.stop): input_ids_list.append( @@ -1123,14 +1256,33 @@ def _prepare_inputs( # type: ignore device=self.accelerator.device, ) ) + if has_images: + if broadcast_data["pixel_values"][i] is not None: + pixel_values_list.append( + torch.tensor(broadcast_data["pixel_values"][i], device=self.accelerator.device) + ) + image_grid_list.append( + torch.tensor(broadcast_data["image_grid_thw"][i], device=self.accelerator.device) + ) + else: + # If some examples have no image insert dummy with correct shape + pixel_values_list.append(torch.zeros_like(pixel_values_list[0])) + image_grid_list.append(torch.zeros_like(image_grid_list[0])) input_ids = pad( input_ids_list, - padding_value=self.processing_class.pad_token_id, # type: ignore + padding_value=self.pad_token_id, # type: ignore padding_side="right", ) # type: ignore attention_mask = pad(attention_mask_list, padding_side="right") # type: ignore + if has_images: + pixel_values = pixel_values_list + image_grid_thw = torch.stack(image_grid_list, dim=0) + else : + pixel_values = None + image_grid_thw = None + # Truncate if needed if self.max_seq_len is not None and input_ids.size(1) > self.max_seq_len: input_ids = input_ids[:, -self.max_seq_len :] @@ -1152,6 +1304,7 @@ def _prepare_inputs( # type: ignore all_prompts=broadcast_data["prompts"], all_completions=broadcast_data["completions"], all_reward_dict=broadcast_data["all_reward_dict"], + all_answers=broadcast_data["answers"], ) # Log completion metrics using full batch data on CPU to save memory @@ -1168,8 +1321,10 @@ def _prepare_inputs( # type: ignore self.model, input_ids, attention_mask, - logits_to_keep, - batch_size=self.per_device_train_batch_size, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + logits_to_keep=logits_to_keep, + batch_size=self.per_device_train_batch_size ) # Concatenate all data for shuffling @@ -1179,9 +1334,12 @@ def _prepare_inputs( # type: ignore "old_per_token_logps": old_per_token_logps, "advantages": advantages, } + if has_images: + full_batch["pixel_values"] = pixel_values + full_batch["image_grid_thw"] = image_grid_thw # Shuffle and split for gradient accumulation - full_batch = shuffle_tensor_dict(full_batch) + full_batch = shuffle_dict_with_lists(full_batch) self._buffered_inputs = split_tensor_dict( full_batch, self.gradient_accumulation_steps ) @@ -1226,17 +1384,19 @@ def compute_loss( # type: ignore completion_mask = attention_mask[:, 1:] logits_to_keep = completion_mask.size(1) per_token_logps = self._get_per_token_logps( - model, input_ids, attention_mask, logits_to_keep + model, + input_ids, + attention_mask, + logits_to_keep, + pixel_values=inputs.get("pixel_values"), + image_grid_thw=inputs.get("image_grid_thw"), ) # Compute the loss advantages = inputs["advantages"] # When using num_iterations == 1, old_per_token_logps == per_token_logps, # so we can skip it's computation (see _generate_and_score_completions) and use per_token_logps.detach() instead. - old_per_token_logps = ( - per_token_logps.detach() - if inputs["old_per_token_logps"] is None - else inputs["old_per_token_logps"] - ) + old_per_token_logps = inputs.get("old_per_token_logps") + old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps coef_1 = torch.exp(per_token_logps - old_per_token_logps) coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) @@ -1257,12 +1417,12 @@ def compute_loss( # type: ignore with torch.no_grad(): if self.ref_model is not None: ref_per_token_logps = self._get_per_token_logps( - self.ref_model, input_ids, attention_mask, logits_to_keep + self.ref_model, input_ids, attention_mask, logits_to_keep, pixel_values=inputs.get("pixel_values"),image_grid_thw=inputs.get("image_grid_thw"), ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): # type: ignore ref_per_token_logps = self._get_per_token_logps( - self.model, input_ids, attention_mask, logits_to_keep + self.model, input_ids, attention_mask, logits_to_keep, pixel_values=inputs.get("pixel_values"),image_grid_thw=inputs.get("image_grid_thw"), ) per_token_kl = ( torch.exp(ref_per_token_logps - per_token_logps) @@ -1344,7 +1504,7 @@ def _sanitize_tool_calls( msg.pop("tool_call_id") return completion - def evaluate( + def evaluate( # TODO : check for images self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval", **kwargs ): """ @@ -1494,11 +1654,12 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: self._metrics[mode].clear() if self.accelerator.is_main_process and self.log_completions: - if len(self._textual_logs["prompt"]) > 0: + if len(self._logs["prompt"]) > 0: print_prompt_completions_sample( - self._textual_logs["prompt"], - self._textual_logs["completion"], - self._textual_logs["rewards"]["reward"], + self._logs["prompt"], + self._logs["completion"], + self._logs["rewards"]["reward"], + self._logs["answers"], self.state.global_step, ) @@ -1511,25 +1672,46 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: table = { "step": [str(self.state.global_step)] - * len(self._textual_logs["prompt"]), - "prompt": list(self._textual_logs["prompt"]), + * len(self._logs["prompt"]), + "prompt": list(self._logs["prompt"]), "completion": [ self._sanitize_tool_calls(c) - for c in self._textual_logs["completion"] + for c in self._logs["completion"] ], - **{k: list(v) for k, v in self._textual_logs["rewards"].items()}, + "answer" : list(self._logs["answers"]), + **{k: list(v) for k, v in self._logs["rewards"].items()}, } + + if self._logs["image"]: + table["image"] = [] + for img in self._logs["image"]: + if img is not None: + table["image"].append(wandb.Image(img)) + else: + table["image"].append(None) + if len(table["prompt"]) > 0: + all_images = [extract_images(p) for p in table["prompt"]] # list of lists + wandb_images = [[wandb.Image(img) for img in imgs] for imgs in all_images] + + if any(len(imgs) > 0 for imgs in wandb_images): + table["images"] = wandb_images + + table["prompt"] = [serialize_for_wandb(p) for p in table["prompt"]] + table["completion"] = [serialize_for_wandb(c) for c in table["completion"]] + df = pd.DataFrame(table) + if self.wandb_log_unique_prompts: df = df.drop_duplicates(subset=["prompt"]) + wandb.log({"completions": wandb.Table(dataframe=df)}) # Clear the textual logs after logging - self._textual_logs["prompt"].clear() - self._textual_logs["completion"].clear() - for key in self._textual_logs["rewards"]: - self._textual_logs["rewards"][key].clear() + self._logs["prompt"].clear() + self._logs["completion"].clear() + for key in self._logs["rewards"]: + self._logs["rewards"][key].clear() def _log_reward_metrics_primary( self, @@ -1566,23 +1748,39 @@ def _log_textual_data_primary( all_prompts: List[Union[str, List[Dict[str, Any]]]], all_completions: List[Union[str, List[Dict[str, Any]]]], all_reward_dict: Dict[str, Any], + all_answers : List[Any] ) -> None: """ Log textual data for wandb (PRIMARY PROCESS ONLY). This logs the full batch of prompts, completions, and rewards. """ - self._textual_logs["prompt"].extend(all_prompts) - self._textual_logs["completion"].extend(all_completions) + self._logs["prompt"].extend(all_prompts) + self._logs["completion"].extend(all_completions) + self._logs["answers"].extend(all_answers) # Log all reward scores - both individual functions and consolidated for reward_key in all_reward_dict: reward_values = all_reward_dict[reward_key] - self._textual_logs["rewards"][reward_key].extend( + self._logs["rewards"][reward_key].extend( reward_values.tolist() if isinstance(reward_values, torch.Tensor) else reward_values ) + def _log_image_data_primary(self, all_images: List[Any]) -> None: + """ + Log images for wandb (PRIMARY PROCESS ONLY). + Converts each image to wandb.Image and stores it in the _logs deque. + """ + if "image" not in self._logs: + self._logs["image"] = deque(maxlen=self._logs_maxlen) + + for img in all_images: + if img is not None: + self._logs["image"].append(wandb.Image(img)) + else: + self._logs["image"].append(None) + def _log_completion_metrics_primary( self, mode: str, @@ -1619,7 +1817,7 @@ def _log_completion_metrics_primary( term_lengths = [] for comp_ids, comp_mask in zip(all_completion_ids, all_completion_mask): has_eos = any( - token == self.processing_class.eos_token_id # type: ignore + token == self.eos_token_id # type: ignore for token, mask in zip(comp_ids, comp_mask) if mask ) diff --git a/verifiers/types.py b/verifiers/types.py index b6c53850c..6dc3068ce 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -3,6 +3,7 @@ Awaitable, Callable, Literal, + Optional, TypedDict, ) @@ -54,8 +55,8 @@ class GenerateInputs(BaseModel): class GenerateOutputs(BaseModel): """Pydantic model for generation outputs.""" - prompt: list[Messages] - completion: list[Messages] + prompt: list[list[dict]] + completion: list[list[dict]] answer: list[str] state: list[State] info: list[Info] @@ -83,6 +84,8 @@ class ProcessedOutputs(BaseModel): prompt_ids: list[list[int]] prompt_mask: list[list[int]] + image_grid_thw: Optional[list[Optional[list[list[int]]]]] = None + pixel_values: Optional[list[Optional[list[list[float]]]]] = None completion_ids: list[list[int]] completion_mask: list[list[int]] completion_logprobs: list[list[float]] diff --git a/verifiers/utils/image_utils.py b/verifiers/utils/image_utils.py new file mode 100644 index 000000000..ef061935c --- /dev/null +++ b/verifiers/utils/image_utils.py @@ -0,0 +1,20 @@ +import base64 +from io import BytesIO +from PIL import Image + +def _base64_to_pil(data_uri: str) -> Image.Image: + """Convert a base64 data URI (data:image/...;base64,...) to a PIL Image.""" + if not data_uri.startswith("data:image"): + raise ValueError(f"Expected base64 image data URI, got: {data_uri[:30]}") + header, b64data = data_uri.split(",", 1) + image_data = base64.b64decode(b64data) + return Image.open(BytesIO(image_data)).convert("RGB") + +def pil_to_base64_url(pil_image) -> str: + """ + Convert a PIL image to a base64 URL string suitable for OpenAI/vLLM messages. + """ + buffered = BytesIO() + pil_image.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") + return f"data:image/png;base64,{img_str}" \ No newline at end of file diff --git a/verifiers/utils/logging_utils.py b/verifiers/utils/logging_utils.py index fba11437a..03998b323 100644 --- a/verifiers/utils/logging_utils.py +++ b/verifiers/utils/logging_utils.py @@ -1,6 +1,7 @@ import json import logging import sys +import copy from rich.console import Console from rich.panel import Panel @@ -10,6 +11,79 @@ from verifiers.types import Messages from collections.abc import Mapping +import base64 +from io import BytesIO +from PIL import Image + + +def extract_images(obj): + """ + Extract and decode Base64 images into a list of PIL.Image objects. + """ + images = [] + + def _extract(o): + if isinstance(o, dict): + for v in o.values(): + _extract(v) + if "image_url" in o and isinstance(o["image_url"], dict): + url = o["image_url"].get("url") + if isinstance(url, str) and url.startswith("data:image/"): + try: + header, b64_data = url.split(",", 1) + image_data = base64.b64decode(b64_data) + image = Image.open(BytesIO(image_data)) + images.append(image) + except Exception: + pass + elif isinstance(o, list): + for v in o: + _extract(v) + + _extract(obj) + return images + +def sanitize_and_serialize(obj): + """ + Sanitize Base64 images and convert nested dict/list to string for WandB. + """ + if isinstance(obj, dict): + obj = {k: sanitize_and_serialize(v) for k, v in obj.items()} + if "image_url" in obj and isinstance(obj["image_url"], dict): + url = obj["image_url"].get("url") + if isinstance(url, str) and url.startswith("data:image/"): + obj["image_url"]["url"] = "" + return obj + elif isinstance(obj, list): + return [sanitize_and_serialize(x) for x in obj] + else: + return obj + +def serialize_for_wandb(obj): + sanitized = sanitize_and_serialize(obj) + return json.dumps(sanitized, ensure_ascii=False) + + +def sanitize_message_for_logging(msg): + """ + Recursively sanitize a message dict, removing Base64 data from image URLs. + """ + msg = copy.deepcopy(msg) + + if isinstance(msg, dict): + for k, v in msg.items(): + if k == "image_url" and isinstance(v, dict) and "url" in v: + url = v["url"] + if url.startswith("data:image/"): + v["url"] = "" + else: + msg[k] = sanitize_message_for_logging(v) + + elif isinstance(msg, list): + msg = [sanitize_message_for_logging(x) for x in msg] + + return msg + def setup_logging( level: str = "INFO", @@ -86,7 +160,9 @@ def _format_messages(messages) -> Text: style = "bright_cyan" if role == "assistant" else "bright_magenta" out.append(f"{role}: ", style="bold") - out.append(content, style=style) + + safe_content = sanitize_message_for_logging(content) + out.append(str(safe_content), style=style) for tc in msg.get("tool_calls") or []: # treat None as empty list payload = _normalize_tool_call(tc) diff --git a/verifiers/utils/message_utils.py b/verifiers/utils/message_utils.py index 12b224156..65c4a82bf 100644 --- a/verifiers/utils/message_utils.py +++ b/verifiers/utils/message_utils.py @@ -49,45 +49,53 @@ def messages_to_printable(messages: Messages) -> Messages: def cleanup_message(message: ChatMessage) -> ChatMessage: - new_message = {} - new_message["role"] = message["role"] + new_message = { + "role": message["role"], + "content": [] + } + if "tool_calls" in message: new_message["tool_calls"] = message["tool_calls"] - if "tool_call_id" in message: new_message["tool_call_id"] = message["tool_call_id"] - new_message["content"] = [] content = message.get("content") if content is None: return cast(ChatMessage, new_message) + if isinstance(content, str): new_message["content"] = content - else: - for c in content: - new_c = c.copy() - c_dict = dict(c) - if "image_url" in c_dict and "type" in c_dict and c_dict["type"] == "text": - new_c.pop("image_url") - new_message["content"].append(new_c) - elif ( - "image_url" in c_dict - and "type" in c_dict - and c_dict["type"] == "image_url" - ): - new_c.pop("text") - new_message["content"].append(new_c) - elif str(c_dict.get("type", "")).startswith("input_audio"): - # Ensure input_audio content blocks only have the required fields - clean_c = { + return cast(ChatMessage, new_message) + + for c in content: + c_dict = dict(c) + c_type = c_dict.get("type") + + if c_type == "text": + if c_dict.get("text") is not None: + new_message["content"].append({ + "type": "text", + "text": c_dict["text"] + }) + + elif c_type == "image_url": + if "image_url" in c_dict: + new_message["content"].append({ + "type": "image_url", + "image_url": c_dict["image_url"] + }) + + elif c_type == "input_audio": + if "input_audio" in c_dict: + new_message["content"].append({ "type": "input_audio", - "input_audio": c_dict.get("input_audio", {}), - } - new_message["content"].append(clean_c) - else: - new_message["content"].append(new_c) - return cast(ChatMessage, new_message) + "input_audio": c_dict["input_audio"] + }) + else: + new_message["content"].append(c_dict) + + return cast(ChatMessage, new_message) def cleanup_messages(messages: Messages) -> Messages: if isinstance(messages, str): diff --git a/verifiers/utils/model_utils.py b/verifiers/utils/model_utils.py index 3003dea17..47f3e5748 100644 --- a/verifiers/utils/model_utils.py +++ b/verifiers/utils/model_utils.py @@ -1,6 +1,7 @@ from importlib.util import find_spec from typing import Any, Callable + import torch # type: ignore[unresolved-import] import torch.nn as nn # type: ignore[unresolved-import] from transformers import ( # type: ignore[unresolved-import] @@ -8,7 +9,6 @@ AutoTokenizer, ) - class _ForwardRedirection: """Implements the `forward-redirection`. @@ -110,4 +110,4 @@ def get_model_and_tokenizer( ) -> tuple[Any, Any]: model = get_model(model_name, use_liger, model_kwargs) tokenizer = get_tokenizer(model_name) - return model, tokenizer + return model, tokenizer \ No newline at end of file diff --git a/verifiers/utils/processor_utils.py b/verifiers/utils/processor_utils.py new file mode 100644 index 000000000..890101426 --- /dev/null +++ b/verifiers/utils/processor_utils.py @@ -0,0 +1,81 @@ +from verifiers.utils.image_utils import _base64_to_pil +from typing import Union, List, Dict, Any, TYPE_CHECKING +from inspect import signature + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizerBase, ProcessorMixin + +def supports_images(obj): # work as a replacement for if isinstance(processing_class, ProcessorMixin), because we already type check with lazy importfrom inspect import signature + if callable(obj): + try: + sig = signature(obj) + return "images" in sig.parameters + except TypeError: + return False + return False + +def encode_chat_with_processor( + conversation: List[Dict], + processing_class: Union["PreTrainedTokenizerBase", "ProcessorMixin"], + add_generation_prompt: bool = False, + add_special_tokens: bool = False, +) -> tuple[list[int], Any, Any]: + """ + Apply chat template and return token IDs, handling both tokenizer and processor. + Supports base64-encoded images in the conversation. + """ + + if supports_images(processing_class): + prompt_text = processing_class.apply_chat_template( + conversation=conversation, + add_generation_prompt=add_generation_prompt, + tokenize=False, + ) + images = [] + for msg in conversation: + for c in msg.get("content", []): + if c.get("type") == "image_url": + pil_img = _base64_to_pil(c["image_url"]["url"]) + images.append(pil_img) + + inputs = processing_class( + text=[prompt_text], + images=images if images else None, + return_tensors="pt", + add_special_tokens=add_special_tokens, + ) + input_ids_list = inputs["input_ids"][0].tolist() + image_grid_list = inputs["image_grid_thw"].tolist() + pixel_values_list = inputs["pixel_values"].tolist() + + return input_ids_list, image_grid_list, pixel_values_list + + else: + prompt_ids : List[int] = processing_class.apply_chat_template( + conversation=conversation, + add_generation_prompt=add_generation_prompt, + ) + return prompt_ids,None,None + +def encode_text_with_processor( + text: str, + processing_class: Union["PreTrainedTokenizerBase", "ProcessorMixin"], +) -> tuple[list[int], Any, Any]: + """ + Encode plain text and return token IDs, handling both tokenizer and processor. + """ + if supports_images(processing_class): + inputs = processing_class( + text=[text], + images=None, + return_tensors="pt", + ) + input_ids = inputs["input_ids"][0].tolist() + image_grid = inputs.get("image_grid_thw", [None]).tolist() + pixel_values = inputs.get("pixel_values", [None]).tolist() + return input_ids, image_grid, pixel_values + else: + prompt_ids: list[int] = processing_class.encode( + text + ) + return prompt_ids, None, None \ No newline at end of file