-
Notifications
You must be signed in to change notification settings - Fork 1
Description
Hi, great work! I think Zamba2 is a great hybrid SSM model for the whole community.
I'm trying to reproduce the time-to-first-token result presented in your tech report. But I found that I couldn't reproduce the TTFT of Zamba2 reported in the paper. On 1.2B and 2.7B models, I tried setting the input prompt length to 2048, the output token number to 1, and batch size is 1.
However, the TTFT of Zamba2 is higher than attention-based models, such as Phi2-2.7B, Qwen2-1.5B, Qwen2.5-3B.
For example, Zamba2-2.7B vs. Phi2-2.7B 150ms vs. 90ms, Zamba2-1.2B vs. Qwen2-1.5B 94ms vs. 81ms, I use 1 A100 (40G)
For Zamba2, I used mamba-ssm and casual-conv1d to speed up the inference, and for attention-based LLM I used flash attention 2.
The following are my machine envs
PRETTY_NAME="Ubuntu 22.04.2 LTS"
NAME="Ubuntu"
VERSION_ID="22.04"
VERSION="22.04.2 LTS (Jammy Jellyfish)"
VERSION_CODENAME=jammy
ID=ubuntu
ID_LIKE=debian
HOME_URL="https://www.ubuntu.com/"
SUPPORT_URL="https://help.ubuntu.com/"
BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/"
PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy"
UBUNTU_CODENAME=jammy
libnccl2.18.3-1+cuda12.2
python pkgs are
Package Version Editable project location
------------------------- -------------------- ----------------------------------------------------------
absl-py 1.4.0
accelerate 0.33.0
accelerator 2024.3.8.dev1
aiohttp 3.8.4
aiosignal 1.3.1
annotated-types 0.7.0
apex 0.1
argon2-cffi 21.3.0
argon2-cffi-bindings 21.2.0
asttokens 2.2.1
astunparse 1.6.3
async-timeout 4.0.2
attrs 23.1.0
audioread 3.0.0
backcall 0.2.0
beautifulsoup4 4.12.2
bitsandbytes 0.43.0
bleach 6.0.0
blis 0.7.10
bottle 0.12.25
cachetools 5.3.1
catalogue 2.0.9
causal-conv1d 1.3.0.post1
certifi 2023.7.22
cffi 1.15.1
charset-normalizer 3.2.0
click 8.1.5
cloudpickle 2.2.1
cmake 3.27.1
comm 0.1.4
confection 0.1.1
contourpy 1.1.0
cubinlinker 0.3.0+2.g7c3675e
cuda-python 12.1.0rc5+1.g994d8d0
cudf 23.6.0
cugraph 23.6.0
cugraph-dgl 23.6.0
cugraph-service-client 23.6.0
cugraph-service-server 23.6.0
cuml 23.6.0
cupy-cuda12x 12.1.0
cycler 0.11.0
cymem 2.0.7
Cython 3.0.0
dask 2023.3.2
dask-cuda 23.6.0
dask-cudf 23.6.0
datasets 2.18.0
debugpy 1.8.7
decorator 5.1.1
deepspeed 0.14.4
defusedxml 0.7.1
dill 0.3.8
distributed 2023.3.2.1
dm-tree 0.1.8
docker-pycreds 0.4.0
einops 0.6.1
exceptiongroup 1.1.2
execnet 2.0.2
executing 1.2.0
expecttest 0.1.3
fastjsonschema 2.18.0
fastrlock 0.8.1
fbgemm-gpu 0.6.0
filelock 3.12.2
flash-attn 2.6.3
fonttools 4.42.0
frozenlist 1.4.0
fsspec 2023.6.0
gast 0.5.4
gin-config 0.5.0
gitdb 4.0.11
GitPython 3.1.43
google-auth 2.22.0
google-auth-oauthlib 0.4.6
graphsurgeon 0.4.6
grpcio 1.56.2
hjson 3.1.0
huggingface-hub 0.26.5
hypothesis 5.35.1
idna 3.4
importlib-metadata 6.8.0
iniconfig 2.0.0
intel-openmp 2021.4.0
iopath 0.1.10
ipykernel 6.25.0
ipython 8.14.0
ipython-genutils 0.2.0
jedi 0.19.0
Jinja2 3.1.2
joblib 1.3.1
json5 0.9.14
jsonlines 4.0.0
jsonschema 4.18.6
jsonschema-specifications 2023.7.1
jupyter_client 8.3.0
jupyter_core 5.3.1
jupyter-tensorboard 0.2.0
jupyterlab 2.3.2
jupyterlab-pygments 0.2.2
jupyterlab-server 1.2.0
jupytext 1.15.0
kiwisolver 1.4.4
langcodes 3.3.0
librosa 0.9.2
llvmlite 0.40.1
locket 1.0.0
mamba-ssm 2.1.0
Markdown 3.4.4
markdown-it-py 3.0.0
MarkupSafe 2.1.3
matplotlib 3.7.2
matplotlib-inline 0.1.6
mdit-py-plugins 0.4.0
mdurl 0.1.2
mistune 3.0.1
mkl 2021.1.1
mkl-devel 2021.1.1
mkl-include 2021.1.1
mock 5.1.0
modelscope 1.22.0
mpmath 1.3.0
msgpack 1.0.5
multidict 6.0.4
multiprocess 0.70.16
murmurhash 1.0.9
nbclient 0.8.0
nbconvert 7.7.3
nbformat 5.9.2
nest-asyncio 1.5.7
networkx 2.6.3
ninja 1.11.1
notebook 6.4.10
numba 0.57.1+1.gc785c8f1f
numpy 1.22.2
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-dali-cuda120 1.28.0
nvidia-ml-py 12.560.30
nvidia-nccl-cu12 2.20.5
nvidia-nvjitlink-cu12 12.6.68
nvidia-nvtx-cu12 12.1.105
nvidia-pyindex 1.0.9
nvtx 0.2.5
oauthlib 3.2.2
onnx 1.14.0
opencv 4.7.0
packaging 23.1
pandas 1.5.2
pandocfilters 1.5.0
parso 0.8.3
partd 1.4.0
pathy 0.10.2
peft 0.11.0
pexpect 4.8.0
pickleshare 0.7.5
Pillow 9.2.0
pip 23.2.1
platformdirs 3.10.0
pluggy 1.2.0
ply 3.11
polygraphy 0.47.1
pooch 1.7.0
portalocker 2.10.1
preshed 3.0.8
prettytable 3.8.0
prometheus-client 0.17.1
prompt-toolkit 3.0.39
protobuf 4.21.12
psutil 5.9.4
ptxcompiler 0.8.1+1.g4a94326
ptyprocess 0.7.0
pure-eval 0.2.2
py-cpuinfo 9.0.0
pyarrow 17.0.0
pyarrow-hotfix 0.6
pyasn1 0.5.0
pyasn1-modules 0.3.0
pybind11 2.11.1
pycocotools 2.0+nv0.7.3
pycparser 2.21
pydantic 2.9.2
pydantic_core 2.23.4
Pygments 2.16.1
pylibcugraph 23.6.0
pylibcugraphops 23.6.0
pylibraft 23.6.0
Pympler 1.1
pynvml 11.4.1
pyparsing 3.0.9
pytest 7.4.0
pytest-flakefinder 1.1.0
pytest-rerunfailures 12.0
pytest-shard 0.1.2
pytest-xdist 3.3.1
python-dateutil 2.8.2
python-hostlist 1.23.0
pytorch-quantization 2.1.2
pytz 2023.3
PyYAML 6.0.1
pyzmq 25.1.0
raft-dask 23.6.0
referencing 0.30.2
regex 2023.6.3
requests 2.31.0
requests-oauthlib 1.3.1
resampy 0.4.2
rmm 23.6.0
rpds-py 0.9.2
rsa 4.9
safetensors 0.4.5
scikit-learn 1.2.0
scipy 1.11.1
Send2Trash 1.8.2
sentencepiece 0.1.99
sentry-sdk 2.17.0
setproctitle 1.3.3
setuptools 68.0.0
six 1.16.0
smart-open 6.3.0
smmap 5.0.1
sortedcontainers 2.4.0
soundfile 0.12.1
soupsieve 2.4.1
spacy 3.6.0
spacy-legacy 3.0.12
spacy-loggers 1.0.4
sphinx-glpi-theme 0.3
srsly 2.4.7
stack-data 0.6.2
sympy 1.12
tabulate 0.9.0
tbb 2021.10.0
tblib 2.0.0
tensorboard 2.9.0
tensorboard-data-server 0.6.1
tensorboard-plugin-wit 1.8.1
tensorrt 8.6.1
terminado 0.17.1
thinc 8.1.10
threadpoolctl 3.2.0
thriftpy2 0.4.16
tinycss2 1.2.1
tokenizers 0.19.1
toml 0.10.2
tomli 2.0.1
toolz 0.12.0
torch 2.4.0
torchdata 0.7.1
torchvision 0.19.0
tornado 6.3.2
tqdm 4.65.0
traitlets 5.9.0
transformer_engine_cu12 1.12.0
transformers 4.43.0.dev0 /codes/transformers_zamba2
treelite 3.2.0
treelite-runtime 3.2.0
triton 3.0.0
typer 0.9.0
types-dataclasses 0.6.6
typing_extensions 4.12.2
ucx-py 0.32.0
uff 0.6.9
urllib3 1.26.16
waitress 3.0.2
wandb 0.18.5
wasabi 1.1.2
wcwidth 0.2.6
webencodings 0.5.1
Werkzeug 2.3.6
wheel 0.41.1
xdoctest 1.0.2
xgboost 1.7.5
xxhash 3.5.0
yarl 1.9.2
zict 3.0.0
zipp 3.16.2
My mamba-ssm and causal-conv1d use the same version as in setup.py. Below is my code to reproduce TTFT:
# Modified from https://github.com/state-spaces/mamba/blob/main/benchmarks/benchmark_generation_mamba_simple.py
import argparse
import time
import json
import torch
import torch.nn.functional as F
import sys
from einops import rearrange
from transformers import AutoTokenizer, AutoModelForCausalLM
parser = argparse.ArgumentParser(description="Generation benchmarking")
parser.add_argument("--model-name", type=str, default="Qwen2-1.5B")
parser.add_argument("--prompt", type=str, default=None)
parser.add_argument("--promptlen", type=int, default=2048)
parser.add_argument("--genlen", type=int, default=1)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--topk", type=int, default=0)
parser.add_argument("--topp", type=float, default=1.0)
parser.add_argument("--minp", type=float, default=0.0)
parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--batch", type=int, default=1)
# parser.add_argument("--output_to_files", type=bool, default=True)
args = parser.parse_args()
repeats = 3
device = "cuda"
dtype = torch.bfloat16
root = '../../model/'
cache_dir = './'
print(f"Loading model {args.model_name}")
tokenizer = AutoTokenizer.from_pretrained(root+args.model_name)
model = AutoModelForCausalLM.from_pretrained(root+args.model_name, device_map=device, torch_dtype=dtype, attn_implementation="flash_attention_2")
model.eval()
print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
torch.random.manual_seed(0)
if args.prompt is None:
input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda")
attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda")
else:
tokens = tokenizer(args.prompt, return_tensors="pt")
input_ids = tokens.input_ids.to(device=device)
attn_mask = tokens.attention_mask.to(device=device)
max_length = input_ids.shape[1] + args.genlen
print(input_ids.shape)
fn = lambda: model.generate(
input_ids=input_ids,
attention_mask=attn_mask,
max_new_tokens=args.genlen,
return_dict_in_generate=True,
pad_token_id=tokenizer.eos_token_id,
do_sample=False,
temperature=args.temperature,
top_k=args.topk,
top_p=args.topp,
repetition_penalty=args.repetition_penalty,
)
try:
out = fn()
except RuntimeError as e:
if "out of memory" in str(e).lower():
with open('ouput.txt', mode='a', encoding='utf-8') as f:
print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}\n", file=f)
print(f"{args.model_name} OOM!\n\n", file=f)
print(f"{args.model_name} OOM!")
sys.exit(1)
if args.prompt is not None:
print(tokenizer.batch_decode(out.sequences.tolist()))
torch.cuda.synchronize()
start = time.time()
for _ in range(repeats):
try:
fn()
except RuntimeError as e:
if "out of memory" in str(e).lower():
with open(f'{args.model_name}-speed.txt', mode='a', encoding='utf-8') as f:
print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}", file=f)
print(f"{args.model_name} OOM!\n\n", file=f)
print(f"{args.model_name} OOM!")
sys.exit(1)
torch.cuda.synchronize()
end=time.time()
print(f"Batch size: {args.batch}, Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}")
print(f"{args.model_name} prompt processing + decoding time: {(end - start) / repeats * 1000:.0f}ms\n\n")
with open(f'{args.model_name}-speed.txt', mode='a', encoding='utf-8') as f:
print(f"Batch size: {args.batch}, Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}", file=f)
print(f"{args.model_name} prompt processing + decoding time: {(end - start) / repeats * 1000:.0f}ms\n\n", file=f)I would like to know if my reproduction process is correct and how you calculated the TTFT results shown in the tech report.
Looking forward to your reply, thanks!