diff --git a/python/sgl_jax/srt/layers/attention/base_attn_backend.py b/python/sgl_jax/srt/layers/attention/base_attn_backend.py index 9a6562c6..0c2a8746 100644 --- a/python/sgl_jax/srt/layers/attention/base_attn_backend.py +++ b/python/sgl_jax/srt/layers/attention/base_attn_backend.py @@ -17,7 +17,7 @@ class AttentionBackend(nnx.Module): """The base class of attention backends""" @abstractmethod - def get_forward_metadata(self, batch: ModelWorkerBatch, mesh: Mesh): + def get_forward_metadata(self, batch: ModelWorkerBatch): """Init the metadata for a forward pass and return it""" raise NotImplementedError() diff --git a/python/sgl_jax/srt/layers/attention/flashattention_backend.py b/python/sgl_jax/srt/layers/attention/flashattention_backend.py index c0bd5e53..0865d07c 100644 --- a/python/sgl_jax/srt/layers/attention/flashattention_backend.py +++ b/python/sgl_jax/srt/layers/attention/flashattention_backend.py @@ -4,7 +4,7 @@ import jax import jax.numpy as jnp import numpy as np -from jax.sharding import Mesh +from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P from jax.tree_util import register_pytree_node_class @@ -75,6 +75,7 @@ def __init__( vmem_limit_bytes: int = 64 * (1 << 20), # 64MB page_size: int = 1, kv_partition_axis: str = "tensor", + mesh: jax.sharding.Mesh = None, ): self.vmem_limit_bytes = vmem_limit_bytes self.num_heads = num_attn_heads @@ -86,8 +87,9 @@ def __init__( self.page_size = page_size self.kv_partition_axis = kv_partition_axis self.forward_metadata = FlashAttentionMetadata() + self.mesh = mesh - def get_forward_metadata(self, batch: ModelWorkerBatch, mesh: Mesh): + def get_forward_metadata(self, batch: ModelWorkerBatch): """Return the metadata for a forward pass.""" metadata = FlashAttentionMetadata() @@ -148,8 +150,10 @@ def get_forward_metadata(self, batch: ModelWorkerBatch, mesh: Mesh): metadata.seq_lens, metadata.distribution, ) = device_array( - mesh, (num_seqs, cu_q_lens, cu_kv_lens, page_indices, seq_lens, distribution), + sharding=( + NamedSharding(self.mesh, P()) if jax.process_count() == 1 else None + ), ) return metadata @@ -255,7 +259,6 @@ def _ragged_paged_attention_with_fused_kv(*args): updated_kv_cache_fused, ) = jax.shard_map( # Fused KV kernel handles cache updates internally _ragged_paged_attention_with_fused_kv, - mesh=jax.sharding.get_abstract_mesh(), in_specs=in_specs, out_specs=out_specs, check_vma=False, diff --git a/python/sgl_jax/srt/layers/attention/native_backend.py b/python/sgl_jax/srt/layers/attention/native_backend.py index 307afb96..a282a469 100644 --- a/python/sgl_jax/srt/layers/attention/native_backend.py +++ b/python/sgl_jax/srt/layers/attention/native_backend.py @@ -39,7 +39,7 @@ def tree_unflatten(cls, aux_data, children): num_attn_heads=aux_data["num_heads"], num_kv_heads=aux_data["num_kv_heads"] ) - def get_forward_metadata(self, batch: ModelWorkerBatch, mesh: Mesh): + def get_forward_metadata(self, batch: ModelWorkerBatch): """Init the metadata for a forward pass and return it.""" return None diff --git a/python/sgl_jax/srt/layers/logits_processor.py b/python/sgl_jax/srt/layers/logits_processor.py index 0357afa4..c4625573 100644 --- a/python/sgl_jax/srt/layers/logits_processor.py +++ b/python/sgl_jax/srt/layers/logits_processor.py @@ -7,7 +7,8 @@ import jax.numpy as jnp import numpy as np from flax import nnx -from jax.sharding import Mesh +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P from jax.tree_util import register_pytree_node_class from sgl_jax.srt.layers.embeddings import Embed @@ -194,7 +195,7 @@ def from_model_worker_batch(cls, batch: ModelWorkerBatch, mesh: Mesh = None): ) = False extend_logprob_pruned_lens_cpu = extend_seq_lens_cpu = None - mesh = mesh if mesh is not None else jax.sharding.get_abstract_mesh() + sharding = NamedSharding(mesh, P()) if jax.process_count() == 1 else None return cls( forward_mode=batch.forward_mode, @@ -202,7 +203,7 @@ def from_model_worker_batch(cls, batch: ModelWorkerBatch, mesh: Mesh = None): extend_return_logprob=extend_return_logprob, extend_return_top_logprob=extend_return_top_logprob, extend_token_ids_logprob=extend_token_ids_logprob, - extend_seq_lens=device_array(mesh, batch.extend_seq_lens), + extend_seq_lens=device_array(batch.extend_seq_lens, sharding=sharding), extend_seq_lens_cpu=extend_seq_lens_cpu, extend_logprob_start_lens_cpu=( batch.extend_logprob_start_lens if batch.return_logprob else None @@ -211,8 +212,7 @@ def from_model_worker_batch(cls, batch: ModelWorkerBatch, mesh: Mesh = None): top_logprobs_nums=batch.top_logprobs_nums, token_ids_logprobs=batch.token_ids_logprobs, extend_input_logprob_token_ids_device=device_array( - mesh, - batch.extend_input_logprob_token_ids, + batch.extend_input_logprob_token_ids, sharding=sharding ), ) @@ -279,14 +279,12 @@ def __call__( pruned_states = jnp.concat(pruned_states) sample_indices = device_array( - self.mesh, np.array( sample_indices, dtype=jnp.int64, ), ) input_logprob_indices = device_array( - self.mesh, np.array(input_logprob_indices, dtype=jnp.int64), ) @@ -324,7 +322,6 @@ def __call__( # Normalize the logprob w/o temperature, top-p pruned_lens = device_array( - self.mesh, np.array( logits_metadata.extend_logprob_pruned_lens_cpu, ), @@ -362,7 +359,7 @@ def __call__( input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None input_token_logprobs = input_logprobs[ - device_array(self.mesh, np.arange(input_logprobs.shape[0])), + device_array(np.arange(input_logprobs.shape[0])), logits_metadata.extend_input_logprob_token_ids_device, ] diff --git a/python/sgl_jax/srt/managers/detokenizer_manager.py b/python/sgl_jax/srt/managers/detokenizer_manager.py index b34cf2d8..396f4902 100644 --- a/python/sgl_jax/srt/managers/detokenizer_manager.py +++ b/python/sgl_jax/srt/managers/detokenizer_manager.py @@ -299,7 +299,7 @@ def run_detokenizer_process( port_args: PortArgs, ): kill_itself_when_parent_died() - setproctitle.setproctitle("sglang::detokenizer") + setproctitle.setproctitle("sglang-jax::detokenizer") configure_logger(server_args) parent_process = psutil.Process().parent() diff --git a/python/sgl_jax/srt/managers/tp_worker.py b/python/sgl_jax/srt/managers/tp_worker.py index 6622b41d..27811807 100644 --- a/python/sgl_jax/srt/managers/tp_worker.py +++ b/python/sgl_jax/srt/managers/tp_worker.py @@ -19,6 +19,7 @@ ModelWorkerBatch, global_server_args_dict, ) +from sgl_jax.srt.managers.utils import resolve_future_token_ids, set_future_token_ids from sgl_jax.srt.mem_cache.memory_pool import ReqToTokenPool from sgl_jax.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, @@ -184,11 +185,11 @@ def normalize_token_paddings(self): self.precompile_token_paddings = normalized_token_paddings - def run_precompile(self): - self.precompile_extend() - self.precompile_decode() + def run_precompile(self, future_token_ids_map=None): + self.precompile_extend(future_token_ids_map) + self.precompile_decode(future_token_ids_map) - def precompile_extend(self): + def precompile_extend(self, future_token_ids_map=None): start_time = time.perf_counter() logger.info( f"[EXTEND] Begin to precompile bs_paddings={self.precompile_bs_paddings[-1:]} token_paddings={self.precompile_token_paddings}" @@ -211,12 +212,22 @@ def precompile_extend(self): ForwardMode.EXTEND, self.precompile_cache_loc_paddings[-1], ) + model_worker_batch.forward_batch = ForwardBatch.init_new( + model_worker_batch, self.model_runner + ) + if future_token_ids_map is not None: + model_worker_batch.forward_batch.input_ids = ( + resolve_future_token_ids( + model_worker_batch.forward_batch.input_ids, + future_token_ids_map, + ) + ) self.forward_batch_generation(model_worker_batch, None, True) end_time = time.perf_counter() logger.info("[EXTEND] Precompile finished in %.0f secs", end_time - start_time) - def precompile_decode(self): + def precompile_decode(self, future_token_ids_map=None): start_time = time.perf_counter() logger.info( f"[DECODE] Begin to precompile bs_paddings={self.precompile_bs_paddings}" @@ -236,9 +247,21 @@ def precompile_decode(self): sampling_metadata = SamplingMetadata.from_model_worker_batch( model_worker_batch, 0, self.mesh ) - self.forward_batch_generation( + model_worker_batch.forward_batch = ForwardBatch.init_new( + model_worker_batch, self.model_runner + ) + if future_token_ids_map is not None: + model_worker_batch.forward_batch.input_ids = ( + resolve_future_token_ids( + model_worker_batch.forward_batch.input_ids, + future_token_ids_map, + ) + ) + _, next_token_ids, _ = self.forward_batch_generation( model_worker_batch, None, False, sampling_metadata ) + if future_token_ids_map is not None: + set_future_token_ids(future_token_ids_map, 0, next_token_ids) end_time = time.perf_counter() logger.info("[DECODE] Precompile finished in %.0f secs", end_time - start_time) @@ -365,7 +388,7 @@ def forward_batch_generation( if forward_metadata is None: forward_metadata = ( self.worker.model_runner.attn_backend.get_forward_metadata( - model_worker_batch, self.mesh + model_worker_batch ) ) diff --git a/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py b/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py index e5dc0928..7769c2ab 100644 --- a/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py +++ b/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py @@ -4,7 +4,6 @@ import logging import signal import threading -import time from queue import Queue from typing import Optional, Tuple @@ -16,34 +15,14 @@ from sgl_jax.srt.managers.schedule_batch import ModelWorkerBatch from sgl_jax.srt.managers.tp_worker import ModelWorker +from sgl_jax.srt.managers.utils import resolve_future_token_ids, set_future_token_ids from sgl_jax.srt.sampling.sampling_batch_info import SamplingMetadata from sgl_jax.srt.server_args import ServerArgs -from sgl_jax.srt.utils.jax_utils import device_array from sgl_jax.utils import get_exception_traceback logger = logging.getLogger(__name__) -@jax.jit -def resolve_future_token_ids(input_ids, future_token_ids_map): - return jnp.where( - input_ids < 0, - future_token_ids_map[jnp.clip(-input_ids, a_min=0)], - input_ids, - ) - - -@jax.jit -def set_future_token_ids(future_token_ids_map, future_token_ids_ct, next_token_ids): - # The start index must be a tuple, one element per dimension of the array. - start_indices = (future_token_ids_ct + 1,) - - # jax.lax.dynamic_update_slice is the correct tool for this job. - return jax.lax.dynamic_update_slice( - future_token_ids_map, next_token_ids, start_indices - ) - - class ModelWorkerClient: """A tensor parallel model worker.""" @@ -187,7 +166,7 @@ def forward_batch_generation( ) forward_metadata = self.worker.model_runner.attn_backend.get_forward_metadata( - model_worker_batch, self.mesh + model_worker_batch ) # Push a new batch to the queue (JAX handles synchronization automatically) @@ -215,35 +194,7 @@ def forward_batch_generation( return None, future_next_token_ids, 0 def run_precompile(self): - start_time = time.perf_counter() - logger.info( - f"[ModelWorkerClient] Begins to run resolve_future_token_ids precompile." - ) - ( - precompile_token_paddings, - precompile_bs_paddings, - _, - ) = self.get_precompile_paddings() - max_padding_bs, _ = self.get_max_padded_size() - bs_paddings = sorted(set(precompile_bs_paddings + [max_padding_bs])) - token_paddings = sorted(set(bs_paddings + precompile_token_paddings)) - for token_padding in token_paddings: - input_ids = device_array( - self.worker.mesh, jnp.arange(0, token_padding, dtype=jnp.int32) - ) - resolve_future_token_ids(input_ids, self.future_token_ids_map) - for bs_padding in bs_paddings: - input_ids = device_array( - self.worker.mesh, jnp.arange(0, bs_padding, dtype=jnp.int32) - ) - set_future_token_ids( - self.future_token_ids_map, self.future_token_ids_ct, input_ids - ) - end_time = time.perf_counter() - logger.info( - f"[ModelWorkerClient] Completes resolve_future_token_ids precompile. Time cost: {end_time - start_time} seconds" - ) - self.worker.run_precompile() + self.worker.run_precompile(self.future_token_ids_map) def __delete__(self): self.input_queue.put((None, None, None, None)) diff --git a/python/sgl_jax/srt/managers/utils.py b/python/sgl_jax/srt/managers/utils.py index feeb32ed..b14255b1 100644 --- a/python/sgl_jax/srt/managers/utils.py +++ b/python/sgl_jax/srt/managers/utils.py @@ -1,6 +1,9 @@ import logging from typing import Optional +import jax +from jax import numpy as jnp + from sgl_jax.srt.managers.schedule_batch import Req logger = logging.getLogger(__name__) @@ -37,3 +40,20 @@ def validate_input_length( return error_msg return None + + +@jax.jit +def resolve_future_token_ids(input_ids, future_token_ids_map): + return jnp.where( + input_ids < 0, + future_token_ids_map[jnp.clip(-input_ids, a_min=0)], + input_ids, + ) + + +@jax.jit +def set_future_token_ids(future_token_ids_map, future_token_ids_ct, next_token_ids): + start_indices = (future_token_ids_ct + 1,) + return jax.lax.dynamic_update_slice( + future_token_ids_map, next_token_ids, start_indices + ) diff --git a/python/sgl_jax/srt/mem_cache/memory_pool.py b/python/sgl_jax/srt/mem_cache/memory_pool.py index 467fd89a..9f6bcf3e 100644 --- a/python/sgl_jax/srt/mem_cache/memory_pool.py +++ b/python/sgl_jax/srt/mem_cache/memory_pool.py @@ -481,35 +481,6 @@ def _set_fused_kv_buffer( ) -def _set_kv_buffer( - k: jax.Array, - v: jax.Array, - loc: jax.Array, - k_cache: jax.Array, - v_cache: jax.Array, - page_size: int, - kv_partition_axis: str = "tensor", -): - """ - k: jax.Array, # [total_tokens, num_heads, head_dim] - v: jax.Array, # [total_tokens, num_heads, head_dim] - loc: jax.Array, # [total_tokens] total_tokens is the padding tokens, if the value is -1, it means the token is padding - k_cache: jax.Array, - v_cache: jax.Array, - """ - k_cache, v_cache = update_kv_cache( - k, - v, - loc, - k_cache, - v_cache, - page_size=page_size, - kv_partition_axis=kv_partition_axis, - ) - - return k_cache, v_cache - - def update_fused_kv_cache( fused_kv: jax.Array, # [total_tokens, num_kv_heads * 2, head_dim] loc: jax.Array, # [total_tokens], -1 for padding @@ -539,40 +510,6 @@ def update_fused_kv_cache( ) -def update_kv_cache( - k: jax.Array, # [total_tokens, num_heads, head_dim] - v: jax.Array, # [total_tokens, num_heads, head_dim] - loc: jax.Array, # [total_tokens], -1 for padding - k_cache: jax.Array, - v_cache: jax.Array, - page_size: int = 1, - kv_partition_axis: str = "tensor", -): - """ - Main KV cache update function that chooses between vectorized and token-by-token approaches. - - Args: - k: Key tensor [total_tokens, num_heads, head_dim] - v: Value tensor [total_tokens, num_heads, head_dim] - loc: Location indices [total_tokens], -1 for padding tokens - k_cache: Key cache buffer - v_cache: Value cache buffer - use_vectorized: Whether to use vectorized (True) or token-by-token (False) approach - - Returns: - Updated k_cache and v_cache - """ - return update_kv_cache_vectorized( - k, - v, - loc, - k_cache, - v_cache, - page_size=page_size, - kv_partition_axis=kv_partition_axis, - ) - - def kv_cache_update_kernel( # Prefetch slices_ref, # [3, padded_num_slices], list of (kv_cache_start, new_kv_start, slice_len) @@ -672,10 +609,7 @@ def kv_cache_update( num_slices_per_block: int = 8, kv_partition_axis: str = "tensor", ): - mesh = jax.sharding.get_abstract_mesh() - @jax.shard_map( - mesh=mesh, in_specs=( P( None, kv_partition_axis, None @@ -783,6 +717,7 @@ def update_kv_cache_vectorized( v_cache: jax.Array, page_size: int, kv_partition_axis: str = "tensor", + mesh: jax.sharding.Mesh = None, ): """ Vectorized KV cache update that handles padding and supports page_size > 1 diff --git a/python/sgl_jax/srt/model_executor/forward_batch_info.py b/python/sgl_jax/srt/model_executor/forward_batch_info.py index bff82d37..34d60b92 100644 --- a/python/sgl_jax/srt/model_executor/forward_batch_info.py +++ b/python/sgl_jax/srt/model_executor/forward_batch_info.py @@ -26,6 +26,8 @@ logger = logging.getLogger(__name__) +from jax.sharding import NamedSharding, PartitionSpec + from sgl_jax.srt.utils.jax_utils import device_array if TYPE_CHECKING: @@ -245,7 +247,6 @@ def init_new( extend_prefix_lens, extend_seq_lens, ) = device_array( - model_runner.mesh, ( batch.input_ids, batch.seq_lens, @@ -257,7 +258,13 @@ def init_new( batch.extend_prefix_lens, batch.extend_seq_lens, ), + sharding=( + NamedSharding(model_runner.mesh, PartitionSpec()) + if jax.process_count() == 1 + else None + ), ) + obj = cls( bid=batch.bid, forward_mode=batch.forward_mode, diff --git a/python/sgl_jax/srt/model_executor/model_runner.py b/python/sgl_jax/srt/model_executor/model_runner.py index 0132dce0..c8ab62fa 100644 --- a/python/sgl_jax/srt/model_executor/model_runner.py +++ b/python/sgl_jax/srt/model_executor/model_runner.py @@ -353,6 +353,7 @@ def _get_attention_backend(self): self.num_kv_heads, self.model_config.head_dim, page_size=self.page_size, + mesh=self.mesh, ) else: raise ValueError( @@ -405,7 +406,7 @@ def _forward_raw( ) -> Tuple[LogitsProcessorOutput, int]: # for compatibility, 0.6.3 need to use use_mesh. set_mesh is not have __entry__ attribute. # on jax 0.7.1, we need to use set_mesh. - # with self.mesh, jax.sharding.set_mesh(self.mesh): + # with jax.sharding.set_mesh(self.mesh): with jax.sharding.use_mesh(self.mesh): if ( forward_batch.forward_mode.is_decode() diff --git a/python/sgl_jax/srt/models/qwen3_moe.py b/python/sgl_jax/srt/models/qwen3_moe.py index f4aaa08f..d5801704 100644 --- a/python/sgl_jax/srt/models/qwen3_moe.py +++ b/python/sgl_jax/srt/models/qwen3_moe.py @@ -180,8 +180,6 @@ def __init__( self.is_moe_layer = False self.moe_gate = None else: - if mesh is None: - mesh = jax.sharding.get_abstract_mesh() num_experts = getattr(config, "num_experts", 128) num_experts_per_tok = getattr(config, "num_experts_per_tok", 8) moe_intermediate_size = getattr(config, "moe_intermediate_size", 768) diff --git a/python/sgl_jax/srt/sampling/sampling_batch_info.py b/python/sgl_jax/srt/sampling/sampling_batch_info.py index b881d679..7dde2154 100644 --- a/python/sgl_jax/srt/sampling/sampling_batch_info.py +++ b/python/sgl_jax/srt/sampling/sampling_batch_info.py @@ -2,9 +2,9 @@ import dataclasses import logging -from typing import TYPE_CHECKING, Callable, List, Optional +from typing import TYPE_CHECKING, List, Optional -from jax.sharding import Mesh +from jax.sharding import Mesh, NamedSharding, PartitionSpec from jax.tree_util import register_pytree_node_class from sgl_jax.srt.sampling.sampling_params import TOP_K_ALL @@ -113,7 +113,12 @@ def from_model_worker_batch( (temperatures_device, top_ps_device, top_ks_device, min_ps_device) = ( device_array( - mesh, (padded_temperatures, padded_top_ps, padded_top_ks, padded_min_ps) + (padded_temperatures, padded_top_ps, padded_top_ks, padded_min_ps), + sharding=( + NamedSharding(mesh, PartitionSpec()) + if jax.process_count() == 1 + else None + ), ) ) @@ -270,12 +275,10 @@ def merge_bias_tensor( if lhs is None: lhs = device_array( - mesh, jnp.full((bs1, *shape), fill_value=default, dtype=dtype), ) if rhs is None: rhs = device_array( - mesh, jnp.full((bs2, *shape), fill_value=default, dtype=dtype), ) return jnp.concat([lhs, rhs]) diff --git a/python/sgl_jax/srt/utils/jax_utils.py b/python/sgl_jax/srt/utils/jax_utils.py index f7bc20bb..cfd94aa9 100644 --- a/python/sgl_jax/srt/utils/jax_utils.py +++ b/python/sgl_jax/srt/utils/jax_utils.py @@ -72,13 +72,11 @@ def get_available_device_memory(device, distributed=False, empty_cache=True): # Use pmap to find the minimum available memory across all devices. mesh = jax.make_mesh((jax.process_count(), 4), ("node", "device")) - with jax.sharding.set_mesh(mesh=mesh): - - @jax.shard_map( - mesh=mesh, in_specs=PartitionSpec(None), out_specs=PartitionSpec(None) - ) - def _get_available_memory_distributed(a): - return jax.lax.pmin(a, axis_name="node") + @jax.shard_map( + mesh=mesh, in_specs=PartitionSpec(None), out_specs=PartitionSpec(None) + ) + def _get_available_memory_distributed(a): + return jax.lax.pmin(a, axis_name="node") # We broadcast the local min memory to all devices and then find the global min. # i64 dtype cannot be all-reduce min @@ -93,7 +91,5 @@ def _get_available_memory_distributed(a): return int(free_gpu_memory * (1 << 10)) -def device_array(mesh, *data, sharding=None, **kwargs) -> jax.Array: - if sharding is None: - sharding = NamedSharding(mesh, PartitionSpec()) +def device_array(*data, sharding=None, **kwargs) -> jax.Array: return jax.device_put(*data, device=sharding, **kwargs) diff --git a/python/sgl_jax/test/mem_cache/test_kv_cache.py b/python/sgl_jax/test/mem_cache/test_kv_cache.py index e2ee07ab..d280dac4 100644 --- a/python/sgl_jax/test/mem_cache/test_kv_cache.py +++ b/python/sgl_jax/test/mem_cache/test_kv_cache.py @@ -3,11 +3,11 @@ import jax import jax.numpy as jnp -import numpy as np -from jax.sharding import AxisType, Mesh, NamedSharding from jax.sharding import PartitionSpec as P -from sgl_jax.srt.mem_cache.memory_pool import update_kv_cache +from sgl_jax.srt.mem_cache.memory_pool import ( + update_kv_cache_vectorized as update_kv_cache, +) from sgl_jax.srt.utils.mesh_utils import create_device_mesh mesh = create_device_mesh(ici_parallelism=[1, -1, 1, 1], dcn_parallelism=[1, 1, 1, 1]) diff --git a/python/sgl_jax/test/model_executor/test_model_runner.py b/python/sgl_jax/test/model_executor/test_model_runner.py index 4d283e48..abb0e8fa 100644 --- a/python/sgl_jax/test/model_executor/test_model_runner.py +++ b/python/sgl_jax/test/model_executor/test_model_runner.py @@ -227,7 +227,8 @@ def test_forward(self): extend_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) with self.mesh: extend_output = self.model_runner.forward( - extend_batch, LogitsMetadata.from_model_worker_batch(model_worker_batch) + extend_batch, + LogitsMetadata.from_model_worker_batch(model_worker_batch, self.mesh), ) # Verify forward_pass_id incremented @@ -265,7 +266,9 @@ def test_forward(self): with self.mesh: decode_output = self.model_runner.forward( current_batch, - LogitsMetadata.from_model_worker_batch(model_worker_batch), + LogitsMetadata.from_model_worker_batch( + model_worker_batch, self.mesh + ), ) decode_outputs.append(decode_output) diff --git a/python/sgl_jax/test/test_flashattention.py b/python/sgl_jax/test/test_flashattention.py index d2c21f71..14f3a239 100644 --- a/python/sgl_jax/test/test_flashattention.py +++ b/python/sgl_jax/test/test_flashattention.py @@ -271,9 +271,7 @@ def align_to_size(l, size, value=0): extend_prefix_lens=extend_prefix_lens, extend_seq_lens=extend_seq_lens, ) - fb.attn_backend.forward_metadata = attention_backend.get_forward_metadata( - mwb, mesh=mesh - ) + fb.attn_backend.forward_metadata = attention_backend.get_forward_metadata(mwb) return fb, q, k, v diff --git a/scripts/killall_sglang.sh b/scripts/killall_sglang.sh index 9c504b4f..31ee6372 100755 --- a/scripts/killall_sglang.sh +++ b/scripts/killall_sglang.sh @@ -5,7 +5,7 @@ set -euxo pipefail uv pip install tpu-info tpu-info # Clean SGLang processes -pgrep -f 'sglang::|sgl_jax::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt|sgl_jax\.launch_server|sgl_jax\.srt|sgl_jax\.bench|sgl_jax\.data_parallel' | xargs -r kill -9 || true +pgrep -f 'sglang::|sglang-jax::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt|sgl_jax\.launch_server|sgl_jax\.srt|sgl_jax\.bench|sgl_jax\.data_parallel' | xargs -r kill -9 || true # Clean all GPU processes if any argument is provided if [ $# -gt 0 ]; then