diff --git a/speculator/benchmark_speculator.py b/speculator/benchmark_speculator.py index 5658eec0..45d9b70e 100644 --- a/speculator/benchmark_speculator.py +++ b/speculator/benchmark_speculator.py @@ -100,7 +100,7 @@ parser.add_argument( "--no_flat", action="store_true", - help="Disable batch auto-flattening for handling candidate trees?" + help="Disable batch auto-flattening for handling candidate trees?", ) args = parser.parse_args() diff --git a/speculator/benchmark_speculator_logical.py b/speculator/benchmark_speculator_logical.py index 93cd561f..0f4eca6f 100644 --- a/speculator/benchmark_speculator_logical.py +++ b/speculator/benchmark_speculator_logical.py @@ -1,8 +1,10 @@ import argparse import itertools +import json import os import time +import fms_extras.models.paged_gpt_bigcode import fms_extras.models.paged_llama import torch import torch._inductor.config @@ -16,13 +18,21 @@ from fms_fsdp.utils.dataset_utils import Streaming_Doc_Dataset -# This example script validates the LLaMA implementation by running inference on a couple of prompts. -# torchrun --nproc_per_node=1 scripts/inference.py --variant=7b --model_path=~/models/7B-F --tokenizer=~/models/tokenizer.model --model_source=meta --speculator_path=~/models/speculator_7B_F.pth --compile +# This example script measures the logical speedup of running a speculator atop a base model. Run as: +# export CUDA_VISIBLE_DEVICES=1 +# e.g., #1: torchrun --nproc_per_node=1 benchmark_speculator_logical.py --architecture=paged_llama --variant=7b --model_path=~/models/7B-F --tokenizer=~/models/tokenizer.model --model_source=hf --speculator_path=~/models/speculator_7B_F.pth --compile --data_path="/path/to/training_dataset_tokens/" --subdata="webhose" +# e.g., #2: torchrun --nproc_per_node=1 benchmark_speculator_logical.py --architecture=paged_gpt_bigcode --variant=ibm.20b --model_path=~/models/granite-20b-instruct --tokenizer=~/models/granite-20b-instruct --model_source=hf --speculator_path=~/models/speculator_granite20B.pth --data_path="/path/to/training_dataset_tokens/" --subdata="github" --n_predict=4 --threshes=[6,4,3,3] parser = argparse.ArgumentParser( description="Script to run inference on a causal model" ) parser.add_argument("--device_type", type=str, default="cuda") +parser.add_argument( + "--architecture", + type=str, + default="paged_llama", + help="The model architecture to benchmark, e.g. 'paged_llama', 'paged_gpt_bigcode'", +) parser.add_argument( "--variant", type=str, @@ -106,11 +116,36 @@ parser.add_argument( "--no_flat", action="store_true", - help="Disable batch auto-flattening for handling candidate trees?" + help="Disable batch auto-flattening for handling candidate trees?", +) + +parser.add_argument( + "--seed", + type=int, + default=42, + help="Seed for torch and data loader", +) + +parser.add_argument( + "--n_predict", + type=int, + default=3, + help="Number of speculator heads / number of tokens to guess ahead", ) +parser.add_argument( + "--threshes", + type=json.loads, + default=[6, 4, 3], + help="number of top k predictions from each head to generate speculator candidate pool; should be same len as n_predict", +) + + args = parser.parse_args() +torch.cuda.manual_seed(args.seed) +torch.manual_seed(args.seed) + local_rank = int(os.getenv("LOCAL_RANK", 0)) world_size = int(os.getenv("WORLD_SIZE", 1)) if args.device_type == "cuda": @@ -138,7 +173,7 @@ distr_param = None model = get_model( - "paged_llama", + args.architecture, args.variant, model_path=args.model_path, checkpoint_sharding=args.checkpoint_sharding, @@ -157,7 +192,10 @@ if args.speculator_path is not None: print("loading speculator") speculator = MLPSpeculator( - model.config.emb_dim, 4096, model.config.src_vocab_size, n_predict=3 + model.config.emb_dim, + 4096, + model.config.src_vocab_size, + n_predict=args.n_predict, ) speculator.load_state_dict( torch.load(args.speculator_path, map_location=device)["model_state"] @@ -171,11 +209,16 @@ use_cache = True +if hasattr(model.config, "kvheads"): + kv_heads = model.config.kvheads +else: + kv_heads = 1 if model.config.multiquery_attn else model.config.nheads + kv_cache_manager = PagedKVCacheManager( model.config.nlayers, model.config.nheads, model.config.emb_dim, - kv_heads=model.config.kvheads, + kv_heads=kv_heads, tensor_parallel_size=dist.get_world_size() if args.distributed else 1, dtype=torch.get_default_dtype(), device=device, @@ -192,6 +235,7 @@ datasets=[ args.subdata, ], + seed=args.seed, min_length=2148, max_chunksize=8192, ) @@ -210,12 +254,12 @@ data = torch.IntTensor(data).to(device) -# def ids_for_prompt(prompt): -# tokens = tokenizer.tokenize(prompt) -# tokens = [""] + tokens -# ids = tokenizer.convert_tokens_to_ids(tokens) -# ids = torch.tensor(ids, dtype=torch.long, device=device) -# return ids +def ids_for_prompt(prompt): + tokens = tokenizer.tokenize(prompt) + tokens = [""] + tokens + ids = tokenizer.convert_tokens_to_ids(tokens) + ids = torch.tensor(ids, dtype=torch.long, device=device) + return ids def print_result(result, inp, n_steps): @@ -232,10 +276,15 @@ def print_result(result, inp, n_steps): print() -def infer(ids, k, warmup, model, decode_model, speculator, flatting): +def infer(ids, k, warmup, model, decode_model, speculator): # With greedy generation (do_sample=False) we _should_ always get the same results. # There is currently a bug in start_pos for batched rotary embeddings that can lead # varying results for the same prompt. + max_seq_len = ( + model.config.max_expected_seq_len + if hasattr(model.config, "max_expected_seq_len") + else model.config.max_pos + ) if k != 0: result, n_steps, generated_token_time_out = speculative_generate( @@ -244,11 +293,11 @@ def infer(ids, k, warmup, model, decode_model, speculator, flatting): speculator, kv_cache_manager, new_tokens=100, - max_seq_len=model.config.max_expected_seq_len, + max_seq_len=max_seq_len, decode_model=decode_model, top_k=k, - threshes=[6, 4, 3], - flatting=flatting, + threshes=args.threshes, + flatting=not args.no_flat, ) else: result, n_steps, generated_token_time_out = paged_generate( @@ -256,7 +305,7 @@ def infer(ids, k, warmup, model, decode_model, speculator, flatting): ids, kv_cache_manager, max_new_tokens=100, - max_seq_len=model.config.max_expected_seq_len, + max_seq_len=max_seq_len, do_sample=False, decode_model=decode_model, ) @@ -269,6 +318,7 @@ def infer(ids, k, warmup, model, decode_model, speculator, flatting): return generated_token_time_out / avg_tokens, avg_tokens / n_steps return None + torch._dynamo.config.cache_size_limit = 64 torch.cuda.empty_cache()