Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
d452eec
Added files for "The Annotated nanoTabPFN"
StatMixedML Aug 5, 2025
f6a29be
Added files for "The Annotated nanoTabPFN"
StatMixedML Aug 5, 2025
316a3f4
Added files for "The Annotated nanoTabPFN"
StatMixedML Aug 5, 2025
60e79e7
Updated files for "The Annotated nanoTabPFN"
StatMixedML Aug 5, 2025
97ae370
Updated files for "The Annotated nanoTabPFN"
StatMixedML Aug 5, 2025
44b585c
Updated files for "The Annotated nanoTabPFN"
StatMixedML Aug 5, 2025
6b56d96
Updated files for "The Annotated nanoTabPFN"
StatMixedML Aug 6, 2025
e8f5d5a
Updated files for "The Annotated nanoTabPFN"
StatMixedML Aug 6, 2025
2f4bae1
Updated files for "The Annotated nanoTabPFN"
StatMixedML Aug 6, 2025
8e9210f
Updated "The Annotated nanoTabPFN"
StatMixedML Aug 11, 2025
8ff9676
Updated "The Annotated nanoTabPFN"
StatMixedML Aug 11, 2025
62f404a
Updated "The Annotated nanoTabPFN"
StatMixedML Aug 14, 2025
cd5fdb8
Updated "The Annotated nanoTabPFN"
StatMixedML Aug 29, 2025
88b999c
Updated "The Annotated nanoTabPFN"
StatMixedML Sep 12, 2025
2c2f109
Enhance explanation of transformer benefits and TabPFN design
StatMixedML Sep 12, 2025
e8e6c9d
Merge pull request #1 from PriorLabs/main
StatMixedML Sep 17, 2025
f84b06e
Merge remote-tracking branch 'origin/main'
StatMixedML Sep 17, 2025
670abfb
feat(attention): implement lightweight MultiheadAttention using PyTor…
StatMixedML Sep 17, 2025
c1982a4
feat(attention): implement lightweight MultiheadAttention using PyTor…
StatMixedML Sep 17, 2025
a32675c
removed annotated files
StatMixedML Sep 17, 2025
c701183
Updated bias arguments in attention.py to allow for variable settings…
StatMixedML Sep 18, 2025
a5afa17
Updated bias arguments in attention.py to allow for variable settings…
StatMixedML Sep 18, 2025
e815dfe
Updated bias arguments in attention.py to allow for variable settings…
StatMixedML Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions nanotabpfn/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# attention.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

class MultiheadAttention(nn.Module):
"""
Minimal Multi-Head Attention using PyTorch's scaled_dot_product_attention (SDPA).

This implementation benefits from PyTorch's automatic dispatch:
- On CUDA with supported dtypes (fp16, bf16, fp32) and head_dim <= 128,
it uses **Flash Attention** kernels for maximum efficiency.
- Otherwise, it falls back to the memory-efficient or math kernel.

Tensor shape notation:
B = Batch size
T = Sequence length
E = Embedding dimension
H = Number of attention heads
D = Per-head dimension (D = E / H)

Parameters
----------
embed_dim : int
Input/output embedding size (E).
num_heads : int
Number of attention heads (H). Must divide embed_dim.
batch_first : bool, default True
If True, input/output is (B, T, E). If False, (T, B, E).
qkv_bias : bool, default False
Include bias terms in the q/k/v projections.
out_proj_bias : bool, default False
Include bias term in the output projection.
device, dtype : Optional
Device and dtype.
"""

def __init__(
self,
embed_dim: int,
num_heads: int,
batch_first: bool = True,
qkv_bias: bool = False,
out_proj_bias: bool = False,
device: torch.device = None,
dtype: torch.dtype = None,
):
super().__init__()
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.batch_first = batch_first

fw = {"device": device, "dtype": dtype}
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=qkv_bias, **fw)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=qkv_bias, **fw)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=qkv_bias, **fw)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=out_proj_bias, **fw)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> Tuple[torch.Tensor, None]:
"""
Compute multi-head attention.

Uses PyTorch's scaled_dot_product_attention (SDPA), which
automatically dispatches to the **Flash Attention kernel** when available.

Args
----
query : Tensor
(B, Tq, E) if batch_first else (Tq, B, E)
key : Tensor
(B, Tk, E) if batch_first else (Tk, B, E)
value : Tensor
(B, Tk, E) if batch_first else (Tk, B, E)

Returns
-------
attn_output : Tensor
Same layout as input (batch_first preserved).
None :
Placeholder for attention weights (not computed).
"""
if not self.batch_first:
# convert (T, B, E) -> (B, T, E)
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)

# Allow for different sequence lengths in query and key/value
B, Tq, _ = query.shape
Tk = key.shape[1]

# Linear projections
q = self.q_proj(query) # (B, Tq, E)
k = self.k_proj(key) # (B, Tk, E)
v = self.v_proj(value) # (B, Tk, E)

# (B, T, E) -> (B, H, T, D), where D = E / H
H, D = self.num_heads, self.head_dim
q = q.view(B, Tq, H, D).transpose(1, 2) # (B, H, Tq, D)
k = k.view(B, Tk, H, D).transpose(1, 2) # (B, H, Tk, D)
v = v.view(B, Tk, H, D).transpose(1, 2) # (B, H, Tk, D)

# SDPA: Flash Attention efficiency when available
attn = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
) # (B, H, Tq, D)

# (B, H, Tq, D) -> (B, Tq, E)
attn = attn.transpose(1, 2).contiguous().view(B, Tq, H * D)
out = self.out_proj(attn) # (B, Tq, E)

if not self.batch_first:
# convert back (B, T, E) -> (T, B, E)
out = out.transpose(0, 1)
# None placeholder for attention weights (not computed)
return out, None
5 changes: 3 additions & 2 deletions nanotabpfn/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def init_model_from_state_dict_file(file_path):
embedding_size = state_dict['feature_encoder.linear_layer.weight'].shape[0]
mlp_hidden_size = state_dict['decoder.linear1.weight'].shape[0]
num_outputs = state_dict['decoder.linear2.weight'].shape[0]
num_layers = sum('self_attn_between_datapoints.in_proj_weight' in k for k in state_dict)
num_heads = state_dict['transformer_encoder.transformer_blocks.0.self_attn_between_datapoints.in_proj_weight'].shape[1]//64
num_layers = sum('self_attn_between_datapoints.q_proj.weight' in k for k in state_dict)
num_heads = state_dict['transformer_encoder.transformer_blocks.0.self_attn_between_datapoints.q_proj.weight'].shape[1]//64
model = NanoTabPFNModel(
num_attention_heads=num_heads,
embedding_size=embedding_size,
Expand Down Expand Up @@ -174,3 +174,4 @@ def predict(self, X_test: np.array) -> np.array:
preds = self.normalized_dist.mean(logits)

return preds.cpu().numpy()

6 changes: 4 additions & 2 deletions nanotabpfn/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.modules.transformer import MultiheadAttention, Linear, LayerNorm
from typing import Tuple, Union
from torch.nn.modules.transformer import Linear, LayerNorm
from .attention import MultiheadAttention
from typing import Tuple


class NanoTabPFNModel(nn.Module):
Expand Down Expand Up @@ -223,3 +224,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
(torch.Tensor) a tensor of shape (batch_size, num_rows, num_outputs)
"""
return self.linear2(F.gelu(self.linear1(x)))