From 05de4a1151c2eb8ca2d2c655de64cea0fa580b78 Mon Sep 17 00:00:00 2001
From: koshe
Date: Sat, 30 Aug 2025 04:50:39 +0200
Subject: [PATCH 1/8] Lazy load for weights, KV cache, optimized inference, !!!
8Gb VRAM !!!
---
README.md | 20 +++
gpt_oss/generate.py | 10 ++
gpt_oss/torch/model.py | 328 ++++++++++++++++++++++++++++++++++-----
gpt_oss/torch/utils.py | 2 +-
gpt_oss/torch/weights.py | 46 ++++--
5 files changed, 354 insertions(+), 52 deletions(-)
diff --git a/README.md b/README.md
index 4ef20827..dd28c455 100644
--- a/README.md
+++ b/README.md
@@ -10,6 +10,15 @@
+GPT-OSS-20B optimized to run on 8GB VRAM
+
+- Use lazy load for Transformer layears (2.5 times slower than without lazy load, but can be runned on 8GB 3070Ti Laptop with 32GB RAM)
+- Added kv_cache to speed up inference (torch)
+- Optimized weight loading speed
+- Optimized forward pass and attention
+
+
+__________________________________________
Welcome to the gpt-oss series, [OpenAI's open-weight models](https://openai.com/open-models/) designed for powerful reasoning, agentic tasks, and versatile developer use cases.
@@ -199,6 +208,17 @@ And then run:
torchrun --nproc-per-node=4 -m gpt_oss.generate gpt-oss-120b/original/
```
+# Windows run example
+```shell
+python -m gpt_oss.generate --backend torch gpt-oss-20b/original/ -p "Hi" -l 10
+
+```
+#with profiler
+```shell
+kernprof -l -v -m gpt_oss.generate --backend torch gpt-oss-20b/original/ -p "Hi" -l 10
+```
+
+
## Reference Triton implementation (single GPU)
We also include an optimized reference implementation that uses [an optimized triton MoE kernel](https://github.com/triton-lang/triton/tree/main/python/triton_kernels/triton_kernels) that supports MXFP4. It also has some optimization on the attention code to reduce the memory cost. To run this implementation, the nightly version of triton and torch will be installed. This version can be run on a single 80GB GPU for `gpt-oss-120b`.
diff --git a/gpt_oss/generate.py b/gpt_oss/generate.py
index c0755805..213444ce 100644
--- a/gpt_oss/generate.py
+++ b/gpt_oss/generate.py
@@ -7,7 +7,17 @@
from gpt_oss.tokenizer import get_tokenizer
+from line_profiler import profile
+try:
+ profile # type: ignore
+except NameError:
+ profile = lambda f: f
+
+import os
+os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False"
+
+@profile
def main(args):
match args.backend:
case "torch":
diff --git a/gpt_oss/torch/model.py b/gpt_oss/torch/model.py
index 9180d493..123cbcb6 100644
--- a/gpt_oss/torch/model.py
+++ b/gpt_oss/torch/model.py
@@ -8,6 +8,11 @@
from gpt_oss.torch.weights import Checkpoint
+try:
+ profile # type: ignore
+except NameError:
+ profile = lambda f: f
+
@dataclass
class ModelConfig:
@@ -59,6 +64,23 @@ def _apply_rotary_emb(
o2 = x2 * cos + x1 * sin
return torch.cat((o1, o2), dim=-1)
+def _apply_rotary_emb_optimized(
+ x: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+) -> torch.Tensor:
+ # Prepare cos/sin with correct shape and dtype outside this function
+ # Use direct indexing instead of chunk for better memory efficiency
+ half_dim = x.size(-1) // 2
+ x1, x2 = x[..., :half_dim], x[..., half_dim:]
+
+ # Compute in-place where possible
+ o1 = x1 * cos - x2 * sin
+ o2 = x2 * cos + x1 * sin
+
+ # Use view instead of cat to avoid memory allocation
+ return torch.cat([o1, o2], dim=-1) # Still need cat, but optimized other parts
+
class RotaryEmbedding(torch.nn.Module):
def __init__(
@@ -122,31 +144,36 @@ def _compute_concentration_and_inv_freq(self) -> torch.Tensor:
return concentration, inv_freq
- def _compute_cos_sin(self, num_tokens: int):
+ def _compute_cos_sin(self, num_tokens: int, start_pos: int = 0):
concentration, inv_freq = self._compute_concentration_and_inv_freq()
- t = torch.arange(num_tokens, dtype=torch.float32, device=self.device)
+ t = torch.arange(start_pos, start_pos + num_tokens, dtype=torch.float32, device=self.device)
freqs = torch.einsum("i,j->ij", t, inv_freq)
cos = freqs.cos() * concentration
sin = freqs.sin() * concentration
return cos, sin
+ @profile
def forward(
- self,
- query: torch.Tensor,
- key: torch.Tensor,
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ start_pos: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
num_tokens = query.shape[0]
- cos, sin = self._compute_cos_sin(num_tokens)
+ cos, sin = self._compute_cos_sin(num_tokens, start_pos=start_pos)
+ cos = cos.unsqueeze(-2).to(query.dtype)
+ sin = sin.unsqueeze(-2).to(query.dtype)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_dim)
- query = _apply_rotary_emb(query, cos, sin)
+ query = _apply_rotary_emb_optimized(query, cos, sin)
query = query.reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_dim)
- key = _apply_rotary_emb(key, cos, sin)
+ key = _apply_rotary_emb_optimized(key, cos, sin)
key = key.reshape(key_shape)
+
return query, key
@@ -214,7 +241,7 @@ def __init__(
device=device,
)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
+ def forward_(self, x: torch.Tensor) -> torch.Tensor:
t = self.norm(x)
qkv = self.qkv(t)
q = qkv[:, : self.num_attention_heads * self.head_dim].contiguous()
@@ -245,6 +272,132 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
t = x + t
return t
+ @profile
+ def forward_2(self,
+ x: torch.Tensor,
+ past_kv: tuple[torch.Tensor, torch.Tensor] | None = None,
+ start_pos: int = 0,
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
+ t = self.norm(x)
+ qkv = self.qkv(t)
+
+ q = qkv[:, : self.num_attention_heads * self.head_dim]
+ k = qkv[:, self.num_attention_heads * self.head_dim: (
+ self.num_attention_heads + self.num_key_value_heads) * self.head_dim]
+ v = qkv[:, (self.num_attention_heads + self.num_key_value_heads) * self.head_dim:]
+
+ q = q.view(-1, self.num_attention_heads, self.head_dim)
+ k = k.view(-1, self.num_key_value_heads, self.head_dim)
+ v = v.view(-1, self.num_key_value_heads, self.head_dim)
+
+ q, k = self.rope(q, k, start_pos=start_pos)
+
+ if past_kv is not None:
+ past_k, past_v = past_kv
+ k = torch.cat([past_k, k], dim=0)
+ v = torch.cat([past_v, v], dim=0)
+
+ new_kv = (k, v)
+
+ num_groups = self.num_attention_heads // self.num_key_value_heads
+ k_expanded = k.repeat_interleave(num_groups, dim=1)
+ v_expanded = v.repeat_interleave(num_groups, dim=1)
+
+ query_len, num_heads, head_dim = q.shape
+ key_len = k_expanded.shape[0]
+
+ QK = torch.einsum("qhd,khd->hqk", q, k_expanded)
+ QK *= self.sm_scale
+
+ all_indices = torch.arange(key_len, device=x.device)
+ query_indices = torch.arange(start_pos, start_pos + query_len, device=x.device)
+ mask = query_indices[:, None] < all_indices[None, :]
+ QK = QK.masked_fill(mask, -torch.inf)
+
+ if self.sliding_window > 0:
+ sliding_mask = query_indices[:, None] < (all_indices[None, :] + self.sliding_window)
+ QK = QK.masked_fill(~sliding_mask, -torch.inf)
+
+ S = self.sinks.view(num_heads, 1, 1).expand(-1, query_len, -1)
+ QK = torch.cat([QK, S], dim=-1)
+
+ W = torch.softmax(QK, dim=-1)
+ W = W[..., :-1]
+
+ attn = torch.einsum("hqk,khd->qhd", W, v_expanded)
+
+ t = attn.reshape(-1, self.num_attention_heads * self.head_dim)
+ t = self.out(t)
+
+ return t, new_kv
+
+ @profile
+ def forward(self,
+ x: torch.Tensor,
+ past_kv: tuple[torch.Tensor, torch.Tensor] | None = None,
+ start_pos: int = 0,
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
+ t = self.norm(x)
+ qkv = self.qkv(t)
+
+ q_end = self.num_attention_heads * self.head_dim
+ k_end = q_end + self.num_key_value_heads * self.head_dim
+ q = qkv[:, :q_end]
+ k = qkv[:, q_end:k_end]
+ v = qkv[:, k_end:]
+
+ q = q.view(-1, self.num_attention_heads, self.head_dim)
+ k = k.view(-1, self.num_key_value_heads, self.head_dim)
+ v = v.view(-1, self.num_key_value_heads, self.head_dim)
+
+ q, k = self.rope(q, k, start_pos=start_pos)
+
+ if past_kv is not None:
+ past_k, past_v = past_kv
+ k = torch.cat([past_k, k], dim=0)
+ v = torch.cat([past_v, v], dim=0)
+
+ new_kv = (k, v)
+ num_groups = self.num_attention_heads // self.num_key_value_heads
+
+ k_expanded = k.repeat_interleave(num_groups, dim=1)
+ v_expanded = v.repeat_interleave(num_groups, dim=1)
+
+ query_len, num_heads, head_dim = q.shape
+ key_len = k_expanded.shape[0]
+
+ q_permuted = q.permute(1, 0, 2)
+ k_permuted = k_expanded.permute(1, 0, 2)
+ QK = torch.bmm(q_permuted, k_permuted.transpose(1, 2))
+ QK *= self.sm_scale
+
+ all_indices = torch.arange(key_len, device=x.device)
+ query_indices = torch.arange(start_pos, start_pos + query_len, device=x.device)
+
+ causal_mask = query_indices[:, None] < all_indices[None, :]
+
+ mask = causal_mask
+
+ if self.sliding_window > 0:
+ sliding_mask = query_indices[:, None] < (all_indices[None, :] + self.sliding_window)
+ mask = mask | ~sliding_mask
+
+ QK = QK.masked_fill(mask, -torch.inf)
+
+ S = self.sinks.view(num_heads, 1, 1).expand(-1, query_len, -1)
+ QK = torch.cat([QK, S], dim=-1)
+
+ W = torch.softmax(QK, dim=-1)
+ W = W[..., :-1]
+
+ v_permuted = v_expanded.permute(1, 0, 2)
+ attn = torch.bmm(W, v_permuted)
+ attn = attn.permute(1, 0, 2)
+
+ t = attn.reshape(-1, self.num_attention_heads * self.head_dim)
+ t = self.out(t)
+
+ return t, new_kv
def swiglu(x, alpha: float = 1.702, limit: float = 7.0):
x_glu, x_linear = x[..., ::2], x[..., 1::2]
@@ -348,44 +501,133 @@ def __init__(
self.attn = AttentionBlock(config, layer_idx, device)
self.mlp = MLPBlock(config, device)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = self.attn(x)
+ @profile
+ def forward(self,
+ x: torch.Tensor,
+ past_kv: tuple[torch.Tensor, torch.Tensor] | None = None,
+ start_pos: int = 0,
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
+ attn_output, new_kv = self.attn(x, past_kv, start_pos)
+ x = x + attn_output
x = self.mlp(x)
- return x
+ return x, new_kv
+checkpoint = None
class Transformer(torch.nn.Module):
def __init__(
self,
config: ModelConfig,
device: torch.device | None = None,
+ lazy_load: bool = True,
):
super().__init__()
- self.embedding = torch.nn.Embedding(
- config.vocab_size, config.hidden_size, device=device, dtype=torch.bfloat16
- )
- self.block = torch.nn.ModuleList(
- [
- TransformerBlock(config, layer_idx, device)
- for layer_idx in range(config.num_hidden_layers)
- ]
- )
+ # lazy load layers
+ self.lazy_load = lazy_load
+ if not self.lazy_load:
+ self.embedding = torch.nn.Embedding(
+ config.vocab_size, config.hidden_size, device=device, dtype=torch.bfloat16
+ )
+ self.config = config
+ self.device = device
+
+ if not self.lazy_load:
+ self.block = torch.nn.ModuleList(
+ [
+ TransformerBlock(config, layer_idx, device)
+ for layer_idx in range(config.num_hidden_layers)
+ ]
+ )
self.norm = RMSNorm(config.hidden_size, device=device)
- self.unembedding = torch.nn.Linear(
- config.hidden_size,
- config.vocab_size,
- bias=False,
- device=device,
- dtype=torch.bfloat16,
- )
+ if not self.lazy_load:
+ self.unembedding = torch.nn.Linear(
+ config.hidden_size,
+ config.vocab_size,
+ bias=False,
+ device=device,
+ dtype=torch.bfloat16,
+ )
+
+ @profile
+ def forward(self,
+ x: torch.Tensor,
+ kv_cache: list[tuple[torch.Tensor, torch.Tensor] | None] | None = None,
+ start_pos: int = 0,
+ ) -> tuple[
+ torch.Tensor, list[tuple[torch.Tensor, torch.Tensor] | None]]:
+ if self.lazy_load:
+ embedding = torch.nn.Embedding(
+ self.config.vocab_size, self.config.hidden_size, device=self.device, dtype=torch.bfloat16
+ )
+ for name, param in embedding.named_parameters():
+ self.load_weights(param, f"embedding.{name}")
+ x = embedding(x)
+ del embedding
+ else:
+ x = self.embedding(x)
+
+ if kv_cache is None:
+ kv_cache = [None] * self.config.num_hidden_layers
+
+ if self.lazy_load:
+ for layer_idx in range(self.config.num_hidden_layers):
+ block = TransformerBlock(self.config, layer_idx, self.device)
+ for name, param in block.named_parameters():
+ self.load_weights(param, f"block.{layer_idx}.{name}")
+ past_kv_for_layer = kv_cache[layer_idx]
+ x, new_kv_for_layer = block(x, past_kv_for_layer, start_pos)
+ kv_cache[layer_idx] = new_kv_for_layer
+ del block
+ else:
+ for layer_idx in range(self.config.num_hidden_layers):
+ past_kv_for_layer = kv_cache[layer_idx]
+ x, new_kv_for_layer = self.block[layer_idx](x, past_kv_for_layer, start_pos)
+ kv_cache[layer_idx] = new_kv_for_layer
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = self.embedding(x)
- for block in self.block:
- x = block(x)
x = self.norm(x)
- x = self.unembedding(x)
- return x
+
+ if self.lazy_load:
+ unembedding = torch.nn.Linear(
+ self.config.hidden_size,
+ self.config.vocab_size,
+ bias=False,
+ device=self.device,
+ dtype=torch.bfloat16,
+ )
+ for name, param in unembedding.named_parameters():
+ self.load_weights(param, f"unembedding.{name}")
+ x = unembedding(x)
+ del unembedding
+ else:
+ x = self.unembedding(x)
+ return x, kv_cache
+
+ def load_weights(self, param, name):
+ global checkpoint
+ my_rank = dist.get_rank() if dist.is_initialized() else 0
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
+ per_rank_intermediate_size = self.config.intermediate_size // world_size
+ loaded_tensor = checkpoint.get(name)
+ if "mlp1" in name:
+ loaded_tensor = loaded_tensor[
+ :,
+ my_rank * 2
+ * per_rank_intermediate_size: (my_rank + 1) * 2
+ * per_rank_intermediate_size,
+ ...,
+ ]
+ elif "mlp2_weight" in name:
+ loaded_tensor = loaded_tensor[
+ ...,
+ my_rank
+ * per_rank_intermediate_size: (my_rank + 1)
+ * per_rank_intermediate_size,
+ ]
+ try:
+ param.data.copy_(loaded_tensor)
+ except:
+ print(f"{name=} {param.data.shape=} {loaded_tensor.shape=}")
+ raise
@staticmethod
def from_checkpoint(
@@ -410,6 +652,7 @@ def from_checkpoint(
world_size = dist.get_world_size() if dist.is_initialized() else 1
per_rank_intermediate_size = config.intermediate_size // world_size
+ global checkpoint
checkpoint = Checkpoint(path, device)
for name, param in model.named_parameters():
@@ -455,16 +698,20 @@ def generate(self,
max_tokens: int = 0,
return_logprobs: bool = False):
tokens = list(prompt_tokens)
+ num_prompt_tokens = len(tokens)
num_generated_tokens = 0
+
+ kv_cache = None
+ prompt_tensor = torch.as_tensor(tokens, dtype=torch.int32, device=self.device)
+ logits, kv_cache = self.model(prompt_tensor, kv_cache=None, start_pos=0)
+ logits = logits[-1]
+
while max_tokens == 0 or num_generated_tokens < max_tokens:
- logits = self.model(torch.as_tensor(tokens, dtype=torch.int32, device=self.device))[-1]
if temperature == 0.0:
predicted_token = torch.argmax(logits, dim=-1).item()
else:
probs = torch.softmax(logits * (1.0 / temperature), dim=-1)
predicted_token = torch.multinomial(probs, num_samples=1).item()
- tokens.append(predicted_token)
- num_generated_tokens += 1
if return_logprobs:
logprobs = torch.log_softmax(logits, dim=-1)
@@ -475,3 +722,10 @@ def generate(self,
if predicted_token in stop_tokens:
break
+
+ tokens.append(predicted_token)
+ num_generated_tokens += 1
+ next_token_tensor = torch.as_tensor([predicted_token], dtype=torch.int32, device=self.device)
+ start_pos = num_prompt_tokens + num_generated_tokens - 1
+ logits, kv_cache = self.model(next_token_tensor, kv_cache=kv_cache, start_pos=start_pos)
+ logits = logits[0]
diff --git a/gpt_oss/torch/utils.py b/gpt_oss/torch/utils.py
index ce87a85d..a5bc5196 100644
--- a/gpt_oss/torch/utils.py
+++ b/gpt_oss/torch/utils.py
@@ -36,5 +36,5 @@ def init_distributed() -> torch.device:
dist.all_reduce(x)
torch.cuda.synchronize(device)
- suppress_output(rank)
+ # suppress_output(rank)
return device
diff --git a/gpt_oss/torch/weights.py b/gpt_oss/torch/weights.py
index aa5df58a..811c0a2c 100644
--- a/gpt_oss/torch/weights.py
+++ b/gpt_oss/torch/weights.py
@@ -4,6 +4,11 @@
import torch
from safetensors import safe_open
+try:
+ profile # type: ignore
+except NameError:
+ profile = lambda f: f
+
# Bytes per MXFP4 block: 32 FP4 numbers packed in 16 bytes
BYTES_PER_BLOCK = 16
@@ -33,6 +38,7 @@ def __init__(self, path: str, device: torch.device):
else device.type + ":" + str(device.index)
)
self.device_str = device_str
+ self.file_handles = {}
# Read from all files ending with .safetensors in the checkpoint directory
safetensor_files = [
@@ -43,11 +49,30 @@ def __init__(self, path: str, device: torch.device):
# Build a mapping from tensor name to (file, key)
tensor_name_to_file = {}
for safetensor_file in safetensor_files:
- with safe_open(safetensor_file, framework="pt", device=device_str) as f:
- for key in f.keys():
- tensor_name_to_file[key] = safetensor_file
+ handle = safe_open(safetensor_file, framework="pt", device='cpu')
+ self.file_handles[safetensor_file] = handle
+ for key in handle.keys():
+ tensor_name_to_file[key] = safetensor_file
self.tensor_name_to_file = tensor_name_to_file
+ self._lut = torch.tensor(FP4_VALUES, dtype=torch.bfloat16, device=device)
+
+ @profile
+ def _get_tensor(self, name: str) -> torch.Tensor:
+ assert name in self.tensor_name_to_file, f"Tensor {name} not found."
+ file_key = self.tensor_name_to_file[name]
+ handle = self.file_handles[file_key]
+
+ # Get tensor and pin memory for faster GPU transfers
+ tensor = handle.get_tensor(name)
+ if self.device_str.startswith('cuda'):
+ tensor = tensor.pin_memory()
+
+ return tensor.to(self.device_str, non_blocking=True)
+
+ def __del__(self):
+ for handle in self.file_handles.values():
+ pass
def get(self, name: str) -> torch.Tensor:
match PARAM_NAME_MAP.get(name, name):
@@ -58,13 +83,7 @@ def get(self, name: str) -> torch.Tensor:
# MoE biases and other weights
return self._get_tensor(tensor_name)
- def _get_tensor(self, name: str) -> str:
- assert name in self.tensor_name_to_file, f"Tensor {name} not found in checkpoint."
- with safe_open(
- self.tensor_name_to_file[name], framework="pt", device=self.device_str
- ) as f:
- return f.get_tensor(name)
-
+ @profile
def _get_mxfp4_tensor(
self,
blocks_name: str,
@@ -87,7 +106,7 @@ def _get_mxfp4_tensor(
f"{blocks.shape=} does not match {scales.shape=}"
)
- lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)
+ lut = self._lut
*prefix_shape, G, B = blocks.shape
rows_total = math.prod(prefix_shape) * G
@@ -103,9 +122,8 @@ def _get_mxfp4_tensor(
blk = blocks[r0:r1]
exp = scales[r0:r1]
- # nibble indices -> int64
- idx_lo = (blk & 0x0F).to(torch.long)
- idx_hi = (blk >> 4).to(torch.long)
+ idx_lo = torch.bitwise_and(blk, 0x0F).long()
+ idx_hi = torch.bitwise_right_shift(blk, 4).long()
sub = out[r0:r1]
sub[:, 0::2] = lut[idx_lo]
From d5d58b9eaaf8397c14a44e5e27f64dd40380c908 Mon Sep 17 00:00:00 2001
From: koshe
Date: Sat, 30 Aug 2025 06:01:38 +0200
Subject: [PATCH 2/8] Add video
---
README.md | 1 +
1 file changed, 1 insertion(+)
diff --git a/README.md b/README.md
index dd28c455..78a24be5 100644
--- a/README.md
+++ b/README.md
@@ -18,6 +18,7 @@
Optimized forward pass and attention
+
__________________________________________
Welcome to the gpt-oss series, [OpenAI's open-weight models](https://openai.com/open-models/) designed for powerful reasoning, agentic tasks, and versatile developer use cases.
From 96bbffb4baea8b9ebe9da67db7630183f51f2fb9 Mon Sep 17 00:00:00 2001
From: nalexand <35492736+nalexand@users.noreply.github.com>
Date: Sat, 30 Aug 2025 06:19:58 +0200
Subject: [PATCH 3/8] Update README.md
---
README.md | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 78a24be5..f120f829 100644
--- a/README.md
+++ b/README.md
@@ -18,7 +18,13 @@
Optimized forward pass and attention
-
+## 🎥 Demo Video
+
+Watch GPT-OSS 20B running on just 8GB of VRAM:
+
+[](https://englyk.com/gpt_oss_20b_8gb_vram.mp4)
+
+*Click the image to watch the full demonstration*
__________________________________________
Welcome to the gpt-oss series, [OpenAI's open-weight models](https://openai.com/open-models/) designed for powerful reasoning, agentic tasks, and versatile developer use cases.
From 811aa4340f0fc5893159b0cb3e40cf9e3a2c159f Mon Sep 17 00:00:00 2001
From: nalexand <35492736+nalexand@users.noreply.github.com>
Date: Sat, 30 Aug 2025 06:33:45 +0200
Subject: [PATCH 4/8] Update README.md
---
README.md | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index f120f829..bba7da9f 100644
--- a/README.md
+++ b/README.md
@@ -24,7 +24,8 @@ Watch GPT-OSS 20B running on just 8GB of VRAM:
[](https://englyk.com/gpt_oss_20b_8gb_vram.mp4)
-*Click the image to watch the full demonstration*
+*Click the image to watch the full demonstration* or -
+[Watch on YouTube](https://youtu.be/0g7MBALZM8c)
__________________________________________
Welcome to the gpt-oss series, [OpenAI's open-weight models](https://openai.com/open-models/) designed for powerful reasoning, agentic tasks, and versatile developer use cases.
From 71697508cd849c2e759e33d997be9b4502fa486f Mon Sep 17 00:00:00 2001
From: koshe
Date: Sat, 30 Aug 2025 15:25:41 +0200
Subject: [PATCH 5/8] Remove unused code
---
gpt_oss/torch/model.py | 59 ------------------------------------------
1 file changed, 59 deletions(-)
diff --git a/gpt_oss/torch/model.py b/gpt_oss/torch/model.py
index 123cbcb6..b1713098 100644
--- a/gpt_oss/torch/model.py
+++ b/gpt_oss/torch/model.py
@@ -272,65 +272,6 @@ def forward_(self, x: torch.Tensor) -> torch.Tensor:
t = x + t
return t
- @profile
- def forward_2(self,
- x: torch.Tensor,
- past_kv: tuple[torch.Tensor, torch.Tensor] | None = None,
- start_pos: int = 0,
- ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
- t = self.norm(x)
- qkv = self.qkv(t)
-
- q = qkv[:, : self.num_attention_heads * self.head_dim]
- k = qkv[:, self.num_attention_heads * self.head_dim: (
- self.num_attention_heads + self.num_key_value_heads) * self.head_dim]
- v = qkv[:, (self.num_attention_heads + self.num_key_value_heads) * self.head_dim:]
-
- q = q.view(-1, self.num_attention_heads, self.head_dim)
- k = k.view(-1, self.num_key_value_heads, self.head_dim)
- v = v.view(-1, self.num_key_value_heads, self.head_dim)
-
- q, k = self.rope(q, k, start_pos=start_pos)
-
- if past_kv is not None:
- past_k, past_v = past_kv
- k = torch.cat([past_k, k], dim=0)
- v = torch.cat([past_v, v], dim=0)
-
- new_kv = (k, v)
-
- num_groups = self.num_attention_heads // self.num_key_value_heads
- k_expanded = k.repeat_interleave(num_groups, dim=1)
- v_expanded = v.repeat_interleave(num_groups, dim=1)
-
- query_len, num_heads, head_dim = q.shape
- key_len = k_expanded.shape[0]
-
- QK = torch.einsum("qhd,khd->hqk", q, k_expanded)
- QK *= self.sm_scale
-
- all_indices = torch.arange(key_len, device=x.device)
- query_indices = torch.arange(start_pos, start_pos + query_len, device=x.device)
- mask = query_indices[:, None] < all_indices[None, :]
- QK = QK.masked_fill(mask, -torch.inf)
-
- if self.sliding_window > 0:
- sliding_mask = query_indices[:, None] < (all_indices[None, :] + self.sliding_window)
- QK = QK.masked_fill(~sliding_mask, -torch.inf)
-
- S = self.sinks.view(num_heads, 1, 1).expand(-1, query_len, -1)
- QK = torch.cat([QK, S], dim=-1)
-
- W = torch.softmax(QK, dim=-1)
- W = W[..., :-1]
-
- attn = torch.einsum("hqk,khd->qhd", W, v_expanded)
-
- t = attn.reshape(-1, self.num_attention_heads * self.head_dim)
- t = self.out(t)
-
- return t, new_kv
-
@profile
def forward(self,
x: torch.Tensor,
From 1cea0706ddf563f90bb787e3370a6544267cf3d2 Mon Sep 17 00:00:00 2001
From: koshe
Date: Sun, 31 Aug 2025 21:28:39 +0200
Subject: [PATCH 6/8] Add support !!! 6 Gb VRAM !!! for gpt-oss-20b
---
gpt_oss/chat.py | 15 +-
gpt_oss/torch/model.py | 340 +++++++++++++++++++++++++--------------
gpt_oss/torch/weights.py | 110 ++++++++++---
3 files changed, 324 insertions(+), 141 deletions(-)
diff --git a/gpt_oss/chat.py b/gpt_oss/chat.py
index 5e40079d..37bd226f 100644
--- a/gpt_oss/chat.py
+++ b/gpt_oss/chat.py
@@ -9,10 +9,15 @@
import os
from pathlib import Path
+os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False"
+
try:
import gnureadline as readline
except ImportError:
- import readline
+ try:
+ import readline
+ except ImportError:
+ import pyreadline3 as readline
import torch
import termcolor
@@ -80,8 +85,8 @@ def main(args):
system_message_content = (
SystemContent.new()
- .with_reasoning_effort(REASONING_EFFORT[args.reasoning_effort])
- .with_conversation_start_date(datetime.datetime.now().strftime("%Y-%m-%d"))
+ #.with_reasoning_effort(REASONING_EFFORT[args.reasoning_effort])
+ #.with_conversation_start_date(datetime.datetime.now().strftime("%Y-%m-%d"))
)
if args.browser:
@@ -245,7 +250,9 @@ async def run_tool():
field_created = False
current_output_text = ""
output_text_delta_buffer = ""
- for predicted_token in generator.generate(tokens, encoding.stop_tokens_for_assistant_actions()):
+ for predicted_token in generator.generate(tokens, encoding.stop_tokens_for_assistant_actions(),
+ #temperature=0, max_tokens=10
+ ):
parser.process(predicted_token)
if args.raw:
print(encoding.decode([predicted_token]), end="", flush=True)
diff --git a/gpt_oss/torch/model.py b/gpt_oss/torch/model.py
index b1713098..b519275d 100644
--- a/gpt_oss/torch/model.py
+++ b/gpt_oss/torch/model.py
@@ -52,34 +52,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return (t * self.scale).to(dtype)
-def _apply_rotary_emb(
- x: torch.Tensor,
- cos: torch.Tensor,
- sin: torch.Tensor,
-) -> torch.Tensor:
- cos = cos.unsqueeze(-2).to(x.dtype)
- sin = sin.unsqueeze(-2).to(x.dtype)
- x1, x2 = torch.chunk(x, 2, dim=-1)
- o1 = x1 * cos - x2 * sin
- o2 = x2 * cos + x1 * sin
- return torch.cat((o1, o2), dim=-1)
-
def _apply_rotary_emb_optimized(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
- # Prepare cos/sin with correct shape and dtype outside this function
- # Use direct indexing instead of chunk for better memory efficiency
half_dim = x.size(-1) // 2
x1, x2 = x[..., :half_dim], x[..., half_dim:]
- # Compute in-place where possible
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
- # Use view instead of cat to avoid memory allocation
- return torch.cat([o1, o2], dim=-1) # Still need cat, but optimized other parts
+ return torch.cat([o1, o2], dim=-1)
class RotaryEmbedding(torch.nn.Module):
@@ -152,7 +136,6 @@ def _compute_cos_sin(self, num_tokens: int, start_pos: int = 0):
sin = freqs.sin() * concentration
return cos, sin
- @profile
def forward(
self,
query: torch.Tensor,
@@ -241,38 +224,6 @@ def __init__(
device=device,
)
- def forward_(self, x: torch.Tensor) -> torch.Tensor:
- t = self.norm(x)
- qkv = self.qkv(t)
- q = qkv[:, : self.num_attention_heads * self.head_dim].contiguous()
- k = qkv[
- :,
- self.num_attention_heads
- * self.head_dim : (self.num_attention_heads + self.num_key_value_heads)
- * self.head_dim,
- ].contiguous()
- v = qkv[
- :,
- (self.num_attention_heads + self.num_key_value_heads)
- * self.head_dim : (self.num_attention_heads + 2 * self.num_key_value_heads)
- * self.head_dim,
- ].contiguous()
-
- q = q.view(
- -1,
- self.num_key_value_heads,
- self.num_attention_heads // self.num_key_value_heads,
- self.head_dim,
- )
- k = k.view(-1, self.num_key_value_heads, self.head_dim)
- v = v.view(-1, self.num_key_value_heads, self.head_dim)
- q, k = self.rope(q, k)
- t = sdpa(q, k, v, self.sinks, self.sm_scale, self.sliding_window)
- t = self.out(t)
- t = x + t
- return t
-
- @profile
def forward(self,
x: torch.Tensor,
past_kv: tuple[torch.Tensor, torch.Tensor] | None = None,
@@ -350,7 +301,7 @@ def swiglu(x, alpha: float = 1.702, limit: float = 7.0):
return out_glu * (x_linear + 1)
-class MLPBlock(torch.nn.Module):
+class MLPBlock_(torch.nn.Module): # memory unefficient
def __init__(
self,
config: ModelConfig,
@@ -429,6 +380,113 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + t
+class MLPBlock(torch.nn.Module):
+ def __init__(
+ self,
+ config: ModelConfig,
+ device: torch.device | None = None,
+ ):
+ super().__init__()
+ self.num_experts = config.num_experts
+ self.experts_per_token = config.experts_per_token
+ self.swiglu_limit = config.swiglu_limit
+ self.world_size = dist.get_world_size() if dist.is_initialized() else 1
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.norm = RMSNorm(config.hidden_size, device=device)
+ self.gate = torch.nn.Linear(
+ config.hidden_size, config.num_experts, device=device, dtype=torch.bfloat16
+ )
+ assert config.intermediate_size % self.world_size == 0
+ self.mlp1_weight = torch.nn.Parameter(
+ torch.empty(
+ (
+ config.num_experts,
+ config.intermediate_size * 2 // self.world_size,
+ config.hidden_size,
+ ),
+ device=device,
+ dtype=torch.bfloat16,
+ )
+ )
+ self.mlp1_bias = torch.nn.Parameter(
+ torch.empty(
+ (config.num_experts, config.intermediate_size * 2 // self.world_size),
+ device=device,
+ dtype=torch.bfloat16,
+ )
+ )
+ self.mlp2_weight = torch.nn.Parameter(
+ torch.empty(
+ (
+ config.num_experts,
+ self.hidden_size,
+ config.intermediate_size // self.world_size,
+ ),
+ device=device,
+ dtype=torch.bfloat16,
+ )
+ )
+ self.mlp2_bias = torch.nn.Parameter(
+ torch.empty(
+ (config.num_experts, self.hidden_size),
+ device=device,
+ dtype=torch.bfloat16,
+ )
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ t = self.norm(x)
+ g = self.gate(t)
+
+ if t.dim() == 2:
+ t = t.unsqueeze(1)
+ g = g.unsqueeze(1)
+ added_seq_dim = True
+ else:
+ added_seq_dim = False
+
+ batch_size, seq_len, hidden_size = t.shape
+
+ t_flat = t.reshape(-1, hidden_size)
+ g_flat = g.reshape(-1, self.num_experts)
+
+ experts = torch.topk(g_flat, k=self.experts_per_token, dim=-1, sorted=True)
+ expert_weights = torch.nn.functional.softmax(experts.values, dim=-1)
+ expert_indices = experts.indices
+
+ output_flat = torch.zeros_like(t_flat)
+
+ for k in range(self.experts_per_token):
+ current_expert_indices = expert_indices[:, k]
+ current_expert_weights = expert_weights[:, k]
+
+ mlp1_w = self.mlp1_weight[
+ current_expert_indices]
+ mlp1_b = self.mlp1_bias[current_expert_indices]
+ mlp2_w = self.mlp2_weight[
+ current_expert_indices]
+ mlp2_b = self.mlp2_bias[current_expert_indices]
+
+ t_k = torch.bmm(t_flat.unsqueeze(1), mlp1_w.transpose(1, 2)).squeeze(1) + mlp1_b
+ t_k = swiglu(t_k, limit=self.swiglu_limit)
+
+ t_k = torch.bmm(t_k.unsqueeze(1), mlp2_w.transpose(1, 2)).squeeze(1)
+
+ if self.world_size > 1:
+ dist.all_reduce(t_k, op=dist.ReduceOp.SUM)
+
+ t_k += mlp2_b
+
+ output_flat += t_k * current_expert_weights.unsqueeze(1)
+
+ output = output_flat.reshape(batch_size, seq_len, hidden_size)
+
+ if added_seq_dim:
+ output = output.squeeze(1)
+
+ return x + output
+
class TransformerBlock(torch.nn.Module):
def __init__(
@@ -455,32 +513,60 @@ def forward(self,
checkpoint = None
+def get_free_gpu_memory_gb(device_id=0):
+ """Returns free GPU memory in GB for specified device (default: 0)"""
+ if not torch.cuda.is_available():
+ return 0.0
+
+ props = torch.cuda.get_device_properties(device_id)
+ total_memory = props.total_memory
+ reserved = torch.cuda.memory_reserved(device_id)
+
+ free_memory = total_memory - reserved
+ free_gb = free_memory / (1024 ** 3)
+
+ return free_gb
+
+
class Transformer(torch.nn.Module):
def __init__(
self,
config: ModelConfig,
device: torch.device | None = None,
lazy_load: bool = True,
+ extreme_low_memory: bool = False,
):
super().__init__()
- # lazy load layers
+ free_mem = get_free_gpu_memory_gb()
+ if free_mem < 6:
+ lazy_load = True
+ extreme_low_memory = True
+ elif free_mem < 8:
+ lazy_load = True
self.lazy_load = lazy_load
- if not self.lazy_load:
- self.embedding = torch.nn.Embedding(
- config.vocab_size, config.hidden_size, device=device, dtype=torch.bfloat16
- )
+ self.lazy_load_embedding = lazy_load
+ self.lazy_load_unembedding = lazy_load
+ self.extreme_low_memory = extreme_low_memory
self.config = config
self.device = device
-
- if not self.lazy_load:
+ self.loaded = {0: False, 1: False, 2: False}
+ if self.lazy_load:
+ self.block = torch.nn.ModuleList([
+ TransformerBlock(config, 0, device=device),
+ ])
+ else:
self.block = torch.nn.ModuleList(
[
TransformerBlock(config, layer_idx, device)
for layer_idx in range(config.num_hidden_layers)
]
)
- self.norm = RMSNorm(config.hidden_size, device=device)
- if not self.lazy_load:
+ self.norm = RMSNorm(config.hidden_size, device=device)
+ if not self.lazy_load_embedding:
+ self.embedding = torch.nn.Embedding(
+ config.vocab_size, config.hidden_size, device=device, dtype=torch.bfloat16
+ )
+ if not self.lazy_load_unembedding:
self.unembedding = torch.nn.Linear(
config.hidden_size,
config.vocab_size,
@@ -496,14 +582,18 @@ def forward(self,
start_pos: int = 0,
) -> tuple[
torch.Tensor, list[tuple[torch.Tensor, torch.Tensor] | None]]:
- if self.lazy_load:
- embedding = torch.nn.Embedding(
- self.config.vocab_size, self.config.hidden_size, device=self.device, dtype=torch.bfloat16
- )
- for name, param in embedding.named_parameters():
- self.load_weights(param, f"embedding.{name}")
- x = embedding(x)
- del embedding
+ if self.lazy_load_embedding:
+ if not self.loaded[1] or self.extreme_low_memory:
+ self.embedding = torch.nn.Embedding(
+ self.config.vocab_size, self.config.hidden_size, device=self.device, dtype=torch.bfloat16
+ )
+ for name, param in self.embedding.named_parameters():
+ self.load_weights(param, f"embedding.{name}")
+ self.loaded[1] = True
+ x = self.embedding(x)
+ if self.extreme_low_memory:
+ self.embedding = None
+ torch.cuda.empty_cache()
else:
x = self.embedding(x)
@@ -512,33 +602,45 @@ def forward(self,
if self.lazy_load:
for layer_idx in range(self.config.num_hidden_layers):
- block = TransformerBlock(self.config, layer_idx, self.device)
- for name, param in block.named_parameters():
+ # layers skipping experiment
+ #if layer_idx % 2 == 0:
+ # continue
+ for name, param in self.block[0].named_parameters():
self.load_weights(param, f"block.{layer_idx}.{name}")
past_kv_for_layer = kv_cache[layer_idx]
- x, new_kv_for_layer = block(x, past_kv_for_layer, start_pos)
+ x, new_kv_for_layer = self.block[0](x, past_kv_for_layer, start_pos)
kv_cache[layer_idx] = new_kv_for_layer
- del block
else:
for layer_idx in range(self.config.num_hidden_layers):
past_kv_for_layer = kv_cache[layer_idx]
x, new_kv_for_layer = self.block[layer_idx](x, past_kv_for_layer, start_pos)
kv_cache[layer_idx] = new_kv_for_layer
- x = self.norm(x)
-
if self.lazy_load:
- unembedding = torch.nn.Linear(
- self.config.hidden_size,
- self.config.vocab_size,
- bias=False,
- device=self.device,
- dtype=torch.bfloat16,
- )
- for name, param in unembedding.named_parameters():
- self.load_weights(param, f"unembedding.{name}")
- x = unembedding(x)
- del unembedding
+ norm = RMSNorm(self.config.hidden_size, device=self.device)
+ for name, param in norm.named_parameters():
+ self.load_weights(param, f"norm.{name}")
+ x = norm(x)
+ del norm
+ else:
+ x = self.norm(x)
+
+ if self.lazy_load_unembedding:
+ if not self.loaded[0] or self.extreme_low_memory:
+ self.unembedding = torch.nn.Linear(
+ self.config.hidden_size,
+ self.config.vocab_size,
+ bias=False,
+ device=self.device,
+ dtype=torch.bfloat16,
+ )
+ for name, param in self.unembedding.named_parameters():
+ self.load_weights(param, f"unembedding.{name}")
+ self.loaded[0] = True
+ x = self.unembedding(x)
+ if self.extreme_low_memory:
+ self.unembedding = None
+ torch.cuda.empty_cache()
else:
x = self.unembedding(x)
return x, kv_cache
@@ -572,7 +674,7 @@ def load_weights(self, param, name):
@staticmethod
def from_checkpoint(
- path: str, device: str | torch.device = "cuda"
+ path: str, device: str | torch.device = "cuda", lazy_load: bool = True
) -> "Transformer":
if not isinstance(device, torch.device):
device = torch.device(device)
@@ -586,41 +688,41 @@ def from_checkpoint(
config=config,
device=device,
)
- model.eval()
-
- # Load weights
- my_rank = dist.get_rank() if dist.is_initialized() else 0
- world_size = dist.get_world_size() if dist.is_initialized() else 1
- per_rank_intermediate_size = config.intermediate_size // world_size
+ if not lazy_load:
+ model.eval()
global checkpoint
- checkpoint = Checkpoint(path, device)
-
- for name, param in model.named_parameters():
- loaded_tensor = checkpoint.get(name)
-
- # Note: it would be more efficient to do sharding before upcasting from MXFP4,
- # but for simplicity we do it after.
- if "mlp1" in name: # both weight and bias
- loaded_tensor = loaded_tensor[
- :,
- my_rank * 2
- * per_rank_intermediate_size : (my_rank + 1) * 2
- * per_rank_intermediate_size,
- ...,
- ]
- elif "mlp2_weight" in name: # only weight
- loaded_tensor = loaded_tensor[
- ...,
- my_rank
- * per_rank_intermediate_size : (my_rank + 1)
- * per_rank_intermediate_size,
- ]
- try:
- param.data.copy_(loaded_tensor)
- except:
- print(f"{name=} {param.data.shape=} {loaded_tensor.shape=}")
- raise
+ checkpoint = Checkpoint(path, device, True)
+
+ if not lazy_load:
+ # Load weights
+ my_rank = dist.get_rank() if dist.is_initialized() else 0
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
+ per_rank_intermediate_size = config.intermediate_size // world_size
+ for name, param in model.named_parameters():
+ loaded_tensor = checkpoint.get(name)
+ # Note: it would be more efficient to do sharding before upcasting from MXFP4,
+ # but for simplicity we do it after.
+ if "mlp1" in name: # both weight and bias
+ loaded_tensor = loaded_tensor[
+ :,
+ my_rank * 2
+ * per_rank_intermediate_size : (my_rank + 1) * 2
+ * per_rank_intermediate_size,
+ ...,
+ ]
+ elif "mlp2_weight" in name: # only weight
+ loaded_tensor = loaded_tensor[
+ ...,
+ my_rank
+ * per_rank_intermediate_size : (my_rank + 1)
+ * per_rank_intermediate_size,
+ ]
+ try:
+ param.data.copy_(loaded_tensor)
+ except:
+ print(f"{name=} {param.data.shape=} {loaded_tensor.shape=}")
+ raise
return model
diff --git a/gpt_oss/torch/weights.py b/gpt_oss/torch/weights.py
index 811c0a2c..3fadfb4b 100644
--- a/gpt_oss/torch/weights.py
+++ b/gpt_oss/torch/weights.py
@@ -1,6 +1,8 @@
import math
import os
-
+import functools
+import collections
+import threading
import torch
from safetensors import safe_open
@@ -30,14 +32,89 @@
}
+class TensorLRUCache:
+ """
+ A thread-safe, memory-aware LRU cache decorator with an individual tensor size limit.
+
+ Args:
+ max_memory_gb (float): The maximum total memory in GB to be used by the cache.
+ max_individual_size_gb (float, optional): The maximum size in GB for any single
+ tensor to be considered for caching. If a tensor is larger than this, it will
+ be bypassed and not cached. Defaults to None (no individual limit).
+ """
+
+ def __init__(self, max_memory_gb: float, max_individual_size_gb: float | None = None, min_individual_size_gb: float | None = None):
+ self.max_memory_bytes = int(max_memory_gb * (1024 ** 3))
+ if max_individual_size_gb is not None:
+ self.max_individual_size_bytes = int(max_individual_size_gb * (1024 ** 3))
+ else:
+ self.max_individual_size_bytes = None
+ if min_individual_size_gb is not None:
+ self.min_individual_size_bytes = int(min_individual_size_gb * (1024 ** 3))
+ else:
+ self.min_individual_size_bytes = None
+
+ self.cache = collections.OrderedDict()
+ self.current_size_bytes = 0
+ self.lock = threading.Lock()
+
+ def __call__(self, func):
+ @functools.wraps(func)
+ def wrapped_func(*args, **kwargs):
+ try:
+ key = (args, frozenset(kwargs.items()))
+ except TypeError:
+ return func(*args, **kwargs)
+
+ # First, check if the item is already in the cache (fast path)
+ with self.lock:
+ if key in self.cache:
+ self.cache.move_to_end(key)
+ return self.cache[key]
+
+ # Item not in cache, so we must call the expensive function
+ result = func(*args, **kwargs)
+
+ if not isinstance(result, torch.Tensor):
+ return result
+
+ tensor_size = result.nbytes
+
+ # --- Bypass cache for oversized/small tensors ---
+ if self.max_individual_size_bytes is not None and tensor_size > self.max_individual_size_bytes or self.min_individual_size_bytes > tensor_size:
+ #print(tensor_size)
+ #print(*args)
+ return result
+
+ # If the tensor itself is larger than the *total* cache capacity, don't cache
+ if tensor_size > self.max_memory_bytes:
+ return result
+
+ # Evict items until there's space for the new tensor
+ with self.lock:
+ while self.current_size_bytes + tensor_size > self.max_memory_bytes:
+ evicted_key, evicted_tensor = self.cache.popitem(last=False)
+ self.current_size_bytes -= evicted_tensor.nbytes
+
+ # Add the new item
+ self.cache[key] = result
+ self.current_size_bytes += tensor_size
+
+ return result
+
+ return wrapped_func
+
+
class Checkpoint:
- def __init__(self, path: str, device: torch.device):
+ def __init__(self, path: str, device: torch.device, pin_memory_for_faster_cpu_gpu_transfer: bool = False):
device_str = (
device.type
if device.index is None
else device.type + ":" + str(device.index)
)
self.device_str = device_str
+ self.pin_memory = pin_memory_for_faster_cpu_gpu_transfer # use gpu shared memory 2.5Gb for gpt-oss-20b 5% faster
+ self.lut_buffer = False # use additional gpu shared memory 1Gb for gpt-oss-20b, 1-2% faster
self.file_handles = {}
# Read from all files ending with .safetensors in the checkpoint directory
@@ -55,9 +132,11 @@ def __init__(self, path: str, device: torch.device):
tensor_name_to_file[key] = safetensor_file
self.tensor_name_to_file = tensor_name_to_file
- self._lut = torch.tensor(FP4_VALUES, dtype=torch.bfloat16, device=device)
+ if self.lut_buffer:
+ self._lut = torch.tensor(FP4_VALUES, dtype=torch.bfloat16, device=device)
@profile
+ #@TensorLRUCache(max_memory_gb=3.6, max_individual_size_gb=0.2, min_individual_size_gb=0.03) # (3.6, 0.2, 0.03) cache for mlp1 use additional memory but 5% faster, use it only for gpt_oss.generate (Max vram used: 7.6Gb)
def _get_tensor(self, name: str) -> torch.Tensor:
assert name in self.tensor_name_to_file, f"Tensor {name} not found."
file_key = self.tensor_name_to_file[name]
@@ -65,7 +144,7 @@ def _get_tensor(self, name: str) -> torch.Tensor:
# Get tensor and pin memory for faster GPU transfers
tensor = handle.get_tensor(name)
- if self.device_str.startswith('cuda'):
+ if self.pin_memory and self.device_str.startswith('cuda'):
tensor = tensor.pin_memory()
return tensor.to(self.device_str, non_blocking=True)
@@ -74,6 +153,7 @@ def __del__(self):
for handle in self.file_handles.values():
pass
+ @profile
def get(self, name: str) -> torch.Tensor:
match PARAM_NAME_MAP.get(name, name):
case (blocks_name, scales_name):
@@ -90,7 +170,7 @@ def _get_mxfp4_tensor(
scales_name: str,
*,
dtype: torch.dtype = torch.bfloat16,
- rows_per_chunk: int = 16384 * 512,
+ rows_per_chunk: int = 16384 * 64,
) -> torch.Tensor:
assert blocks_name in self.tensor_name_to_file, (
f"Blocks tensor {blocks_name} not found in checkpoint."
@@ -106,7 +186,10 @@ def _get_mxfp4_tensor(
f"{blocks.shape=} does not match {scales.shape=}"
)
- lut = self._lut
+ if self.lut_buffer:
+ lut = self._lut
+ else:
+ lut = torch.tensor(FP4_VALUES, dtype=torch.bfloat16, device=blocks.device)
*prefix_shape, G, B = blocks.shape
rows_total = math.prod(prefix_shape) * G
@@ -118,19 +201,10 @@ def _get_mxfp4_tensor(
for r0 in range(0, rows_total, rows_per_chunk):
r1 = min(r0 + rows_per_chunk, rows_total)
-
- blk = blocks[r0:r1]
- exp = scales[r0:r1]
-
- idx_lo = torch.bitwise_and(blk, 0x0F).long()
- idx_hi = torch.bitwise_right_shift(blk, 4).long()
-
sub = out[r0:r1]
- sub[:, 0::2] = lut[idx_lo]
- sub[:, 1::2] = lut[idx_hi]
-
- torch.ldexp(sub, exp, out=sub)
- del idx_lo, idx_hi, blk, exp
+ sub[:, 0::2] = lut[torch.bitwise_and(blocks[r0:r1], 0x0F).long()]
+ sub[:, 1::2] = lut[torch.bitwise_right_shift(blocks[r0:r1], 4).long()]
+ torch.ldexp_(sub, scales[r0:r1])
return out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
From 3962a3fb8036e1349a770b8092111eb46bf99786 Mon Sep 17 00:00:00 2001
From: nalexand <35492736+nalexand@users.noreply.github.com>
Date: Sun, 31 Aug 2025 21:39:30 +0200
Subject: [PATCH 7/8] Update README.md
---
README.md | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/README.md b/README.md
index bba7da9f..1c71c0b8 100644
--- a/README.md
+++ b/README.md
@@ -26,6 +26,15 @@ Watch GPT-OSS 20B running on just 8GB of VRAM:
*Click the image to watch the full demonstration* or -
[Watch on YouTube](https://youtu.be/0g7MBALZM8c)
+
+UPDATE: 08/31/2025 - Added support 6 Gb VRAM for gpt-oss-20b !!!
+
+- Optimized MLPBlock
+- gpt_oss.generate min 6 Gb VRAM
+- gpt_oss.chat min 8 Gb VRAM
+- gpt_oss.chat windows support with pyreadline3 module
+- auto tune options for awailable VRAM
+
__________________________________________
Welcome to the gpt-oss series, [OpenAI's open-weight models](https://openai.com/open-models/) designed for powerful reasoning, agentic tasks, and versatile developer use cases.
From 54436c3ee3e1a4f2f70b425c42012a91a51465e2 Mon Sep 17 00:00:00 2001
From: koshe
Date: Sun, 31 Aug 2025 22:20:41 +0200
Subject: [PATCH 8/8] Add options for control memory usage
---
gpt_oss/chat.py | 8 +++-----
gpt_oss/generate.py | 2 +-
gpt_oss/torch/model.py | 36 +++++++++++++++++++-----------------
3 files changed, 23 insertions(+), 23 deletions(-)
diff --git a/gpt_oss/chat.py b/gpt_oss/chat.py
index 37bd226f..27bf1190 100644
--- a/gpt_oss/chat.py
+++ b/gpt_oss/chat.py
@@ -9,8 +9,6 @@
import os
from pathlib import Path
-os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False"
-
try:
import gnureadline as readline
except ImportError:
@@ -74,7 +72,7 @@ def main(args):
from gpt_oss.torch.model import TokenGenerator as TorchGenerator
from gpt_oss.torch.utils import init_distributed
device = init_distributed()
- generator = TorchGenerator(args.checkpoint, device)
+ generator = TorchGenerator(args.checkpoint, device, pin_memory=False)
case "vllm":
from gpt_oss.vllm.token_generator import TokenGenerator as VLLMGenerator
generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=2)
@@ -85,8 +83,8 @@ def main(args):
system_message_content = (
SystemContent.new()
- #.with_reasoning_effort(REASONING_EFFORT[args.reasoning_effort])
- #.with_conversation_start_date(datetime.datetime.now().strftime("%Y-%m-%d"))
+ .with_reasoning_effort(REASONING_EFFORT[args.reasoning_effort])
+ .with_conversation_start_date(datetime.datetime.now().strftime("%Y-%m-%d"))
)
if args.browser:
diff --git a/gpt_oss/generate.py b/gpt_oss/generate.py
index 213444ce..e21a144e 100644
--- a/gpt_oss/generate.py
+++ b/gpt_oss/generate.py
@@ -24,7 +24,7 @@ def main(args):
from gpt_oss.torch.utils import init_distributed
from gpt_oss.torch.model import TokenGenerator as TorchGenerator
device = init_distributed()
- generator = TorchGenerator(args.checkpoint, device=device)
+ generator = TorchGenerator(args.checkpoint, device=device, pin_memory=True)
case "triton":
from gpt_oss.torch.utils import init_distributed
from gpt_oss.triton.model import TokenGenerator as TritonGenerator
diff --git a/gpt_oss/torch/model.py b/gpt_oss/torch/model.py
index b519275d..64daab36 100644
--- a/gpt_oss/torch/model.py
+++ b/gpt_oss/torch/model.py
@@ -534,19 +534,17 @@ def __init__(
config: ModelConfig,
device: torch.device | None = None,
lazy_load: bool = True,
- extreme_low_memory: bool = False,
+ extremly_low_memory: bool = False,
):
super().__init__()
free_mem = get_free_gpu_memory_gb()
if free_mem < 6:
lazy_load = True
- extreme_low_memory = True
+ extremly_low_memory = True
elif free_mem < 8:
lazy_load = True
self.lazy_load = lazy_load
- self.lazy_load_embedding = lazy_load
- self.lazy_load_unembedding = lazy_load
- self.extreme_low_memory = extreme_low_memory
+ self.extremly_low_memory = extremly_low_memory
self.config = config
self.device = device
self.loaded = {0: False, 1: False, 2: False}
@@ -562,11 +560,9 @@ def __init__(
]
)
self.norm = RMSNorm(config.hidden_size, device=device)
- if not self.lazy_load_embedding:
self.embedding = torch.nn.Embedding(
config.vocab_size, config.hidden_size, device=device, dtype=torch.bfloat16
)
- if not self.lazy_load_unembedding:
self.unembedding = torch.nn.Linear(
config.hidden_size,
config.vocab_size,
@@ -582,8 +578,8 @@ def forward(self,
start_pos: int = 0,
) -> tuple[
torch.Tensor, list[tuple[torch.Tensor, torch.Tensor] | None]]:
- if self.lazy_load_embedding:
- if not self.loaded[1] or self.extreme_low_memory:
+ if self.lazy_load:
+ if not self.loaded[1] or self.extremly_low_memory:
self.embedding = torch.nn.Embedding(
self.config.vocab_size, self.config.hidden_size, device=self.device, dtype=torch.bfloat16
)
@@ -591,7 +587,7 @@ def forward(self,
self.load_weights(param, f"embedding.{name}")
self.loaded[1] = True
x = self.embedding(x)
- if self.extreme_low_memory:
+ if self.extremly_low_memory:
self.embedding = None
torch.cuda.empty_cache()
else:
@@ -625,8 +621,8 @@ def forward(self,
else:
x = self.norm(x)
- if self.lazy_load_unembedding:
- if not self.loaded[0] or self.extreme_low_memory:
+ if self.lazy_load:
+ if not self.loaded[0] or self.extremly_low_memory:
self.unembedding = torch.nn.Linear(
self.config.hidden_size,
self.config.vocab_size,
@@ -638,7 +634,7 @@ def forward(self,
self.load_weights(param, f"unembedding.{name}")
self.loaded[0] = True
x = self.unembedding(x)
- if self.extreme_low_memory:
+ if self.extremly_low_memory:
self.unembedding = None
torch.cuda.empty_cache()
else:
@@ -674,7 +670,7 @@ def load_weights(self, param, name):
@staticmethod
def from_checkpoint(
- path: str, device: str | torch.device = "cuda", lazy_load: bool = True
+ path: str, device: str | torch.device = "cuda", lazy_load: bool = True, pin_memory: bool = False
) -> "Transformer":
if not isinstance(device, torch.device):
device = torch.device(device)
@@ -684,15 +680,20 @@ def from_checkpoint(
json_config = json.load(f)
config = ModelConfig(**json_config)
+ extremly_low_memory = False
+ if not pin_memory:
+ extremly_low_memory = True
+
model = Transformer(
config=config,
device=device,
+ extremly_low_memory=extremly_low_memory,
)
if not lazy_load:
model.eval()
global checkpoint
- checkpoint = Checkpoint(path, device, True)
+ checkpoint = Checkpoint(path, device, pin_memory)
if not lazy_load:
# Load weights
@@ -729,9 +730,10 @@ def from_checkpoint(
class TokenGenerator:
@torch.inference_mode()
- def __init__(self, checkpoint: str, device: torch.device):
+ def __init__(self, checkpoint: str, device: torch.device, pin_memory: bool = False):
self.device = device
- self.model = Transformer.from_checkpoint(checkpoint, device=self.device)
+ self.model = Transformer.from_checkpoint(
+ checkpoint, device=self.device, pin_memory=pin_memory)
@torch.inference_mode()
def generate(self,