From 05de4a1151c2eb8ca2d2c655de64cea0fa580b78 Mon Sep 17 00:00:00 2001 From: koshe Date: Sat, 30 Aug 2025 04:50:39 +0200 Subject: [PATCH 1/8] Lazy load for weights, KV cache, optimized inference, !!! 8Gb VRAM !!! --- README.md | 20 +++ gpt_oss/generate.py | 10 ++ gpt_oss/torch/model.py | 328 ++++++++++++++++++++++++++++++++++----- gpt_oss/torch/utils.py | 2 +- gpt_oss/torch/weights.py | 46 ++++-- 5 files changed, 354 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index 4ef20827..dd28c455 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,15 @@


+

GPT-OSS-20B optimized to run on 8GB VRAM

+ + +__________________________________________ Welcome to the gpt-oss series, [OpenAI's open-weight models](https://openai.com/open-models/) designed for powerful reasoning, agentic tasks, and versatile developer use cases. @@ -199,6 +208,17 @@ And then run: torchrun --nproc-per-node=4 -m gpt_oss.generate gpt-oss-120b/original/ ``` +# Windows run example +```shell +python -m gpt_oss.generate --backend torch gpt-oss-20b/original/ -p "Hi" -l 10 + +``` +#with profiler +```shell +kernprof -l -v -m gpt_oss.generate --backend torch gpt-oss-20b/original/ -p "Hi" -l 10 +``` + + ## Reference Triton implementation (single GPU) We also include an optimized reference implementation that uses [an optimized triton MoE kernel](https://github.com/triton-lang/triton/tree/main/python/triton_kernels/triton_kernels) that supports MXFP4. It also has some optimization on the attention code to reduce the memory cost. To run this implementation, the nightly version of triton and torch will be installed. This version can be run on a single 80GB GPU for `gpt-oss-120b`. diff --git a/gpt_oss/generate.py b/gpt_oss/generate.py index c0755805..213444ce 100644 --- a/gpt_oss/generate.py +++ b/gpt_oss/generate.py @@ -7,7 +7,17 @@ from gpt_oss.tokenizer import get_tokenizer +from line_profiler import profile +try: + profile # type: ignore +except NameError: + profile = lambda f: f + +import os +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False" + +@profile def main(args): match args.backend: case "torch": diff --git a/gpt_oss/torch/model.py b/gpt_oss/torch/model.py index 9180d493..123cbcb6 100644 --- a/gpt_oss/torch/model.py +++ b/gpt_oss/torch/model.py @@ -8,6 +8,11 @@ from gpt_oss.torch.weights import Checkpoint +try: + profile # type: ignore +except NameError: + profile = lambda f: f + @dataclass class ModelConfig: @@ -59,6 +64,23 @@ def _apply_rotary_emb( o2 = x2 * cos + x1 * sin return torch.cat((o1, o2), dim=-1) +def _apply_rotary_emb_optimized( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> torch.Tensor: + # Prepare cos/sin with correct shape and dtype outside this function + # Use direct indexing instead of chunk for better memory efficiency + half_dim = x.size(-1) // 2 + x1, x2 = x[..., :half_dim], x[..., half_dim:] + + # Compute in-place where possible + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + + # Use view instead of cat to avoid memory allocation + return torch.cat([o1, o2], dim=-1) # Still need cat, but optimized other parts + class RotaryEmbedding(torch.nn.Module): def __init__( @@ -122,31 +144,36 @@ def _compute_concentration_and_inv_freq(self) -> torch.Tensor: return concentration, inv_freq - def _compute_cos_sin(self, num_tokens: int): + def _compute_cos_sin(self, num_tokens: int, start_pos: int = 0): concentration, inv_freq = self._compute_concentration_and_inv_freq() - t = torch.arange(num_tokens, dtype=torch.float32, device=self.device) + t = torch.arange(start_pos, start_pos + num_tokens, dtype=torch.float32, device=self.device) freqs = torch.einsum("i,j->ij", t, inv_freq) cos = freqs.cos() * concentration sin = freqs.sin() * concentration return cos, sin + @profile def forward( - self, - query: torch.Tensor, - key: torch.Tensor, + self, + query: torch.Tensor, + key: torch.Tensor, + start_pos: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: num_tokens = query.shape[0] - cos, sin = self._compute_cos_sin(num_tokens) + cos, sin = self._compute_cos_sin(num_tokens, start_pos=start_pos) + cos = cos.unsqueeze(-2).to(query.dtype) + sin = sin.unsqueeze(-2).to(query.dtype) query_shape = query.shape query = query.view(num_tokens, -1, self.head_dim) - query = _apply_rotary_emb(query, cos, sin) + query = _apply_rotary_emb_optimized(query, cos, sin) query = query.reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_dim) - key = _apply_rotary_emb(key, cos, sin) + key = _apply_rotary_emb_optimized(key, cos, sin) key = key.reshape(key_shape) + return query, key @@ -214,7 +241,7 @@ def __init__( device=device, ) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward_(self, x: torch.Tensor) -> torch.Tensor: t = self.norm(x) qkv = self.qkv(t) q = qkv[:, : self.num_attention_heads * self.head_dim].contiguous() @@ -245,6 +272,132 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: t = x + t return t + @profile + def forward_2(self, + x: torch.Tensor, + past_kv: tuple[torch.Tensor, torch.Tensor] | None = None, + start_pos: int = 0, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + t = self.norm(x) + qkv = self.qkv(t) + + q = qkv[:, : self.num_attention_heads * self.head_dim] + k = qkv[:, self.num_attention_heads * self.head_dim: ( + self.num_attention_heads + self.num_key_value_heads) * self.head_dim] + v = qkv[:, (self.num_attention_heads + self.num_key_value_heads) * self.head_dim:] + + q = q.view(-1, self.num_attention_heads, self.head_dim) + k = k.view(-1, self.num_key_value_heads, self.head_dim) + v = v.view(-1, self.num_key_value_heads, self.head_dim) + + q, k = self.rope(q, k, start_pos=start_pos) + + if past_kv is not None: + past_k, past_v = past_kv + k = torch.cat([past_k, k], dim=0) + v = torch.cat([past_v, v], dim=0) + + new_kv = (k, v) + + num_groups = self.num_attention_heads // self.num_key_value_heads + k_expanded = k.repeat_interleave(num_groups, dim=1) + v_expanded = v.repeat_interleave(num_groups, dim=1) + + query_len, num_heads, head_dim = q.shape + key_len = k_expanded.shape[0] + + QK = torch.einsum("qhd,khd->hqk", q, k_expanded) + QK *= self.sm_scale + + all_indices = torch.arange(key_len, device=x.device) + query_indices = torch.arange(start_pos, start_pos + query_len, device=x.device) + mask = query_indices[:, None] < all_indices[None, :] + QK = QK.masked_fill(mask, -torch.inf) + + if self.sliding_window > 0: + sliding_mask = query_indices[:, None] < (all_indices[None, :] + self.sliding_window) + QK = QK.masked_fill(~sliding_mask, -torch.inf) + + S = self.sinks.view(num_heads, 1, 1).expand(-1, query_len, -1) + QK = torch.cat([QK, S], dim=-1) + + W = torch.softmax(QK, dim=-1) + W = W[..., :-1] + + attn = torch.einsum("hqk,khd->qhd", W, v_expanded) + + t = attn.reshape(-1, self.num_attention_heads * self.head_dim) + t = self.out(t) + + return t, new_kv + + @profile + def forward(self, + x: torch.Tensor, + past_kv: tuple[torch.Tensor, torch.Tensor] | None = None, + start_pos: int = 0, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + t = self.norm(x) + qkv = self.qkv(t) + + q_end = self.num_attention_heads * self.head_dim + k_end = q_end + self.num_key_value_heads * self.head_dim + q = qkv[:, :q_end] + k = qkv[:, q_end:k_end] + v = qkv[:, k_end:] + + q = q.view(-1, self.num_attention_heads, self.head_dim) + k = k.view(-1, self.num_key_value_heads, self.head_dim) + v = v.view(-1, self.num_key_value_heads, self.head_dim) + + q, k = self.rope(q, k, start_pos=start_pos) + + if past_kv is not None: + past_k, past_v = past_kv + k = torch.cat([past_k, k], dim=0) + v = torch.cat([past_v, v], dim=0) + + new_kv = (k, v) + num_groups = self.num_attention_heads // self.num_key_value_heads + + k_expanded = k.repeat_interleave(num_groups, dim=1) + v_expanded = v.repeat_interleave(num_groups, dim=1) + + query_len, num_heads, head_dim = q.shape + key_len = k_expanded.shape[0] + + q_permuted = q.permute(1, 0, 2) + k_permuted = k_expanded.permute(1, 0, 2) + QK = torch.bmm(q_permuted, k_permuted.transpose(1, 2)) + QK *= self.sm_scale + + all_indices = torch.arange(key_len, device=x.device) + query_indices = torch.arange(start_pos, start_pos + query_len, device=x.device) + + causal_mask = query_indices[:, None] < all_indices[None, :] + + mask = causal_mask + + if self.sliding_window > 0: + sliding_mask = query_indices[:, None] < (all_indices[None, :] + self.sliding_window) + mask = mask | ~sliding_mask + + QK = QK.masked_fill(mask, -torch.inf) + + S = self.sinks.view(num_heads, 1, 1).expand(-1, query_len, -1) + QK = torch.cat([QK, S], dim=-1) + + W = torch.softmax(QK, dim=-1) + W = W[..., :-1] + + v_permuted = v_expanded.permute(1, 0, 2) + attn = torch.bmm(W, v_permuted) + attn = attn.permute(1, 0, 2) + + t = attn.reshape(-1, self.num_attention_heads * self.head_dim) + t = self.out(t) + + return t, new_kv def swiglu(x, alpha: float = 1.702, limit: float = 7.0): x_glu, x_linear = x[..., ::2], x[..., 1::2] @@ -348,44 +501,133 @@ def __init__( self.attn = AttentionBlock(config, layer_idx, device) self.mlp = MLPBlock(config, device) - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.attn(x) + @profile + def forward(self, + x: torch.Tensor, + past_kv: tuple[torch.Tensor, torch.Tensor] | None = None, + start_pos: int = 0, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + attn_output, new_kv = self.attn(x, past_kv, start_pos) + x = x + attn_output x = self.mlp(x) - return x + return x, new_kv +checkpoint = None class Transformer(torch.nn.Module): def __init__( self, config: ModelConfig, device: torch.device | None = None, + lazy_load: bool = True, ): super().__init__() - self.embedding = torch.nn.Embedding( - config.vocab_size, config.hidden_size, device=device, dtype=torch.bfloat16 - ) - self.block = torch.nn.ModuleList( - [ - TransformerBlock(config, layer_idx, device) - for layer_idx in range(config.num_hidden_layers) - ] - ) + # lazy load layers + self.lazy_load = lazy_load + if not self.lazy_load: + self.embedding = torch.nn.Embedding( + config.vocab_size, config.hidden_size, device=device, dtype=torch.bfloat16 + ) + self.config = config + self.device = device + + if not self.lazy_load: + self.block = torch.nn.ModuleList( + [ + TransformerBlock(config, layer_idx, device) + for layer_idx in range(config.num_hidden_layers) + ] + ) self.norm = RMSNorm(config.hidden_size, device=device) - self.unembedding = torch.nn.Linear( - config.hidden_size, - config.vocab_size, - bias=False, - device=device, - dtype=torch.bfloat16, - ) + if not self.lazy_load: + self.unembedding = torch.nn.Linear( + config.hidden_size, + config.vocab_size, + bias=False, + device=device, + dtype=torch.bfloat16, + ) + + @profile + def forward(self, + x: torch.Tensor, + kv_cache: list[tuple[torch.Tensor, torch.Tensor] | None] | None = None, + start_pos: int = 0, + ) -> tuple[ + torch.Tensor, list[tuple[torch.Tensor, torch.Tensor] | None]]: + if self.lazy_load: + embedding = torch.nn.Embedding( + self.config.vocab_size, self.config.hidden_size, device=self.device, dtype=torch.bfloat16 + ) + for name, param in embedding.named_parameters(): + self.load_weights(param, f"embedding.{name}") + x = embedding(x) + del embedding + else: + x = self.embedding(x) + + if kv_cache is None: + kv_cache = [None] * self.config.num_hidden_layers + + if self.lazy_load: + for layer_idx in range(self.config.num_hidden_layers): + block = TransformerBlock(self.config, layer_idx, self.device) + for name, param in block.named_parameters(): + self.load_weights(param, f"block.{layer_idx}.{name}") + past_kv_for_layer = kv_cache[layer_idx] + x, new_kv_for_layer = block(x, past_kv_for_layer, start_pos) + kv_cache[layer_idx] = new_kv_for_layer + del block + else: + for layer_idx in range(self.config.num_hidden_layers): + past_kv_for_layer = kv_cache[layer_idx] + x, new_kv_for_layer = self.block[layer_idx](x, past_kv_for_layer, start_pos) + kv_cache[layer_idx] = new_kv_for_layer - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.embedding(x) - for block in self.block: - x = block(x) x = self.norm(x) - x = self.unembedding(x) - return x + + if self.lazy_load: + unembedding = torch.nn.Linear( + self.config.hidden_size, + self.config.vocab_size, + bias=False, + device=self.device, + dtype=torch.bfloat16, + ) + for name, param in unembedding.named_parameters(): + self.load_weights(param, f"unembedding.{name}") + x = unembedding(x) + del unembedding + else: + x = self.unembedding(x) + return x, kv_cache + + def load_weights(self, param, name): + global checkpoint + my_rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + per_rank_intermediate_size = self.config.intermediate_size // world_size + loaded_tensor = checkpoint.get(name) + if "mlp1" in name: + loaded_tensor = loaded_tensor[ + :, + my_rank * 2 + * per_rank_intermediate_size: (my_rank + 1) * 2 + * per_rank_intermediate_size, + ..., + ] + elif "mlp2_weight" in name: + loaded_tensor = loaded_tensor[ + ..., + my_rank + * per_rank_intermediate_size: (my_rank + 1) + * per_rank_intermediate_size, + ] + try: + param.data.copy_(loaded_tensor) + except: + print(f"{name=} {param.data.shape=} {loaded_tensor.shape=}") + raise @staticmethod def from_checkpoint( @@ -410,6 +652,7 @@ def from_checkpoint( world_size = dist.get_world_size() if dist.is_initialized() else 1 per_rank_intermediate_size = config.intermediate_size // world_size + global checkpoint checkpoint = Checkpoint(path, device) for name, param in model.named_parameters(): @@ -455,16 +698,20 @@ def generate(self, max_tokens: int = 0, return_logprobs: bool = False): tokens = list(prompt_tokens) + num_prompt_tokens = len(tokens) num_generated_tokens = 0 + + kv_cache = None + prompt_tensor = torch.as_tensor(tokens, dtype=torch.int32, device=self.device) + logits, kv_cache = self.model(prompt_tensor, kv_cache=None, start_pos=0) + logits = logits[-1] + while max_tokens == 0 or num_generated_tokens < max_tokens: - logits = self.model(torch.as_tensor(tokens, dtype=torch.int32, device=self.device))[-1] if temperature == 0.0: predicted_token = torch.argmax(logits, dim=-1).item() else: probs = torch.softmax(logits * (1.0 / temperature), dim=-1) predicted_token = torch.multinomial(probs, num_samples=1).item() - tokens.append(predicted_token) - num_generated_tokens += 1 if return_logprobs: logprobs = torch.log_softmax(logits, dim=-1) @@ -475,3 +722,10 @@ def generate(self, if predicted_token in stop_tokens: break + + tokens.append(predicted_token) + num_generated_tokens += 1 + next_token_tensor = torch.as_tensor([predicted_token], dtype=torch.int32, device=self.device) + start_pos = num_prompt_tokens + num_generated_tokens - 1 + logits, kv_cache = self.model(next_token_tensor, kv_cache=kv_cache, start_pos=start_pos) + logits = logits[0] diff --git a/gpt_oss/torch/utils.py b/gpt_oss/torch/utils.py index ce87a85d..a5bc5196 100644 --- a/gpt_oss/torch/utils.py +++ b/gpt_oss/torch/utils.py @@ -36,5 +36,5 @@ def init_distributed() -> torch.device: dist.all_reduce(x) torch.cuda.synchronize(device) - suppress_output(rank) + # suppress_output(rank) return device diff --git a/gpt_oss/torch/weights.py b/gpt_oss/torch/weights.py index aa5df58a..811c0a2c 100644 --- a/gpt_oss/torch/weights.py +++ b/gpt_oss/torch/weights.py @@ -4,6 +4,11 @@ import torch from safetensors import safe_open +try: + profile # type: ignore +except NameError: + profile = lambda f: f + # Bytes per MXFP4 block: 32 FP4 numbers packed in 16 bytes BYTES_PER_BLOCK = 16 @@ -33,6 +38,7 @@ def __init__(self, path: str, device: torch.device): else device.type + ":" + str(device.index) ) self.device_str = device_str + self.file_handles = {} # Read from all files ending with .safetensors in the checkpoint directory safetensor_files = [ @@ -43,11 +49,30 @@ def __init__(self, path: str, device: torch.device): # Build a mapping from tensor name to (file, key) tensor_name_to_file = {} for safetensor_file in safetensor_files: - with safe_open(safetensor_file, framework="pt", device=device_str) as f: - for key in f.keys(): - tensor_name_to_file[key] = safetensor_file + handle = safe_open(safetensor_file, framework="pt", device='cpu') + self.file_handles[safetensor_file] = handle + for key in handle.keys(): + tensor_name_to_file[key] = safetensor_file self.tensor_name_to_file = tensor_name_to_file + self._lut = torch.tensor(FP4_VALUES, dtype=torch.bfloat16, device=device) + + @profile + def _get_tensor(self, name: str) -> torch.Tensor: + assert name in self.tensor_name_to_file, f"Tensor {name} not found." + file_key = self.tensor_name_to_file[name] + handle = self.file_handles[file_key] + + # Get tensor and pin memory for faster GPU transfers + tensor = handle.get_tensor(name) + if self.device_str.startswith('cuda'): + tensor = tensor.pin_memory() + + return tensor.to(self.device_str, non_blocking=True) + + def __del__(self): + for handle in self.file_handles.values(): + pass def get(self, name: str) -> torch.Tensor: match PARAM_NAME_MAP.get(name, name): @@ -58,13 +83,7 @@ def get(self, name: str) -> torch.Tensor: # MoE biases and other weights return self._get_tensor(tensor_name) - def _get_tensor(self, name: str) -> str: - assert name in self.tensor_name_to_file, f"Tensor {name} not found in checkpoint." - with safe_open( - self.tensor_name_to_file[name], framework="pt", device=self.device_str - ) as f: - return f.get_tensor(name) - + @profile def _get_mxfp4_tensor( self, blocks_name: str, @@ -87,7 +106,7 @@ def _get_mxfp4_tensor( f"{blocks.shape=} does not match {scales.shape=}" ) - lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device) + lut = self._lut *prefix_shape, G, B = blocks.shape rows_total = math.prod(prefix_shape) * G @@ -103,9 +122,8 @@ def _get_mxfp4_tensor( blk = blocks[r0:r1] exp = scales[r0:r1] - # nibble indices -> int64 - idx_lo = (blk & 0x0F).to(torch.long) - idx_hi = (blk >> 4).to(torch.long) + idx_lo = torch.bitwise_and(blk, 0x0F).long() + idx_hi = torch.bitwise_right_shift(blk, 4).long() sub = out[r0:r1] sub[:, 0::2] = lut[idx_lo] From d5d58b9eaaf8397c14a44e5e27f64dd40380c908 Mon Sep 17 00:00:00 2001 From: koshe Date: Sat, 30 Aug 2025 06:01:38 +0200 Subject: [PATCH 2/8] Add video --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index dd28c455..78a24be5 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@
  • Optimized forward pass and attention
  • + __________________________________________ Welcome to the gpt-oss series, [OpenAI's open-weight models](https://openai.com/open-models/) designed for powerful reasoning, agentic tasks, and versatile developer use cases. From 96bbffb4baea8b9ebe9da67db7630183f51f2fb9 Mon Sep 17 00:00:00 2001 From: nalexand <35492736+nalexand@users.noreply.github.com> Date: Sat, 30 Aug 2025 06:19:58 +0200 Subject: [PATCH 3/8] Update README.md --- README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 78a24be5..f120f829 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,13 @@
  • Optimized forward pass and attention
  • - +## 🎥 Demo Video + +Watch GPT-OSS 20B running on just 8GB of VRAM: + +[![GPT-OSS 20B Demo](https://englyk.com/gpt-oss-20b-8gb-vram.jpg)](https://englyk.com/gpt_oss_20b_8gb_vram.mp4) + +*Click the image to watch the full demonstration* __________________________________________ Welcome to the gpt-oss series, [OpenAI's open-weight models](https://openai.com/open-models/) designed for powerful reasoning, agentic tasks, and versatile developer use cases. From 811aa4340f0fc5893159b0cb3e40cf9e3a2c159f Mon Sep 17 00:00:00 2001 From: nalexand <35492736+nalexand@users.noreply.github.com> Date: Sat, 30 Aug 2025 06:33:45 +0200 Subject: [PATCH 4/8] Update README.md --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index f120f829..bba7da9f 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,8 @@ Watch GPT-OSS 20B running on just 8GB of VRAM: [![GPT-OSS 20B Demo](https://englyk.com/gpt-oss-20b-8gb-vram.jpg)](https://englyk.com/gpt_oss_20b_8gb_vram.mp4) -*Click the image to watch the full demonstration* +*Click the image to watch the full demonstration* or - +[Watch on YouTube](https://youtu.be/0g7MBALZM8c) __________________________________________ Welcome to the gpt-oss series, [OpenAI's open-weight models](https://openai.com/open-models/) designed for powerful reasoning, agentic tasks, and versatile developer use cases. From 71697508cd849c2e759e33d997be9b4502fa486f Mon Sep 17 00:00:00 2001 From: koshe Date: Sat, 30 Aug 2025 15:25:41 +0200 Subject: [PATCH 5/8] Remove unused code --- gpt_oss/torch/model.py | 59 ------------------------------------------ 1 file changed, 59 deletions(-) diff --git a/gpt_oss/torch/model.py b/gpt_oss/torch/model.py index 123cbcb6..b1713098 100644 --- a/gpt_oss/torch/model.py +++ b/gpt_oss/torch/model.py @@ -272,65 +272,6 @@ def forward_(self, x: torch.Tensor) -> torch.Tensor: t = x + t return t - @profile - def forward_2(self, - x: torch.Tensor, - past_kv: tuple[torch.Tensor, torch.Tensor] | None = None, - start_pos: int = 0, - ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - t = self.norm(x) - qkv = self.qkv(t) - - q = qkv[:, : self.num_attention_heads * self.head_dim] - k = qkv[:, self.num_attention_heads * self.head_dim: ( - self.num_attention_heads + self.num_key_value_heads) * self.head_dim] - v = qkv[:, (self.num_attention_heads + self.num_key_value_heads) * self.head_dim:] - - q = q.view(-1, self.num_attention_heads, self.head_dim) - k = k.view(-1, self.num_key_value_heads, self.head_dim) - v = v.view(-1, self.num_key_value_heads, self.head_dim) - - q, k = self.rope(q, k, start_pos=start_pos) - - if past_kv is not None: - past_k, past_v = past_kv - k = torch.cat([past_k, k], dim=0) - v = torch.cat([past_v, v], dim=0) - - new_kv = (k, v) - - num_groups = self.num_attention_heads // self.num_key_value_heads - k_expanded = k.repeat_interleave(num_groups, dim=1) - v_expanded = v.repeat_interleave(num_groups, dim=1) - - query_len, num_heads, head_dim = q.shape - key_len = k_expanded.shape[0] - - QK = torch.einsum("qhd,khd->hqk", q, k_expanded) - QK *= self.sm_scale - - all_indices = torch.arange(key_len, device=x.device) - query_indices = torch.arange(start_pos, start_pos + query_len, device=x.device) - mask = query_indices[:, None] < all_indices[None, :] - QK = QK.masked_fill(mask, -torch.inf) - - if self.sliding_window > 0: - sliding_mask = query_indices[:, None] < (all_indices[None, :] + self.sliding_window) - QK = QK.masked_fill(~sliding_mask, -torch.inf) - - S = self.sinks.view(num_heads, 1, 1).expand(-1, query_len, -1) - QK = torch.cat([QK, S], dim=-1) - - W = torch.softmax(QK, dim=-1) - W = W[..., :-1] - - attn = torch.einsum("hqk,khd->qhd", W, v_expanded) - - t = attn.reshape(-1, self.num_attention_heads * self.head_dim) - t = self.out(t) - - return t, new_kv - @profile def forward(self, x: torch.Tensor, From 1cea0706ddf563f90bb787e3370a6544267cf3d2 Mon Sep 17 00:00:00 2001 From: koshe Date: Sun, 31 Aug 2025 21:28:39 +0200 Subject: [PATCH 6/8] Add support !!! 6 Gb VRAM !!! for gpt-oss-20b --- gpt_oss/chat.py | 15 +- gpt_oss/torch/model.py | 340 +++++++++++++++++++++++++-------------- gpt_oss/torch/weights.py | 110 ++++++++++--- 3 files changed, 324 insertions(+), 141 deletions(-) diff --git a/gpt_oss/chat.py b/gpt_oss/chat.py index 5e40079d..37bd226f 100644 --- a/gpt_oss/chat.py +++ b/gpt_oss/chat.py @@ -9,10 +9,15 @@ import os from pathlib import Path +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False" + try: import gnureadline as readline except ImportError: - import readline + try: + import readline + except ImportError: + import pyreadline3 as readline import torch import termcolor @@ -80,8 +85,8 @@ def main(args): system_message_content = ( SystemContent.new() - .with_reasoning_effort(REASONING_EFFORT[args.reasoning_effort]) - .with_conversation_start_date(datetime.datetime.now().strftime("%Y-%m-%d")) + #.with_reasoning_effort(REASONING_EFFORT[args.reasoning_effort]) + #.with_conversation_start_date(datetime.datetime.now().strftime("%Y-%m-%d")) ) if args.browser: @@ -245,7 +250,9 @@ async def run_tool(): field_created = False current_output_text = "" output_text_delta_buffer = "" - for predicted_token in generator.generate(tokens, encoding.stop_tokens_for_assistant_actions()): + for predicted_token in generator.generate(tokens, encoding.stop_tokens_for_assistant_actions(), + #temperature=0, max_tokens=10 + ): parser.process(predicted_token) if args.raw: print(encoding.decode([predicted_token]), end="", flush=True) diff --git a/gpt_oss/torch/model.py b/gpt_oss/torch/model.py index b1713098..b519275d 100644 --- a/gpt_oss/torch/model.py +++ b/gpt_oss/torch/model.py @@ -52,34 +52,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return (t * self.scale).to(dtype) -def _apply_rotary_emb( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, -) -> torch.Tensor: - cos = cos.unsqueeze(-2).to(x.dtype) - sin = sin.unsqueeze(-2).to(x.dtype) - x1, x2 = torch.chunk(x, 2, dim=-1) - o1 = x1 * cos - x2 * sin - o2 = x2 * cos + x1 * sin - return torch.cat((o1, o2), dim=-1) - def _apply_rotary_emb_optimized( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: - # Prepare cos/sin with correct shape and dtype outside this function - # Use direct indexing instead of chunk for better memory efficiency half_dim = x.size(-1) // 2 x1, x2 = x[..., :half_dim], x[..., half_dim:] - # Compute in-place where possible o1 = x1 * cos - x2 * sin o2 = x2 * cos + x1 * sin - # Use view instead of cat to avoid memory allocation - return torch.cat([o1, o2], dim=-1) # Still need cat, but optimized other parts + return torch.cat([o1, o2], dim=-1) class RotaryEmbedding(torch.nn.Module): @@ -152,7 +136,6 @@ def _compute_cos_sin(self, num_tokens: int, start_pos: int = 0): sin = freqs.sin() * concentration return cos, sin - @profile def forward( self, query: torch.Tensor, @@ -241,38 +224,6 @@ def __init__( device=device, ) - def forward_(self, x: torch.Tensor) -> torch.Tensor: - t = self.norm(x) - qkv = self.qkv(t) - q = qkv[:, : self.num_attention_heads * self.head_dim].contiguous() - k = qkv[ - :, - self.num_attention_heads - * self.head_dim : (self.num_attention_heads + self.num_key_value_heads) - * self.head_dim, - ].contiguous() - v = qkv[ - :, - (self.num_attention_heads + self.num_key_value_heads) - * self.head_dim : (self.num_attention_heads + 2 * self.num_key_value_heads) - * self.head_dim, - ].contiguous() - - q = q.view( - -1, - self.num_key_value_heads, - self.num_attention_heads // self.num_key_value_heads, - self.head_dim, - ) - k = k.view(-1, self.num_key_value_heads, self.head_dim) - v = v.view(-1, self.num_key_value_heads, self.head_dim) - q, k = self.rope(q, k) - t = sdpa(q, k, v, self.sinks, self.sm_scale, self.sliding_window) - t = self.out(t) - t = x + t - return t - - @profile def forward(self, x: torch.Tensor, past_kv: tuple[torch.Tensor, torch.Tensor] | None = None, @@ -350,7 +301,7 @@ def swiglu(x, alpha: float = 1.702, limit: float = 7.0): return out_glu * (x_linear + 1) -class MLPBlock(torch.nn.Module): +class MLPBlock_(torch.nn.Module): # memory unefficient def __init__( self, config: ModelConfig, @@ -429,6 +380,113 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x + t +class MLPBlock(torch.nn.Module): + def __init__( + self, + config: ModelConfig, + device: torch.device | None = None, + ): + super().__init__() + self.num_experts = config.num_experts + self.experts_per_token = config.experts_per_token + self.swiglu_limit = config.swiglu_limit + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.norm = RMSNorm(config.hidden_size, device=device) + self.gate = torch.nn.Linear( + config.hidden_size, config.num_experts, device=device, dtype=torch.bfloat16 + ) + assert config.intermediate_size % self.world_size == 0 + self.mlp1_weight = torch.nn.Parameter( + torch.empty( + ( + config.num_experts, + config.intermediate_size * 2 // self.world_size, + config.hidden_size, + ), + device=device, + dtype=torch.bfloat16, + ) + ) + self.mlp1_bias = torch.nn.Parameter( + torch.empty( + (config.num_experts, config.intermediate_size * 2 // self.world_size), + device=device, + dtype=torch.bfloat16, + ) + ) + self.mlp2_weight = torch.nn.Parameter( + torch.empty( + ( + config.num_experts, + self.hidden_size, + config.intermediate_size // self.world_size, + ), + device=device, + dtype=torch.bfloat16, + ) + ) + self.mlp2_bias = torch.nn.Parameter( + torch.empty( + (config.num_experts, self.hidden_size), + device=device, + dtype=torch.bfloat16, + ) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + t = self.norm(x) + g = self.gate(t) + + if t.dim() == 2: + t = t.unsqueeze(1) + g = g.unsqueeze(1) + added_seq_dim = True + else: + added_seq_dim = False + + batch_size, seq_len, hidden_size = t.shape + + t_flat = t.reshape(-1, hidden_size) + g_flat = g.reshape(-1, self.num_experts) + + experts = torch.topk(g_flat, k=self.experts_per_token, dim=-1, sorted=True) + expert_weights = torch.nn.functional.softmax(experts.values, dim=-1) + expert_indices = experts.indices + + output_flat = torch.zeros_like(t_flat) + + for k in range(self.experts_per_token): + current_expert_indices = expert_indices[:, k] + current_expert_weights = expert_weights[:, k] + + mlp1_w = self.mlp1_weight[ + current_expert_indices] + mlp1_b = self.mlp1_bias[current_expert_indices] + mlp2_w = self.mlp2_weight[ + current_expert_indices] + mlp2_b = self.mlp2_bias[current_expert_indices] + + t_k = torch.bmm(t_flat.unsqueeze(1), mlp1_w.transpose(1, 2)).squeeze(1) + mlp1_b + t_k = swiglu(t_k, limit=self.swiglu_limit) + + t_k = torch.bmm(t_k.unsqueeze(1), mlp2_w.transpose(1, 2)).squeeze(1) + + if self.world_size > 1: + dist.all_reduce(t_k, op=dist.ReduceOp.SUM) + + t_k += mlp2_b + + output_flat += t_k * current_expert_weights.unsqueeze(1) + + output = output_flat.reshape(batch_size, seq_len, hidden_size) + + if added_seq_dim: + output = output.squeeze(1) + + return x + output + class TransformerBlock(torch.nn.Module): def __init__( @@ -455,32 +513,60 @@ def forward(self, checkpoint = None +def get_free_gpu_memory_gb(device_id=0): + """Returns free GPU memory in GB for specified device (default: 0)""" + if not torch.cuda.is_available(): + return 0.0 + + props = torch.cuda.get_device_properties(device_id) + total_memory = props.total_memory + reserved = torch.cuda.memory_reserved(device_id) + + free_memory = total_memory - reserved + free_gb = free_memory / (1024 ** 3) + + return free_gb + + class Transformer(torch.nn.Module): def __init__( self, config: ModelConfig, device: torch.device | None = None, lazy_load: bool = True, + extreme_low_memory: bool = False, ): super().__init__() - # lazy load layers + free_mem = get_free_gpu_memory_gb() + if free_mem < 6: + lazy_load = True + extreme_low_memory = True + elif free_mem < 8: + lazy_load = True self.lazy_load = lazy_load - if not self.lazy_load: - self.embedding = torch.nn.Embedding( - config.vocab_size, config.hidden_size, device=device, dtype=torch.bfloat16 - ) + self.lazy_load_embedding = lazy_load + self.lazy_load_unembedding = lazy_load + self.extreme_low_memory = extreme_low_memory self.config = config self.device = device - - if not self.lazy_load: + self.loaded = {0: False, 1: False, 2: False} + if self.lazy_load: + self.block = torch.nn.ModuleList([ + TransformerBlock(config, 0, device=device), + ]) + else: self.block = torch.nn.ModuleList( [ TransformerBlock(config, layer_idx, device) for layer_idx in range(config.num_hidden_layers) ] ) - self.norm = RMSNorm(config.hidden_size, device=device) - if not self.lazy_load: + self.norm = RMSNorm(config.hidden_size, device=device) + if not self.lazy_load_embedding: + self.embedding = torch.nn.Embedding( + config.vocab_size, config.hidden_size, device=device, dtype=torch.bfloat16 + ) + if not self.lazy_load_unembedding: self.unembedding = torch.nn.Linear( config.hidden_size, config.vocab_size, @@ -496,14 +582,18 @@ def forward(self, start_pos: int = 0, ) -> tuple[ torch.Tensor, list[tuple[torch.Tensor, torch.Tensor] | None]]: - if self.lazy_load: - embedding = torch.nn.Embedding( - self.config.vocab_size, self.config.hidden_size, device=self.device, dtype=torch.bfloat16 - ) - for name, param in embedding.named_parameters(): - self.load_weights(param, f"embedding.{name}") - x = embedding(x) - del embedding + if self.lazy_load_embedding: + if not self.loaded[1] or self.extreme_low_memory: + self.embedding = torch.nn.Embedding( + self.config.vocab_size, self.config.hidden_size, device=self.device, dtype=torch.bfloat16 + ) + for name, param in self.embedding.named_parameters(): + self.load_weights(param, f"embedding.{name}") + self.loaded[1] = True + x = self.embedding(x) + if self.extreme_low_memory: + self.embedding = None + torch.cuda.empty_cache() else: x = self.embedding(x) @@ -512,33 +602,45 @@ def forward(self, if self.lazy_load: for layer_idx in range(self.config.num_hidden_layers): - block = TransformerBlock(self.config, layer_idx, self.device) - for name, param in block.named_parameters(): + # layers skipping experiment + #if layer_idx % 2 == 0: + # continue + for name, param in self.block[0].named_parameters(): self.load_weights(param, f"block.{layer_idx}.{name}") past_kv_for_layer = kv_cache[layer_idx] - x, new_kv_for_layer = block(x, past_kv_for_layer, start_pos) + x, new_kv_for_layer = self.block[0](x, past_kv_for_layer, start_pos) kv_cache[layer_idx] = new_kv_for_layer - del block else: for layer_idx in range(self.config.num_hidden_layers): past_kv_for_layer = kv_cache[layer_idx] x, new_kv_for_layer = self.block[layer_idx](x, past_kv_for_layer, start_pos) kv_cache[layer_idx] = new_kv_for_layer - x = self.norm(x) - if self.lazy_load: - unembedding = torch.nn.Linear( - self.config.hidden_size, - self.config.vocab_size, - bias=False, - device=self.device, - dtype=torch.bfloat16, - ) - for name, param in unembedding.named_parameters(): - self.load_weights(param, f"unembedding.{name}") - x = unembedding(x) - del unembedding + norm = RMSNorm(self.config.hidden_size, device=self.device) + for name, param in norm.named_parameters(): + self.load_weights(param, f"norm.{name}") + x = norm(x) + del norm + else: + x = self.norm(x) + + if self.lazy_load_unembedding: + if not self.loaded[0] or self.extreme_low_memory: + self.unembedding = torch.nn.Linear( + self.config.hidden_size, + self.config.vocab_size, + bias=False, + device=self.device, + dtype=torch.bfloat16, + ) + for name, param in self.unembedding.named_parameters(): + self.load_weights(param, f"unembedding.{name}") + self.loaded[0] = True + x = self.unembedding(x) + if self.extreme_low_memory: + self.unembedding = None + torch.cuda.empty_cache() else: x = self.unembedding(x) return x, kv_cache @@ -572,7 +674,7 @@ def load_weights(self, param, name): @staticmethod def from_checkpoint( - path: str, device: str | torch.device = "cuda" + path: str, device: str | torch.device = "cuda", lazy_load: bool = True ) -> "Transformer": if not isinstance(device, torch.device): device = torch.device(device) @@ -586,41 +688,41 @@ def from_checkpoint( config=config, device=device, ) - model.eval() - - # Load weights - my_rank = dist.get_rank() if dist.is_initialized() else 0 - world_size = dist.get_world_size() if dist.is_initialized() else 1 - per_rank_intermediate_size = config.intermediate_size // world_size + if not lazy_load: + model.eval() global checkpoint - checkpoint = Checkpoint(path, device) - - for name, param in model.named_parameters(): - loaded_tensor = checkpoint.get(name) - - # Note: it would be more efficient to do sharding before upcasting from MXFP4, - # but for simplicity we do it after. - if "mlp1" in name: # both weight and bias - loaded_tensor = loaded_tensor[ - :, - my_rank * 2 - * per_rank_intermediate_size : (my_rank + 1) * 2 - * per_rank_intermediate_size, - ..., - ] - elif "mlp2_weight" in name: # only weight - loaded_tensor = loaded_tensor[ - ..., - my_rank - * per_rank_intermediate_size : (my_rank + 1) - * per_rank_intermediate_size, - ] - try: - param.data.copy_(loaded_tensor) - except: - print(f"{name=} {param.data.shape=} {loaded_tensor.shape=}") - raise + checkpoint = Checkpoint(path, device, True) + + if not lazy_load: + # Load weights + my_rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + per_rank_intermediate_size = config.intermediate_size // world_size + for name, param in model.named_parameters(): + loaded_tensor = checkpoint.get(name) + # Note: it would be more efficient to do sharding before upcasting from MXFP4, + # but for simplicity we do it after. + if "mlp1" in name: # both weight and bias + loaded_tensor = loaded_tensor[ + :, + my_rank * 2 + * per_rank_intermediate_size : (my_rank + 1) * 2 + * per_rank_intermediate_size, + ..., + ] + elif "mlp2_weight" in name: # only weight + loaded_tensor = loaded_tensor[ + ..., + my_rank + * per_rank_intermediate_size : (my_rank + 1) + * per_rank_intermediate_size, + ] + try: + param.data.copy_(loaded_tensor) + except: + print(f"{name=} {param.data.shape=} {loaded_tensor.shape=}") + raise return model diff --git a/gpt_oss/torch/weights.py b/gpt_oss/torch/weights.py index 811c0a2c..3fadfb4b 100644 --- a/gpt_oss/torch/weights.py +++ b/gpt_oss/torch/weights.py @@ -1,6 +1,8 @@ import math import os - +import functools +import collections +import threading import torch from safetensors import safe_open @@ -30,14 +32,89 @@ } +class TensorLRUCache: + """ + A thread-safe, memory-aware LRU cache decorator with an individual tensor size limit. + + Args: + max_memory_gb (float): The maximum total memory in GB to be used by the cache. + max_individual_size_gb (float, optional): The maximum size in GB for any single + tensor to be considered for caching. If a tensor is larger than this, it will + be bypassed and not cached. Defaults to None (no individual limit). + """ + + def __init__(self, max_memory_gb: float, max_individual_size_gb: float | None = None, min_individual_size_gb: float | None = None): + self.max_memory_bytes = int(max_memory_gb * (1024 ** 3)) + if max_individual_size_gb is not None: + self.max_individual_size_bytes = int(max_individual_size_gb * (1024 ** 3)) + else: + self.max_individual_size_bytes = None + if min_individual_size_gb is not None: + self.min_individual_size_bytes = int(min_individual_size_gb * (1024 ** 3)) + else: + self.min_individual_size_bytes = None + + self.cache = collections.OrderedDict() + self.current_size_bytes = 0 + self.lock = threading.Lock() + + def __call__(self, func): + @functools.wraps(func) + def wrapped_func(*args, **kwargs): + try: + key = (args, frozenset(kwargs.items())) + except TypeError: + return func(*args, **kwargs) + + # First, check if the item is already in the cache (fast path) + with self.lock: + if key in self.cache: + self.cache.move_to_end(key) + return self.cache[key] + + # Item not in cache, so we must call the expensive function + result = func(*args, **kwargs) + + if not isinstance(result, torch.Tensor): + return result + + tensor_size = result.nbytes + + # --- Bypass cache for oversized/small tensors --- + if self.max_individual_size_bytes is not None and tensor_size > self.max_individual_size_bytes or self.min_individual_size_bytes > tensor_size: + #print(tensor_size) + #print(*args) + return result + + # If the tensor itself is larger than the *total* cache capacity, don't cache + if tensor_size > self.max_memory_bytes: + return result + + # Evict items until there's space for the new tensor + with self.lock: + while self.current_size_bytes + tensor_size > self.max_memory_bytes: + evicted_key, evicted_tensor = self.cache.popitem(last=False) + self.current_size_bytes -= evicted_tensor.nbytes + + # Add the new item + self.cache[key] = result + self.current_size_bytes += tensor_size + + return result + + return wrapped_func + + class Checkpoint: - def __init__(self, path: str, device: torch.device): + def __init__(self, path: str, device: torch.device, pin_memory_for_faster_cpu_gpu_transfer: bool = False): device_str = ( device.type if device.index is None else device.type + ":" + str(device.index) ) self.device_str = device_str + self.pin_memory = pin_memory_for_faster_cpu_gpu_transfer # use gpu shared memory 2.5Gb for gpt-oss-20b 5% faster + self.lut_buffer = False # use additional gpu shared memory 1Gb for gpt-oss-20b, 1-2% faster self.file_handles = {} # Read from all files ending with .safetensors in the checkpoint directory @@ -55,9 +132,11 @@ def __init__(self, path: str, device: torch.device): tensor_name_to_file[key] = safetensor_file self.tensor_name_to_file = tensor_name_to_file - self._lut = torch.tensor(FP4_VALUES, dtype=torch.bfloat16, device=device) + if self.lut_buffer: + self._lut = torch.tensor(FP4_VALUES, dtype=torch.bfloat16, device=device) @profile + #@TensorLRUCache(max_memory_gb=3.6, max_individual_size_gb=0.2, min_individual_size_gb=0.03) # (3.6, 0.2, 0.03) cache for mlp1 use additional memory but 5% faster, use it only for gpt_oss.generate (Max vram used: 7.6Gb) def _get_tensor(self, name: str) -> torch.Tensor: assert name in self.tensor_name_to_file, f"Tensor {name} not found." file_key = self.tensor_name_to_file[name] @@ -65,7 +144,7 @@ def _get_tensor(self, name: str) -> torch.Tensor: # Get tensor and pin memory for faster GPU transfers tensor = handle.get_tensor(name) - if self.device_str.startswith('cuda'): + if self.pin_memory and self.device_str.startswith('cuda'): tensor = tensor.pin_memory() return tensor.to(self.device_str, non_blocking=True) @@ -74,6 +153,7 @@ def __del__(self): for handle in self.file_handles.values(): pass + @profile def get(self, name: str) -> torch.Tensor: match PARAM_NAME_MAP.get(name, name): case (blocks_name, scales_name): @@ -90,7 +170,7 @@ def _get_mxfp4_tensor( scales_name: str, *, dtype: torch.dtype = torch.bfloat16, - rows_per_chunk: int = 16384 * 512, + rows_per_chunk: int = 16384 * 64, ) -> torch.Tensor: assert blocks_name in self.tensor_name_to_file, ( f"Blocks tensor {blocks_name} not found in checkpoint." @@ -106,7 +186,10 @@ def _get_mxfp4_tensor( f"{blocks.shape=} does not match {scales.shape=}" ) - lut = self._lut + if self.lut_buffer: + lut = self._lut + else: + lut = torch.tensor(FP4_VALUES, dtype=torch.bfloat16, device=blocks.device) *prefix_shape, G, B = blocks.shape rows_total = math.prod(prefix_shape) * G @@ -118,19 +201,10 @@ def _get_mxfp4_tensor( for r0 in range(0, rows_total, rows_per_chunk): r1 = min(r0 + rows_per_chunk, rows_total) - - blk = blocks[r0:r1] - exp = scales[r0:r1] - - idx_lo = torch.bitwise_and(blk, 0x0F).long() - idx_hi = torch.bitwise_right_shift(blk, 4).long() - sub = out[r0:r1] - sub[:, 0::2] = lut[idx_lo] - sub[:, 1::2] = lut[idx_hi] - - torch.ldexp(sub, exp, out=sub) - del idx_lo, idx_hi, blk, exp + sub[:, 0::2] = lut[torch.bitwise_and(blocks[r0:r1], 0x0F).long()] + sub[:, 1::2] = lut[torch.bitwise_right_shift(blocks[r0:r1], 4).long()] + torch.ldexp_(sub, scales[r0:r1]) return out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) From 3962a3fb8036e1349a770b8092111eb46bf99786 Mon Sep 17 00:00:00 2001 From: nalexand <35492736+nalexand@users.noreply.github.com> Date: Sun, 31 Aug 2025 21:39:30 +0200 Subject: [PATCH 7/8] Update README.md --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index bba7da9f..1c71c0b8 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,15 @@ Watch GPT-OSS 20B running on just 8GB of VRAM: *Click the image to watch the full demonstration* or - [Watch on YouTube](https://youtu.be/0g7MBALZM8c) + +

    UPDATE: 08/31/2025 - Added support 6 Gb VRAM for gpt-oss-20b !!!

    + +- Optimized MLPBlock +- gpt_oss.generate min 6 Gb VRAM +- gpt_oss.chat min 8 Gb VRAM +- gpt_oss.chat windows support with pyreadline3 module +- auto tune options for awailable VRAM + __________________________________________ Welcome to the gpt-oss series, [OpenAI's open-weight models](https://openai.com/open-models/) designed for powerful reasoning, agentic tasks, and versatile developer use cases. From 54436c3ee3e1a4f2f70b425c42012a91a51465e2 Mon Sep 17 00:00:00 2001 From: koshe Date: Sun, 31 Aug 2025 22:20:41 +0200 Subject: [PATCH 8/8] Add options for control memory usage --- gpt_oss/chat.py | 8 +++----- gpt_oss/generate.py | 2 +- gpt_oss/torch/model.py | 36 +++++++++++++++++++----------------- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/gpt_oss/chat.py b/gpt_oss/chat.py index 37bd226f..27bf1190 100644 --- a/gpt_oss/chat.py +++ b/gpt_oss/chat.py @@ -9,8 +9,6 @@ import os from pathlib import Path -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False" - try: import gnureadline as readline except ImportError: @@ -74,7 +72,7 @@ def main(args): from gpt_oss.torch.model import TokenGenerator as TorchGenerator from gpt_oss.torch.utils import init_distributed device = init_distributed() - generator = TorchGenerator(args.checkpoint, device) + generator = TorchGenerator(args.checkpoint, device, pin_memory=False) case "vllm": from gpt_oss.vllm.token_generator import TokenGenerator as VLLMGenerator generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=2) @@ -85,8 +83,8 @@ def main(args): system_message_content = ( SystemContent.new() - #.with_reasoning_effort(REASONING_EFFORT[args.reasoning_effort]) - #.with_conversation_start_date(datetime.datetime.now().strftime("%Y-%m-%d")) + .with_reasoning_effort(REASONING_EFFORT[args.reasoning_effort]) + .with_conversation_start_date(datetime.datetime.now().strftime("%Y-%m-%d")) ) if args.browser: diff --git a/gpt_oss/generate.py b/gpt_oss/generate.py index 213444ce..e21a144e 100644 --- a/gpt_oss/generate.py +++ b/gpt_oss/generate.py @@ -24,7 +24,7 @@ def main(args): from gpt_oss.torch.utils import init_distributed from gpt_oss.torch.model import TokenGenerator as TorchGenerator device = init_distributed() - generator = TorchGenerator(args.checkpoint, device=device) + generator = TorchGenerator(args.checkpoint, device=device, pin_memory=True) case "triton": from gpt_oss.torch.utils import init_distributed from gpt_oss.triton.model import TokenGenerator as TritonGenerator diff --git a/gpt_oss/torch/model.py b/gpt_oss/torch/model.py index b519275d..64daab36 100644 --- a/gpt_oss/torch/model.py +++ b/gpt_oss/torch/model.py @@ -534,19 +534,17 @@ def __init__( config: ModelConfig, device: torch.device | None = None, lazy_load: bool = True, - extreme_low_memory: bool = False, + extremly_low_memory: bool = False, ): super().__init__() free_mem = get_free_gpu_memory_gb() if free_mem < 6: lazy_load = True - extreme_low_memory = True + extremly_low_memory = True elif free_mem < 8: lazy_load = True self.lazy_load = lazy_load - self.lazy_load_embedding = lazy_load - self.lazy_load_unembedding = lazy_load - self.extreme_low_memory = extreme_low_memory + self.extremly_low_memory = extremly_low_memory self.config = config self.device = device self.loaded = {0: False, 1: False, 2: False} @@ -562,11 +560,9 @@ def __init__( ] ) self.norm = RMSNorm(config.hidden_size, device=device) - if not self.lazy_load_embedding: self.embedding = torch.nn.Embedding( config.vocab_size, config.hidden_size, device=device, dtype=torch.bfloat16 ) - if not self.lazy_load_unembedding: self.unembedding = torch.nn.Linear( config.hidden_size, config.vocab_size, @@ -582,8 +578,8 @@ def forward(self, start_pos: int = 0, ) -> tuple[ torch.Tensor, list[tuple[torch.Tensor, torch.Tensor] | None]]: - if self.lazy_load_embedding: - if not self.loaded[1] or self.extreme_low_memory: + if self.lazy_load: + if not self.loaded[1] or self.extremly_low_memory: self.embedding = torch.nn.Embedding( self.config.vocab_size, self.config.hidden_size, device=self.device, dtype=torch.bfloat16 ) @@ -591,7 +587,7 @@ def forward(self, self.load_weights(param, f"embedding.{name}") self.loaded[1] = True x = self.embedding(x) - if self.extreme_low_memory: + if self.extremly_low_memory: self.embedding = None torch.cuda.empty_cache() else: @@ -625,8 +621,8 @@ def forward(self, else: x = self.norm(x) - if self.lazy_load_unembedding: - if not self.loaded[0] or self.extreme_low_memory: + if self.lazy_load: + if not self.loaded[0] or self.extremly_low_memory: self.unembedding = torch.nn.Linear( self.config.hidden_size, self.config.vocab_size, @@ -638,7 +634,7 @@ def forward(self, self.load_weights(param, f"unembedding.{name}") self.loaded[0] = True x = self.unembedding(x) - if self.extreme_low_memory: + if self.extremly_low_memory: self.unembedding = None torch.cuda.empty_cache() else: @@ -674,7 +670,7 @@ def load_weights(self, param, name): @staticmethod def from_checkpoint( - path: str, device: str | torch.device = "cuda", lazy_load: bool = True + path: str, device: str | torch.device = "cuda", lazy_load: bool = True, pin_memory: bool = False ) -> "Transformer": if not isinstance(device, torch.device): device = torch.device(device) @@ -684,15 +680,20 @@ def from_checkpoint( json_config = json.load(f) config = ModelConfig(**json_config) + extremly_low_memory = False + if not pin_memory: + extremly_low_memory = True + model = Transformer( config=config, device=device, + extremly_low_memory=extremly_low_memory, ) if not lazy_load: model.eval() global checkpoint - checkpoint = Checkpoint(path, device, True) + checkpoint = Checkpoint(path, device, pin_memory) if not lazy_load: # Load weights @@ -729,9 +730,10 @@ def from_checkpoint( class TokenGenerator: @torch.inference_mode() - def __init__(self, checkpoint: str, device: torch.device): + def __init__(self, checkpoint: str, device: torch.device, pin_memory: bool = False): self.device = device - self.model = Transformer.from_checkpoint(checkpoint, device=self.device) + self.model = Transformer.from_checkpoint( + checkpoint, device=self.device, pin_memory=pin_memory) @torch.inference_mode() def generate(self,