From 2c7fdf875569a647812bf7cb673beac2e4d29505 Mon Sep 17 00:00:00 2001 From: Jonah Samost Date: Mon, 20 Apr 2026 13:16:20 -0700 Subject: [PATCH 1/3] working lora + gemma4 --- trl/chat_template_utils.py | 38 ++ trl/chat_templates/gemma4.jinja | 344 ++++++++++++++++++ .../async_grpo/async_grpo_config.py | 27 ++ .../async_grpo/async_grpo_trainer.py | 60 ++- .../async_grpo/async_rollout_worker.py | 29 +- 5 files changed, 483 insertions(+), 15 deletions(-) create mode 100644 trl/chat_templates/gemma4.jinja diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 35abd3aa696..118b5009b1b 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -305,6 +305,40 @@ def clone_chat_template( }, } +gemma4_schema = { + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "thinking": {"type": "string"}, + "content": {"type": "string"}, + "tool_calls": { + "x-regex-iterator": r"<\|tool_call>(.*?)", + "type": "array", + "items": { + "type": "object", + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "x-regex": r"call\:(?P\w+)(?P\{.*\})", + "properties": { + "name": { + "type": "string", + }, + "arguments": { + "type": "object", + "x-parser": "gemma4-tool-call", + "additionalProperties": {}, + }, + }, + }, + }, + }, + }, + }, + "x-regex": r"(\<\|channel\>thought\n(?P.*?)\)?(?P(?:(?!\<\|tool_call\>).)+)?(?P\<\|tool_call\>.*\)?", +} + deepseekv3_chat_template = (_CHAT_TEMPLATES_DIR / "deepseekv3.jinja").read_text() @@ -328,6 +362,8 @@ def clone_chat_template( qwen3_5_chat_template_4b_and_above = (_CHAT_TEMPLATES_DIR / "qwen3_5_4b_and_above.jinja").read_text() +gemma4_chat_template = (_CHAT_TEMPLATES_DIR / "gemma4.jinja").read_text() + ProcessingClassT = TypeVar("ProcessingClassT", PreTrainedTokenizer, ProcessorMixin) @@ -382,6 +418,8 @@ def add_response_schema(processing_class: ProcessingClassT) -> ProcessingClassT: tokenizer.response_schema = qwen3_schema elif chat_template in [qwen3_5_chat_template_2b_and_below, qwen3_5_chat_template_4b_and_above]: tokenizer.response_schema = qwen3_5_schema + elif chat_template in [gemma4_chat_template]: + tokenizer.response_schema = gemma4_schema else: raise ValueError( "Unrecognized chat template, failed to add response schema. Please manually set the response schema on " diff --git a/trl/chat_templates/gemma4.jinja b/trl/chat_templates/gemma4.jinja new file mode 100644 index 00000000000..07e50e69a8c --- /dev/null +++ b/trl/chat_templates/gemma4.jinja @@ -0,0 +1,344 @@ +{%- macro format_parameters(properties, required) -%} + {%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%} + {%- set ns = namespace(found_first=false) -%} + {%- for key, value in properties | dictsort -%} + {%- set add_comma = false -%} + {%- if key not in standard_keys -%} + {%- if ns.found_first %},{% endif -%} + {%- set ns.found_first = true -%} + {{ key }}:{ + {%- if value['description'] -%} + description:<|"|>{{ value['description'] }}<|"|> + {%- set add_comma = true -%} + {%- endif -%} + {%- if value['type'] | upper == 'STRING' -%} + {%- if value['enum'] -%} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + enum:{{ format_argument(value['enum']) }} + {%- endif -%} + {%- elif value['type'] | upper == 'ARRAY' -%} + {%- if value['items'] is mapping and value['items'] -%} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + items:{ + {%- set ns_items = namespace(found_first=false) -%} + {%- for item_key, item_value in value['items'] | dictsort -%} + {%- if item_value is not none -%} + {%- if ns_items.found_first %},{% endif -%} + {%- set ns_items.found_first = true -%} + {%- if item_key == 'properties' -%} + properties:{ + {%- if item_value is mapping -%} + {{- format_parameters(item_value, value['items']['required'] | default([])) -}} + {%- endif -%} + } + {%- elif item_key == 'required' -%} + required:[ + {%- for req_item in item_value -%} + <|"|>{{- req_item -}}<|"|> + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + ] + {%- elif item_key == 'type' -%} + {%- if item_value is string -%} + type:{{ format_argument(item_value | upper) }} + {%- else -%} + type:{{ format_argument(item_value | map('upper') | list) }} + {%- endif -%} + {%- else -%} + {{ item_key }}:{{ format_argument(item_value) }} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + } + {%- endif -%} + {%- endif -%} + {%- if value['nullable'] %} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + nullable:true + {%- endif -%} + {%- if value['type'] | upper == 'OBJECT' -%} + {%- if value['properties'] is defined and value['properties'] is mapping -%} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + properties:{ + {{- format_parameters(value['properties'], value['required'] | default([])) -}} + } + {%- elif value is mapping -%} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + properties:{ + {{- format_parameters(value, value['required'] | default([])) -}} + } + {%- endif -%} + {%- if value['required'] -%} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + required:[ + {%- for item in value['required'] | default([]) -%} + <|"|>{{- item -}}<|"|> + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + ] + {%- endif -%} + {%- endif -%} + {%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%} + type:<|"|>{{ value['type'] | upper }}<|"|>} + {%- endif -%} + {%- endfor -%} +{%- endmacro -%} +{%- macro format_function_declaration(tool_data) -%} + declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|> + {%- set params = tool_data['function']['parameters'] -%} + {%- if params -%} + ,parameters:{ + {%- if params['properties'] -%} + properties:{ {{- format_parameters(params['properties'], params['required']) -}} }, + {%- endif -%} + {%- if params['required'] -%} + required:[ + {%- for item in params['required'] -%} + <|"|>{{- item -}}<|"|> + {{- ',' if not loop.last -}} + {%- endfor -%} + ], + {%- endif -%} + {%- if params['type'] -%} + type:<|"|>{{- params['type'] | upper -}}<|"|>} + {%- endif -%} + {%- endif -%} + {%- if 'response' in tool_data['function'] -%} + {%- set response_declaration = tool_data['function']['response'] -%} + ,response:{ + {%- if response_declaration['description'] -%} + description:<|"|>{{- response_declaration['description'] -}}<|"|>, + {%- endif -%} + {%- if response_declaration['type'] | upper == 'OBJECT' -%} + type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>} + {%- endif -%} + {%- endif -%} + } +{%- endmacro -%} +{%- macro format_argument(argument, escape_keys=True) -%} + {%- if argument is string -%} + {{- '<|"|>' + argument + '<|"|>' -}} + {%- elif argument is boolean -%} + {{- 'true' if argument else 'false' -}} + {%- elif argument is mapping -%} + {{- '{' -}} + {%- set ns = namespace(found_first=false) -%} + {%- for key, value in argument | dictsort -%} + {%- if ns.found_first %},{% endif -%} + {%- set ns.found_first = true -%} + {%- if escape_keys -%} + {{- '<|"|>' + key + '<|"|>' -}} + {%- else -%} + {{- key -}} + {%- endif -%} + :{{- format_argument(value, escape_keys=escape_keys) -}} + {%- endfor -%} + {{- '}' -}} + {%- elif argument is sequence -%} + {{- '[' -}} + {%- for item in argument -%} + {{- format_argument(item, escape_keys=escape_keys) -}} + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + {{- ']' -}} + {%- else -%} + {{- argument -}} + {%- endif -%} +{%- endmacro -%} +{%- macro strip_thinking(text) -%} + {%- set ns = namespace(result='') -%} + {%- for part in text.split('') -%} + {%- if '<|channel>' in part -%} + {%- set ns.result = ns.result + part.split('<|channel>')[0] -%} + {%- else -%} + {%- set ns.result = ns.result + part -%} + {%- endif -%} + {%- endfor -%} + {{- ns.result | trim -}} +{%- endmacro -%} + +{%- macro format_tool_response_block(tool_name, response) -%} + {{- '<|tool_response>' -}} + {%- if response is mapping -%} + {{- 'response:' + tool_name + '{' -}} + {%- for key, value in response | dictsort -%} + {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + {%- if not loop.last %},{% endif -%} + {%- endfor -%} + {{- '}' -}} + {%- else -%} + {{- 'response:' + tool_name + '{value:' + format_argument(response, escape_keys=False) + '}' -}} + {%- endif -%} + {{- '' -}} +{%- endmacro -%} + +{%- set ns = namespace(prev_message_type=None) -%} +{%- set loop_messages = messages -%} +{{- bos_token -}} +{#- Handle System/Tool Definitions Block -#} +{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%} + {{- '<|turn>system\n' -}} + + {#- Inject Thinking token at the very top of the FIRST system turn -#} + {%- if enable_thinking is defined and enable_thinking -%} + {{- '<|think|>\n' -}} + {%- set ns.prev_message_type = 'think' -%} + {%- endif -%} + + {%- if messages[0]['role'] in ['system', 'developer'] -%} + {{- messages[0]['content'] | trim -}} + {%- set loop_messages = messages[1:] -%} + {%- endif -%} + + {%- if tools -%} + {%- for tool in tools %} + {{- '<|tool>' -}} + {{- format_function_declaration(tool) | trim -}} + {{- '' -}} + {%- endfor %} + {%- set ns.prev_message_type = 'tool' -%} + {%- endif -%} + + {{- '\n' -}} +{%- endif %} + +{#- Pre-scan: find last user message index for reasoning guard -#} +{%- set ns_turn = namespace(last_user_idx=-1) -%} +{%- for i in range(loop_messages | length) -%} + {%- if loop_messages[i]['role'] == 'user' -%} + {%- set ns_turn.last_user_idx = i -%} + {%- endif -%} +{%- endfor -%} + +{#- Loop through messages -#} +{%- for message in loop_messages -%} + {%- if message['role'] != 'tool' -%} + {%- set ns.prev_message_type = None -%} + {%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%} + {#- Detect continuation: suppress duplicate <|turn>model when previous non-tool message was also assistant -#} + {%- set prev_nt = namespace(role=None, found=false) -%} + {%- if loop.index0 > 0 -%} + {%- for j in range(loop.index0 - 1, -1, -1) -%} + {%- if not prev_nt.found -%} + {%- if loop_messages[j]['role'] != 'tool' -%} + {%- set prev_nt.role = loop_messages[j]['role'] -%} + {%- set prev_nt.found = true -%} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + {%- set continue_same_model_turn = (role == 'model' and prev_nt.role == 'assistant') -%} + {%- if not continue_same_model_turn -%} + {{- '<|turn>' + role + '\n' }} + {%- endif -%} + + {#- Render reasoning/reasoning_content as thinking channel -#} + {%- set thinking_text = message.get('reasoning') or message.get('reasoning_content') -%} + {%- if thinking_text and loop.index0 > ns_turn.last_user_idx and message.get('tool_calls') -%} + {{- '<|channel>thought\n' + thinking_text + '\n' -}} + {%- endif -%} + + {%- if message['tool_calls'] -%} + {%- for tool_call in message['tool_calls'] -%} + {%- set function = tool_call['function'] -%} + {{- '<|tool_call>call:' + function['name'] + '{' -}} + {%- if function['arguments'] is mapping -%} + {%- set ns_args = namespace(found_first=false) -%} + {%- for key, value in function['arguments'] | dictsort -%} + {%- if ns_args.found_first %},{% endif -%} + {%- set ns_args.found_first = true -%} + {{- key -}}:{{- format_argument(value, escape_keys=False) -}} + {%- endfor -%} + {%- elif function['arguments'] is string -%} + {{- function['arguments'] -}} + {%- endif -%} + {{- '}' -}} + {%- endfor -%} + {%- set ns.prev_message_type = 'tool_call' -%} + {%- endif -%} + + {%- set ns_tr_out = namespace(flag=false) -%} + {%- if message.get('tool_responses') -%} + {#- Legacy: tool_responses embedded on the assistant message (Google/Gemma native) -#} + {%- for tool_response in message['tool_responses'] -%} + {{- format_tool_response_block(tool_response['name'] | default('unknown'), tool_response['response']) -}} + {%- set ns_tr_out.flag = true -%} + {%- set ns.prev_message_type = 'tool_response' -%} + {%- endfor -%} + {%- elif message.get('tool_calls') -%} + {#- OpenAI Chat Completions: forward-scan consecutive role:tool messages -#} + {%- set ns_tool_scan = namespace(stopped=false) -%} + {%- for k in range(loop.index0 + 1, loop_messages | length) -%} + {%- if ns_tool_scan.stopped -%} + {%- elif loop_messages[k]['role'] != 'tool' -%} + {%- set ns_tool_scan.stopped = true -%} + {%- else -%} + {%- set follow = loop_messages[k] -%} + {#- Resolve tool_call_id to function name -#} + {%- set ns_tname = namespace(name=follow.get('name') | default('unknown')) -%} + {%- for tc in message['tool_calls'] -%} + {%- if tc.get('id') == follow.get('tool_call_id') -%} + {%- set ns_tname.name = tc['function']['name'] -%} + {%- endif -%} + {%- endfor -%} + {#- Handle content as string or content-parts array -#} + {%- set tool_body = follow.get('content') -%} + {%- if tool_body is string -%} + {{- format_tool_response_block(ns_tname.name, tool_body) -}} + {%- elif tool_body is sequence and tool_body is not string -%} + {%- set ns_txt = namespace(s='') -%} + {%- for part in tool_body -%} + {%- if part.get('type') == 'text' -%} + {%- set ns_txt.s = ns_txt.s + (part.get('text') | default('')) -%} + {%- endif -%} + {%- endfor -%} + {{- format_tool_response_block(ns_tname.name, ns_txt.s) -}} + {%- else -%} + {{- format_tool_response_block(ns_tname.name, tool_body) -}} + {%- endif -%} + {%- set ns_tr_out.flag = true -%} + {%- set ns.prev_message_type = 'tool_response' -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + + {%- if message['content'] is string -%} + {%- if role == 'model' -%} + {{- strip_thinking(message['content']) -}} + {%- else -%} + {{- message['content'] | trim -}} + {%- endif -%} + {%- elif message['content'] is sequence -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'text' -%} + {%- if role == 'model' -%} + {{- strip_thinking(item['text']) -}} + {%- else -%} + {{- item['text'] | trim -}} + {%- endif -%} + {%- elif item['type'] == 'image' -%} + {{- '<|image|>' -}} + {%- set ns.prev_message_type = 'image' -%} + {%- elif item['type'] == 'audio' -%} + {{- '<|audio|>' -}} + {%- set ns.prev_message_type = 'audio' -%} + {%- elif item['type'] == 'video' -%} + {{- '<|video|>' -}} + {%- set ns.prev_message_type = 'video' -%} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + + {%- if ns.prev_message_type == 'tool_call' and not ns_tr_out.flag -%} + {{- '<|tool_response>' -}} + {%- elif not (ns_tr_out.flag and not message.get('content')) -%} + {{- '\n' -}} + {%- endif -%} + {%- endif -%} +{%- endfor -%} + +{%- if add_generation_prompt -%} + {%- if ns.prev_message_type != 'tool_response' and ns.prev_message_type != 'tool_call' -%} + {{- '<|turn>model\n' -}} + {%- endif -%} +{%- endif -%} \ No newline at end of file diff --git a/trl/experimental/async_grpo/async_grpo_config.py b/trl/experimental/async_grpo/async_grpo_config.py index 2afd760e7fc..8da09efac2d 100644 --- a/trl/experimental/async_grpo/async_grpo_config.py +++ b/trl/experimental/async_grpo/async_grpo_config.py @@ -185,6 +185,30 @@ class AsyncGRPOConfig(_BaseConfig): metadata={"help": "Number of training steps between weight synchronizations to the vLLM server."}, ) + # Parameters that control LoRA training and weight sync + use_lora: bool = field( + default=False, + metadata={ + "help": "Enable LoRA mode. When True, the model is loaded as a PEFT adapter (base model auto-resolved " + "from adapter_config.json), only LoRA weights are trained, and weight sync saves the adapter to disk " + "then tells vLLM to hot-reload via /v1/load_lora_adapter instead of streaming all weights over NCCL." + }, + ) + lora_adapter_path: str | None = field( + default=None, + metadata={ + "help": "Path to the PEFT LoRA adapter directory. Required when use_lora=True. This is where the " + "adapter is saved during weight sync and where vLLM reads it from." + }, + ) + lora_name: str = field( + default="sft", + metadata={ + "help": "The LoRA adapter name registered in vLLM (via --lora-modules name=path). Used both as the " + "'model' field in generation requests and as the adapter name in /v1/load_lora_adapter calls." + }, + ) + # Parameters that control the logging log_completions: bool = field( default=False, @@ -201,6 +225,9 @@ class AsyncGRPOConfig(_BaseConfig): def __post_init__(self): super().__post_init__() + if self.use_lora and not self.lora_adapter_path: + raise ValueError("lora_adapter_path is required when use_lora=True") + # Accelerator config: required for the async IterableDataset-backed dataloader to work correctly. # split_batches=True and dispatch_batches=True ensure that the main process drives the dataloader # and batches are broadcast to other processes rather than each process pulling independently. diff --git a/trl/experimental/async_grpo/async_grpo_trainer.py b/trl/experimental/async_grpo/async_grpo_trainer.py index a81dad5639f..4ac17a19edd 100644 --- a/trl/experimental/async_grpo/async_grpo_trainer.py +++ b/trl/experimental/async_grpo/async_grpo_trainer.py @@ -14,6 +14,7 @@ import math +import os import queue import textwrap import time @@ -354,15 +355,15 @@ def __init__( # Use the injected worker (e.g. a stub in tests). The queue is owned by the worker. self.rollout_worker = rollout_worker else: - # Collect weight metadata once — names/dtypes/shapes are fixed for the lifetime of training. - # DTensor.shape returns the global shape without triggering any all-gather. + # NCCL weight transfer needs full metadata; LoRA mode skips this entirely. weight_names, weight_dtype_names, weight_shapes = [], [], [] - for name, param in model.named_parameters(): - # DDP/FSDP1 wrapping, avoids vllm module not exist error - name = name.removeprefix("module.") - weight_names.append(name) - weight_dtype_names.append(str(param.dtype).split(".")[-1]) - weight_shapes.append(list(param.shape)) + if not self.args.use_lora: + for name, param in model.named_parameters(): + name = name.removeprefix("module.") # DDP/FSDP1 wrapping + weight_names.append(name) + weight_dtype_names.append(str(param.dtype).split(".")[-1]) + weight_shapes.append(list(param.shape)) + self.rollout_worker = AsyncRolloutWorker( model_name=model_name, dataset=train_dataset, @@ -384,6 +385,8 @@ def __init__( weight_names=weight_names, weight_dtype_names=weight_dtype_names, weight_shapes=weight_shapes, + use_lora=self.args.use_lora, + lora_name=self.args.lora_name, ) self.rollout_queue = self.rollout_worker.rollout_buffer else: @@ -579,6 +582,44 @@ def _streaming_iter(self): def _sync_weight(self): t0 = time.time() + + if self.args.use_lora: + self._sync_weight_lora(t0) + else: + self._sync_weight_nccl(t0) + + weight_sync_time_s = time.time() - t0 + self._metrics["train"]["weight_sync_time_s"].append(weight_sync_time_s) + logger.info(f"Weight sync: done. Total {weight_sync_time_s:.1f}s") + + def _sync_weight_lora(self, t0: float): + """LoRA sync: save adapter to disk, then tell vLLM to hot-reload it.""" + adapter_path = self.args.lora_adapter_path + lora_name = self.args.lora_name + + # Pause vLLM FIRST so no requests trigger lazy LoRA loading mid-write + if self.accelerator.is_main_process and self.rollout_worker: + self.rollout_worker.pause() + + self.accelerator.wait_for_everyone() + if self.accelerator.is_main_process: + unwrapped = self.accelerator.unwrap_model(self.model) + logger.info(f"Weight sync (LoRA): saving adapter to {adapter_path}...") + unwrapped.save_pretrained(adapter_path) + os.sync() + t_save = time.time() + logger.info(f"Weight sync (LoRA): save took {t_save - t0:.1f}s") + + self.accelerator.wait_for_everyone() + + if self.accelerator.is_main_process and self.rollout_worker: + self.rollout_worker.reload_lora(adapter_path, lora_name) + self.rollout_worker.resume() + self.model_version += 1 + self.rollout_worker.update_model_version(self.model_version) + + def _sync_weight_nccl(self, t0: float): + """Original NCCL path: stream all weights to vLLM.""" logger.info("Weight sync: pausing vLLM...") if self.accelerator.is_main_process and self.rollout_worker: self.rollout_worker.pause() @@ -604,9 +645,6 @@ def _sync_weight(self): self.rollout_worker.resume() self.model_version += 1 self.rollout_worker.update_model_version(self.model_version) - weight_sync_time_s = time.time() - t0 - self._metrics["train"]["weight_sync_time_s"].append(weight_sync_time_s) - logger.info(f"Weight sync: done. Total {weight_sync_time_s:.1f}s") def _inner_training_loop(self, *args, **kwargs): # Start the rollout worker here (not in __init__) so that checkpoint loading in Trainer.train() diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 4fd11312fd2..e49cd29adbb 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -107,12 +107,14 @@ def __init__( weight_names: list[str] | None = None, weight_dtype_names: list[str] | None = None, weight_shapes: list[list[int]] | None = None, + use_lora: bool = False, + lora_name: str | None = None, ): if not is_vllm_available(min_version="0.17.1"): raise ImportError( "vLLM >= 0.17.1 is required to use AsyncRolloutWorker. Install it with: pip install 'vllm>=0.17.1'" ) - self.model_name = model_name + self.lora_sync = use_lora self.max_tool_calling_iterations = max_tool_calling_iterations self.dataset = dataset self._dataset_iter = iter(dataset) @@ -127,6 +129,10 @@ def __init__( "is_checkpoint_format": True, } + # When LoRA sync is active, generation requests use the LoRA adapter name + # (e.g. "sft") while the tokenizer still loads from model_name (adapter dir). + self.model_name = lora_name if self.lora_sync else model_name + self.reward_funcs = reward_funcs self.reward_func_names = [f.__name__ for f in reward_funcs] self.num_generations = num_generations @@ -165,7 +171,7 @@ def __init__( self.chat_template_kwargs = chat_template_kwargs or {} self.log_completions = log_completions self.num_completions_to_print = num_completions_to_print - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) # Always use original path for tokenizer self.tokenizer = add_response_schema(self.tokenizer) # In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template # isn't, we replace it at initialization with a training-safe, prefix-preserving template. @@ -181,9 +187,12 @@ def __init__( self.model_version = 0 self.session = None - # Wait for the vLLM server and initialize NCCL weight transfer. self._wait_for_server_ready_sync(timeout_s=self.server_timeout) - self._init_weight_transfer() + if self.lora_sync: + logger.info("LoRA sync mode: skipping NCCL weight transfer init (will use save-to-disk + HTTP reload)") + self.model_update_group = None + else: + self._init_weight_transfer() def _wait_for_server_ready_sync(self, timeout_s: float = 240.0, poll_interval_s: float = 2.0) -> None: """Block until the vLLM server is healthy.""" @@ -296,6 +305,18 @@ def resume(self) -> None: requests.post(f"{self.vllm_server_url}/resume") logger.debug(f"[weight_sync] resume HTTP took {time.time() - t0:.1f}s") + def reload_lora(self, adapter_path: str, lora_name: str) -> None: + """Tell vLLM to hot-reload a LoRA adapter from disk.""" + t0 = time.time() + payload = { + "lora_name": lora_name, + "lora_path": adapter_path, + "load_inplace": True, + } + resp = requests.post(f"{self.vllm_server_url}/v1/load_lora_adapter", json=payload, timeout=120) + resp.raise_for_status() + logger.info(f"[weight_sync] LoRA reload ({lora_name} from {adapter_path}) took {time.time() - t0:.1f}s") + def send_weights(self, iterator) -> None: if self.model_update_group is None: return From e15880419de7019869c3525bc9727e38020c5e3b Mon Sep 17 00:00:00 2001 From: Jonah Samost Date: Mon, 20 Apr 2026 17:44:12 -0700 Subject: [PATCH 2/3] training --- trl/experimental/async_grpo/async_grpo_trainer.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/trl/experimental/async_grpo/async_grpo_trainer.py b/trl/experimental/async_grpo/async_grpo_trainer.py index 4ac17a19edd..139c237d5eb 100644 --- a/trl/experimental/async_grpo/async_grpo_trainer.py +++ b/trl/experimental/async_grpo/async_grpo_trainer.py @@ -292,6 +292,19 @@ def __init__( model_name = model model = AutoModelForCausalLM.from_pretrained(model, device_map=None, dtype=torch.float32) + if self.args.use_lora: + lora_count = 0 + for name, param in model.named_parameters(): + param.requires_grad = "lora_" in name + if param.requires_grad: + lora_count += 1 + if lora_count == 0: + raise ValueError( + "use_lora=True but no LoRA parameters found in model. " + "Ensure the model path contains adapter_config.json and adapter weights." + ) + logger.info(f"Enabled gradients on {lora_count} LoRA parameter tensors") + if self.args.use_liger_kernel: raise NotImplementedError("`use_liger_kernel` is not supported yet.") From 9c5daf1ca8f6b3e88aa69fe52a88d345e92a293d Mon Sep 17 00:00:00 2001 From: Jonah Samost Date: Tue, 21 Apr 2026 04:59:40 -0700 Subject: [PATCH 3/3] bugbot --- trl/experimental/async_grpo/async_grpo_trainer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/trl/experimental/async_grpo/async_grpo_trainer.py b/trl/experimental/async_grpo/async_grpo_trainer.py index 139c237d5eb..cacda7ed084 100644 --- a/trl/experimental/async_grpo/async_grpo_trainer.py +++ b/trl/experimental/async_grpo/async_grpo_trainer.py @@ -61,6 +61,7 @@ def stop(self) -> None: ... def pause(self) -> None: ... def resume(self) -> None: ... def send_weights(self, iterator: Iterator[tuple[str, torch.Tensor]]) -> None: ... + def reload_lora(self, adapter_path: str, lora_name: str) -> None: ... def update_model_version(self, version: int) -> None: ... @@ -615,10 +616,14 @@ def _sync_weight_lora(self, t0: float): self.rollout_worker.pause() self.accelerator.wait_for_everyone() + + # All ranks must call save_pretrained so that FSDP2 DTensor full_tensor() collectives + # (which are all-gathers) don't deadlock. Only rank 0 actually writes files to disk. + unwrapped = self.accelerator.unwrap_model(self.model) if self.accelerator.is_main_process: - unwrapped = self.accelerator.unwrap_model(self.model) logger.info(f"Weight sync (LoRA): saving adapter to {adapter_path}...") - unwrapped.save_pretrained(adapter_path) + unwrapped.save_pretrained(adapter_path, is_main_process=self.accelerator.is_main_process) + if self.accelerator.is_main_process: os.sync() t_save = time.time() logger.info(f"Weight sync (LoRA): save took {t_save - t0:.1f}s")