Skip to content

Commit e82f72b

Browse files
PyTorch Custom Operator Integration (#1544)
* Sketch out first custom op registration * Add note * Initial int8 op registration * Cleanup some deprecated functions. * Int8 ops updates; tests * Implement 4bit quant/dequant ops * Fix nested quant * cleanup * Test improvements * Clean up and improve tests * Add higher level custom op for int8 matmul + dequant + bias * Add gemv 4bit custom op * Cleanup * Implement out kwarg overloads for custom ops * Update PyTorch minimum to 2.1 * Deprecation updates * Deprecation updates * Cleanup; rename int8_linear_dequant -> int8_scaled_mm * Bump min pytorch to 2.2 * cleanup * Test reorganization * Remove deprecated supports_igemmlt * More cleanup * Cleanup obsolete C++/CUDA code * Cleanup * Create 'default' backend for fallback op implementations; initial CPU nf4 work * Stub out for multi-platform * Fix serialization tests for torch>=2.6.0 * Add example for torch.compile e2e inference * Test update --------- Co-authored-by: Titus von Koeller <[email protected]>
1 parent f0735f9 commit e82f72b

28 files changed

+2689
-3342
lines changed

bitsandbytes/__init__.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,35 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from . import research, utils
6+
7+
import torch
8+
9+
from . import _ops, research, utils
710
from .autograd._functions import (
811
MatmulLtState,
9-
bmm_cublas,
1012
matmul,
1113
matmul_4bit,
12-
matmul_cublas,
13-
mm_cublas,
1414
)
15+
from .backends.cpu import ops as cpu_ops
16+
from .backends.default import ops as default_ops
1517
from .nn import modules
1618
from .optim import adam
1719

20+
# This is a signal for integrations with transformers/diffusers.
21+
# Eventually, we will remove this and check based on release version.
22+
features = {"multi-backend"}
23+
supported_torch_devices = {
24+
"cuda",
25+
"cpu",
26+
# "mps",
27+
# "xpu",
28+
# "hpu",
29+
# "npu",
30+
}
31+
32+
if torch.cuda.is_available():
33+
from .backends.cuda import ops as cuda_ops
34+
1835
__pdoc__ = {
1936
"libbitsandbytes": False,
2037
"optim.optimizer.Optimizer8bit": False,

bitsandbytes/_ops.py

+302
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
from math import prod
2+
from typing import Optional, Sequence, Tuple
3+
4+
import torch
5+
6+
_IS_TORCH_GTE_24 = False
7+
8+
if hasattr(torch.library, "register_fake"):
9+
_IS_TORCH_GTE_24 = True
10+
register_fake = torch.library.register_fake
11+
register_kernel = torch.library.register_kernel
12+
else:
13+
# PyTorch <= 2.3
14+
register_fake = torch.library.impl_abstract
15+
register_kernel = torch.library.impl
16+
17+
18+
# Higher level op: int8 matmul + dequant + bias
19+
torch.library.define(
20+
"bitsandbytes::int8_scaled_mm",
21+
"(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType dtype=float16) -> Tensor",
22+
)
23+
24+
25+
@register_fake("bitsandbytes::int8_scaled_mm")
26+
def _(
27+
A: torch.Tensor,
28+
B: torch.Tensor,
29+
row_stats: torch.Tensor,
30+
col_stats: torch.Tensor,
31+
bias: Optional[torch.Tensor] = None,
32+
dtype=torch.float16,
33+
) -> torch.Tensor:
34+
shapeC = (*A.shape[:-1], B.shape[0])
35+
return torch.empty(shapeC, device=A.device, dtype=dtype)
36+
37+
38+
torch.library.define(
39+
"bitsandbytes::int8_linear_matmul",
40+
"(Tensor A, Tensor B) -> Tensor",
41+
)
42+
43+
44+
@register_fake("bitsandbytes::int8_linear_matmul")
45+
def _(A: torch.Tensor, B: torch.Tensor):
46+
torch._check(A.dtype == torch.int8, lambda: "A must be int8")
47+
torch._check(B.dtype == torch.int8, lambda: "B must be int8")
48+
shapeC = (*A.shape[:-1], B.shape[0])
49+
return torch.empty(shapeC, device=A.device, dtype=torch.int32)
50+
51+
52+
# More info on `out` overloads:
53+
# https://github.com/pytorch/pytorch/issues/125044
54+
torch.library.define(
55+
"bitsandbytes::int8_linear_matmul.out",
56+
"(Tensor A, Tensor B, Tensor! out) -> ()",
57+
)
58+
59+
60+
@register_fake("bitsandbytes::int8_linear_matmul.out")
61+
def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
62+
shapeC = (*A.shape[:-1], B.shape[0])
63+
64+
torch._check(A.dtype == torch.int8, lambda: "A must be int8")
65+
torch._check(B.dtype == torch.int8, lambda: "B must be int8")
66+
torch._check(out.shape == shapeC, lambda: f"Expected out.shape == {shapeC}, got {out.shape}")
67+
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
68+
torch._check(out.dtype == torch.int32, lambda: f"Expected out.dtype == int32, got {out.dtype}")
69+
70+
71+
torch.library.define(
72+
"bitsandbytes::int8_vectorwise_quant",
73+
"(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor?)",
74+
)
75+
76+
77+
@register_fake("bitsandbytes::int8_vectorwise_quant")
78+
def _(A: torch.Tensor, threshold=0.0):
79+
out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)
80+
row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32)
81+
82+
if threshold == 0.0:
83+
return out_row, row_stats, None
84+
85+
outlier_cols = torch.library.get_ctx().new_dynamic_size()
86+
87+
return out_row, row_stats, A.new_empty(outlier_cols, dtype=torch.int64)
88+
89+
90+
torch.library.define("bitsandbytes::int8_vectorwise_dequant", "(Tensor A, Tensor stats) -> Tensor")
91+
92+
93+
@register_fake("bitsandbytes::int8_vectorwise_dequant")
94+
def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor:
95+
torch._check(A.dtype == torch.int8, lambda: "A must be int8")
96+
return torch.empty_like(A, dtype=torch.float32)
97+
98+
99+
# Default PyTorch-native implementation
100+
@register_kernel("bitsandbytes::int8_vectorwise_dequant", None)
101+
def _(A: torch.Tensor, stats: torch.Tensor):
102+
# To dequantize we divide by 127, or multiply by the reciprocal.
103+
return A * stats.view(-1, 1) * 7.874015718698502e-3
104+
105+
106+
torch.library.define(
107+
"bitsandbytes::int8_mm_dequant",
108+
"(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? bias=None) -> Tensor",
109+
)
110+
111+
112+
@register_fake("bitsandbytes::int8_mm_dequant")
113+
def _(
114+
A: torch.Tensor,
115+
row_stats: torch.Tensor,
116+
col_stats: torch.Tensor,
117+
dtype=torch.float16,
118+
bias: Optional[torch.Tensor] = None,
119+
) -> torch.Tensor:
120+
torch._check(A.dtype == torch.int32, lambda: "A must be int32")
121+
return torch.empty_like(A, dtype=dtype)
122+
123+
124+
torch.library.define(
125+
"bitsandbytes::int8_double_quant",
126+
"(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)",
127+
)
128+
129+
130+
@register_fake("bitsandbytes::int8_double_quant")
131+
def _(
132+
A: torch.Tensor,
133+
threshold=0.0,
134+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
135+
out_row = torch.empty_like(A, dtype=torch.int8)
136+
out_col = torch.empty_like(A, dtype=torch.int8)
137+
row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32)
138+
col_stats = torch.empty(A.shape[-1], device=A.device, dtype=torch.float32)
139+
outlier_n = torch.library.get_ctx().new_dynamic_size()
140+
outlier_cols = A.new_empty(outlier_n, dtype=torch.int64)
141+
return out_row, out_col, row_stats, col_stats, outlier_cols
142+
143+
144+
torch.library.define(
145+
"bitsandbytes::dequantize_4bit",
146+
"(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype) -> Tensor",
147+
)
148+
149+
150+
@register_fake("bitsandbytes::dequantize_4bit")
151+
def _(
152+
A: torch.Tensor,
153+
absmax: torch.Tensor,
154+
blocksize: int,
155+
quant_type: str,
156+
shape: Sequence[int],
157+
dtype: torch.dtype,
158+
) -> torch.Tensor:
159+
torch._check_is_size(blocksize)
160+
return torch.empty(shape, dtype=dtype, device=A.device)
161+
162+
163+
torch.library.define(
164+
"bitsandbytes::dequantize_4bit.out",
165+
"(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype, Tensor! out) -> ()",
166+
)
167+
168+
169+
@register_fake("bitsandbytes::dequantize_4bit.out")
170+
def _(
171+
A: torch.Tensor,
172+
absmax: torch.Tensor,
173+
blocksize: int,
174+
quant_type: str,
175+
shape: Sequence[int],
176+
dtype: torch.dtype,
177+
out: torch.Tensor,
178+
) -> None:
179+
torch._check_is_size(blocksize)
180+
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
181+
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
182+
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
183+
184+
185+
torch.library.define(
186+
"bitsandbytes::quantize_4bit",
187+
"(Tensor A, int blocksize, str quant_type, ScalarType quant_storage) -> (Tensor, Tensor)",
188+
)
189+
190+
191+
@register_fake("bitsandbytes::quantize_4bit")
192+
def _(
193+
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
194+
) -> Tuple[torch.Tensor, torch.Tensor]:
195+
torch._check_is_size(blocksize)
196+
197+
n = A.numel()
198+
blocks = -(n // -blocksize)
199+
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
200+
out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage)
201+
return out, absmax
202+
203+
204+
torch.library.define(
205+
"bitsandbytes::dequantize_blockwise",
206+
"(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype) -> Tensor",
207+
)
208+
209+
210+
@register_fake("bitsandbytes::dequantize_blockwise")
211+
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
212+
torch._check_is_size(blocksize)
213+
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
214+
return torch.empty_like(A, dtype=dtype)
215+
216+
217+
torch.library.define(
218+
"bitsandbytes::dequantize_blockwise.out",
219+
"(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype, Tensor! out) -> ()",
220+
)
221+
222+
223+
@register_fake("bitsandbytes::dequantize_blockwise.out")
224+
def _(
225+
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
226+
):
227+
torch._check_is_size(blocksize)
228+
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
229+
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
230+
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
231+
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
232+
233+
234+
torch.library.define("bitsandbytes::quantize_blockwise", "(Tensor A, Tensor code, int blocksize) -> (Tensor, Tensor)")
235+
236+
237+
@register_fake("bitsandbytes::quantize_blockwise")
238+
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
239+
torch._check_is_size(blocksize)
240+
n = A.numel()
241+
blocks = -(n // -blocksize)
242+
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
243+
out = torch.empty_like(A, dtype=torch.uint8)
244+
return out, absmax
245+
246+
247+
torch.library.define(
248+
"bitsandbytes::gemv_4bit",
249+
"(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize) -> Tensor",
250+
)
251+
252+
253+
@register_fake("bitsandbytes::gemv_4bit")
254+
def _(
255+
A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int
256+
) -> torch.Tensor:
257+
torch._check_is_size(blocksize)
258+
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
259+
torch._check(
260+
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
261+
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
262+
)
263+
torch._check(
264+
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
265+
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
266+
)
267+
shape = (*A.shape[:-1], shapeB[0])
268+
return torch.empty(shape, device=A.device, dtype=A.dtype)
269+
270+
271+
torch.library.define(
272+
"bitsandbytes::gemv_4bit.out",
273+
"(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize, Tensor! out) -> ()",
274+
)
275+
276+
277+
@register_fake("bitsandbytes::gemv_4bit.out")
278+
def _(
279+
A: torch.Tensor,
280+
B: torch.Tensor,
281+
shapeB: Sequence[int],
282+
absmax: torch.Tensor,
283+
code: torch.Tensor,
284+
blocksize: int,
285+
out: torch.Tensor,
286+
) -> None:
287+
torch._check_is_size(blocksize)
288+
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
289+
torch._check(
290+
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
291+
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
292+
)
293+
torch._check(
294+
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
295+
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
296+
)
297+
torch._check(
298+
out.shape == (*A.shape[:-1], shapeB[0]),
299+
lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}",
300+
)
301+
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
302+
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")

0 commit comments

Comments
 (0)