Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch Custom Operator Integration #1544

Merged
merged 35 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6268912
Sketch out first custom op registration
matthewdouglas Jan 28, 2025
04e1bc6
Add note
matthewdouglas Jan 28, 2025
d5df4c6
Merge branch 'main' into customop-refactoring
matthewdouglas Feb 11, 2025
04482ff
Initial int8 op registration
matthewdouglas Feb 12, 2025
2813571
Cleanup some deprecated functions.
matthewdouglas Feb 12, 2025
4ad1d9e
Int8 ops updates; tests
matthewdouglas Feb 13, 2025
e9c79cf
Implement 4bit quant/dequant ops
matthewdouglas Feb 14, 2025
9d0f459
Fix nested quant
matthewdouglas Feb 17, 2025
f360a08
cleanup
matthewdouglas Feb 17, 2025
45ead33
Test improvements
matthewdouglas Feb 18, 2025
6aeea81
Clean up and improve tests
matthewdouglas Feb 18, 2025
cbd1670
Add higher level custom op for int8 matmul + dequant + bias
matthewdouglas Feb 25, 2025
db07f4e
Add gemv 4bit custom op
matthewdouglas Feb 26, 2025
23eba7a
Cleanup
matthewdouglas Feb 26, 2025
2d5b2cc
Implement out kwarg overloads for custom ops
matthewdouglas Mar 7, 2025
6172770
Update PyTorch minimum to 2.1
matthewdouglas Mar 7, 2025
242c602
Deprecation updates
matthewdouglas Mar 8, 2025
25368bc
Deprecation updates
matthewdouglas Mar 8, 2025
32345e4
merge main
Titus-von-Koeller Mar 10, 2025
2b85100
Cleanup; rename int8_linear_dequant -> int8_scaled_mm
matthewdouglas Mar 13, 2025
aacd408
Merge branch 'customop-refactoring' of https://github.com/TimDettmers…
matthewdouglas Mar 13, 2025
a61c0fa
Bump min pytorch to 2.2
matthewdouglas Mar 13, 2025
fd74c06
cleanup
matthewdouglas Mar 13, 2025
587120a
Test reorganization
matthewdouglas Mar 13, 2025
975c356
Remove deprecated supports_igemmlt
matthewdouglas Mar 13, 2025
da40911
More cleanup
matthewdouglas Mar 13, 2025
0b04376
Merge branch 'main' into customop-refactoring
matthewdouglas Mar 14, 2025
11e2e92
Cleanup obsolete C++/CUDA code
matthewdouglas Mar 14, 2025
b599401
Cleanup
matthewdouglas Mar 14, 2025
c703d8d
Create 'default' backend for fallback op implementations; initial CPU…
matthewdouglas Mar 17, 2025
431819d
Stub out for multi-platform
matthewdouglas Mar 17, 2025
fa188f6
Fix serialization tests for torch>=2.6.0
matthewdouglas Mar 25, 2025
2015127
Add example for torch.compile e2e inference
matthewdouglas Mar 25, 2025
0a11fae
Test update
matthewdouglas Mar 25, 2025
dcc2c16
Merge branch 'main' into customop-refactoring
matthewdouglas Mar 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,35 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from . import research, utils

import torch

from . import _ops, research, utils
from .autograd._functions import (
MatmulLtState,
bmm_cublas,
matmul,
matmul_4bit,
matmul_cublas,
mm_cublas,
)
from .backends.cpu import ops as cpu_ops
from .backends.default import ops as default_ops
from .nn import modules
from .optim import adam

# This is a signal for integrations with transformers/diffusers.
# Eventually, we will remove this and check based on release version.
features = {"multi-backend"}
supported_torch_devices = {
"cuda",
"cpu",
# "mps",
# "xpu",
# "hpu",
# "npu",
}

if torch.cuda.is_available():
from .backends.cuda import ops as cuda_ops

__pdoc__ = {
"libbitsandbytes": False,
"optim.optimizer.Optimizer8bit": False,
Expand Down
302 changes: 302 additions & 0 deletions bitsandbytes/_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
from math import prod
from typing import Optional, Sequence, Tuple

import torch

_IS_TORCH_GTE_24 = False

if hasattr(torch.library, "register_fake"):
_IS_TORCH_GTE_24 = True
register_fake = torch.library.register_fake
register_kernel = torch.library.register_kernel
else:
# PyTorch <= 2.3
register_fake = torch.library.impl_abstract
register_kernel = torch.library.impl


# Higher level op: int8 matmul + dequant + bias
torch.library.define(
"bitsandbytes::int8_scaled_mm",
"(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType dtype=float16) -> Tensor",
)


@register_fake("bitsandbytes::int8_scaled_mm")
def _(
A: torch.Tensor,
B: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dtype=torch.float16,
) -> torch.Tensor:
shapeC = (*A.shape[:-1], B.shape[0])
return torch.empty(shapeC, device=A.device, dtype=dtype)


torch.library.define(
"bitsandbytes::int8_linear_matmul",
"(Tensor A, Tensor B) -> Tensor",
)


@register_fake("bitsandbytes::int8_linear_matmul")
def _(A: torch.Tensor, B: torch.Tensor):
torch._check(A.dtype == torch.int8, lambda: "A must be int8")
torch._check(B.dtype == torch.int8, lambda: "B must be int8")
shapeC = (*A.shape[:-1], B.shape[0])
return torch.empty(shapeC, device=A.device, dtype=torch.int32)


# More info on `out` overloads:
# https://github.com/pytorch/pytorch/issues/125044
torch.library.define(
"bitsandbytes::int8_linear_matmul.out",
"(Tensor A, Tensor B, Tensor! out) -> ()",
)


@register_fake("bitsandbytes::int8_linear_matmul.out")
def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
shapeC = (*A.shape[:-1], B.shape[0])

torch._check(A.dtype == torch.int8, lambda: "A must be int8")
torch._check(B.dtype == torch.int8, lambda: "B must be int8")
torch._check(out.shape == shapeC, lambda: f"Expected out.shape == {shapeC}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == torch.int32, lambda: f"Expected out.dtype == int32, got {out.dtype}")


torch.library.define(
"bitsandbytes::int8_vectorwise_quant",
"(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor?)",
)


@register_fake("bitsandbytes::int8_vectorwise_quant")
def _(A: torch.Tensor, threshold=0.0):
out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)
row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32)

if threshold == 0.0:
return out_row, row_stats, None

outlier_cols = torch.library.get_ctx().new_dynamic_size()

return out_row, row_stats, A.new_empty(outlier_cols, dtype=torch.int64)


torch.library.define("bitsandbytes::int8_vectorwise_dequant", "(Tensor A, Tensor stats) -> Tensor")


@register_fake("bitsandbytes::int8_vectorwise_dequant")
def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor:
torch._check(A.dtype == torch.int8, lambda: "A must be int8")
return torch.empty_like(A, dtype=torch.float32)


# Default PyTorch-native implementation
@register_kernel("bitsandbytes::int8_vectorwise_dequant", None)
def _(A: torch.Tensor, stats: torch.Tensor):
# To dequantize we divide by 127, or multiply by the reciprocal.
return A * stats.view(-1, 1) * 7.874015718698502e-3


torch.library.define(
"bitsandbytes::int8_mm_dequant",
"(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? bias=None) -> Tensor",
)


@register_fake("bitsandbytes::int8_mm_dequant")
def _(
A: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
dtype=torch.float16,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: "A must be int32")
return torch.empty_like(A, dtype=dtype)


torch.library.define(
"bitsandbytes::int8_double_quant",
"(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)",
)


@register_fake("bitsandbytes::int8_double_quant")
def _(
A: torch.Tensor,
threshold=0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
out_row = torch.empty_like(A, dtype=torch.int8)
out_col = torch.empty_like(A, dtype=torch.int8)
row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32)
col_stats = torch.empty(A.shape[-1], device=A.device, dtype=torch.float32)
outlier_n = torch.library.get_ctx().new_dynamic_size()
outlier_cols = A.new_empty(outlier_n, dtype=torch.int64)
return out_row, out_col, row_stats, col_stats, outlier_cols


torch.library.define(
"bitsandbytes::dequantize_4bit",
"(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype) -> Tensor",
)


@register_fake("bitsandbytes::dequantize_4bit")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
return torch.empty(shape, dtype=dtype, device=A.device)


torch.library.define(
"bitsandbytes::dequantize_4bit.out",
"(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype, Tensor! out) -> ()",
)


@register_fake("bitsandbytes::dequantize_4bit.out")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check_is_size(blocksize)
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")


torch.library.define(
"bitsandbytes::quantize_4bit",
"(Tensor A, int blocksize, str quant_type, ScalarType quant_storage) -> (Tensor, Tensor)",
)


@register_fake("bitsandbytes::quantize_4bit")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)

n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage)
return out, absmax


torch.library.define(
"bitsandbytes::dequantize_blockwise",
"(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype) -> Tensor",
)


@register_fake("bitsandbytes::dequantize_blockwise")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
return torch.empty_like(A, dtype=dtype)


torch.library.define(
"bitsandbytes::dequantize_blockwise.out",
"(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype, Tensor! out) -> ()",
)


@register_fake("bitsandbytes::dequantize_blockwise.out")
def _(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
):
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")


torch.library.define("bitsandbytes::quantize_blockwise", "(Tensor A, Tensor code, int blocksize) -> (Tensor, Tensor)")


@register_fake("bitsandbytes::quantize_blockwise")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty_like(A, dtype=torch.uint8)
return out, absmax


torch.library.define(
"bitsandbytes::gemv_4bit",
"(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize) -> Tensor",
)


@register_fake("bitsandbytes::gemv_4bit")
def _(
A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
)
torch._check(
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
)
shape = (*A.shape[:-1], shapeB[0])
return torch.empty(shape, device=A.device, dtype=A.dtype)


torch.library.define(
"bitsandbytes::gemv_4bit.out",
"(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize, Tensor! out) -> ()",
)


@register_fake("bitsandbytes::gemv_4bit.out")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
out: torch.Tensor,
) -> None:
torch._check_is_size(blocksize)
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
)
torch._check(
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
)
torch._check(
out.shape == (*A.shape[:-1], shapeB[0]),
lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}",
)
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
Loading