Skip to content

Commit 1d354b4

Browse files
committed
Misc improvements
Enable callbacks injection from plugins Fix misc issues with axolotl plugins Fix remote code checking Enable loss average across devices Add seq len validation Enhance sequence lens validation Remove legacy code for patching _get_unpad_data Add pre truncation token counting for completion Fix plugin callbacks duplication Enable eval on start Read extra hf args from cfg
1 parent 60a8f09 commit 1d354b4

File tree

9 files changed

+228
-29
lines changed

9 files changed

+228
-29
lines changed

src/axolotl/core/trainer_builder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import sys
2525
from abc import abstractmethod
2626
from pathlib import Path
27-
from typing import List, Type, Union
27+
from typing import Any, List, Type, Union
2828

2929
import torch
3030
import transformers
@@ -326,7 +326,7 @@ def build(self, total_num_steps):
326326
else max(min(int(0.005 * total_num_steps), 10), 1)
327327
)
328328

329-
training_arguments_kwargs = {}
329+
training_arguments_kwargs = self.cfg.get("extra_hf_training_args") or {}
330330

331331
if self.cfg.include_tokens_per_second is not None:
332332
training_arguments_kwargs["include_tokens_per_second"] = (
@@ -795,7 +795,7 @@ def build(self, total_num_steps):
795795
None
796796
)
797797

798-
data_collator_kwargs = {
798+
data_collator_kwargs: dict[str, Any] = {
799799
"padding": True, # True/"longest" is the default
800800
}
801801
if self.cfg.pad_to_sequence_len:
@@ -955,7 +955,7 @@ def get_post_trainer_create_callbacks(self, trainer):
955955
return callbacks
956956

957957
def build_training_arguments(self, total_num_steps):
958-
training_args_kwargs = {}
958+
training_args_kwargs = self.cfg.get("extra_hf_training_args") or {}
959959
for arg in [
960960
"adam_beta1",
961961
"adam_beta2",

src/axolotl/logging_config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,17 @@ def format(self, record):
5454
"filters": [],
5555
"stream": sys.stdout,
5656
},
57+
"file": {
58+
"class": "logging.FileHandler",
59+
"formatter": "simple",
60+
"filename": "train.log",
61+
"mode": "w",
62+
},
5763
},
5864
"root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")},
5965
"loggers": {
6066
"axolotl": {
61-
"handlers": ["color_console"],
67+
"handlers": ["color_console", "file"],
6268
"level": "DEBUG",
6369
"propagate": False,
6470
},

src/axolotl/prompt_strategies/alpaca_w_system.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ def tokenize_prompt(self, prompt):
5050
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
5151
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
5252

53+
if "num_tokens_pre_truncation" in tokenized_prompt:
54+
tokenized_prompt["num_tokens_pre_truncation"] = (
55+
tokenized_prompt["num_tokens_pre_truncation"]
56+
+ tokenized_res_prompt["num_tokens_pre_truncation"]
57+
)
58+
5359
return tokenized_prompt
5460

5561

src/axolotl/prompt_strategies/chat_template.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,25 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None):
8888
images=images,
8989
return_tensors="pt",
9090
)
91+
# dict_keys(['input_ids', 'attention_mask', 'pixel_values'])
9192
# workaround since processor works in batches instead of single examples
9293
for k, val in batch.items():
9394
if k in ["pixel_values"]:
9495
batch[k] = val.tolist()
9596
else:
9697
batch[k] = val.squeeze().tolist()
98+
batch["num_tokens_pre_truncation"] = len(batch["input_ids"])
9799
return batch
98100

99-
return self.tokenizer.apply_chat_template(
101+
input_ids = self.tokenizer.apply_chat_template(
100102
conversation,
101103
add_generation_prompt=add_generation_prompt,
102104
chat_template=self.chat_template,
103105
)
106+
return {
107+
"input_ids": input_ids,
108+
"num_tokens_pre_truncation": len(input_ids),
109+
}
104110

105111
def get_offsets_for_train_detail(
106112
self, text: str, train_details: List[Dict], mask_untrainable: bool = True
@@ -290,20 +296,29 @@ def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
290296
):
291297
turns = self.get_conversation_thread(prompt)
292298
images = self.get_images(prompt)
293-
prompt_ids = self.prompter.build_prompt( # type: ignore
299+
prompt_tokenized = self.prompter.build_prompt(
294300
turns[:-1],
295301
add_generation_prompt=True,
296302
images=images,
297303
)
298-
tokenized_res = self.prompter.build_prompt(turns, images=images) # type: ignore
304+
all_turns_tokenized = self.prompter.build_prompt(turns, images=images)
299305
tokenized_prompt = {}
300-
if isinstance(tokenized_res, list):
301-
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
306+
if "attention_mask" not in all_turns_tokenized:
307+
prompt_ids = prompt_tokenized["input_ids"]
308+
input_ids = (
309+
prompt_ids + all_turns_tokenized["input_ids"][len(prompt_ids) :]
310+
)
302311
tokenized_prompt["input_ids"] = input_ids
312+
num_tokens_pre_truncation = all_turns_tokenized[
313+
"num_tokens_pre_truncation"
314+
]
303315
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
304316
else:
305-
input_ids = tokenized_res["input_ids"]
306-
tokenized_prompt = tokenized_res
317+
input_ids = all_turns_tokenized["input_ids"]
318+
num_tokens_pre_truncation = all_turns_tokenized[
319+
"num_tokens_pre_truncation"
320+
]
321+
tokenized_prompt = all_turns_tokenized
307322

308323
if not self.train_on_inputs:
309324
user_prompt_len = len(prompt_ids)
@@ -312,11 +327,14 @@ def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
312327
labels = input_ids
313328

314329
tokenized_prompt["labels"] = labels
330+
tokenized_prompt["num_tokens_pre_truncation"] = num_tokens_pre_truncation
315331

316332
return tokenized_prompt
317333

318334
turns = self.get_conversation_thread(prompt)
319-
input_ids = self.prompter.build_prompt(turns) # type: ignore
335+
tokenized_res = self.prompter.build_prompt(turns)
336+
input_ids = tokenized_res["input_ids"]
337+
num_tokens_pre_truncation = tokenized_res["num_tokens_pre_truncation"]
320338
labels = [IGNORE_TOKEN_ID] * len(input_ids)
321339

322340
last_eos_idx = -1
@@ -393,6 +411,7 @@ def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
393411
"input_ids": input_ids,
394412
"labels": labels,
395413
"attention_mask": [1] * len(input_ids),
414+
"num_tokens_pre_truncation": num_tokens_pre_truncation,
396415
}
397416

398417
def find_first_eos_token(self, input_ids, start_idx):
@@ -433,10 +452,10 @@ def find_turn(self, turns: list[dict], turn_idx: int):
433452
turns_with_content = turns[: turn_idx + 1]
434453

435454
# Generate the conversation up to the turn, with final turn replaced with dummy content
436-
dummy_ids = self.prompter.build_prompt(turns_with_empty) # type: ignore
455+
dummy_ids = self.prompter.build_prompt(turns_with_empty)["input_ids"] # type: ignore
437456

438457
# Generate the conversation up to the turn, with final turn included
439-
full_ids = self.prompter.build_prompt(turns_with_content) # type: ignore
458+
full_ids = self.prompter.build_prompt(turns_with_content)["input_ids"] # type: ignore
440459

441460
if not full_ids or not dummy_ids:
442461
LOG.warning(f"Empty template generated for turn {turn_idx}")

src/axolotl/prompt_tokenizers.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Module containing PromptTokenizingStrategy and Prompter classes"""
22

33
import abc
4+
import functools
45
import logging
56
from typing import Callable, Dict, List, Optional, Tuple, Union
67

@@ -62,18 +63,23 @@ def supports_batched(self):
6263
def _tokenize(
6364
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
6465
) -> BatchEncoding:
65-
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
66+
empty = BatchEncoding(
67+
data={"input_ids": [], "attention_mask": [], "num_tokens_pre_truncation": 0}
68+
)
6669
if not prompt:
6770
LOG.warning("Empty text requested for tokenization.")
6871
return empty
6972

70-
result = self.tokenizer(
71-
prompt,
72-
truncation=True,
73+
_tokenize = functools.partial(
74+
self.tokenizer,
7375
max_length=self.max_length,
7476
padding=False,
7577
return_tensors=None,
7678
)
79+
result = _tokenize(
80+
prompt,
81+
truncation=True,
82+
)
7783
if len(result["input_ids"]) == 0:
7884
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
7985
return empty
@@ -91,6 +97,20 @@ def _tokenize(
9197
result["attention_mask"] = result["attention_mask"][1:]
9298

9399
result["labels"] = result["input_ids"].copy()
100+
101+
_all_tokens = _tokenize(prompt, truncation=False)
102+
num_tokens_pre_truncation = len(_all_tokens["input_ids"])
103+
if (
104+
_all_tokens["input_ids"][-1] != self.tokenizer.eos_token_id
105+
and add_eos_token
106+
):
107+
num_tokens_pre_truncation += 1
108+
if (
109+
_all_tokens["input_ids"][0] == self.tokenizer.bos_token_id
110+
and strip_bos_token
111+
):
112+
num_tokens_pre_truncation -= 1
113+
result["num_tokens_pre_truncation"] = num_tokens_pre_truncation
94114
return result
95115

96116

src/axolotl/train.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,19 @@
1616
from datasets import Dataset
1717
from huggingface_hub.errors import OfflineModeIsEnabled
1818
from peft import PeftConfig, PeftModel
19-
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
19+
from transformers import (
20+
PreTrainedModel,
21+
PreTrainedTokenizer,
22+
ProcessorMixin,
23+
)
2024
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
2125
from transformers.trainer import Trainer
2226

2327
from axolotl.common.datasets import TrainDatasetMeta
2428
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
2529
fix_untrained_tokens,
2630
)
27-
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
31+
from axolotl.integrations.base import PluginManager
2832
from axolotl.logging_config import configure_logging
2933
from axolotl.utils.dict import DictDefault
3034
from axolotl.utils.distributed import cleanup_distributed
@@ -78,6 +82,9 @@ def setup_model_and_tokenizer(
7882
if model.generation_config is not None:
7983
model.generation_config.do_sample = True
8084

85+
plugin_manager = PluginManager.get_instance()
86+
plugin_manager.post_model_load(cfg, model)
87+
8188
# Apply freezing if specified
8289
if cfg.unfrozen_parameters:
8390
freeze_layers_except(model, cfg.unfrozen_parameters)
@@ -153,7 +160,11 @@ def setup_signal_handler(
153160
safe_serialization: Whether to use safe serialization when saving
154161
"""
155162
# ray workers don't have access to this signal
156-
if cfg.local_rank == 0 and not cfg.use_ray:
163+
if (
164+
cfg.local_rank == 0
165+
and not cfg.use_ray
166+
and cfg.get("save_model_on_interrupt", True)
167+
):
157168

158169
def terminate_handler(_, __, model_weakref):
159170
if model_weakref() is not None:
@@ -418,7 +429,7 @@ def handle_untrained_tokens_fix(
418429

419430

420431
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
421-
HFRLTrainerBuilder | HFCausalTrainerBuilder,
432+
Trainer,
422433
PeftModel | PreTrainedModel,
423434
PreTrainedTokenizer,
424435
PeftConfig | None,
@@ -514,7 +525,13 @@ def train(
514525
# Save the trained model and cleanup
515526
save_trained_model(cfg, trainer, model, safe_serialization)
516527
create_model_card(cfg, trainer)
517-
if not cfg.use_ray:
518-
cleanup_distributed()
528+
529+
if cfg.deepspeed:
530+
trainer.deepspeed.destroy()
531+
trainer.accelerator.free_memory()
532+
trainer.model, trainer.model_wrapped, trainer.optimizer = None, None, None
533+
534+
# if not cfg.use_ray:
535+
# cleanup_distributed()
519536

520537
return model, tokenizer, trainer

src/axolotl/utils/data/sft.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,11 @@
5252
retry_on_request_exceptions,
5353
)
5454
from axolotl.utils.dict import DictDefault
55-
from axolotl.utils.distributed import is_local_main_process, zero_first
55+
from axolotl.utils.distributed import (
56+
compute_and_broadcast,
57+
is_local_main_process,
58+
zero_first,
59+
)
5660
from axolotl.utils.trainer import (
5761
calculate_total_num_steps,
5862
process_datasets_for_packing,
@@ -156,9 +160,15 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
156160
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
157161
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
158162
if total_eval_steps == 0:
159-
raise ValueError(
160-
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
163+
LOG.warning(
164+
"eval dataset split is too small for sample_packing. Setting `eval_sample_packing to False`."
161165
)
166+
if cfg.world_size > 1:
167+
_eval_sample_packing = compute_and_broadcast(lambda: 0)
168+
if _eval_sample_packing < 1:
169+
cfg.eval_sample_packing = False
170+
else:
171+
cfg.eval_sample_packing = False
162172

163173
if cfg.max_steps:
164174
total_num_steps = min(

0 commit comments

Comments
 (0)