diff --git a/aiu_fms_testing_utils/testing/utils.py b/aiu_fms_testing_utils/testing/utils.py new file mode 100644 index 00000000..cacff899 --- /dev/null +++ b/aiu_fms_testing_utils/testing/utils.py @@ -0,0 +1,24 @@ +from collections.abc import Iterable + + +def format_kwargs_to_string(**kwargs): + """ + Turns kwargs into a str with variable names using `-`, variables separated by `_` and iterable separated by `,` + """ + formatted_pairs = [] + for key, value in sorted(kwargs.items()): + formatted_value = None + if isinstance(value, str): + formatted_value = value + elif isinstance(value, Iterable): + formatted_value = ",".join(map(str, value)) + elif value: + formatted_value = str(value) + # only append if formatted_value exists + if formatted_value: + # Keep previous convention of variable names with `-` instead of `_` + formatted_pairs.append( + f"{key.replace('_', '-')}-{formatted_value.replace('/', '--')}" + ) + + return "_".join(formatted_pairs) diff --git a/aiu_fms_testing_utils/testing/validation.py b/aiu_fms_testing_utils/testing/validation.py index 0c655ff5..5bf120a0 100644 --- a/aiu_fms_testing_utils/testing/validation.py +++ b/aiu_fms_testing_utils/testing/validation.py @@ -5,6 +5,9 @@ from aiu_fms_testing_utils.utils.aiu_setup import dprint from aiu_fms_testing_utils._version import version_tuple import os +from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string + +import hashlib class LogitsExtractorHook( @@ -125,13 +128,7 @@ def __len__(self): def get_default_validation_prefix( - model_id: str, - max_new_tokens: int, - batch_size: int, - seq_length: int, - dtype: str, - attn_type: str, - aftu_version: str, + **kwargs, ): """ Args: @@ -144,9 +141,17 @@ def get_default_validation_prefix( aftu_version (str): introduced in v0.3.0 to track changed in log Returns: - str: A prefix that will be prepended to the file name + str: A hashed prefix that will be prepended to the file name """ - return f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}.{aftu_version}" + aftu_version = kwargs.pop( + "aftu_version", ".".join([str(_) for _ in version_tuple[:3]]) + ) + kwargs_str = format_kwargs_to_string(**kwargs) + + filename = f"{kwargs_str}" + hash_object = hashlib.sha256(filename.encode("utf-8")) + hex_digest = hash_object.hexdigest() + return f"{hex_digest}_{aftu_version}" def load_validation_information( @@ -416,11 +421,14 @@ def get_validation_info_path( aftu_version: Optional[Tuple[int, int, int]] = None, device_type: str = "cpu", dtype: str = "fp16", + **kwargs, ): if aftu_version is None: aftu_version = version_tuple - validation_file_name = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]))}.{device_type}_validation_info.{seed}.out" + sample_key = kwargs.get("sample_key", None) + + validation_file_name = f"{get_default_validation_prefix(aftu_version='.'.join([str(_) for _ in aftu_version[:3]]), model_id=model_variant, max_new_tokens=max_new_tokens, batch_size=batch_size, seq_length=seq_length, dtype=dtype, attn_type=attn_type, sample_key=sample_key)}.{device_type}_validation_info.{seed}.out" full_path = os.path.join(validation_info_dir, validation_file_name) return full_path @@ -452,10 +460,12 @@ def find_validation_info_path( version_allow_decrement: bool = False, device_type: str = "cpu", dtype: str = "fp16", + **kwargs, ): """ Find the validation info path if it exists, otherwise return None """ + sample_key = kwargs.get("sample_key", None) if aftu_version is None: loc_version_tuple = version_tuple[:3] @@ -476,6 +486,7 @@ def find_validation_info_path( loc_version_tuple, device_type, dtype, + sample_key=sample_key, ) # if the path is found, we are done searching and can return if os.path.exists(full_path): diff --git a/aiu_fms_testing_utils/utils/__init__.py b/aiu_fms_testing_utils/utils/__init__.py index 65a0f9ab..6615c5c9 100644 --- a/aiu_fms_testing_utils/utils/__init__.py +++ b/aiu_fms_testing_utils/utils/__init__.py @@ -11,6 +11,7 @@ from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, world_size from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string from fms.utils.generation import pad_input_ids import torch @@ -482,6 +483,7 @@ def sample_rag_factoid_requests( enforce_sizes: List[int] = [], truncation: bool = False, pad_multiple: int = 64, + return_key: bool = False, ) -> List[Tuple[str, int]]: if not os.path.exists(dataset_path): print("error dataset does not exist") @@ -492,7 +494,7 @@ def sample_rag_factoid_requests( for line in f: dataset.append(line) - return __sample_requests( + sample_request = __sample_requests( dataset, num_requests, tokenizer, @@ -506,6 +508,24 @@ def sample_rag_factoid_requests( _cached_dataset_key=dataset_path, ) + if return_key: + sample_key: str = format_kwargs_to_string( + dataset="rag_factoid", + num_requests=num_requests, + tokenizer=tokenizer.name_or_path.replace("/", "--"), + prompt_length_min=prompt_length_min, + prompt_length_max=prompt_length_max, + seed=seed, + enforce_heterogeneous=enforce_heterogeneous, + enforce_sizes=enforce_sizes, + truncate=truncation, + pad_multiple=pad_multiple, + ) + + return sample_request, sample_key + else: + return sample_request + def sample_sharegpt_requests( dataset_path: str, @@ -518,6 +538,7 @@ def sample_sharegpt_requests( enforce_sizes: List[int] | None = None, truncation: bool = False, pad_multiple: int = 64, + return_key: bool = False, ) -> List[Tuple[str, int]]: if not os.path.exists(dataset_path): print("downloading share-gpt dataset as it does not exist") @@ -543,7 +564,7 @@ def sample_sharegpt_requests( dataset = [data for data in dataset if len(data["conversations"]) >= 2] dataset: List[str] = [data["conversations"][0]["value"] for data in dataset] - return __sample_requests( + sample_request = __sample_requests( dataset, num_requests, tokenizer, @@ -557,6 +578,23 @@ def sample_sharegpt_requests( _cached_dataset_key=dataset_path, ) + if return_key: + sample_key: str = format_kwargs_to_string( + dataset="sharegpt", + num_requests=num_requests, + tokenizer=tokenizer.name_or_path.replace("/", "--"), + prompt_length_min=prompt_length_min, + prompt_length_max=prompt_length_max, + seed=seed, + enforce_heterogeneous=enforce_heterogeneous, + enforce_sizes=enforce_sizes, + truncate=truncation, + pad_multiple=pad_multiple, + ) + return sample_request, sample_key + else: + return sample_request + def sample_squad_v2_qa_requests( dataset_path: str, @@ -569,6 +607,7 @@ def sample_squad_v2_qa_requests( enforce_sizes: List[int] | None = None, truncation: bool = False, pad_multiple: int = 64, + return_key: bool = False, ) -> List[Tuple[str, int]]: from datasets import load_dataset @@ -582,7 +621,7 @@ def sample_squad_v2_qa_requests( ds = [f"{data['context']}\n{data['question']}" for data in ds] - return __sample_requests( + sample_request = __sample_requests( ds, num_requests, tokenizer, @@ -595,6 +634,23 @@ def sample_squad_v2_qa_requests( pad_multiple, ) + if return_key: + sample_key: str = format_kwargs_to_string( + dataset="squad_v2", + num_requests=num_requests, + tokenizer=tokenizer.name_or_path.replace("/", "--"), + prompt_length_min=prompt_length_min, + prompt_length_max=prompt_length_max, + seed=seed, + enforce_heterogeneous=enforce_heterogeneous, + enforce_sizes=enforce_sizes, + truncate=truncation, + pad_multiple=pad_multiple, + ) + return sample_request, sample_key + else: + return sample_request + def prepare_inputs( batch_size, seq_length, tokenizer, ds_path, seed=0, ds_type="sharegpt" diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index edf0c548..1d6bcbc7 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -153,7 +153,9 @@ def generate( raise ValueError("model must have a distributed_strategy") kvheads = kvheads // tensor_parallel_size if kvheads > 1 else kvheads - head_size = model.config.emb_dim // nheads + head_size = getattr( + model.config, "head_dim", model.config.emb_dim // model.config.nheads + ) if "fp8" in kwargs["attn_name"]: from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor diff --git a/scripts/drive_paged_programs.py b/scripts/drive_paged_programs.py index ea51bad8..033a8efe 100644 --- a/scripts/drive_paged_programs.py +++ b/scripts/drive_paged_programs.py @@ -40,6 +40,7 @@ get_programs_prompts, KVCACHE_NUM_BLOCKS_HINT, ) +from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string parser = argparse.ArgumentParser( description="Script which will drive paged programs for debugging" @@ -195,6 +196,10 @@ custom_shape = (len(result), max([_[1] for _ in result])) def __custom_line_sampler(*args, **kwargs): + return_key = kwargs.get("return_key", False) + sample_key = format_kwargs_to_string(**kwargs) + if return_key: + return result, sample_key return result sampler = __custom_line_sampler @@ -245,7 +250,7 @@ def __custom_line_sampler(*args, **kwargs): def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0): start = time.time() - prompts_and_sizes = sampler( + prompts_and_sizes, sample_key = sampler( DATASET_PATH, batch_size, tokenizer, @@ -254,6 +259,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0 seed, enforce_sizes=enforce_sizes, truncation=allow_truncation, + return_key=True, ) end = time.time() if local_rank == 0: @@ -274,7 +280,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0 input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length) extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16) - return input_ids, extra_kwargs + return input_ids, extra_kwargs, sample_key def __maybe_prepare_fp8_weights(model_in, is_fp8): @@ -296,7 +302,9 @@ def __load_validation_info( tokenizer, seed, attn_type: str, + **kwargs, ): + sample_key = kwargs.get("sample_key", None) full_path = find_validation_info_path( args.validation_info_outputs_dir, model_variant, @@ -307,6 +315,7 @@ def __load_validation_info( attn_type, version_allow_decrement=True, dtype=CPU_DTYPE, + sample_key=sample_key, ) if full_path is not None: dprint(f"cpu validation info found for seed={seed} -- loading it") @@ -367,13 +376,14 @@ def __load_validation_info( # warmup with any input so compiler produces criteria json # TODO: Swap this with __prepare_inputs once fix for shape_id is available -# input_ids, extra_kwargs = __prepare_inputs(2, max_tkv, tokenizer) +# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer) prompt_list = [torch.arange(0, 64, dtype=torch.int64)] # matching vllm warmup to pad to 2 on fp8, and no pad for fp16 if is_fp8: prompt_list = prompt_list * 2 input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64) extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16) + extra_kwargs["attn_name"] = ATTN_NAME if ( "granite-3.3-8b-instruct" in model_variant @@ -494,7 +504,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: for valid_prompt_shape in valid_prompt_shapes: if valid_prompt_shape == custom_shape: enforce_sizes = [valid_prompt_shape[1]] - input_ids, extra_kwargs = __prepare_inputs( + input_ids, extra_kwargs, sample_key = __prepare_inputs( valid_prompt_shape[0], valid_prompt_shape[1], tokenizer, @@ -506,6 +516,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: custom_shape, input_ids, extra_kwargs, + sample_key, ) ] break @@ -566,7 +577,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: ) ) try: - input_ids, extra_kwargs = __prepare_inputs( + input_ids, extra_kwargs, sample_key = __prepare_inputs( valid_prompt_shape[0], valid_prompt_shape[1], tokenizer, @@ -578,6 +589,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: valid_prompt_shape, input_ids, extra_kwargs, + sample_key, ) ) used_keys.add(program_seq_key[0]) @@ -609,7 +621,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): failed_cases = [] # for each program and valid prompt (batch size, sequence length) -for program_id, valid_prompt, input_ids, extra_kwargs in valid_prompts: +for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts: extra_kwargs["attn_name"] = ATTN_NAME if ( "granite-3.3-8b-instruct" in model_variant @@ -634,6 +646,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): tokenizer, seed=0, attn_type=ATTN_NAME, + sample_key=sample_key, ) # if the cpu validation info is not yet computed, compute it if cpu_validation_info is None: @@ -657,6 +670,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): 0, ATTN_NAME, dtype=CPU_DTYPE, + sample_key=sample_key, ) ) diff --git a/scripts/generate_layers_metrics.py b/scripts/generate_layers_metrics.py index d3245123..ffc01930 100644 --- a/scripts/generate_layers_metrics.py +++ b/scripts/generate_layers_metrics.py @@ -473,7 +473,11 @@ def generate_layers_metrics(model_path, batch_size, seq_length, max_new_tokens): cos_sim = tensor_cos_sim(tensor_cpu_out, cuda_output) prefix = get_default_validation_prefix( - model_path, max_new_token, batch_size, seq_length, "float16" + model_id=model_path, + max_new_tokens=max_new_token, + batch_size=batch_size, + seq_length=seq_length, + dtype="float16", ) layer_name = str(layer_key).replace("[", "").replace("]", "") diff --git a/scripts/generate_metrics.py b/scripts/generate_metrics.py index 8ec3f028..f65149fa 100644 --- a/scripts/generate_metrics.py +++ b/scripts/generate_metrics.py @@ -134,11 +134,11 @@ # this follows the same pattern of naming in test_shapes. This way we can save and re-use for quicker shape testing. prefix = get_default_validation_prefix( - args.variant, - args.max_new_tokens, - args.batch_size, - args.min_pad_length, - args.default_dtype, + model_id=args.variant, + max_new_tokens=args.max_new_tokens, + batch_size=args.batch_size, + seq_len=args.min_pad_length, + dtype=args.default_dtype, ) if os.path.exists(os.path.join(args.output_dir, f"{prefix}.prob_mean.csv")): print("skipping metric generation as it has already been done") diff --git a/scripts/inference.py b/scripts/inference.py index 3ec33f0e..78919673 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -20,6 +20,7 @@ from fms.models.llama import LLaMAConfig, _llama_factory_factory from fms.utils import generation from fms.utils.generation import pad_input_ids +from fms.utils import serialization from transformers import AutoTokenizer @@ -257,6 +258,15 @@ default=0, help="Timeout to use for messaging in minutes. Default set by PyTorch dist.init_process_group", ) + +parser.add_argument( + "--head_dim", + type=int, + default=None, + help="Override the head_dim in the model config", +) + + args = parser.parse_args() attention_map = { @@ -504,6 +514,12 @@ def select_int8_module( dprint(f"data_type={default_dtype}") dprint("=" * 60 + "\n") +if args.device_type == "aiu" and args.head_dim is not None: + serialization.extend_adapter( + "granite", "hf", ["weight_expansion_for_mismatched_head_dim"] + ) + + with stagger_region(args.stagger_load): model = get_model( args.architecture, @@ -516,6 +532,10 @@ def select_int8_module( group=dist.group.WORLD, linear_config=linear_config, fused_weights=fused_weights, + override_hf_pretrained_config=True + if args.device_type == "aiu" and args.head_dim is not None + else False, + head_dim=args.head_dim, ) ### Quantization diff --git a/tests/models/conftest.py b/tests/models/conftest.py index 7e716618..e2d06176 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -23,6 +23,10 @@ def pytest_sessionstart(session): os.environ.setdefault("DTLOG_LEVEL", "error") os.environ.setdefault("DT_DEEPRT_VERBOSE", "-1") + # NOTE: we should configure the cachedir before importing torchsendnn's + # graph cache to prevent it from being initialized in the wrong place. + os.environ.setdefault("TORCH_SENDNN_CACHE_DIR", os.path.join(os.getcwd(), ".cache")) + def pytest_addoption(parser): parser.addoption( diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index 122c9664..4f95e61e 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -29,7 +29,6 @@ from transformers import AutoTokenizer from aiu_fms_testing_utils.utils.aiu_setup import dprint, aiu_dist_setup - import os try: @@ -50,7 +49,7 @@ GRANITE_20B_CODE_INSTRUCT_8K = "ibm-granite/granite-20b-code-instruct-8k" LLAMA_3p1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct" -micro_model_mapping = { +MICRO_MODEL_MAPPING = { LLAMA_3p1_8B_INSTRUCT: os.path.join( MICRO_MODELS_HOME, "llama-3.1-8b-layers-3-step-24000" ), @@ -76,24 +75,24 @@ os.environ.get("FMS_TEST_SHAPES_CUMULATIVE_TEST_TOKENS_PER_SEQUENCE", "1024") ) ATTN_TYPE = os.environ.get("FMS_TEST_SHAPES_ATTN_TYPE", "sdpa") -attention_map = { +ATTENTION_MAP = { "sdpa": "sdpa_causal", "paged": "spyre_paged_attn", "math_fp8": "math_fp8", "paged_fp8": "spyre_paged_attn_fp8", } -ATTN_NAME = attention_map[ATTN_TYPE] +ATTN_NAME = ATTENTION_MAP[ATTN_TYPE] CPU_DTYPE = "fp8" if "fp8" in ATTN_TYPE else "fp32" FORCE_VALIDATION_LEVEL_1 = ( os.environ.get("FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1", "0") == "1" ) -skip_assertions = os.environ.get("FMS_TEST_SHAPES_SKIP_ASSERTIONS", {}) -validation_info_dir = os.environ.get( +SKIP_ASSERTIONS = os.environ.get("FMS_TEST_SHAPES_SKIP_ASSERTIONS", {}) +VALIDATION_INFO_DIR = os.environ.get( "FMS_TEST_SHAPES_VALIDATION_INFO_DIR", "/tmp/models/validation_info" ) -common_model_paths = os.environ.get( +COMMON_MODEL_PATHS = os.environ.get( "FMS_TEST_SHAPES_COMMON_MODEL_PATHS", [ LLAMA_3p1_8B_INSTRUCT, @@ -103,96 +102,96 @@ LLAMA_3p1_70B_INSTRUCT, ], ) -model_configuration_path = os.environ.get( +MODEL_CONFIGURATION_PATH = os.environ.get( "FMS_TEST_SHAPES_FROM_MODEL_CONFIGURATION", "" ) -model_configuration_frequency = os.environ.get( +MODEL_CONFIGURATION_FREQUENCY = os.environ.get( "FMS_TEST_SHAPES_FROM_MODEL_CONFIGURATION_FREQUENCY", "0" ) # for validation level 1, the default is a failure rate of 1% # set this environment variable if you would like to relax that threshold -failure_rate_threshold = os.environ.get("FMS_TEST_SHAPES_FAILURE_THRESHOLD", 0.01) -default_metrics_threshold = os.environ.get( +FAILURE_RATE_THRESHOLD = os.environ.get("FMS_TEST_SHAPES_FAILURE_THRESHOLD", 0.01) +DEFAULT_METRICS_THRESHOLD = os.environ.get( "FMS_TEST_SHAPES_METRICS_THRESHOLD", (3.0, 0.001) ) -save_validation_info_outputs = ( +SAVE_VALIDATION_INFO_OUTPUTS = ( os.environ.get("FMS_TEST_SHAPES_SAVE_VALIDATION_INFO_OUTPUTS", "0") == "1" ) -common_batch_sizes = os.environ.get("FMS_TEST_SHAPES_COMMON_BATCH_SIZES", [1, 2, 4, 8]) -common_seq_lengths = os.environ.get("FMS_TEST_SHAPES_COMMON_SEQ_LENGTHS", [64, 2048]) -common_max_new_tokens = os.environ.get("FMS_TEST_SHAPES_COMMON_MAX_NEW_TOKENS", [128]) +COMMON_BATCH_SIZES = os.environ.get("FMS_TEST_SHAPES_COMMON_BATCH_SIZES", [1, 2, 4, 8]) +COMMON_SEQ_LENGTHS = os.environ.get("FMS_TEST_SHAPES_COMMON_SEQ_LENGTHS", [64, 2048]) +COMMON_MAX_NEW_TOKENS = os.environ.get("FMS_TEST_SHAPES_COMMON_MAX_NEW_TOKENS", [128]) if USE_DISTRIBUTED: dist.init_process_group() aiu_dist_setup(dist.get_rank(), dist.get_world_size()) - save_validation_info_outputs = save_validation_info_outputs and ( + SAVE_VALIDATION_INFO_OUTPUTS = SAVE_VALIDATION_INFO_OUTPUTS and ( dist.get_rank() == 0 ) if USE_MICRO_MODELS: - validation_info_dir = os.path.join(validation_info_dir, "tiny_models") + VALIDATION_INFO_DIR = os.path.join(VALIDATION_INFO_DIR, "tiny_models") -# pass custom model path list for eg: EXPORT FMS_TESTING_COMMON_MODEL_PATHS="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base" -if isinstance(common_model_paths, str): - common_model_paths = common_model_paths.split(",") +# pass custom model path list for eg: EXPORT FMS_TEST_SHAPES_COMMON_MODEL_PATHS="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base" +if isinstance(COMMON_MODEL_PATHS, str): + COMMON_MODEL_PATHS = COMMON_MODEL_PATHS.split(",") # pass custom failure rate threshold as float -if isinstance(failure_rate_threshold, str): - failure_rate_threshold = float(failure_rate_threshold) +if isinstance(FAILURE_RATE_THRESHOLD, str): + FAILURE_RATE_THRESHOLD = float(FAILURE_RATE_THRESHOLD) # pass custom default metrics threshold as a comma separated str of floats , -if isinstance(default_metrics_threshold, str): - default_metrics_threshold = tuple( - [float(m) for m in default_metrics_threshold.split(",")] +if isinstance(DEFAULT_METRICS_THRESHOLD, str): + DEFAULT_METRICS_THRESHOLD = tuple( + [float(m) for m in DEFAULT_METRICS_THRESHOLD.split(",")] ) # pass custom common batch sizes as a comma separated str of ints -if isinstance(common_batch_sizes, str): - common_batch_sizes = [int(bs) for bs in common_batch_sizes.split(",")] +if isinstance(COMMON_BATCH_SIZES, str): + COMMON_BATCH_SIZES = [int(bs) for bs in COMMON_BATCH_SIZES.split(",")] # pass custom common seq lengths as a comma separated str of ints -if isinstance(common_seq_lengths, str): - common_seq_lengths = [int(sl) for sl in common_seq_lengths.split(",")] +if isinstance(COMMON_SEQ_LENGTHS, str): + COMMON_SEQ_LENGTHS = [int(sl) for sl in COMMON_SEQ_LENGTHS.split(",")] # pass custom common max new tokens as a comma separated str of ints -if isinstance(common_max_new_tokens, str): - common_max_new_tokens = [int(mnt) for mnt in common_max_new_tokens.split(",")] +if isinstance(COMMON_MAX_NEW_TOKENS, str): + COMMON_MAX_NEW_TOKENS = [int(mnt) for mnt in COMMON_MAX_NEW_TOKENS.split(",")] # pass metrics to skip as a comma separated list (ce,mean_diff) -if isinstance(skip_assertions, str): +if isinstance(SKIP_ASSERTIONS, str): _skip_assertions = [] - for metric in skip_assertions.split(","): + for metric in SKIP_ASSERTIONS.split(","): metric = metric.lower() if metric not in {"ce", "mean_diff"}: pytest.fail( "FMS_TEST_SHAPES_SKIP_ASSERTIONS can only accept metrics ce and mean_diff" ) _skip_assertions.append(metric) - skip_assertions = set(_skip_assertions) + SKIP_ASSERTIONS = set(_skip_assertions) -compile_dynamic_sendnn = ATTN_TYPE == "paged" +COMPILE_DYNAMIC_SENDNN = ATTN_TYPE == "paged" -if compile_dynamic_sendnn: +if COMPILE_DYNAMIC_SENDNN: import bisect # the compiler supports certain max context lengths (VLLM_DT_MAX_CONTEXT_LEN) # this will ensure that we select smallest supported VLLM_DT_MAX_CONTEXT_LEN that fits the largest possible context (prompt size + max_new_tokens) - __largest_context = max(common_seq_lengths) + max(common_max_new_tokens) + __largest_context = max(COMMON_SEQ_LENGTHS) + max(COMMON_MAX_NEW_TOKENS) __supported_context_lengths = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768] os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = str( __supported_context_lengths[ bisect.bisect_left(__supported_context_lengths, __largest_context) ] ) - os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(max(max(common_batch_sizes), 2)) + os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(max(max(COMMON_BATCH_SIZES), 2)) fx_config.backed_size_oblivious = True # thresholds are chosen based on 1024 tokens per sequence # 1% error threshold rate between cpu fp32 and cuda fp16 # if a models failure thresholds do not exist in this dict, default to the default_metrics_threshold defined above # threshold key is (model_id, is_tiny_model) -fail_thresholds = { +FAIL_THRESHOLDS = { (LLAMA_3p1_8B_INSTRUCT, False): ( 2.6994638133048965, 0.00047589250549208347, @@ -215,33 +214,33 @@ ), } -if model_configuration_path != "": +if MODEL_CONFIGURATION_PATH != "": print( "ignoring FMS_TEST_SHAPES_COMMON_MODEL_PATHS, FMS_TEST_SHAPES_USE_MICRO_MODELS as configuration will be set by FMS_TEST_SHAPES_FROM_MODEL_CONFIGURATION" ) USE_MICRO_MODELS = False - common_model_paths = [] - frequency = int(model_configuration_frequency) - with open(model_configuration_path, "r") as f: + COMMON_MODEL_PATHS = [] + FREQUENCY = int(MODEL_CONFIGURATION_FREQUENCY) + with open(MODEL_CONFIGURATION_PATH, "r") as f: for line in f: try: - model_config = json.loads(line) - if model_config["frequency"] <= frequency: - common_model_paths.append(model_config["model_id"]) + MODEL_CONFIG = json.loads(line) + if MODEL_CONFIG["frequency"] <= FREQUENCY: + COMMON_MODEL_PATHS.append(MODEL_CONFIG["model_id"]) # assume fullsize models - fail_thresholds[(model_config["model_id"], USE_MICRO_MODELS)] = ( - model_config["ce"], - model_config["mean_diff"], + FAIL_THRESHOLDS[(MODEL_CONFIG["model_id"], USE_MICRO_MODELS)] = ( + MODEL_CONFIG["ce"], + MODEL_CONFIG["mean_diff"], ) except json.JSONDecodeError: print(f"config contained an improper json line: {line.strip()}") -common_shapes = list( +COMMON_SHAPES = list( itertools.product( - common_model_paths, - common_batch_sizes, - common_seq_lengths, - common_max_new_tokens, + COMMON_MODEL_PATHS, + COMMON_BATCH_SIZES, + COMMON_SEQ_LENGTHS, + COMMON_MAX_NEW_TOKENS, ) ) @@ -255,7 +254,7 @@ @pytest.fixture(autouse=True) def reset_compiler(): yield # run the test - if not compile_dynamic_sendnn: + if not COMPILE_DYNAMIC_SENDNN: torch.compiler.reset() torch._dynamo.reset() os.environ.pop("COMPILATION_MODE", None) @@ -309,7 +308,7 @@ def __maybe_get_gptq_kwargs(model_path): return gptq_kwargs_aiu, gptq_kwargs_cpu -def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0): +def __prepare_inputs(batch_size, seq_length, tokenizer, model_path, seed=0): if "paged" in ATTN_NAME: prompts_and_sizes = sample_sharegpt_requests( SHARE_GPT_DATASET_PATH, @@ -337,6 +336,14 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0): prompt_list.append(tokenizer.encode(prompt, return_tensors="pt").squeeze(0)) input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length) + extra_kwargs["attn_name"] = ATTN_NAME + if ( + "paged" in ATTN_NAME + and "ibm-granite/granite-3.3-8b-instruct" in model_path + and USE_DISTRIBUTED + and dist.get_world_size() == 4 + ): + extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT return input_ids, extra_kwargs @@ -364,11 +371,17 @@ def __filter_before_eos(metrics, filter_indexes): def __load_validation_info( - model_path, batch_size, seq_length, max_new_tokens, tokenizer, seed, attn_type: str + model_path, + batch_size, + seq_length, + max_new_tokens, + tokenizer, + seed, + attn_type: str, ): # if path doesn't exist and paged isn't in the attention name, remove `attn_type` and recheck again, warn that we will no longer in the future have paths without 'attn_type' full_path = find_validation_info_path( - validation_info_dir, + VALIDATION_INFO_DIR, model_path, batch_size, seq_length, @@ -405,10 +418,10 @@ def get_or_create(self, is_gptq, is_fp8, **kwargs): model.eval() model.compile( - backend="sendnn", options={"sendnn.dynamic": compile_dynamic_sendnn} + backend="sendnn", options={"sendnn.dynamic": COMPILE_DYNAMIC_SENDNN} ) - if compile_dynamic_sendnn: + if COMPILE_DYNAMIC_SENDNN: self.model = model return model @@ -457,31 +470,58 @@ def persistent_model(): return PersistentModel() -@pytest.mark.parametrize( - "model_path,batch_size,seq_length,max_new_tokens", common_shapes -) -def test_common_shapes( - model_path, - batch_size, - seq_length, - max_new_tokens, - persistent_model, - record_property, +##### Common utils +# metric calculator based on the cross-entropy and mean diff for each decode step +def _metric_calculator(r: torch.Tensor, t: torch.Tensor): + cross_entropy = torch.nn.CrossEntropyLoss()( + r, t.softmax(dim=1).to(dtype=torch.float32) + ) + diff = torch.mean( + torch.abs( + r.softmax(dim=1).to(dtype=torch.float32) + - t.softmax(dim=1).to(dtype=torch.float32) + ) + ) + return (cross_entropy, diff) + + +def _check_failure_thresholds( + diff_fail_responses_list, + ce_fail_responses_list, + total_tokens, + record_property=None, ): - torch.manual_seed(42) - torch.set_grad_enabled(False) - os.environ["COMPILATION_MODE"] = "offline_decoder" + # test the failure rates for across all tokens + diff_failure_rate = len(diff_fail_responses_list) / total_tokens + ce_failure_rate = len(ce_fail_responses_list) / total_tokens + dprint(f"mean diff failure rate: {diff_failure_rate}") + dprint(f"cross entropy loss failure rate: {ce_failure_rate}") - dprint( - f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, attn_type={ATTN_TYPE}" - ) + if record_property is not None: + # Add failure rates to xml report + record_property("mean_diff_failure_rate", diff_failure_rate) + record_property("cross_entropy_loss_failure_rate", ce_failure_rate) + + if "mean_diff" not in SKIP_ASSERTIONS: + assert diff_failure_rate < FAILURE_RATE_THRESHOLD, ( + f"failure rate for mean diff was too high: {diff_failure_rate}" + ) + if "ce" not in SKIP_ASSERTIONS: + assert ce_failure_rate < FAILURE_RATE_THRESHOLD, ( + f"failure rate for cross entropy loss was too high: {ce_failure_rate}" + ) + print("passed validation level 1") + else: + print("passed validation level 0") - # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured - gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path) - is_gptq = len(gptq_kwargs_aiu) != 0 - is_fp8 = "fp8" in ATTN_NAME - micro_model_path = micro_model_mapping.get(model_path, None) +def _get_common_model_kwargs(is_gptq, model_path): + if is_gptq: + return {} + # Get the micro model kwargs + # TODO clean up path handling for micro models + micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) + if USE_MICRO_MODELS and micro_model_path is None: dprint("using randomly initialized model") micro_model_kwargs = {"architecture": "hf_configured", "nlayers": 3} @@ -489,6 +529,7 @@ def test_common_shapes( dprint("using trained model") micro_model_kwargs = {"architecture": "hf_pretrained"} + # Get the model path kwargs if not USE_MICRO_MODELS and os.path.exists(model_path): model_path_kwargs = {"model_path": model_path} elif USE_MICRO_MODELS and micro_model_path is not None: @@ -496,84 +537,152 @@ def test_common_shapes( else: model_path_kwargs = {"variant": model_path} + # Get the distributed kwargs distributed_kwargs = {} if USE_DISTRIBUTED: distributed_kwargs["distributed_strategy"] = "tp" distributed_kwargs["group"] = dist.group.WORLD - get_model_kwargs = {} - if not is_gptq: - get_model_kwargs = { - **model_path_kwargs, - **micro_model_kwargs, - **distributed_kwargs, - } - - tokenizer = AutoTokenizer.from_pretrained(model_path) + return { + **model_path_kwargs, + **micro_model_kwargs, + **distributed_kwargs, + } - # prepare the AIU model - model = persistent_model.get_or_create( - is_gptq, is_fp8, **gptq_kwargs_aiu, **get_model_kwargs - ) +# NOTE micro_model_state_dict should be None if USE_MICRO_MODELS is true +# Otherwise it should be model.state_dict() where model is the AIU model +def _get_cpu_model(is_gptq, is_fp8, micro_model_state_dict=None, **kwargs): # prepare the cpu model validation_model = get_model( device_type="cpu", data_type=None if is_fp8 or is_gptq else torch.float32, fused_weights=False, - **gptq_kwargs_cpu, - **get_model_kwargs, + **kwargs, ) - if USE_MICRO_MODELS: + # This is a micro model, so we need to copy the state dict directly. + if micro_model_state_dict is not None: serialization.load_state_dict_into_model( - validation_model, model.state_dict(), **__custom_adapter + validation_model, micro_model_state_dict, **__custom_adapter ) + return validation_model - # prepare input_ids - input_ids, extra_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer) - extra_kwargs["attn_name"] = ATTN_NAME - if ( - "paged" in ATTN_NAME - and "ibm-granite/granite-3.3-8b-instruct" in model_path - and USE_DISTRIBUTED - and dist.get_world_size() == 4 - ): - extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT - # warmup aiu model - warmup_model( - model, input_ids, max_new_tokens, compile_dynamic_sendnn, **extra_kwargs - ) +def _get_device_validation_information( + model_path, + batch_size, + seq_length, + max_new_tokens, + post_iteration_hook, + model, + input_ids, + extra_kwargs, + token_iter, + device="aiu", + tokenizer=None, +): + # For CPU, we try to load it from disk first if it exists + if device == "cpu": + cpu_validation_info = __load_validation_info( + model_path, + batch_size, + seq_length, + max_new_tokens, + tokenizer, + token_iter, + ATTN_NAME, + ) + + if cpu_validation_info is not None: + return cpu_validation_info + + # overrides for validation info that are device specific + device_dependent_kwargs = {} + if device == "cpu": + device_dependent_kwargs["attn_algorithm"] = "math" + + if device == "aiu": + device_dependent_kwargs["last_n_tokens"] = 64 if "paged" in ATTN_NAME else 1 - # generate cpu validation info - cpu_validation_info = __load_validation_info( - model_path, batch_size, seq_length, max_new_tokens, tokenizer, 0, ATTN_NAME + # Otherwise we need to get the AIU / CPU validation info + validation_info = extract_validation_information( + model, + input_ids, + max_new_tokens, + post_iteration_hook, + timing=TIMING, + **extra_kwargs, + **device_dependent_kwargs, ) - if cpu_validation_info is None: - cpu_validation_info = extract_validation_information( - validation_model, - input_ids, - max_new_tokens, - LogitsExtractorHook(), - attn_algorithm="math", - timing=TIMING, - **extra_kwargs, + if SAVE_VALIDATION_INFO_OUTPUTS: + dprint(f"saving {device} validation for - iter={token_iter}") + # TODO - there is probably a cleaner way to handle this too + kwargs = {} + if device == "cpu": + kwargs["dtype"] = CPU_DTYPE + + validation_info.save( + get_validation_info_path( + VALIDATION_INFO_DIR, + model_path, + batch_size, + seq_length, + max_new_tokens, + token_iter, + ATTN_NAME, + device_type=device, + **kwargs, + ) ) + return validation_info + - if save_validation_info_outputs: - cpu_validation_info.save( - get_validation_info_path( - validation_info_dir, - model_path, - batch_size, - seq_length, - max_new_tokens, - 0, - ATTN_NAME, - dtype=CPU_DTYPE, - ) +def _resolve_thresholds(model_path, micro_model_path): + # if we do not have real model weights, use a default_metrics_threshold + if USE_MICRO_MODELS and micro_model_path is None: + ce_threshold, diff_threshold = DEFAULT_METRICS_THRESHOLD + # if we have real weights, try and get the proper validation metrics threshold + else: + # if we have a micro model with real weights, but no real thresholds, default to the full model thresholds + if USE_MICRO_MODELS: + ce_threshold, diff_threshold = FAIL_THRESHOLDS.get( + (model_path, True), + FAIL_THRESHOLDS.get((model_path, False), DEFAULT_METRICS_THRESHOLD), + ) + else: + ce_threshold, diff_threshold = FAIL_THRESHOLDS.get( + (model_path, False), DEFAULT_METRICS_THRESHOLD ) + return ce_threshold, diff_threshold + + +def _run_validation_level_0( + model_path, + batch_size, + seq_length, + max_new_tokens, + tokenizer, + validation_model, + input_ids, + extra_kwargs, + model, +): + cpu_validation_info = _get_device_validation_information( + model_path=model_path, + batch_size=batch_size, + seq_length=seq_length, + max_new_tokens=max_new_tokens, + post_iteration_hook=LogitsExtractorHook(), + model=validation_model, + input_ids=input_ids, + extra_kwargs=extra_kwargs, + token_iter=0, + device="cpu", + tokenizer=tokenizer, + ) + + # Get the cpu static toks / initial eos sequences for iter 0 cpu_static_tokens = cpu_validation_info.get_info("tokens") eos_indexes = __find_eos_index( cpu_static_tokens, tokenizer.eos_token_id, seq_length, max_new_tokens @@ -583,14 +692,18 @@ def test_common_shapes( ) # first test validation level 0 - aiu_validation_info = extract_validation_information( - model, - input_ids, - max_new_tokens, - None, - last_n_tokens=64 if "paged" in ATTN_NAME else 1, - timing=TIMING, - **extra_kwargs, + aiu_validation_info = _get_device_validation_information( + model_path=model_path, + batch_size=batch_size, + seq_length=seq_length, + max_new_tokens=max_new_tokens, + post_iteration_hook=None, + model=model, + input_ids=input_ids, + extra_kwargs=extra_kwargs, + token_iter=0, + device="aiu", + tokenizer=tokenizer, ) dprint("aiu validation info extracted for validation level 0") @@ -599,7 +712,159 @@ def test_common_shapes( aiu_validation_info.get_info("tokens"), cpu_static_tokens ) - failed_validation_level_0 = len(failed_responses) != 0 + # Keep things we may need on the first iter for validation 1 + validation_zero_info = { + "cpu_validation_info": cpu_validation_info, + "cpu_static_tokens": cpu_static_tokens, + "eos_indexes": eos_indexes, + } + return len(failed_responses) != 0, validation_zero_info + + +def _run_validation_level_1( + model_path, + batch_size, + seq_length, + max_new_tokens, + tokenizer, + validation_model, + input_ids, + extra_kwargs, + model, + micro_model_path, + validation_zero_info, + record_property, +): + iters = int(CUMULATIVE_TEST_TOKENS_PER_SEQUENCE) // max_new_tokens + ce_fail_responses_list = [] + diff_fail_responses_list = [] + total_tokens = 0 + for i in range(iters): + # for iteration 0, we have computed the cpu validation info in the prior step for seed=0, so skip + if i != 0: + input_ids, extra_kwargs = __prepare_inputs( + batch_size, seq_length, tokenizer, model_path, seed=i + ) + cpu_validation_info = _get_device_validation_information( + model_path=model_path, + batch_size=batch_size, + seq_length=seq_length, + max_new_tokens=max_new_tokens, + post_iteration_hook=LogitsExtractorHook(), + model=validation_model, + input_ids=input_ids, + extra_kwargs=extra_kwargs, + token_iter=i, + device="cpu", + tokenizer=tokenizer, + ) + dprint(f"cpu validation info extracted for validation level 1 - iter={i}") + + cpu_static_tokens = cpu_validation_info.get_info("tokens") + eos_indexes = __find_eos_index( + cpu_static_tokens, + tokenizer.eos_token_id, + seq_length, + max_new_tokens, + ) + else: + # TODO this can be cleaned up further + cpu_validation_info = validation_zero_info["cpu_validation_info"] + cpu_static_tokens = validation_zero_info["cpu_static_tokens"] + eos_indexes = validation_zero_info["eos_indexes"] + + aiu_validation_info = _get_device_validation_information( + model_path=model_path, + batch_size=batch_size, + seq_length=seq_length, + max_new_tokens=max_new_tokens, + post_iteration_hook=GoldenTokenHook(cpu_static_tokens), + model=model, + input_ids=input_ids, + extra_kwargs=extra_kwargs, + token_iter=i, + device="aiu", + tokenizer=tokenizer, + ) + dprint(f"aiu validation info extracted for validation level 1 - iter={i}") + + # capture all level 1 metrics + level_1_metrics = capture_level_1_metrics( + cpu_validation_info.get_info("logits"), + aiu_validation_info.get_info("logits"), + top_k_loss_calculator(20, _metric_calculator), + ) + # only consider those metrics captured prior to the eos + level_1_metrics = __filter_before_eos(level_1_metrics, eos_indexes) + + ce_threshold, diff_threshold = _resolve_thresholds(model_path, micro_model_path) + + # get all failed responses for each metric + ce_fail_responses = filter_failed_level_1_cases( + level_1_metrics, lambda m: m[0] >= ce_threshold + ) + diff_fail_responses = filter_failed_level_1_cases( + level_1_metrics, + lambda m: m[1] >= diff_threshold, + ) + + ce_fail_responses_list.extend(ce_fail_responses) + diff_fail_responses_list.extend(diff_fail_responses) + total_tokens += len(level_1_metrics) + + _check_failure_thresholds( + diff_fail_responses_list, + ce_fail_responses_list, + total_tokens, + record_property, + ) + + +##### Test definitions +def _run_cpu_aiu_validation_test( + model_path, + batch_size, + seq_length, + max_new_tokens, + cpu_model, + aiu_model, + micro_model_path, + record_property, +): + # Get the tokenizer and AIU / CPU models to compare + tokenizer = AutoTokenizer.from_pretrained(model_path) + + # prepare input_ids + input_ids, extra_kwargs = __prepare_inputs( + batch_size, seq_length, tokenizer, model_path + ) + + extra_kwargs["attn_name"] = ATTN_NAME + if ( + "paged" in ATTN_NAME + and "ibm-granite/granite-3.3-8b-instruct" in model_path + and USE_DISTRIBUTED + and dist.get_world_size() == 4 + ): + extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT + + # warmup aiu model + warmup_model( + aiu_model, input_ids, max_new_tokens, COMPILE_DYNAMIC_SENDNN, **extra_kwargs + ) + + # Run validation level 0 + failed_validation_level_0, validation_zero_info = _run_validation_level_0( + model_path, + batch_size, + seq_length, + max_new_tokens, + tokenizer, + cpu_model, + input_ids, + extra_kwargs, + aiu_model, + ) # if level 0 fails validation, validate level 1 if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0: @@ -607,164 +872,68 @@ def test_common_shapes( dprint("failed validation level 0, testing validation level 1") else: dprint("passed validation level 0, testing validation level 1") + _run_validation_level_1( + model_path, + batch_size, + seq_length, + max_new_tokens, + tokenizer, + cpu_model, + input_ids, + extra_kwargs, + aiu_model, + micro_model_path, + validation_zero_info, + record_property, + ) - # metric calculator based on the cross-entropy and mean diff for each decode step - def _metric_calculator(r: torch.Tensor, t: torch.Tensor): - cross_entropy = torch.nn.CrossEntropyLoss()( - r, t.softmax(dim=1).to(dtype=torch.float32) - ) - diff = torch.mean( - torch.abs( - r.softmax(dim=1).to(dtype=torch.float32) - - t.softmax(dim=1).to(dtype=torch.float32) - ) - ) - return (cross_entropy, diff) - - iters = int(CUMULATIVE_TEST_TOKENS_PER_SEQUENCE) // max_new_tokens - ce_fail_responses_list = [] - diff_fail_responses_list = [] - total_tokens = 0 - for i in range(iters): - # for iteration 0, we have computed the cpu validation info in the prior step for seed=0, so skip - if i != 0: - input_ids, extra_kwargs = __prepare_inputs( - batch_size, seq_length, tokenizer, seed=i - ) - extra_kwargs["attn_name"] = ATTN_NAME - if ( - "paged" in ATTN_NAME - and "ibm-granite/granite-3.3-8b-instruct" in model_path - and USE_DISTRIBUTED - and dist.get_world_size() == 4 - ): - extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT - - cpu_validation_info = __load_validation_info( - model_path, - batch_size, - seq_length, - max_new_tokens, - tokenizer, - i, - ATTN_NAME, - ) - if cpu_validation_info is None: - cpu_validation_info = extract_validation_information( - validation_model, - input_ids, - max_new_tokens, - LogitsExtractorHook(), - attn_algorithm="math", - timing=TIMING, - **extra_kwargs, - ) - dprint( - f"cpu validation info extracted for validation level 1 - iter={i}" - ) - if save_validation_info_outputs: - cpu_validation_info.save( - get_validation_info_path( - validation_info_dir, - model_path, - batch_size, - seq_length, - max_new_tokens, - i, - ATTN_NAME, - dtype=CPU_DTYPE, - ) - ) - cpu_static_tokens = cpu_validation_info.get_info("tokens") - eos_indexes = __find_eos_index( - cpu_static_tokens, - tokenizer.eos_token_id, - seq_length, - max_new_tokens, - ) - - # generate aiu validation info - aiu_validation_info = extract_validation_information( - model, - input_ids, - max_new_tokens, - GoldenTokenHook(cpu_static_tokens), - last_n_tokens=64 if "paged" in ATTN_NAME else 1, - timing=TIMING, - **extra_kwargs, - ) - dprint(f"aiu validation info extracted for validation level 1 - iter={i}") - if save_validation_info_outputs: - aiu_validation_info.save( - get_validation_info_path( - validation_info_dir, - model_path, - batch_size, - seq_length, - max_new_tokens, - i, - ATTN_NAME, - device_type="aiu", - ) - ) - # capture all level 1 metrics - level_1_metrics = capture_level_1_metrics( - cpu_validation_info.get_info("logits"), - aiu_validation_info.get_info("logits"), - top_k_loss_calculator(20, _metric_calculator), - ) - # only consider those metrics captured prior to the eos - level_1_metrics = __filter_before_eos(level_1_metrics, eos_indexes) +@pytest.mark.parametrize( + "model_path,batch_size,seq_length,max_new_tokens", COMMON_SHAPES +) +def test_common_shapes( + model_path, + batch_size, + seq_length, + max_new_tokens, + persistent_model, + record_property, +): + torch.manual_seed(42) + torch.set_grad_enabled(False) + os.environ["COMPILATION_MODE"] = "offline_decoder" + micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) - # if we do not have real model weights, use a default_metrics_threshold - if USE_MICRO_MODELS and micro_model_path is None: - ce_threshold, diff_threshold = default_metrics_threshold - # if we have real weights, try and get the proper validation metrics threshold - else: - # if we have a micro model with real weights, but no real thresholds, default to the full model thresholds - if USE_MICRO_MODELS: - ce_threshold, diff_threshold = fail_thresholds.get( - (model_path, True), - fail_thresholds.get( - (model_path, False), default_metrics_threshold - ), - ) - else: - ce_threshold, diff_threshold = fail_thresholds.get( - (model_path, False), default_metrics_threshold - ) + dprint( + f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, attn_type={ATTN_TYPE}" + ) - # get all failed responses for each metric - ce_fail_responses = filter_failed_level_1_cases( - level_1_metrics, lambda m: m[0] >= ce_threshold - ) - diff_fail_responses = filter_failed_level_1_cases( - level_1_metrics, - lambda m: m[1] >= diff_threshold, - ) + # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured + gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path) + is_gptq = len(gptq_kwargs_aiu) != 0 + is_fp8 = "fp8" in ATTN_NAME + model_kwargs = _get_common_model_kwargs(is_gptq, model_path) - ce_fail_responses_list.extend(ce_fail_responses) - diff_fail_responses_list.extend(diff_fail_responses) - total_tokens += len(level_1_metrics) + # Get the AIU model w/ the persistent model fixture + model = persistent_model.get_or_create( + is_gptq, is_fp8, **gptq_kwargs_aiu, **model_kwargs + ) - # test the failure rates for across all tokens - diff_failure_rate = len(diff_fail_responses_list) / total_tokens - ce_failure_rate = len(ce_fail_responses_list) / total_tokens - dprint(f"mean diff failure rate: {diff_failure_rate}") - dprint(f"cross entropy loss failure rate: {ce_failure_rate}") - # Add failure rates to xml report - record_property("mean_diff_failure_rate", diff_failure_rate) - record_property("cross_entropy_loss_failure_rate", ce_failure_rate) - if "mean_diff" not in skip_assertions: - assert diff_failure_rate < failure_rate_threshold, ( - f"failure rate for mean diff was too high: {diff_failure_rate}" - ) - if "ce" not in skip_assertions: - assert ce_failure_rate < failure_rate_threshold, ( - f"failure rate for cross entropy loss was too high: {ce_failure_rate}" - ) + validation_model = _get_cpu_model( + is_gptq, + is_fp8, + micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None, + **gptq_kwargs_cpu, + **model_kwargs, + ) - print("passed validation level 1") - else: - print("passed validation level 0") + _run_cpu_aiu_validation_test( + model_path, + batch_size, + seq_length, + max_new_tokens, + validation_model, + model, + micro_model_path, + record_property, + ) diff --git a/tests/models/test_scripts.py b/tests/models/test_scripts.py index 79bd9952..21b47fb4 100644 --- a/tests/models/test_scripts.py +++ b/tests/models/test_scripts.py @@ -6,6 +6,7 @@ from pathlib import Path import itertools import math +from aiu_fms_testing_utils.utils.paged import get_programs_prompts, ProgramCriteria FMS_DIR = Path(__file__).parent AIU_FMS_DIR = os.path.join(FMS_DIR, "../../../aiu-fms-testing-utils/") @@ -291,28 +292,48 @@ def test_dpp_script( ) print(result_text) with open(os.environ["DT_PROG_CRITERIA_FILEPATH"], "r") as f: - program_criteria_list = json.load(f)["programs"] + program_criteria_json_list = json.load(f)["programs"] + program_criteria_list = [] + for i, d in enumerate(program_criteria_json_list): + program_criteria_list.append( + ProgramCriteria( + i, + d["max_batch"], + d["max_tkv"], + d["batch_granularity"], + d["tkv_granularity"], + ) + ) if programs is None: program_assertions = [i for i in range(len(program_criteria_list))] shape_assertions = [">=0", ">=0"] else: + program_map = get_programs_prompts( + program_criteria_list, + multiple=64, + max_batch_size=2, + max_tkv=512, + program_cycles=max_new_tokens, + ) programs_split = programs.split(":") program_ids_str = programs_split[0] shape_assertions = [ f">={_}" if _.isnumeric() else _ for _ in programs_split[1].split(",") ] - match_number = r"\d+" - valid_program_assertions = [ - f">={re.search(match_number, _).group()}" for _ in shape_assertions - ] - # need to add 1 for tkv as that is the first decode - program_assertions = [ - i - for i, p in enumerate(program_criteria_list) - if eval(f"p['max_batch']{valid_program_assertions[0]}") - and eval(f"p['max_tkv']{valid_program_assertions[1]}+1") - ] + + program_assertions = [] + for program_id_seq, shapes in program_map.items(): + if any( + ( + eval( + f"shape[0]{shape_assertions[0]} and shape[1]{shape_assertions[1]}" + ) + for shape in shapes + ) + ): + program_assertions.append(program_id_seq[0].program_id) + if program_ids_str == "?": program_assertions = program_assertions[:1] elif program_ids_str.isnumeric(): diff --git a/tests/resources/expectations/models.test_model_expectations.TestAIUDecoderModels.Llama-3.1-8B-Instruct.test_model_output b/tests/resources/expectations/models.test_model_expectations.TestAIUDecoderModels.Llama-3.1-8B-Instruct.test_model_output index 098709e3..bfbcd6b1 100644 --- a/tests/resources/expectations/models.test_model_expectations.TestAIUDecoderModels.Llama-3.1-8B-Instruct.test_model_output +++ b/tests/resources/expectations/models.test_model_expectations.TestAIUDecoderModels.Llama-3.1-8B-Instruct.test_model_output @@ -1 +1 @@ -9.625,9.625,9.6875,9.625,10.53125,37.375,8.65625,14.90625,1.03125,5.875,15.6875,6.0625,9.5,17.5625,37.0,10.34375,6.25,13.125,3.8125,9.21875,21.96875,14.28125,0.0,13.09375,7.6875,6.4375,19.09375,10.6875,23.9375,13.0,11.84375,46.4375,6.59375,0.0,13.0,23.125,16.34375,3.125,12.65625,6.03125,14.375,6.84375,14.9375,20.9375,5.625,37.0,4.875,3.25,7.40625,2.6875,18.9375,4.1875,13.5,8.4375,21.1875,13.21875,35.25,21.78125,8.3125,4.75,12.0625,3.90625,9.34375,4.25 \ No newline at end of file +0.18359375,0.18359375,0.181640625,0.189453125,0.2734375,0.544921875,0.607421875,0.365234375,0.30078125,0.25,0.078125,0.302734375,0.0,0.322265625,0.142578125,0.099609375,0.296875,0.28125,0.673828125,0.44921875,0.13671875,0.42578125,1.072265625,0.18359375,0.388671875,0.177734375,0.193359375,0.296875,0.484375,0.3515625,0.826171875,0.349609375,0.296875,0.720703125,0.634765625,0.607421875,0.14453125,0.29296875,0.154296875,0.287109375,0.482421875,0.2421875,0.48046875,0.203125,0.349609375,0.21484375,0.28515625,0.17578125,0.162109375,0.3203125,0.3125,0.54296875,0.287109375,0.361328125,0.390625,0.08984375,0.2109375,0.5,0.18359375,0.228515625,0.314453125,0.291015625,0.248046875,0.5078125 \ No newline at end of file diff --git a/tests/resources/expectations/models.test_model_expectations.TestAIUDecoderModels.Llama-3.1-8B-Instruct.test_model_weight_keys b/tests/resources/expectations/models.test_model_expectations.TestAIUDecoderModels.Llama-3.1-8B-Instruct.test_model_weight_keys index 6329cb98..3fcc470f 100644 --- a/tests/resources/expectations/models.test_model_expectations.TestAIUDecoderModels.Llama-3.1-8B-Instruct.test_model_weight_keys +++ b/tests/resources/expectations/models.test_model_expectations.TestAIUDecoderModels.Llama-3.1-8B-Instruct.test_model_weight_keys @@ -1 +1 @@ -dec_norm.weight,layers.0.attn.dense.weight,layers.0.attn.in_proj.key.weight,layers.0.attn.in_proj.query.weight,layers.0.attn.in_proj.value.weight,layers.0.ff_ln.weight,layers.0.ff_sub_layer.w1.weight,layers.0.ff_sub_layer.w2.weight,layers.0.ff_sub_layer.wg.weight,layers.0.ln.weight,layers.1.attn.dense.weight,layers.1.attn.in_proj.key.weight,layers.1.attn.in_proj.query.weight,layers.1.attn.in_proj.value.weight,layers.1.ff_ln.weight,layers.1.ff_sub_layer.w1.weight,layers.1.ff_sub_layer.w2.weight,layers.1.ff_sub_layer.wg.weight,layers.1.ln.weight,layers.2.attn.dense.weight,layers.2.attn.in_proj.key.weight,layers.2.attn.in_proj.query.weight,layers.2.attn.in_proj.value.weight,layers.2.ff_ln.weight,layers.2.ff_sub_layer.w1.weight,layers.2.ff_sub_layer.w2.weight,layers.2.ff_sub_layer.wg.weight,layers.2.ln.weight,shared.emb.weight,shared.head.weight \ No newline at end of file +base_model.dec_norm.weight,base_model.embedding.weight,base_model.layers.0.attn.dense.weight,base_model.layers.0.attn.in_proj.key.weight,base_model.layers.0.attn.in_proj.query.weight,base_model.layers.0.attn.in_proj.value.weight,base_model.layers.0.ff_ln.weight,base_model.layers.0.ff_sub_layer.w1.weight,base_model.layers.0.ff_sub_layer.w2.weight,base_model.layers.0.ff_sub_layer.wg.weight,base_model.layers.0.ln.weight,base_model.layers.1.attn.dense.weight,base_model.layers.1.attn.in_proj.key.weight,base_model.layers.1.attn.in_proj.query.weight,base_model.layers.1.attn.in_proj.value.weight,base_model.layers.1.ff_ln.weight,base_model.layers.1.ff_sub_layer.w1.weight,base_model.layers.1.ff_sub_layer.w2.weight,base_model.layers.1.ff_sub_layer.wg.weight,base_model.layers.1.ln.weight,base_model.layers.2.attn.dense.weight,base_model.layers.2.attn.in_proj.key.weight,base_model.layers.2.attn.in_proj.query.weight,base_model.layers.2.attn.in_proj.value.weight,base_model.layers.2.ff_ln.weight,base_model.layers.2.ff_sub_layer.w1.weight,base_model.layers.2.ff_sub_layer.w2.weight,base_model.layers.2.ff_sub_layer.wg.weight,base_model.layers.2.ln.weight,head.weight \ No newline at end of file diff --git a/tests/testing/test_validation.py b/tests/testing/test_validation.py index ac3367ae..95f2ff4e 100644 --- a/tests/testing/test_validation.py +++ b/tests/testing/test_validation.py @@ -8,7 +8,14 @@ get_validation_info_path, find_validation_info_path, __decrement_version, + get_default_validation_prefix, ) +import hashlib +import os +from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string +from aiu_fms_testing_utils.utils import sample_sharegpt_requests +from transformers import AutoTokenizer + from aiu_fms_testing_utils._version import version_tuple from fms.models import get_model from fms.utils.generation import pad_input_ids @@ -73,12 +80,21 @@ def test_validation_info_round_trip(validation_type, post_iteration_hook): def test_get_validation_info_path(tmp_path): + check_pathname = "attn-type-sdpa_batch-size-4_dtype-fp16_max-new-tokens-128_model-id-ibm-granite--granite-3.3-8b-instruct_seq-length-64" + hash_object = hashlib.sha256(check_pathname.encode("utf-8")) + hex_digest = hash_object.hexdigest() + assert ( get_validation_info_path( tmp_path, "ibm-granite/granite-3.3-8b-instruct", 4, 64, 128, 0, "sdpa" ) - == f"{tmp_path}/ibm-granite--granite-3.3-8b-instruct_max-new-tokens-128_batch-size-4_seq-length-64_dtype-fp16_attn-type-sdpa.{'.'.join([str(_) for _ in version_tuple[:3]])}.cpu_validation_info.0.out" + == f"{tmp_path}/{hex_digest}_{'.'.join([str(_) for _ in version_tuple[:3]])}.cpu_validation_info.0.out" ) + + check_pathname = "attn-type-sdpa_batch-size-4_dtype-fp16_max-new-tokens-128_model-id-ibm-granite--granite-3.3-8b-instruct_seq-length-64" + hash_object = hashlib.sha256(check_pathname.encode("utf-8")) + hex_digest = hash_object.hexdigest() + assert ( get_validation_info_path( tmp_path, @@ -90,7 +106,7 @@ def test_get_validation_info_path(tmp_path): "sdpa", aftu_version=(1, 2, 3), ) - == f"{tmp_path}/ibm-granite--granite-3.3-8b-instruct_max-new-tokens-128_batch-size-4_seq-length-64_dtype-fp16_attn-type-sdpa.1.2.3.cpu_validation_info.0.out" + == f"{tmp_path}/{hex_digest}_1.2.3.cpu_validation_info.0.out" ) @@ -238,3 +254,81 @@ def test_decrement_version(max_minor, max_patch, current_version): + patch + 1 ) + + +def test_format_kwargs_to_string(): + kwargs = { + "enforce_sizes": [1, 32, 4, 8], + "batch_size": 1, + "model_id": "granite-3.3-8b", + "seq_len": 64, + } + kwargs_str = format_kwargs_to_string(**kwargs) + assert ( + kwargs_str + == "batch-size-1_enforce-sizes-1,32,4,8_model-id-granite-3.3-8b_seq-len-64" + ) + + +DATASET_PATH = os.getenv( + "DATASET_PATH", "/mnt/home/models/ShareGPT_V3_unfiltered_cleaned_split.json" +) +TOKENIZER = os.getenv("TOKENIZER", "ibm-granite/granite-3.3-8b-Instruct") + + +@pytest.mark.parametrize( + "model_variant,max_new_tokens,batch_size,seq_length,dtype,attn_type,device_type,seed,aftu_version", + [("granite-3.3-8b", 64, 2, 64, "fp16", "spda", "cpu", 0, (1, 2, 3))], +) +def test_get_default_validation_prefix( + model_variant, + max_new_tokens, + batch_size, + seq_length, + dtype, + attn_type, + device_type, + seed, + aftu_version, +): + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) + + sample_key = None + # get_default_validation_prefix with sample_key set to None + check_prefix_sample_key_none = f"attn-type-{attn_type}_batch-size-{batch_size}_dtype-{dtype}_max-new-tokens-{max_new_tokens}_model-id-{model_variant}_seq-length-{seq_length}" + hash_object = hashlib.sha256(check_prefix_sample_key_none.encode("utf-8")) + hex_digest = hash_object.hexdigest() + prefix_sample_key_none = f"{get_default_validation_prefix(model_id=model_variant, max_new_tokens=max_new_tokens, batch_size=batch_size, seq_length=seq_length, dtype=dtype, attn_type=attn_type, aftu_version='.'.join([str(_) for _ in aftu_version[:3]]), sample_key=sample_key)}.{device_type}_validation_info.{seed}.out" + + assert prefix_sample_key_none == f"{hex_digest}_1.2.3.cpu_validation_info.0.out" + + # get_default_validation_prefix with no kwargs using legacy case + legacy_prefix = f"{get_default_validation_prefix(model_id=model_variant, max_new_tokens=max_new_tokens, batch_size=batch_size, seq_length=seq_length, dtype=dtype, attn_type=attn_type, aftu_version='.'.join([str(_) for _ in aftu_version[:3]]))}.{device_type}_validation_info.{seed}.out" + assert prefix_sample_key_none == legacy_prefix + + # retrieve a sample_key with return_key is True + dataset_1, sample_key = sample_sharegpt_requests( + DATASET_PATH, + batch_size, + tokenizer, + 32, + seq_length * 2, + seed=seed, + enforce_sizes=[], + return_key=True, + ) + + # Check sample key sorted by parameter name + assert sample_key.split("_") == sorted(sample_key.split("_")) + + dataset_2 = sample_sharegpt_requests( + DATASET_PATH, + batch_size, + tokenizer, + 32, + seq_length * 2, + seed=seed, + enforce_sizes=[], + ) + + assert dataset_1 == dataset_2