1717from instructlab .training .config import (
1818 DataProcessArgs ,
1919 DistributedBackend ,
20+ LoraOptions ,
2021 TorchrunArgs ,
2122 TrainingArgs ,
2223)
2324from instructlab .training .main_ds import run_training
2425
2526MINIMAL_TRAINING_ARGS = {
2627 "max_seq_len" : 140 , # this config fits nicely on 4xL40s and may need modification for other setups
27- "max_batch_len" : 15000 ,
28+ "max_batch_len" : 5000 ,
2829 "num_epochs" : 1 ,
29- "effective_batch_size" : 3840 ,
30+ "effective_batch_size" : 128 ,
3031 "save_samples" : 0 ,
3132 "learning_rate" : 1e-4 ,
3233 "warmup_steps" : 1 ,
5051RUNNER_CPUS_EXPECTED = 4
5152
5253# Number of samples to randomly sample from the processed dataset for faster training
53- NUM_SAMPLES_TO_KEEP = 5000
54+ NUM_SAMPLES_TO_KEEP = 2500
5455
5556
5657@pytest .fixture (scope = "module" )
@@ -231,25 +232,36 @@ def cached_training_data(
231232@pytest .mark .parametrize (
232233 "dist_backend" , [DistributedBackend .FSDP , DistributedBackend .DEEPSPEED ]
233234)
234- @pytest .mark .parametrize ("cpu_offload" , [True , False ])
235+ @pytest .mark .parametrize ("cpu_offload" , [False , True ])
236+ @pytest .mark .parametrize ("lora_rank" , [0 ])
237+ @pytest .mark .parametrize ("use_liger" , [False , True ])
235238def test_training_feature_matrix (
236239 cached_test_model : pathlib .Path ,
237240 cached_training_data : pathlib .Path ,
238241 checkpoint_dir : pathlib .Path ,
239242 prepared_data_dir : pathlib .Path ,
243+ use_liger : bool ,
244+ lora_rank : int ,
240245 cpu_offload : bool ,
241246 dist_backend : DistributedBackend ,
242247) -> None :
248+ torch_args = TorchrunArgs (** DEFAULT_TORCHRUN_ARGS )
243249 train_args = TrainingArgs (
244250 model_path = str (cached_test_model ),
245251 data_path = str (cached_training_data ),
246252 data_output_dir = str (prepared_data_dir ),
247253 ckpt_output_dir = str (checkpoint_dir ),
254+ lora = LoraOptions (rank = lora_rank ),
255+ use_liger = use_liger ,
248256 ** MINIMAL_TRAINING_ARGS ,
249257 )
250258
251259 train_args .distributed_backend = dist_backend
252260
261+ if lora_rank > 0 :
262+ # LoRA doesn't support full state saving.
263+ train_args .accelerate_full_state_at_epoch = False
264+
253265 if dist_backend == DistributedBackend .FSDP :
254266 train_args .fsdp_options .cpu_offload_params = cpu_offload
255267 else :
@@ -258,6 +270,4 @@ def test_training_feature_matrix(
258270 pytest .xfail ("DeepSpeed CPU Adam isn't currently building correctly" )
259271 train_args .deepspeed_options .cpu_offload_optimizer = cpu_offload
260272
261- torch_args = TorchrunArgs (** DEFAULT_TORCHRUN_ARGS )
262-
263273 run_training (torch_args = torch_args , train_args = train_args )
0 commit comments