Skip to content
78 changes: 59 additions & 19 deletions speculator/benchmark_speculator_logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import itertools
import os
import time
import torch
import json

import fms_extras.models.paged_llama
import torch
import fms_extras.models.paged_gpt_bigcode
import torch._inductor.config
from fms.models import get_model
from fms.utils import generation, tokenizers
Expand All @@ -16,13 +18,21 @@
from fms_fsdp.utils.dataset_utils import Streaming_Doc_Dataset


# This example script validates the LLaMA implementation by running inference on a couple of prompts.
# torchrun --nproc_per_node=1 scripts/inference.py --variant=7b --model_path=~/models/7B-F --tokenizer=~/models/tokenizer.model --model_source=meta --speculator_path=~/models/speculator_7B_F.pth --compile
# This example script measures the logical speedup of running a speculator atop a base model. Run as:
# export CUDA_VISIBLE_DEVICES=1
# e.g., #1: torchrun --nproc_per_node=1 benchmark_speculator_logical.py --architecture=paged_llama --variant=7b --model_path=~/models/7B-F --tokenizer=~/models/tokenizer.model --model_source=hf --speculator_path=~/models/speculator_7B_F.pth --compile --data_path="/path/to/training_dataset_tokens/" --subdata="webhose"
# e.g., #2: torchrun --nproc_per_node=1 benchmark_speculator_logical.py --architecture=paged_gpt_bigcode --variant=ibm.20b --model_path=~/models/granite-20b-instruct --tokenizer=~/models/granite-20b-instruct --model_source=hf --speculator_path=~/models/speculator_granite20B.pth --data_path="/path/to/training_dataset_tokens/" --subdata="github" --n_predict=4 --threshes=[6,4,3,3]

parser = argparse.ArgumentParser(
description="Script to run inference on a causal model"
)
parser.add_argument("--device_type", type=str, default="cuda")
parser.add_argument(
"--architecture",
type=str,
default="paged_llama",
help="The model architecture to benchmark, e.g. 'paged_llama', 'paged_gpt_bigcode'",
)
parser.add_argument(
"--variant",
type=str,
Expand Down Expand Up @@ -109,8 +119,33 @@
help="Disable batch auto-flattening for handling candidate trees?"
)

parser.add_argument(
"--seed",
type=int,
default=42,
help="Seed for torch and data loader",
)

parser.add_argument(
"--n_predict",
type=int,
default=3,
help="Number of speculator heads / number of tokens to guess ahead",
)

parser.add_argument(
"--threshes",
type=json.loads,
default=[6,4,3],
help="number of top k predictions from each head to generate speculator candidate pool; should be same len as n_predict"
)


args = parser.parse_args()

torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)

local_rank = int(os.getenv("LOCAL_RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
if args.device_type == "cuda":
Expand Down Expand Up @@ -138,7 +173,7 @@
distr_param = None

model = get_model(
"paged_llama",
args.architecture,
args.variant,
model_path=args.model_path,
checkpoint_sharding=args.checkpoint_sharding,
Expand All @@ -157,7 +192,7 @@
if args.speculator_path is not None:
print("loading speculator")
speculator = MLPSpeculator(
model.config.emb_dim, 4096, model.config.src_vocab_size, n_predict=3
model.config.emb_dim, 4096, model.config.src_vocab_size, n_predict=args.n_predict
)
speculator.load_state_dict(
torch.load(args.speculator_path, map_location=device)["model_state"]
Expand All @@ -169,13 +204,17 @@
# cache setup
from fms_extras.utils.cache.paged import PagedKVCacheManager


use_cache = True
if hasattr(model.config, "kvheads"):
kv_heads = model.config.kvheads
else:
kv_heads = 1 if model.config.multiquery_attn else model.config.nheads

kv_cache_manager = PagedKVCacheManager(
model.config.nlayers,
model.config.nheads,
model.config.emb_dim,
kv_heads=model.config.kvheads,
kv_heads=kv_heads,
tensor_parallel_size=dist.get_world_size() if args.distributed else 1,
dtype=torch.get_default_dtype(),
device=device,
Expand All @@ -192,6 +231,7 @@
datasets=[
args.subdata,
],
seed=args.seed,
min_length=2148,
max_chunksize=8192,
)
Expand All @@ -209,13 +249,12 @@
in_middle = True
data = torch.IntTensor(data).to(device)


# def ids_for_prompt(prompt):
# tokens = tokenizer.tokenize(prompt)
# tokens = ["<s>"] + tokens
# ids = tokenizer.convert_tokens_to_ids(tokens)
# ids = torch.tensor(ids, dtype=torch.long, device=device)
# return ids
def ids_for_prompt(prompt):
tokens = tokenizer.tokenize(prompt)
tokens = ["<s>"] + tokens
ids = tokenizer.convert_tokens_to_ids(tokens)
ids = torch.tensor(ids, dtype=torch.long, device=device)
return ids


def print_result(result, inp, n_steps):
Expand All @@ -232,10 +271,11 @@ def print_result(result, inp, n_steps):
print()


def infer(ids, k, warmup, model, decode_model, speculator, flatting):
def infer(ids, k, warmup, model, decode_model, speculator):
# With greedy generation (do_sample=False) we _should_ always get the same results.
# There is currently a bug in start_pos for batched rotary embeddings that can lead
# varying results for the same prompt.
max_seq_len = model.config.max_expected_seq_len if hasattr(model.config, "max_expected_seq_len") else model.config.max_pos

if k != 0:
result, n_steps, generated_token_time_out = speculative_generate(
Expand All @@ -244,19 +284,19 @@ def infer(ids, k, warmup, model, decode_model, speculator, flatting):
speculator,
kv_cache_manager,
new_tokens=100,
max_seq_len=model.config.max_expected_seq_len,
max_seq_len=max_seq_len,
decode_model=decode_model,
top_k=k,
threshes=[6, 4, 3],
flatting=flatting,
threshes=args.threshes,
flatting=not args.no_flat,
)
else:
result, n_steps, generated_token_time_out = paged_generate(
model,
ids,
kv_cache_manager,
max_new_tokens=100,
max_seq_len=model.config.max_expected_seq_len,
max_seq_len=max_seq_len,
do_sample=False,
decode_model=decode_model,
)
Expand Down