Skip to content

[WIP] [Liger] liger JSD support #2573

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions tests/test_gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers.testing_utils import require_liger_kernel

from trl import GKDConfig, GKDTrainer
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
Expand Down Expand Up @@ -262,3 +263,35 @@ def test_generation_config_init(self):
self.assertEqual(trainer.generation_config.max_new_tokens, training_args.max_new_tokens)
self.assertEqual(trainer.generation_config.temperature, training_args.temperature)
self.assertEqual(trainer.generation_config.top_k, 0)

@require_liger_kernel
def test_gkd_trainer_with_liger(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GKDConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
eval_strategy="steps",
max_steps=4,
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
report_to="none",
use_liger_loss=True, # Enable Liger loss
)
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling")

trainer = GKDTrainer(
model=self.model_id,
teacher_model=self.model_id,
args=training_args,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
processing_class=self.tokenizer,
)

trainer.train()

self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
30 changes: 30 additions & 0 deletions trl/trainer/gkd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ class GKDConfig(SFTConfig):
seq_kd (`bool`, *optional*, defaults to `False`):
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
on teacher-generated output).
use_liger_loss (`bool`, *optional*, defaults to `False`):
Whether to use Liger loss.
base_model_attribute_name (`str`, *optional*, defaults to `"model"`):
Name of the attribute in the model that contains the base model. This is used to get the base model
from the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is
`True`.
teacher_base_model_attribute_name (`str`, *optional*, defaults to `"model"`):
Name of the attribute in the teacher model that contains the base model. This is used to get the base model
from the teacher model when the teacher model does not have a `get_decoder` method in the case when
`use_liger_loss` is `True`.
"""

temperature: float = field(
Expand Down Expand Up @@ -95,6 +105,26 @@ class GKDConfig(SFTConfig):
"FT on teacher-generated output)."
},
)
use_liger_loss: bool = field(
default=False,
metadata={"help": "Whether to use Liger loss."},
)
base_model_attribute_name: str = field(
default="model",
metadata={
"help": "Name of the attribute in the model that contains the base model. This is used to get the base "
"model from the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` "
"is `True`."
},
)
teacher_base_model_attribute_name: str = field(
default="model",
metadata={
"help": "Name of the attribute in the teacher model that contains the base model. This is used to get the "
"base model from the teacher model when the teacher model does not have a `get_decoder` method in the case "
"when `use_liger_loss` is `True`."
},
)

def __post_init__(self):
super().__post_init__()
Expand Down
94 changes: 72 additions & 22 deletions trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import deepspeed

if is_liger_kernel_available():
from liger_kernel.chunked_loss import LigerFusedLinearJSDLoss
from liger_kernel.transformers import AutoLigerKernelForCausalLM

if is_peft_available():
Expand Down Expand Up @@ -158,6 +159,16 @@ def __init__(
):
self.generation_config.eos_token_id = self.model.generation_config.eos_token_id

if args.use_liger_loss:
if not is_liger_kernel_available():
raise ValueError(
"You set `use_liger_loss=True` but the liger kernel is not available. "
"Please install liger-kernel first: `pip install liger-kernel`"
)
self.liger_loss = LigerFusedLinearJSDLoss(
weight_hard_loss=0, weight_soft_loss=1, temperature=args.temperature
)

@staticmethod
def generalized_jsd_loss(
student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
Expand Down Expand Up @@ -218,33 +229,72 @@ def generalized_jsd_loss(
return jsd

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
# compute student output
outputs_student = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)
if self.args.use_liger_loss:
if hasattr(model, "get_decoder"):
base_model = model.get_decoder()
else:
base_model = getattr(model, self.args.base_model_attribute_name)
outputs_student = base_model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)
lm_head_student = model.get_output_embeddings()

# compute teacher output in eval mode
self.teacher_model.eval()
with torch.no_grad():
outputs_teacher = self.teacher_model(
# compute teacher output in eval mode
if hasattr(self.teacher_model, "get_decoder"):
base_teacher_model = self.teacher_model.get_decoder()
else:
base_teacher_model = getattr(self.teacher_model, self.args.teacher_base_model_attribute_name)
base_teacher_model.eval()
with torch.no_grad():
outputs_teacher = base_teacher_model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)
lm_head_teacher = self.teacher_model.get_output_embeddings()

# slice the logits for the generated tokens using the inputs["prompts"] lengths
prompt_lengths = inputs["prompts"].shape[1]
shifted_student_outputs = outputs_student.last_hidden_state[:, prompt_lengths - 1 : -1, :]
shifted_teacher_outputs = outputs_teacher.last_hidden_state[:, prompt_lengths - 1 : -1, :]
shifted_labels = inputs["labels"][:, prompt_lengths:]

# compute loss
loss = self.liger_loss(
student_input=shifted_student_outputs,
student_weight=lm_head_student.weight,
teacher_input=shifted_teacher_outputs,
teacher_weight=lm_head_teacher.weight,
true_labels=shifted_labels,
)
else:
# compute student output
outputs_student = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)

# slice the logits for the generated tokens using the inputs["prompts"] lengths
prompt_lengths = inputs["prompts"].shape[1]
shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
shifted_labels = inputs["labels"][:, prompt_lengths:]

# compute loss
loss = self.generalized_jsd_loss(
student_logits=shifted_student_logits,
teacher_logits=shifted_teacher_logits,
labels=shifted_labels,
beta=self.beta,
)
# compute teacher output in eval mode
self.teacher_model.eval()
with torch.no_grad():
outputs_teacher = self.teacher_model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)

# slice the logits for the generated tokens using the inputs["prompts"] lengths
prompt_lengths = inputs["prompts"].shape[1]
shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
shifted_labels = inputs["labels"][:, prompt_lengths:]

# compute loss
loss = self.generalized_jsd_loss(
student_logits=shifted_student_logits,
teacher_logits=shifted_teacher_logits,
labels=shifted_labels,
beta=self.beta,
)

# empty cache
empty_cache()
Expand Down