Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,7 +1224,9 @@ def _update_actor(self, batch: DataProto) -> DataProto:
batch_td = batch.to_tensordict()
# step 2: convert from padding to no-padding
batch_td = left_right_2_no_padding(batch_td)
calculate_entropy = self.config.actor_rollout_ref.actor.entropy_coeff != 0.0
calculate_entropy = self.config.actor_rollout_ref.actor.get("calculate_entropy", False) or (
self.config.actor_rollout_ref.actor.entropy_coeff != 0.0
)
distillation_use_topk = (
self.distillation_config.distillation_loss.loss_settings.use_topk
if is_distillation_enabled(self.config.get("distillation"))
Expand Down
6 changes: 4 additions & 2 deletions verl/workers/actor/megatron_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,8 +551,10 @@ def loss_func(output, data, meta_info):
entropy = output["entropy"][:, -response_length - 1 : -1].contiguous()
if not forward_only:
entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
stats["actor/entropy"] = entropy_loss.detach().item()
entropy_coeff = meta_info["entropy_coeff"]
policy_loss = pg_loss - entropy_coeff * entropy_loss
if entropy_coeff != 0:
policy_loss = pg_loss - entropy_coeff * entropy_loss
else:
ret_entropy = entropy

Expand Down Expand Up @@ -788,7 +790,7 @@ def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = Fals
# if use distributed optimizer, zero grad buffer will be handled by optimizer
chunk.zero_grad_buffer()

calculate_entropy = self.config.entropy_coeff != 0
calculate_entropy = self.config.get("calculate_entropy", False) or (self.config.entropy_coeff != 0)
if data.meta_info.get("micro_batch_size", None) is not None:
micro_batch_size = data.meta_info["micro_batch_size"]
else:
Expand Down
Loading