Skip to content

Commit 4ce9591

Browse files
author
chenfeiyu
committed
initial commit: add piecewise attention
0 parents  commit 4ce9591

11 files changed

+1171
-0
lines changed

.gitignore

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# results of benchmark
2+
benchmark/results/
3+
4+
# version, since setuptools-scm is used, this file is automatic generated when building the package
5+
src/flag_attn/_version.py
6+
7+
# Editors
8+
.vscode/
9+
.idea/
10+
11+
# Vagrant
12+
.vagrant/
13+
14+
# Mac/OSX
15+
.DS_Store
16+
17+
# Windows
18+
Thumbs.db
19+
20+
# Source for the following rules: https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore
21+
# Byte-compiled / optimized / DLL files
22+
__pycache__/
23+
*.py[cod]
24+
*$py.class
25+
26+
# C extensions
27+
*.so
28+
29+
# Distribution / packaging
30+
.Python
31+
build/
32+
develop-eggs/
33+
dist/
34+
downloads/
35+
eggs/
36+
.eggs/
37+
lib/
38+
lib64/
39+
parts/
40+
sdist/
41+
var/
42+
wheels/
43+
*.egg-info/
44+
.installed.cfg
45+
*.egg
46+
MANIFEST
47+
48+
# PyInstaller
49+
# Usually these files are written by a python script from a template
50+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
51+
*.manifest
52+
*.spec
53+
54+
# Installer logs
55+
pip-log.txt
56+
pip-delete-this-directory.txt
57+
58+
# Unit test / coverage reports
59+
htmlcov/
60+
.tox/
61+
.nox/
62+
.coverage
63+
.coverage.*
64+
.cache
65+
nosetests.xml
66+
coverage.xml
67+
*.cover
68+
.hypothesis/
69+
.pytest_cache/
70+
71+
# Translations
72+
*.mo
73+
*.pot
74+
75+
# Django stuff:
76+
*.log
77+
local_settings.py
78+
db.sqlite3
79+
80+
# Flask stuff:
81+
instance/
82+
.webassets-cache
83+
84+
# Scrapy stuff:
85+
.scrapy
86+
87+
# Sphinx documentation
88+
docs/_build/
89+
90+
# PyBuilder
91+
target/
92+
93+
# Jupyter Notebook
94+
.ipynb_checkpoints
95+
96+
# IPython
97+
profile_default/
98+
ipython_config.py
99+
100+
# pyenv
101+
.python-version
102+
103+
# celery beat schedule file
104+
celerybeat-schedule
105+
106+
# SageMath parsed files
107+
*.sage.py
108+
109+
# Environments
110+
.env
111+
.venv
112+
env/
113+
venv/
114+
ENV/
115+
env.bak/
116+
venv.bak/
117+
118+
# Spyder project settings
119+
.spyderproject
120+
.spyproject
121+
122+
# Rope project settings
123+
.ropeproject
124+
125+
# mkdocs documentation
126+
/site
127+
128+
# mypy
129+
.mypy_cache/
130+
.dmypy.json
131+
dmypy.json
132+

LICENSE

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
Copyright 2023 BAAI
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.

README.md

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# FlagAttention
2+
3+
FlagAttention is a project for memory-efficient attention operators implemented in the Triton language. It is inspired by [FlashAttention](https://arxiv.org/abs/2205.14135) and [FlahAttention v2](https://tridao.me/publications/flash2/flash2.pdf) and extends them to satisfy the needs for research on large language modeling. FlashAttention and FlashAttention-2 save memory footprint and traffic to improve memory efficiency, but to modify them and add more options and functionalities requires precision in cuda programming. Thus, Flag Attention is implemented in the Triton language, which is easier to use to write custom GPU kernels.
4+
5+
## Installation
6+
7+
## Requirements
8+
9+
Flag Attention requires torch and triton. To use new features of triton, triton nightly is recommended.
10+
11+
Instructions for installing torch nightly can be found at https://pytorch.org/get-started/locally/ . Triton is now a dependency of torch nightly, so it can be installed automatically.
12+
13+
Flag Attention requires Ampere Nvidia GPUs(e.g. A100, RTX-3090, ...) and CUDA Toolkit 11.6 and above. Other GPUs may work but not been tested yet.
14+
15+
FlagAttention can be installed in either way below.
16+
17+
1. Editable Installation. This includes tests and benchmarks. Changes to the code in local source tree are effective without re-installation.
18+
2. Build a distribution and then install. Only the package is installed.
19+
20+
### Editable Installation
21+
22+
Editable installation with `pip`.
23+
24+
```sh
25+
git clone https://github.com/FlagOpen/FlagAttention && cd FlagAttention
26+
pip install -e .
27+
```
28+
29+
### Build a Distribution & Install
30+
31+
Following modern python packaging convention, `FlagAttention` is configured by [`pyproject.toml`](https://pip.pypa.io/en/stable/reference/build-system/pyproject-toml/), and no `setup.py` is provided. To build a distribution, either a source distribution or a binary distribution, python module `build` is recommended.
32+
33+
First, install `build` package via pip.
34+
35+
```sh
36+
pip install build
37+
```
38+
39+
Then build the package.
40+
41+
```sh
42+
git clone https://github.com/FlagOpen/FlagAttention && cd FlagAttention
43+
python -m build --no-isolation
44+
```
45+
46+
The built package is in `dist/` for installation.
47+
48+
```sh
49+
pip install dist/flag_attn-xxx.whl
50+
```
51+
52+
## Usage
53+
54+
FlagAttention provides customized attention operators. When an operator is equivalent to a torch function, it can be used as a drop-in replacement.
55+
56+
## Run the Tests
57+
58+
A recent version of `pytest`(>=7.1.0) is required to run the tests in `tests`. Operators in `FlagAttention` are tested against a reference implementation in pytorch, both forward and backward. For `float16` and `bfloat16`. we set absolute and relative tolerance to `1e-2` and `1e-3`, respectively.
59+
60+
```sh
61+
pytest .
62+
```
63+
64+
## Run the Benchmark
65+
66+
Benchmarks are provided to measure the TFLOPs/s achieved. Since operators in `FlagAttention` deviates from flash attention, the total amount of computation is different even when batch size, sequence length, number of heads, head dimension, and other configurations are the same. To calculate the FLOPs of an operator, only matmuls are counted. The FLOPs is divided by the median runtime to get the achieved FLOPs/s.
67+
68+
## Operators
69+
70+
### Piecewise Attention
71+
72+
The first extension of flash attention is [piecewise attention](src/flag_attn/piecewise.py).
73+
74+
```
75+
piecewise_attention(q1, k1, q2, k2, v, dist_threshold, softmax_scale=None, causal=False)
76+
```
77+
78+
It is named `piecewise_attention` in that it takes two `q`'s and two `k`'s to compute attention scores (S) before applying softmax to get the attention weights (P). The design originates from the fact that a transformer with rotary position embedding is not good at predicting sequences longer than the longest sequence that it is trained on. Pair of (q, k) gets unexpectedly high attention scores when the distance is greater the max sequence length in traing set. A proposal to solve the problem is to compute the attention score in different ways, depending on whether the distance between `q` and `k` is greater than a threshold.
79+
80+
In practice, `q` and `k` can be preprocessed in two different ways to get `q1, q2` and `k1, k2`. Then then attention score is computed as the dot product of `q1, k1` or `q2, k2` depending on the distance between `q` and `k`.
81+
82+
![piecewise attention](assets/piecewise_attention.png)
83+
84+
Features:
85+
86+
- the sequence length of k/v can be larger than that of q;
87+
- data type support, float16 and bfloat16 for Ampere Nvidia GPUs;
88+
- support causal and non causal modes.
89+
- support forward & backward modes.
90+
91+
Limitations:
92+
93+
- headdim should be in `[16, 32, 64, 128]`.
94+
- dropout is not supported yet.
95+
96+
## TODOs
97+
98+
1. Test on other GPUs;
99+
2. Test on more triton versions
100+
3. Improve performance of attention operators.
101+
2. Support other extensions of flash attention.
102+

assets/piecewise_attention.png

1.44 MB
Loading

benchmark/piecewise_benchmark.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import math
2+
import pathlib
3+
import torch
4+
import triton
5+
6+
from flag_attn import piecewise_attn
7+
8+
try:
9+
from flash_attn.flash_attn_interface import \
10+
flash_attn_qkvpacked_func as flash_attn_func
11+
FLASH_VER = 2
12+
except BaseException:
13+
try:
14+
from flash_attn.flash_attn_interface import flash_attn_func
15+
FLASH_VER = 1
16+
except BaseException:
17+
FLASH_VER = None
18+
HAS_FLASH = FLASH_VER is not None
19+
20+
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
21+
22+
configs = [triton.testing.Benchmark(
23+
x_names=['N_CTX'],
24+
x_vals=[2**i for i in range(10, 16)],
25+
line_arg='provider',
26+
line_vals=['triton', ] + (['flash'] if HAS_FLASH else []),
27+
line_names=['triton', ] + ([f'flash-{FLASH_VER}'] if HAS_FLASH else []),
28+
styles=[('red', '-'), ('green', '-')],
29+
ylabel='tflop/s',
30+
plot_name=f'piecewise_attention_batch-{BATCH}_head-{N_HEADS}_d-{D_HEAD}_mode-{mode}_caucal-{causal}_dtype-{dtype}',
31+
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': dtype, 'mode': mode, 'causal': causal}
32+
) for mode in ['fwd', 'bwd'] for causal in [False, True] for dtype in [torch.float16, torch.bfloat16]]
33+
34+
@triton.testing.perf_report(configs)
35+
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"):
36+
assert mode in ['fwd', 'bwd']
37+
w = N_CTX // 2 # dist thresold
38+
warmup = 25
39+
rep = 100
40+
if provider == "triton":
41+
q1 = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
42+
k1 = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
43+
q2 = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
44+
k2 = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
45+
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
46+
sm_scale = 1. / math.sqrt(D_HEAD)
47+
fn = lambda: piecewise_attn(q1, k1, q2, k2, v, w, causal, sm_scale)
48+
if mode == 'bwd':
49+
o = fn()
50+
do = torch.randn_like(o)
51+
fn = lambda: o.backward(do, retain_graph=True)
52+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
53+
if provider == "flash":
54+
qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
55+
if FLASH_VER == 1:
56+
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
57+
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
58+
cu_seqlens[1:] = lengths.cumsum(0)
59+
qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD)
60+
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal)
61+
elif FLASH_VER == 2:
62+
fn = lambda: flash_attn_func(qkv, causal=causal)
63+
else:
64+
raise ValueError(f'unknown {FLASH_VER = }')
65+
if mode == 'bwd':
66+
o = fn()
67+
do = torch.randn_like(o)
68+
fn = lambda: o.backward(do, retain_graph=True)
69+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
70+
71+
# total TFLOPS: following Flash Attention v2, only gemms are counted.
72+
# NOTE: It is not a fair play here, the total amount of flops and the elapsed time are different,
73+
# the tflop/s is a used as a metric of the performance of the operator, for refernce only.
74+
if provider == "flash":
75+
macs = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD # Q@K, P@V
76+
if mode == 'bwd':
77+
macs *= 2.5 # Q@K, dO@V, dO@P, dS@Q dS@K
78+
else:
79+
macs = 3. * BATCH * H * N_CTX * N_CTX * D_HEAD # Q1@K1, Q2@K2, P@V
80+
if mode == 'bwd':
81+
macs *= 8/3.0 # Q1@K1, Q2@K2, dO@V, dO@P, dS1@@Q1, dS1@K1, dS2@@Q2, dS2@K2
82+
total_flops = 2 * macs
83+
84+
if causal:
85+
total_flops *= 0.5
86+
return total_flops / ms * 1e-9
87+
88+
# only works on post-Ampere GPUs right now
89+
output_dir = pathlib.Path("results")
90+
output_dir.mkdir(exist_ok=True)
91+
bench_flash_attention.run(save_path=output_dir, print_data=True)

pyproject.toml

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
[project]
2+
name = "flag_attn"
3+
dynamic = ["version"]
4+
authors = [
5+
{name = "Chen Feiyu", email = "[email protected]"},
6+
]
7+
description = "A collection of memory efficient attention operators implemented in triton language."
8+
readme = {file= "README.md", content-type="text/markdown"}
9+
requires-python = ">=3.7"
10+
license = {text = "LICENSE.txt"}
11+
classifiers = [
12+
"Development Status :: 3 - Alpha",
13+
"Programming Language :: Python :: 3",
14+
"License :: OSI Approved :: Apache Software License",
15+
]
16+
dependencies = [
17+
"torch>=2.1.0",
18+
]
19+
20+
[project.optional-dependencies]
21+
test = [
22+
"pytest>=7.1.0",
23+
]
24+
25+
[project.urls]
26+
homepage = "https://github.com/FlagOpen/FlagAttention"
27+
28+
29+
[build-system]
30+
requires = ["setuptools>=60", "setuptools-scm>=8.0"]
31+
build-backend = "setuptools.build_meta"
32+
33+
[tool.setuptools_scm]
34+
version_file = "src/flag_attn/_version.py"
35+
36+
[tool.setuptools.packages.find]
37+
where = ["src"]
38+
include = ["flag_attn"]
39+
namespaces = false
40+
41+
# helps for setting up pytest in pyprojects
42+
# https://docs.pytest.org/en/7.3.x/reference/customize.html#rootdir
43+
# https://docs.pytest.org/en/7.3.x/reference/reference.html#confval-pythonpath
44+
[tool.pytest.ini_options]
45+
testpaths = [
46+
"tests",
47+
]
48+
pythonpath = [
49+
"src",
50+
"tests/flag_attn",
51+
]
52+

0 commit comments

Comments
 (0)