diff --git a/benchmark/text-generation/README.md b/benchmark/text-generation/README.md new file mode 100644 index 000000000..83b0a846d --- /dev/null +++ b/benchmark/text-generation/README.md @@ -0,0 +1,12 @@ +# Usage + +```shell +python llama2.py +``` +This will produce several JSON files. + +```shell +python gen_barchcharts.py +``` + +This will create three barchart images for encoding times, latency and throughput. diff --git a/benchmark/text-generation/benchmark.py b/benchmark/text-generation/benchmark.py new file mode 100644 index 000000000..39f1c909a --- /dev/null +++ b/benchmark/text-generation/benchmark.py @@ -0,0 +1,92 @@ +import argparse +import json +import os +import time + +import torch +from transformers import AutoConfig, AutoTokenizer, set_seed + +from optimum.neuron import NeuronModelForCausalLM + + +def generate(model, input_ids, length): + start = time.time() + with torch.inference_mode(): + output_tokens = model.generate(input_ids, do_sample=False, min_length=length, max_length=length) + end = time.time() + return output_tokens, (end - start) + + +def run(model_id, inc_length, max_length, json_path=None): + prompts = ["One of my fondest memory"] + config = AutoConfig.from_pretrained(model_id) + batch_size = config.neuron["batch_size"] + if len(prompts) < batch_size: + prompts = prompts + [prompts[-1]] * (batch_size - len(prompts)) + model = NeuronModelForCausalLM.from_pretrained(model_id, export=False, low_cpu_mem_usage=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + # Specify padding options for decoder-only architecture + tokenizer.pad_token_id = tokenizer.eos_token_id + tokenizer.padding_side = "left" + # Encode the first input tokens + tokens = tokenizer(prompts, return_tensors="pt", padding=True) + bootstrap_input_ids = tokens.input_ids + # Generate the first set of inputs + input_ids, latency = generate(model, bootstrap_input_ids, inc_length) + input_length = input_ids.size()[-1] + neuron_config = getattr(model.config, "neuron") + benchmark = {"neuron_config": neuron_config, "results": []} + while input_length < max_length: + # Generate a single input, just to evaluate the context encoding time + _, encoding_time = generate(model, input_ids, input_length + 1) + result = { + "input_length": input_length, + "batch_size": batch_size, + "encoding_time": encoding_time, + "generations": [], + } + for sequence_length in range(input_length + inc_length, max_length + 1, inc_length): + output_ids, latency = generate(model, input_ids, sequence_length) + throughput = batch_size * sequence_length / latency + result["generations"].append( + { + "sequence_length": sequence_length, + "new_tokens": sequence_length - input_length, + "latency": latency, + "generation_time": latency - encoding_time, + "throughput": throughput, + } + ) + # Reuse the first generated tokens for the next step + input_length += inc_length + input_ids = output_ids[:, :input_length] + benchmark["results"].append(result) + if json_path is not None: + with open(json_path, "w") as fp: + json.dump(benchmark, fp, indent=4) + return benchmark + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("model", type=str, help="A neuron model in a local directory.") + parser.add_argument("--inc-length", type=int, default=128, help="The number of tokens in each increment.") + parser.add_argument("--max-length", type=int, default=2048, help="The maximum number of generated tokens.") + parser.add_argument("--seed", type=int, default=None, help="Pass a seed for reproducibility.") + args = parser.parse_args() + if args.seed is not None: + set_seed(args.seed) + model_name = os.path.basename(os.path.normpath(args.model)) + benchmark = run(args.model, args.inc_length, args.max_length, json_path=f"{model_name}.json") + # Dump encoding times + results = benchmark["results"] + print(f"{benchmark['neuron_config']}") + print("Encoding times") + print([result["input_length"] for result in results]) + print([f"{result['encoding_time']:.2f}" for result in results]) + # Just look at the first set of generations + generations = results[0]["generations"] + print(f"Latency and throughput for {args.inc_length} input tokens") + print([generation["new_tokens"] for generation in generations]) + print([f"{generation['latency']:.2f}" for generation in generations]) + print([f"{generation['throughput']:.2f}" for generation in generations]) diff --git a/benchmark/text-generation/gen_barcharts.py b/benchmark/text-generation/gen_barcharts.py new file mode 100644 index 000000000..c694b6999 --- /dev/null +++ b/benchmark/text-generation/gen_barcharts.py @@ -0,0 +1,98 @@ +import argparse +import glob +import json +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + + +def save_bar_chart(title, labels, ylabel, series, save_path): + x = np.arange(len(labels)) # the label locations + width = 0.15 # the width of the bars + multiplier = 0 + + fig, ax = plt.subplots(layout="constrained") + fig.set_figwidth(10) + + max_value = 0 + + for attribute, measurement in series.items(): + max_value = max(max_value, max(measurement)) + offset = width * multiplier + rects = ax.bar(x + offset, measurement, width, label=attribute) + ax.bar_label(rects, padding=5) + multiplier += 1 + + # Add some text for labels, title and custom x-axis tick labels, etc. + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.set_xticks(x + width, labels) + ax.legend(loc="upper left", ncols=3) + ax.set_ylim(0, max_value * 1.2) + + plt.savefig(save_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("inputs", type=str, nargs="*", help="A list of benchmark results files (.json).") + args = parser.parse_args() + inputs = args.inputs + if len(inputs) == 0: + inputs = glob.glob("*.json") + benchmarks = {} + for input in inputs: + model_name = Path(input).stem + with open(input) as f: + benchmarks[model_name] = json.load(f) + model_names = benchmarks.keys() + # Generate encoding barchart + input_length = [] + encoding_times = {} + for name in model_names: + results = benchmarks[name]["results"] + cur_input_length = [result["input_length"] for result in results] + if len(input_length) == 0: + input_length = cur_input_length + else: + assert cur_input_length == input_length, f"{name} does not have the same number of results" + encoding_times[name] = [round(result["encoding_time"], 1) for result in results] + save_bar_chart( + title="Encoding time per input token", + labels=input_length, + series=encoding_times, + ylabel="Encoding time (s)", + save_path="encoding_times.png", + ) + # Generate latency and throughput barcharts (for the first input length only) + new_tokens = [] + latencies = {} + throughputs = {} + for name in model_names: + generations = benchmarks[name]["results"][0]["generations"] + cur_new_tokens = [generation["new_tokens"] for generation in generations] + if len(new_tokens) == 0: + new_tokens = cur_new_tokens + else: + assert cur_new_tokens == new_tokens, f"{name} does not have the same number of results" + latencies[name] = [round(generation["latency"], 1) for generation in generations] + throughputs[name] = [round(generation["throughput"], 0) for generation in generations] + save_bar_chart( + title="End-to-end latency per generated tokens for 256 input tokens", + labels=new_tokens, + series=latencies, + ylabel="Latency (s)", + save_path="latencies.png", + ) + save_bar_chart( + title="Throughput per generated tokens for 256 input tokens", + labels=new_tokens, + series=throughputs, + ylabel="Throughput (tokens/s)", + save_path="throughputs.png", + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/text-generation/llama2.py b/benchmark/text-generation/llama2.py new file mode 100644 index 000000000..737e8a5cf --- /dev/null +++ b/benchmark/text-generation/llama2.py @@ -0,0 +1,33 @@ +import os +from tempfile import TemporaryDirectory + +from transformers import AutoTokenizer + +from benchmark import run +from optimum.neuron import NeuronModelForCausalLM + + +model_configurations = { + "Llama-2-7BL": ["meta-llama/Llama-2-7b-chat-hf", 1, 2048], + "Llama-2-7BT": ["meta-llama/Llama-2-7b-chat-hf", 4, 2048], +} + +num_cores = len(os.listdir("/sys/class/neuron_device/")) * 2 +if num_cores >= 4: + extra_model_configurations = { + "Llama-2-13BL": ["meta-llama/Llama-2-13b-chat-hf", 1, 2048], + "Llama-2-13BT": ["meta-llama/Llama-2-13b-chat-hf", 4, 2048], + } + model_configurations = {**model_configurations, **extra_model_configurations} + +for model_name, model_configuration in model_configurations.items(): + model_id, batch_size, seq_length = model_configuration + model = NeuronModelForCausalLM.from_pretrained( + model_id, export=True, batch_size=batch_size, sequence_length=seq_length, auto_cast_type="fp16" + ) + with TemporaryDirectory() as tmpdir: + model.save_pretrained(tmpdir) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.save_pretrained(tmpdir) + json_path = f"{model_name}.json" + run(tmpdir, 256, 1024, json_path=json_path) diff --git a/docs/assets/benchmarks/inferentia-llama2/encoding_times.png b/docs/assets/benchmarks/inferentia-llama2/encoding_times.png new file mode 100644 index 000000000..144342227 Binary files /dev/null and b/docs/assets/benchmarks/inferentia-llama2/encoding_times.png differ diff --git a/docs/assets/benchmarks/inferentia-llama2/latencies.png b/docs/assets/benchmarks/inferentia-llama2/latencies.png new file mode 100644 index 000000000..3fe8270cc Binary files /dev/null and b/docs/assets/benchmarks/inferentia-llama2/latencies.png differ diff --git a/docs/assets/benchmarks/inferentia-llama2/throughputs.png b/docs/assets/benchmarks/inferentia-llama2/throughputs.png new file mode 100644 index 000000000..e76649565 Binary files /dev/null and b/docs/assets/benchmarks/inferentia-llama2/throughputs.png differ diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 7730aa35c..2416dd141 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -17,7 +17,7 @@ - local: tutorials/llama2-13b-chatbot title: Create your own chatbot with llama-2-13B on AWS Inferentia - local: tutorials/fine_tune_llama_7b - title: Fine-tune Llama 2 7B on AWS Trainium + title: Fine-tune Llama 2 7B on AWS Trainium - local: tutorials/sentence_transformers title: Sentence Transformers on AWS Inferentia title: Tutorials @@ -51,5 +51,9 @@ - local: package_reference/modeling title: Neuron Models title: Reference + - sections: + - local: benchmarks/inferentia-llama2 + title: Llama on AWS Inferentia2 + title: Benchmarks title: Optimum Neuron isExpanded: true diff --git a/docs/source/benchmarks/inferentia-llama2.mdx b/docs/source/benchmarks/inferentia-llama2.mdx new file mode 100644 index 000000000..c7c6d7fbe --- /dev/null +++ b/docs/source/benchmarks/inferentia-llama2.mdx @@ -0,0 +1,77 @@ + + +# Llama performance on AWS Inferentia2 (Latency & Througput) + +How fast is Llama on Inferentia2? Let's figure out! + +For this benchmark we will use the LLama 2 7B and 13B models with different configurations: + +| Model type | num cores | batch_size | +|----------------------------|-----------|------------| +| Llama2 7B - L (latency) | 24 | 1 | +| Llama2 7B - T (throughput) | 24 | 4 | +| Llama2 13B - L (latency) | 24 | 1 | +| Llama2 13B - T (throughput)| 24 | 4 | + +*Note: all models are compiled with a maximum sequence length of 2048.* + +All models are compiled to use the full extent of cores available on the `inf2.48xlarge` instance. + +*Note: please refer to the [inferentia2 product page](https://aws.amazon.com/ec2/instance-types/inf2/) for details on the available instances.* + +We created two "latency" oriented configurations for the `llama2 7B` and `llama2 13B` models that can serve only one request at a time, but at full speed and two "throughput" oriented configurations to serve up to four requests in parallel. + +To evaluate the models, we generate tokens up to a total sequence length of 1024, starting from +256 input tokens (i.e. we generate 256, 512 and 768 tokens). + +## Encoding time (time to first token) + +The encoding time or time to first token is the time required to process the input tokens and generate the first output token. +It is a very important metric, as it corresponds to the latency directly perceived by the user when streaming generated tokens. + +We test the encoding time for increasing context sizes, 256 input tokens corresponding roughly to a typical Q/A usage, +while 768 is more typical of a Retrieval Augmented Generation (RAG) use-case. + +Encoding time is expressed in **seconds**. + +![Llama2 inferentia2 encoding-time](https://raw.githubusercontent.com/huggingface/optimum-neuron/main/docs/assets/benchmarks/inferentia-llama2/encoding-times.png "Encoding time") + +We can see that all deployed models exhibit excellent response times, even for long contexts. + +## End-to-end Latency + +The end-to-end latency corresponds to the total time to reach a sequence length of 1024 tokens. + +It therefore includes the encoding and generation time. + +Latency is expressed in **seconds**. + +![Llama2 inferentia2 end-to-end latency](https://raw.githubusercontent.com/huggingface/optimum-neuron/main/docs/assets/benchmarks/inferentia-llama2/latencies.png "Latency") + +All models deployed on the high-end instance exhibit a good latency, even those actually configured to optimize throughput. + +### Throughput + +We adopt the same convention as other benchmarks to evaluate the throughput, by dividing the end-to-end +latency by the sum of both input and output tokens. +In other words, we divide the end-to-end latency by `batch_size * sequence_length` to obtain the number of generated tokens per second. + +Throughput is expressed in **tokens/second**. + +![Llama2 inferentia2 throughput](https://raw.githubusercontent.com/huggingface/optimum-neuron/main/docs/assets/benchmarks/inferentia-llama2/throughputs.png "Throughput") + +Again, the models deployed on the high-end instance have a very good throughput, even those optimized for latency.