Skip to content
2 changes: 1 addition & 1 deletion speculator/benchmark_speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
parser.add_argument(
"--no_flat",
action="store_true",
help="Disable batch auto-flattening for handling candidate trees?"
help="Disable batch auto-flattening for handling candidate trees?",
)

args = parser.parse_args()
Expand Down
84 changes: 67 additions & 17 deletions speculator/benchmark_speculator_logical.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import argparse
import itertools
import json
import os
import time

import fms_extras.models.paged_gpt_bigcode
import fms_extras.models.paged_llama
import torch
import torch._inductor.config
Expand All @@ -16,13 +18,21 @@
from fms_fsdp.utils.dataset_utils import Streaming_Doc_Dataset


# This example script validates the LLaMA implementation by running inference on a couple of prompts.
# torchrun --nproc_per_node=1 scripts/inference.py --variant=7b --model_path=~/models/7B-F --tokenizer=~/models/tokenizer.model --model_source=meta --speculator_path=~/models/speculator_7B_F.pth --compile
# This example script measures the logical speedup of running a speculator atop a base model. Run as:
# export CUDA_VISIBLE_DEVICES=1
# e.g., #1: torchrun --nproc_per_node=1 benchmark_speculator_logical.py --architecture=paged_llama --variant=7b --model_path=~/models/7B-F --tokenizer=~/models/tokenizer.model --model_source=hf --speculator_path=~/models/speculator_7B_F.pth --compile --data_path="/path/to/training_dataset_tokens/" --subdata="webhose"
# e.g., #2: torchrun --nproc_per_node=1 benchmark_speculator_logical.py --architecture=paged_gpt_bigcode --variant=ibm.20b --model_path=~/models/granite-20b-instruct --tokenizer=~/models/granite-20b-instruct --model_source=hf --speculator_path=~/models/speculator_granite20B.pth --data_path="/path/to/training_dataset_tokens/" --subdata="github" --n_predict=4 --threshes=[6,4,3,3]

parser = argparse.ArgumentParser(
description="Script to run inference on a causal model"
)
parser.add_argument("--device_type", type=str, default="cuda")
parser.add_argument(
"--architecture",
type=str,
default="paged_llama",
help="The model architecture to benchmark, e.g. 'paged_llama', 'paged_gpt_bigcode'",
)
parser.add_argument(
"--variant",
type=str,
Expand Down Expand Up @@ -106,11 +116,36 @@
parser.add_argument(
"--no_flat",
action="store_true",
help="Disable batch auto-flattening for handling candidate trees?"
help="Disable batch auto-flattening for handling candidate trees?",
)

parser.add_argument(
"--seed",
type=int,
default=42,
help="Seed for torch and data loader",
)

parser.add_argument(
"--n_predict",
type=int,
default=3,
help="Number of speculator heads / number of tokens to guess ahead",
)

parser.add_argument(
"--threshes",
type=json.loads,
default=[6, 4, 3],
help="number of top k predictions from each head to generate speculator candidate pool; should be same len as n_predict",
)


args = parser.parse_args()

torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)

local_rank = int(os.getenv("LOCAL_RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
if args.device_type == "cuda":
Expand Down Expand Up @@ -138,7 +173,7 @@
distr_param = None

model = get_model(
"paged_llama",
args.architecture,
args.variant,
model_path=args.model_path,
checkpoint_sharding=args.checkpoint_sharding,
Expand All @@ -157,7 +192,10 @@
if args.speculator_path is not None:
print("loading speculator")
speculator = MLPSpeculator(
model.config.emb_dim, 4096, model.config.src_vocab_size, n_predict=3
model.config.emb_dim,
4096,
model.config.src_vocab_size,
n_predict=args.n_predict,
)
speculator.load_state_dict(
torch.load(args.speculator_path, map_location=device)["model_state"]
Expand All @@ -171,11 +209,16 @@


use_cache = True
if hasattr(model.config, "kvheads"):
kv_heads = model.config.kvheads
else:
kv_heads = 1 if model.config.multiquery_attn else model.config.nheads

kv_cache_manager = PagedKVCacheManager(
model.config.nlayers,
model.config.nheads,
model.config.emb_dim,
kv_heads=model.config.kvheads,
kv_heads=kv_heads,
tensor_parallel_size=dist.get_world_size() if args.distributed else 1,
dtype=torch.get_default_dtype(),
device=device,
Expand All @@ -192,6 +235,7 @@
datasets=[
args.subdata,
],
seed=args.seed,
min_length=2148,
max_chunksize=8192,
)
Expand All @@ -210,12 +254,12 @@
data = torch.IntTensor(data).to(device)


# def ids_for_prompt(prompt):
# tokens = tokenizer.tokenize(prompt)
# tokens = ["<s>"] + tokens
# ids = tokenizer.convert_tokens_to_ids(tokens)
# ids = torch.tensor(ids, dtype=torch.long, device=device)
# return ids
def ids_for_prompt(prompt):
tokens = tokenizer.tokenize(prompt)
tokens = ["<s>"] + tokens
ids = tokenizer.convert_tokens_to_ids(tokens)
ids = torch.tensor(ids, dtype=torch.long, device=device)
return ids


def print_result(result, inp, n_steps):
Expand All @@ -232,10 +276,15 @@ def print_result(result, inp, n_steps):
print()


def infer(ids, k, warmup, model, decode_model, speculator, flatting):
def infer(ids, k, warmup, model, decode_model, speculator):
# With greedy generation (do_sample=False) we _should_ always get the same results.
# There is currently a bug in start_pos for batched rotary embeddings that can lead
# varying results for the same prompt.
max_seq_len = (
model.config.max_expected_seq_len
if hasattr(model.config, "max_expected_seq_len")
else model.config.max_pos
)

if k != 0:
result, n_steps, generated_token_time_out = speculative_generate(
Expand All @@ -244,19 +293,19 @@ def infer(ids, k, warmup, model, decode_model, speculator, flatting):
speculator,
kv_cache_manager,
new_tokens=100,
max_seq_len=model.config.max_expected_seq_len,
max_seq_len=max_seq_len,
decode_model=decode_model,
top_k=k,
threshes=[6, 4, 3],
flatting=flatting,
threshes=args.threshes,
flatting=not args.no_flat,
)
else:
result, n_steps, generated_token_time_out = paged_generate(
model,
ids,
kv_cache_manager,
max_new_tokens=100,
max_seq_len=model.config.max_expected_seq_len,
max_seq_len=max_seq_len,
do_sample=False,
decode_model=decode_model,
)
Expand All @@ -269,6 +318,7 @@ def infer(ids, k, warmup, model, decode_model, speculator, flatting):
return generated_token_time_out / avg_tokens, avg_tokens / n_steps
return None


torch._dynamo.config.cache_size_limit = 64

torch.cuda.empty_cache()
Expand Down