Skip to content

Commit 30c52ec

Browse files
committed
Misc improvements
Enable callbacks injection from plugins Fix misc issues with axolotl plugins Fix remote code checking Enable loss average across devices Add seq len validation Enhance sequence lens validation Remove legacy code for patching _get_unpad_data Add pre truncation token counting for completion Fix plugin callbacks duplication Enable eval on start Read extra hf args from cfg
1 parent 7878802 commit 30c52ec

File tree

11 files changed

+229
-27
lines changed

11 files changed

+229
-27
lines changed

src/axolotl/core/builders/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
439439
def _set_base_training_args(
440440
self, total_num_steps
441441
) -> tuple[dict[str, Any], dict[str, Any]]:
442-
training_args_kwargs: dict[str, Any] = {}
442+
training_args_kwargs: dict[str, Any] = self.cfg.get("extra_hf_training_args") or {}
443443
trainer_kwargs: dict[str, Any] = {}
444444

445445
self._configure_warmup_and_logging(total_num_steps, training_args_kwargs)

src/axolotl/loaders/tokenizer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import os
55

6+
from axolotl.utils.dict import DictDefault
67
import transformers
78
from transformers import (
89
AddedToken,
@@ -185,6 +186,12 @@ def load_tokenizer(cfg):
185186
setattr(tokenizer, attr_name, "<|endoftext|>")
186187

187188
additional_special_tokens = None
189+
190+
if not tokenizer.pad_token:
191+
if not cfg.special_tokens:
192+
cfg.special_tokens = DictDefault({})
193+
cfg.special_tokens.pad_token = tokenizer.eos_token
194+
188195
if cfg.special_tokens:
189196
special_tokens = cfg.special_tokens.to_dict()
190197
additional_special_tokens = special_tokens.pop(

src/axolotl/logging_config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,12 @@ def format(self, record):
9696
"filters": [],
9797
"stream": sys.stdout,
9898
},
99+
"file": {
100+
"class": "logging.FileHandler",
101+
"formatter": "simple",
102+
"filename": "train.log",
103+
"mode": "w",
104+
},
99105
},
100106
# log level will be superseded by the AxolotlLogger
101107
"root": {
@@ -104,7 +110,7 @@ def format(self, record):
104110
},
105111
"loggers": {
106112
"axolotl": {
107-
"handlers": ["color_console"],
113+
"handlers": ["color_console", "file"],
108114
"level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL),
109115
"propagate": False,
110116
},

src/axolotl/prompt_strategies/alpaca_w_system.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ def tokenize_prompt(self, prompt):
5050
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
5151
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
5252

53+
if "num_tokens_pre_truncation" in tokenized_prompt:
54+
tokenized_prompt["num_tokens_pre_truncation"] = (
55+
tokenized_prompt["num_tokens_pre_truncation"]
56+
+ tokenized_res_prompt["num_tokens_pre_truncation"]
57+
)
58+
5359
return tokenized_prompt
5460

5561

src/axolotl/prompt_strategies/chat_template.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,20 +94,26 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None):
9494
images=images,
9595
return_tensors="pt",
9696
)
97+
# dict_keys(['input_ids', 'attention_mask', 'pixel_values'])
9798
# workaround since processor works in batches instead of single examples
9899
for k, val in batch.items():
99100
if k in ["pixel_values"]:
100101
batch[k] = val.tolist()
101102
else:
102103
batch[k] = val.squeeze().tolist()
104+
batch["num_tokens_pre_truncation"] = len(batch["input_ids"])
103105
return batch
104106

105-
return self.tokenizer.apply_chat_template(
107+
input_ids = self.tokenizer.apply_chat_template(
106108
conversation,
107109
add_generation_prompt=add_generation_prompt,
108110
chat_template=self.chat_template,
109111
**self.chat_template_kwargs,
110112
)
113+
return {
114+
"input_ids": input_ids,
115+
"num_tokens_pre_truncation": len(input_ids),
116+
}
111117

112118
def get_offsets_for_train_detail(
113119
self, text: str, train_details: List[Dict], mask_untrainable: bool = True
@@ -377,21 +383,25 @@ def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
377383
):
378384
turns = self.get_conversation_thread(prompt)
379385
images = self.get_images(prompt)
380-
prompt_ids = self.prompter.build_prompt( # type: ignore
386+
# We get back {"input_ids": [...], "num_tokens_pre_truncation": ...}
387+
_prompt_ids = self.prompter.build_prompt(
381388
turns[:-1],
382389
add_generation_prompt=True,
383390
images=images,
384391
)
392+
prompt_ids = _prompt_ids["input_ids"]
385393
tokenized_res = self.prompter.build_prompt(
386394
turns, images=images
387395
) # type: ignore
388396
tokenized_prompt = {}
389-
if isinstance(tokenized_res, list):
390-
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
397+
if "attention_mask" not in tokenized_res:
398+
input_ids = prompt_ids + tokenized_res["input_ids"][len(prompt_ids) :]
391399
tokenized_prompt["input_ids"] = input_ids
400+
num_tokens_pre_truncation = tokenized_res["num_tokens_pre_truncation"]
392401
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
393402
else:
394403
input_ids = tokenized_res["input_ids"]
404+
num_tokens_pre_truncation = tokenized_res["num_tokens_pre_truncation"]
395405
tokenized_prompt = tokenized_res
396406

397407
if not self.train_on_inputs:
@@ -401,11 +411,14 @@ def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
401411
labels = input_ids
402412

403413
tokenized_prompt["labels"] = labels
414+
tokenized_prompt["num_tokens_pre_truncation"] = num_tokens_pre_truncation
404415

405416
return tokenized_prompt
406417

407418
turns = self.get_conversation_thread(prompt)
408-
input_ids = self.prompter.build_prompt(turns) # type: ignore
419+
tokenized_res = self.prompter.build_prompt(turns)
420+
input_ids = tokenized_res["input_ids"]
421+
num_tokens_pre_truncation = tokenized_res["num_tokens_pre_truncation"]
409422
labels = [IGNORE_TOKEN_ID] * len(input_ids)
410423

411424
last_eos_idx = -1
@@ -518,6 +531,7 @@ def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
518531
"input_ids": input_ids,
519532
"labels": labels,
520533
"attention_mask": [1] * len(input_ids),
534+
"num_tokens_pre_truncation": num_tokens_pre_truncation,
521535
}
522536

523537
def find_first_eos_token(self, input_ids, start_idx):
@@ -577,10 +591,10 @@ def find_turn(self, turns: list[dict], turn_idx: int):
577591
turns_with_content = turns[: turn_idx + 1]
578592

579593
# Generate the conversation up to the turn, with final turn replaced with dummy content
580-
dummy_ids = self.prompter.build_prompt(turns_with_empty) # type: ignore
594+
dummy_ids = self.prompter.build_prompt(turns_with_empty)["input_ids"] # type: ignore
581595

582596
# Generate the conversation up to the turn, with final turn included
583-
full_ids = self.prompter.build_prompt(turns_with_content) # type: ignore
597+
full_ids = self.prompter.build_prompt(turns_with_content)["input_ids"] # type: ignore
584598

585599
if not full_ids or not dummy_ids:
586600
LOG.warning(f"Empty template generated for turn {turn_idx}")

src/axolotl/prompt_tokenizers.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Module containing PromptTokenizingStrategy and Prompter classes"""
22

33
import abc
4+
import functools
45
from typing import Callable, Dict, List, Optional, Tuple, Union
56

67
from transformers import BatchEncoding, PreTrainedTokenizer
@@ -62,18 +63,23 @@ def supports_batched(self):
6263
def _tokenize(
6364
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
6465
) -> BatchEncoding:
65-
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
66+
empty = BatchEncoding(
67+
data={"input_ids": [], "attention_mask": [], "num_tokens_pre_truncation": 0}
68+
)
6669
if not prompt:
6770
LOG.warning("Empty text requested for tokenization.")
6871
return empty
6972

70-
result = self.tokenizer(
71-
prompt,
72-
truncation=True,
73+
_tokenize = functools.partial(
74+
self.tokenizer,
7375
max_length=self.max_length,
7476
padding=False,
7577
return_tensors=None,
7678
)
79+
result = _tokenize(
80+
prompt,
81+
truncation=True,
82+
)
7783
if len(result["input_ids"]) == 0:
7884
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
7985
return empty
@@ -91,6 +97,20 @@ def _tokenize(
9197
result["attention_mask"] = result["attention_mask"][1:]
9298

9399
result["labels"] = result["input_ids"].copy()
100+
101+
_all_tokens = _tokenize(prompt, truncation=False)
102+
num_tokens_pre_truncation = len(_all_tokens["input_ids"])
103+
if (
104+
_all_tokens["input_ids"][-1] != self.tokenizer.eos_token_id
105+
and add_eos_token
106+
):
107+
num_tokens_pre_truncation += 1
108+
if (
109+
_all_tokens["input_ids"][0] == self.tokenizer.bos_token_id
110+
and strip_bos_token
111+
):
112+
num_tokens_pre_truncation -= 1
113+
result["num_tokens_pre_truncation"] = num_tokens_pre_truncation
94114
return result
95115

96116

src/axolotl/train.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
from datasets import Dataset
1717
from huggingface_hub.errors import OfflineModeIsEnabled
1818
from peft import PeftConfig, PeftModel
19-
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
19+
from transformers import (
20+
PreTrainedModel,
21+
PreTrainedTokenizer,
22+
ProcessorMixin,
23+
)
2024
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
2125
from transformers.trainer import Trainer
2226

@@ -25,7 +29,6 @@
2529
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
2630
fix_untrained_tokens,
2731
)
28-
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
2932
from axolotl.integrations.base import PluginManager
3033
from axolotl.loaders import (
3134
ModelLoader,
@@ -83,6 +86,9 @@ def setup_model_and_tokenizer(
8386
if model.generation_config is not None:
8487
model.generation_config.do_sample = True
8588

89+
plugin_manager = PluginManager.get_instance()
90+
plugin_manager.post_model_load(cfg, model)
91+
8692
# Apply freezing if specified
8793
if cfg.unfrozen_parameters:
8894
freeze_layers_except(model, cfg.unfrozen_parameters)
@@ -159,7 +165,11 @@ def setup_signal_handler(
159165
safe_serialization: Whether to use safe serialization when saving
160166
"""
161167
# ray workers don't have access to this signal
162-
if cfg.local_rank == 0 and not cfg.use_ray:
168+
if (
169+
cfg.local_rank == 0
170+
and not cfg.use_ray
171+
and cfg.get("save_model_on_interrupt", True)
172+
):
163173

164174
def terminate_handler(_, __, model_weakref):
165175
if model_weakref() is not None:
@@ -472,7 +482,7 @@ def handle_untrained_tokens_fix(
472482

473483

474484
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
475-
HFRLTrainerBuilder | HFCausalTrainerBuilder,
485+
Trainer,
476486
PeftModel | PreTrainedModel,
477487
PreTrainedTokenizer,
478488
PeftConfig | None,
@@ -573,8 +583,14 @@ def train(
573583
# Save the trained model and cleanup
574584
save_trained_model(cfg, trainer, model, safe_serialization)
575585
create_model_card(cfg, trainer)
576-
if not cfg.use_ray:
577-
cleanup_distributed()
586+
587+
if cfg.deepspeed:
588+
trainer.deepspeed.destroy()
589+
trainer.accelerator.free_memory()
590+
trainer.model, trainer.model_wrapped, trainer.optimizer = None, None, None
591+
592+
# if not cfg.use_ray:
593+
# cleanup_distributed()
578594

579595
plugin_manager.post_train(cfg, model)
580596

src/axolotl/utils/data/sft.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,11 @@
5252
retry_on_request_exceptions,
5353
)
5454
from axolotl.utils.dict import DictDefault
55-
from axolotl.utils.distributed import is_local_main_process, zero_first
55+
from axolotl.utils.distributed import (
56+
compute_and_broadcast,
57+
is_local_main_process,
58+
zero_first,
59+
)
5660
from axolotl.utils.logging import get_logger
5761
from axolotl.utils.trainer import (
5862
calculate_total_num_steps,
@@ -174,9 +178,15 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
174178
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
175179
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
176180
if total_eval_steps == 0:
177-
raise ValueError(
178-
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
181+
LOG.warning(
182+
"eval dataset split is too small for sample_packing. Setting `eval_sample_packing to False`."
179183
)
184+
if cfg.world_size > 1:
185+
_eval_sample_packing = compute_and_broadcast(lambda: 0)
186+
if _eval_sample_packing < 1:
187+
cfg.eval_sample_packing = False
188+
else:
189+
cfg.eval_sample_packing = False
180190

181191
if cfg.max_steps:
182192
total_num_steps = min(

0 commit comments

Comments
 (0)