When I set tp=4, the model trains normally, but an OOM (Out of Memory) error occurs during checkpointing. When I set tp=8, the following error is reported:
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/trainer/ppo/ray_trainer.py", line 1437, in fit
ref_log_prob = self._compute_ref_log_prob(batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/trainer/ppo/ray_trainer.py", line 1125, in _compute_ref_log_prob
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/single_controller/ray/base.py", line 55, in call
output = ray.get(output)
^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^
ray.exceptions.RayTaskError: ray::WorkerDict.ref_compute_ref_log_prob() (pid=193257, ip=29.160.49.68, actor_id=a3a4794062e7929dfc45cd5401000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7f6bbdbb1f70>)
File "/usr/lib/python3.12/concurrent/futures/_base.py", line 456, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/single_controller/ray/base.py", line 932, in func
return getattr(self.worker_dict[key], name)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/single_controller/base/decorator.py", line 427, in inner
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/utils/profiler/performance.py", line 105, in f
return self.log(decorated_function, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/utils/profiler/performance.py", line 118, in log
output = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/utils/profiler/profile.py", line 173, in wrapper
return func(self_instance, *args, **kwargs_inner)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/workers/megatron_workers.py", line 858, in compute_ref_log_prob
output, _, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/utils/profiler/performance.py", line 105, in f
return self.log(decorated_function, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/utils/profiler/performance.py", line 118, in log
output = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/workers/actor/megatron_actor.py", line 259, in compute_log_prob
output = self.forward_backward_batch(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/workers/actor/megatron_actor.py", line 733, in forward_backward_batch
losses_reduced = forward_backward_func(
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/pipeline_parallel/schedules.py", line 636, in forward_backward_no_pipelining
output_tensor, num_tokens = forward_step(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/pipeline_parallel/schedules.py", line 423, in forward_step
output_tensor, loss_func = forward_step_func(data_iterator, model)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/workers/actor/megatron_actor.py", line 683, in forward_step
output = forward_fn(
^^^^^^^^^^^
File "/apdcephfs_bjzf/share_304704649/allenzpma/deliver/code/verl_qwen35/verl/verl/models/mcore/model_forward.py", line 141, in model_forward
output_orig = model(
^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/module.py", line 489, in forward
outputs = self.module(*inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/mbridge/models/qwen3_5/model.py", line 367, in forward
output = self.language_model(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/models/gpt/gpt_model.py", line 525, in forward
hidden_states = self.decoder(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_block.py", line 619, in call
return super().call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/module.py", line 352, in call
return super().call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_block.py", line 765, in forward
hidden_states, context = layer(
^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_layer.py", line 1217, in call
return super().call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/module.py", line 352, in call
return super().call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_layer.py", line 513, in forward
hidden_states, context = self._forward_attention(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_layer.py", line 597, in _forward_attention
attention_output_with_bias = self.self_attention(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1787, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/mbridge/models/qwen3_5/attention.py", line 360, in forward
core_attn_out = self._apply_output_gate(core_attn_out, gate)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 953, in compile_wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 2202, in call
result = self._torchdynamo_orig_backend(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1945, in call
result = self._inner_convert(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 707, in call
result = _compile(
^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1752, in _compile
guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_utils_internal.py", line 97, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1433, in compile_inner
return _compile_inner(code, one_graph, hooks)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1467, in _compile_inner
dynamo_output = compile_frame(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1341, in compile_frame
bytecode, tracer_output = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1600, in transform_code_object
tracer_output = transformations(instructions, code_options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 1313, in transform
tracer_output = trace_frame(
^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 328, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 838, in trace_frame
run_tracer()
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py", line 819, in run_tracer
tracer.run()
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1654, in run
while self.step():
^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1334, in step
self.dispatch_table[inst.opcode](self, inst)
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 866, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 2582, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py", line 1240, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/misc.py", line 1148, in call_function
return self.obj.call_method(tx, self.name, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/tensor.py", line 745, in call_method
return wrap_fx_proxy(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/builder.py", line 2795, in wrap_fx_proxy
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/builder.py", line 2861, in wrap_fx_proxy_cls
out = _wrap_fx_proxy(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/builder.py", line 2972, in _wrap_fx_proxy
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3626, in get_fake_value
raise TorchRuntimeError(msg).with_traceback(e.traceback) from None
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3524, in get_fake_value
ret_val = wrap_fake_exception(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 2966, in wrap_fake_exception
return fn()
^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3525, in
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3735, in run_node
raise RuntimeError(make_error_message(e)).with_traceback(
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/utils.py", line 3705, in run_node
return getattr(args[0], node.target)(*args[1:], **kwargs) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_stats.py", line 29, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1397, in torch_dispatch
return self.dispatch(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 2155, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 1544, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_tensor.py", line 2793, in _dispatch_impl
op_impl_out = op_impl(self, func, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_impls.py", line 180, in dispatch_to_op_implementations_dict
return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_subclasses/fake_impls.py", line 629, in _view_meta
return torch._refs._reshape_view_helper(a, shape, allow_copy=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_refs/init.py", line 3950, in _reshape_view_helper
shape = utils.infer_size(shape, a.numel())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_prims_common/init.py", line 1064, in infer_size
torch._check(
File "/usr/local/lib/python3.12/dist-packages/torch/init.py", line 1732, in _check
_check_with(RuntimeError, cond, message) # pyrefly: ignore [bad-argument-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/init.py", line 1714, in _check_with
raise error_type(message_evaluated)
torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_method view((FakeTensor(..., device='cuda:0', size=(432, 1, 6, 256), dtype=torch.bfloat16), 432, 1, 768), **{}): got RuntimeError("shape '[432, 1, 768]' is invalid for input of size 663552")
When I set tp=4, the model trains normally, but an OOM (Out of Memory) error occurs during checkpointing. When I set tp=8, the following error is reported:
from user code:
File "/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/attention.py", line 1221, in _apply_output_gate
gate = gate.view(*x.shape)
My experimental setup is as follows: