diff --git a/aiu_fms_testing_utils/scripts/drive_paged_programs.py b/aiu_fms_testing_utils/scripts/drive_paged_programs.py index 033a8efe..ea50b41d 100644 --- a/aiu_fms_testing_utils/scripts/drive_paged_programs.py +++ b/aiu_fms_testing_utils/scripts/drive_paged_programs.py @@ -162,6 +162,11 @@ action="store_true", help="set to true to save cpu validation outputs for later consumption", ) +parser.add_argument( + "--only_save_validation_output", + action="store_true", + help="set to true to ONLY save cpu validation outputs for later consumption", +) parser.add_argument( "--prioritize_large_batch_sizes", action="store_true", @@ -391,14 +396,16 @@ def __load_validation_info( and dist.get_world_size() == 4 ): extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT -warmup_model( - model, - input_ids, - max_new_tokens=max_new_tokens, - compile_dynamic_sendnn=True, - stagger_update_lazyhandle=args.stagger_update_lazyhandle, - **extra_kwargs, -) + +if not args.only_save_validation_output: + warmup_model( + model, + input_ids, + max_new_tokens=max_new_tokens, + compile_dynamic_sendnn=True, + stagger_update_lazyhandle=args.stagger_update_lazyhandle, + **extra_kwargs, + ) if USE_DISTRIBUTED: # wait for rank0 to be finished as it is the only one generating the criteria json @@ -659,7 +666,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): **extra_kwargs, ) # save the cpu validation info for later consumption - if save_validation_info_outputs: + if save_validation_info_outputs or args.only_save_validation_output: cpu_validation_info.save( get_validation_info_path( args.validation_info_outputs_dir, @@ -674,7 +681,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): ) ) - if args.test_type == "metrics": + if args.test_type == "metrics" and not args.only_save_validation_output: aiu_validation_info = extract_validation_information( model, input_ids, @@ -718,7 +725,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): if failure_rate >= args.failure_rate_threshold: failed_cases.append((program_id, valid_prompt, failure_rate)) - elif args.test_type == "tokens": + elif args.test_type == "tokens" and not args.only_save_validation_output: aiu_validation_info = extract_validation_information( model, input_ids, @@ -758,9 +765,11 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): dprint(f"AIU tokens:\n{aiu_tokens_generated}") dprint(f"CPU output:\n{tokenizer.decode(cpu_tokens_generated)}") dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}") + elif args.only_save_validation_output: + pass else: raise ValueError("test type must be one of metrics or tokens") - else: + elif not args.only_save_validation_output: aiu_validation_info = extract_validation_information( model, input_ids, @@ -784,7 +793,11 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): dprint(f"AIU tokens:\n{aiu_tokens_generated}") dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}") -if not args.skip_validation and local_rank == 0: +if ( + not args.skip_validation + and local_rank == 0 + and not args.only_save_validation_output +): if len(failed_cases) != 0: dprint("the test failed with the following cases:") for failed_case in failed_cases: