Skip to content

Commit 27e9d2a

Browse files
committed
Misc improvements
Enable callbacks injection from plugins Fix misc issues with axolotl plugins Fix remote code checking Enable loss average across devices Add seq len validation Enhance sequence lens validation Remove legacy code for patching _get_unpad_data Add pre truncation token counting for completion Fix plugin callbacks duplication Enable eval on start Read extra hf args from cfg
1 parent 6778856 commit 27e9d2a

File tree

10 files changed

+222
-27
lines changed

10 files changed

+222
-27
lines changed

src/axolotl/core/builders/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
439439
def _set_base_training_args(
440440
self, total_num_steps
441441
) -> tuple[dict[str, Any], dict[str, Any]]:
442-
training_args_kwargs: dict[str, Any] = {}
442+
training_args_kwargs: dict[str, Any] = self.cfg.get("extra_hf_training_args") or {}
443443
trainer_kwargs: dict[str, Any] = {}
444444

445445
self._configure_warmup_and_logging(total_num_steps, training_args_kwargs)

src/axolotl/logging_config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,17 @@ def format(self, record):
5454
"filters": [],
5555
"stream": sys.stdout,
5656
},
57+
"file": {
58+
"class": "logging.FileHandler",
59+
"formatter": "simple",
60+
"filename": "train.log",
61+
"mode": "w",
62+
},
5763
},
5864
"root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")},
5965
"loggers": {
6066
"axolotl": {
61-
"handlers": ["color_console"],
67+
"handlers": ["color_console", "file"],
6268
"level": "DEBUG",
6369
"propagate": False,
6470
},

src/axolotl/prompt_strategies/alpaca_w_system.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ def tokenize_prompt(self, prompt):
5050
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
5151
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
5252

53+
if "num_tokens_pre_truncation" in tokenized_prompt:
54+
tokenized_prompt["num_tokens_pre_truncation"] = (
55+
tokenized_prompt["num_tokens_pre_truncation"]
56+
+ tokenized_res_prompt["num_tokens_pre_truncation"]
57+
)
58+
5359
return tokenized_prompt
5460

5561

src/axolotl/prompt_strategies/chat_template.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,19 +91,25 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None):
9191
images=images,
9292
return_tensors="pt",
9393
)
94+
# dict_keys(['input_ids', 'attention_mask', 'pixel_values'])
9495
# workaround since processor works in batches instead of single examples
9596
for k, val in batch.items():
9697
if k in ["pixel_values"]:
9798
batch[k] = val.tolist()
9899
else:
99100
batch[k] = val.squeeze().tolist()
101+
batch["num_tokens_pre_truncation"] = len(batch["input_ids"])
100102
return batch
101103

102-
return self.tokenizer.apply_chat_template(
104+
input_ids = self.tokenizer.apply_chat_template(
103105
conversation,
104106
add_generation_prompt=add_generation_prompt,
105107
chat_template=self.chat_template,
106108
)
109+
return {
110+
"input_ids": input_ids,
111+
"num_tokens_pre_truncation": len(input_ids),
112+
}
107113

108114
def get_offsets_for_train_detail(
109115
self, text: str, train_details: List[Dict], mask_untrainable: bool = True
@@ -373,21 +379,25 @@ def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
373379
):
374380
turns = self.get_conversation_thread(prompt)
375381
images = self.get_images(prompt)
376-
prompt_ids = self.prompter.build_prompt( # type: ignore
382+
# We get back {"input_ids": [...], "num_tokens_pre_truncation": ...}
383+
_prompt_ids = self.prompter.build_prompt(
377384
turns[:-1],
378385
add_generation_prompt=True,
379386
images=images,
380387
)
388+
prompt_ids = _prompt_ids["input_ids"]
381389
tokenized_res = self.prompter.build_prompt(
382390
turns, images=images
383391
) # type: ignore
384392
tokenized_prompt = {}
385-
if isinstance(tokenized_res, list):
386-
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
393+
if "attention_mask" not in tokenized_res:
394+
input_ids = prompt_ids + tokenized_res["input_ids"][len(prompt_ids) :]
387395
tokenized_prompt["input_ids"] = input_ids
396+
num_tokens_pre_truncation = tokenized_res["num_tokens_pre_truncation"]
388397
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
389398
else:
390399
input_ids = tokenized_res["input_ids"]
400+
num_tokens_pre_truncation = tokenized_res["num_tokens_pre_truncation"]
391401
tokenized_prompt = tokenized_res
392402

393403
if not self.train_on_inputs:
@@ -397,11 +407,14 @@ def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
397407
labels = input_ids
398408

399409
tokenized_prompt["labels"] = labels
410+
tokenized_prompt["num_tokens_pre_truncation"] = num_tokens_pre_truncation
400411

401412
return tokenized_prompt
402413

403414
turns = self.get_conversation_thread(prompt)
404-
input_ids = self.prompter.build_prompt(turns) # type: ignore
415+
tokenized_res = self.prompter.build_prompt(turns)
416+
input_ids = tokenized_res["input_ids"]
417+
num_tokens_pre_truncation = tokenized_res["num_tokens_pre_truncation"]
405418
labels = [IGNORE_TOKEN_ID] * len(input_ids)
406419

407420
last_eos_idx = -1
@@ -514,6 +527,7 @@ def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
514527
"input_ids": input_ids,
515528
"labels": labels,
516529
"attention_mask": [1] * len(input_ids),
530+
"num_tokens_pre_truncation": num_tokens_pre_truncation,
517531
}
518532

519533
def find_first_eos_token(self, input_ids, start_idx):
@@ -573,10 +587,10 @@ def find_turn(self, turns: list[dict], turn_idx: int):
573587
turns_with_content = turns[: turn_idx + 1]
574588

575589
# Generate the conversation up to the turn, with final turn replaced with dummy content
576-
dummy_ids = self.prompter.build_prompt(turns_with_empty) # type: ignore
590+
dummy_ids = self.prompter.build_prompt(turns_with_empty)["input_ids"] # type: ignore
577591

578592
# Generate the conversation up to the turn, with final turn included
579-
full_ids = self.prompter.build_prompt(turns_with_content) # type: ignore
593+
full_ids = self.prompter.build_prompt(turns_with_content)["input_ids"] # type: ignore
580594

581595
if not full_ids or not dummy_ids:
582596
LOG.warning(f"Empty template generated for turn {turn_idx}")

src/axolotl/prompt_tokenizers.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Module containing PromptTokenizingStrategy and Prompter classes"""
22

33
import abc
4+
import functools
45
from typing import Callable, Dict, List, Optional, Tuple, Union
56

67
from transformers import BatchEncoding, PreTrainedTokenizer
@@ -62,18 +63,23 @@ def supports_batched(self):
6263
def _tokenize(
6364
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
6465
) -> BatchEncoding:
65-
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
66+
empty = BatchEncoding(
67+
data={"input_ids": [], "attention_mask": [], "num_tokens_pre_truncation": 0}
68+
)
6669
if not prompt:
6770
LOG.warning("Empty text requested for tokenization.")
6871
return empty
6972

70-
result = self.tokenizer(
71-
prompt,
72-
truncation=True,
73+
_tokenize = functools.partial(
74+
self.tokenizer,
7375
max_length=self.max_length,
7476
padding=False,
7577
return_tensors=None,
7678
)
79+
result = _tokenize(
80+
prompt,
81+
truncation=True,
82+
)
7783
if len(result["input_ids"]) == 0:
7884
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
7985
return empty
@@ -91,6 +97,20 @@ def _tokenize(
9197
result["attention_mask"] = result["attention_mask"][1:]
9298

9399
result["labels"] = result["input_ids"].copy()
100+
101+
_all_tokens = _tokenize(prompt, truncation=False)
102+
num_tokens_pre_truncation = len(_all_tokens["input_ids"])
103+
if (
104+
_all_tokens["input_ids"][-1] != self.tokenizer.eos_token_id
105+
and add_eos_token
106+
):
107+
num_tokens_pre_truncation += 1
108+
if (
109+
_all_tokens["input_ids"][0] == self.tokenizer.bos_token_id
110+
and strip_bos_token
111+
):
112+
num_tokens_pre_truncation -= 1
113+
result["num_tokens_pre_truncation"] = num_tokens_pre_truncation
94114
return result
95115

96116

src/axolotl/train.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
from datasets import Dataset
1717
from huggingface_hub.errors import OfflineModeIsEnabled
1818
from peft import PeftConfig, PeftModel
19-
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
19+
from transformers import (
20+
PreTrainedModel,
21+
PreTrainedTokenizer,
22+
ProcessorMixin,
23+
)
2024
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
2125
from transformers.trainer import Trainer
2226

@@ -25,7 +29,6 @@
2529
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
2630
fix_untrained_tokens,
2731
)
28-
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
2932
from axolotl.integrations.base import PluginManager
3033
from axolotl.loaders import (
3134
ModelLoader,
@@ -83,6 +86,9 @@ def setup_model_and_tokenizer(
8386
if model.generation_config is not None:
8487
model.generation_config.do_sample = True
8588

89+
plugin_manager = PluginManager.get_instance()
90+
plugin_manager.post_model_load(cfg, model)
91+
8692
# Apply freezing if specified
8793
if cfg.unfrozen_parameters:
8894
freeze_layers_except(model, cfg.unfrozen_parameters)
@@ -159,7 +165,11 @@ def setup_signal_handler(
159165
safe_serialization: Whether to use safe serialization when saving
160166
"""
161167
# ray workers don't have access to this signal
162-
if cfg.local_rank == 0 and not cfg.use_ray:
168+
if (
169+
cfg.local_rank == 0
170+
and not cfg.use_ray
171+
and cfg.get("save_model_on_interrupt", True)
172+
):
163173

164174
def terminate_handler(_, __, model_weakref):
165175
if model_weakref() is not None:
@@ -472,7 +482,7 @@ def handle_untrained_tokens_fix(
472482

473483

474484
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
475-
HFRLTrainerBuilder | HFCausalTrainerBuilder,
485+
Trainer,
476486
PeftModel | PreTrainedModel,
477487
PreTrainedTokenizer,
478488
PeftConfig | None,
@@ -573,8 +583,14 @@ def train(
573583
# Save the trained model and cleanup
574584
save_trained_model(cfg, trainer, model, safe_serialization)
575585
create_model_card(cfg, trainer)
576-
if not cfg.use_ray:
577-
cleanup_distributed()
586+
587+
if cfg.deepspeed:
588+
trainer.deepspeed.destroy()
589+
trainer.accelerator.free_memory()
590+
trainer.model, trainer.model_wrapped, trainer.optimizer = None, None, None
591+
592+
# if not cfg.use_ray:
593+
# cleanup_distributed()
578594

579595
plugin_manager.post_train(cfg, model)
580596

src/axolotl/utils/data/sft.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,11 @@
5252
retry_on_request_exceptions,
5353
)
5454
from axolotl.utils.dict import DictDefault
55-
from axolotl.utils.distributed import is_local_main_process, zero_first
55+
from axolotl.utils.distributed import (
56+
compute_and_broadcast,
57+
is_local_main_process,
58+
zero_first,
59+
)
5660
from axolotl.utils.logging import get_logger
5761
from axolotl.utils.trainer import (
5862
calculate_total_num_steps,
@@ -174,9 +178,15 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
174178
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
175179
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
176180
if total_eval_steps == 0:
177-
raise ValueError(
178-
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
181+
LOG.warning(
182+
"eval dataset split is too small for sample_packing. Setting `eval_sample_packing to False`."
179183
)
184+
if cfg.world_size > 1:
185+
_eval_sample_packing = compute_and_broadcast(lambda: 0)
186+
if _eval_sample_packing < 1:
187+
cfg.eval_sample_packing = False
188+
else:
189+
cfg.eval_sample_packing = False
180190

181191
if cfg.max_steps:
182192
total_num_steps = min(

0 commit comments

Comments
 (0)