Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/sgl_jax/srt/layers/attention/base_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class AttentionBackend(nnx.Module):
"""The base class of attention backends"""

@abstractmethod
def get_forward_metadata(self, batch: ModelWorkerBatch, mesh: Mesh):
def get_forward_metadata(self, batch: ModelWorkerBatch):
"""Init the metadata for a forward pass and return it"""
raise NotImplementedError()

Expand Down
11 changes: 7 additions & 4 deletions python/sgl_jax/srt/layers/attention/flashattention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
from jax.tree_util import register_pytree_node_class

Expand Down Expand Up @@ -75,6 +75,7 @@ def __init__(
vmem_limit_bytes: int = 64 * (1 << 20), # 64MB
page_size: int = 1,
kv_partition_axis: str = "tensor",
mesh: jax.sharding.Mesh = None,
):
self.vmem_limit_bytes = vmem_limit_bytes
self.num_heads = num_attn_heads
Expand All @@ -86,8 +87,9 @@ def __init__(
self.page_size = page_size
self.kv_partition_axis = kv_partition_axis
self.forward_metadata = FlashAttentionMetadata()
self.mesh = mesh

def get_forward_metadata(self, batch: ModelWorkerBatch, mesh: Mesh):
def get_forward_metadata(self, batch: ModelWorkerBatch):
"""Return the metadata for a forward pass."""
metadata = FlashAttentionMetadata()

Expand Down Expand Up @@ -148,8 +150,10 @@ def get_forward_metadata(self, batch: ModelWorkerBatch, mesh: Mesh):
metadata.seq_lens,
metadata.distribution,
) = device_array(
mesh,
(num_seqs, cu_q_lens, cu_kv_lens, page_indices, seq_lens, distribution),
sharding=(
NamedSharding(self.mesh, P()) if jax.process_count() == 1 else None
),
)
return metadata

Expand Down Expand Up @@ -255,7 +259,6 @@ def _ragged_paged_attention_with_fused_kv(*args):
updated_kv_cache_fused,
) = jax.shard_map( # Fused KV kernel handles cache updates internally
_ragged_paged_attention_with_fused_kv,
mesh=jax.sharding.get_abstract_mesh(),
in_specs=in_specs,
out_specs=out_specs,
check_vma=False,
Expand Down
2 changes: 1 addition & 1 deletion python/sgl_jax/srt/layers/attention/native_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def tree_unflatten(cls, aux_data, children):
num_attn_heads=aux_data["num_heads"], num_kv_heads=aux_data["num_kv_heads"]
)

def get_forward_metadata(self, batch: ModelWorkerBatch, mesh: Mesh):
def get_forward_metadata(self, batch: ModelWorkerBatch):
"""Init the metadata for a forward pass and return it."""
return None

Expand Down
15 changes: 6 additions & 9 deletions python/sgl_jax/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import jax.numpy as jnp
import numpy as np
from flax import nnx
from jax.sharding import Mesh
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jax.tree_util import register_pytree_node_class

from sgl_jax.srt.layers.embeddings import Embed
Expand Down Expand Up @@ -194,15 +195,15 @@ def from_model_worker_batch(cls, batch: ModelWorkerBatch, mesh: Mesh = None):
) = False
extend_logprob_pruned_lens_cpu = extend_seq_lens_cpu = None

mesh = mesh if mesh is not None else jax.sharding.get_abstract_mesh()
sharding = NamedSharding(mesh, P()) if jax.process_count() == 1 else None

return cls(
forward_mode=batch.forward_mode,
capture_hidden_mode=batch.capture_hidden_mode,
extend_return_logprob=extend_return_logprob,
extend_return_top_logprob=extend_return_top_logprob,
extend_token_ids_logprob=extend_token_ids_logprob,
extend_seq_lens=device_array(mesh, batch.extend_seq_lens),
extend_seq_lens=device_array(batch.extend_seq_lens, sharding=sharding),
extend_seq_lens_cpu=extend_seq_lens_cpu,
extend_logprob_start_lens_cpu=(
batch.extend_logprob_start_lens if batch.return_logprob else None
Expand All @@ -211,8 +212,7 @@ def from_model_worker_batch(cls, batch: ModelWorkerBatch, mesh: Mesh = None):
top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs,
extend_input_logprob_token_ids_device=device_array(
mesh,
batch.extend_input_logprob_token_ids,
batch.extend_input_logprob_token_ids, sharding=sharding
),
)

Expand Down Expand Up @@ -279,14 +279,12 @@ def __call__(

pruned_states = jnp.concat(pruned_states)
sample_indices = device_array(
self.mesh,
np.array(
sample_indices,
dtype=jnp.int64,
),
)
input_logprob_indices = device_array(
self.mesh,
np.array(input_logprob_indices, dtype=jnp.int64),
)

Expand Down Expand Up @@ -324,7 +322,6 @@ def __call__(

# Normalize the logprob w/o temperature, top-p
pruned_lens = device_array(
self.mesh,
np.array(
logits_metadata.extend_logprob_pruned_lens_cpu,
),
Expand Down Expand Up @@ -362,7 +359,7 @@ def __call__(
input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None

input_token_logprobs = input_logprobs[
device_array(self.mesh, np.arange(input_logprobs.shape[0])),
device_array(np.arange(input_logprobs.shape[0])),
logits_metadata.extend_input_logprob_token_ids_device,
]

Expand Down
2 changes: 1 addition & 1 deletion python/sgl_jax/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def run_detokenizer_process(
port_args: PortArgs,
):
kill_itself_when_parent_died()
setproctitle.setproctitle("sglang::detokenizer")
setproctitle.setproctitle("sglang-jax::detokenizer")
configure_logger(server_args)
parent_process = psutil.Process().parent()

Expand Down
37 changes: 30 additions & 7 deletions python/sgl_jax/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ModelWorkerBatch,
global_server_args_dict,
)
from sgl_jax.srt.managers.utils import resolve_future_token_ids, set_future_token_ids
from sgl_jax.srt.mem_cache.memory_pool import ReqToTokenPool
from sgl_jax.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
Expand Down Expand Up @@ -184,11 +185,11 @@ def normalize_token_paddings(self):

self.precompile_token_paddings = normalized_token_paddings

def run_precompile(self):
self.precompile_extend()
self.precompile_decode()
def run_precompile(self, future_token_ids_map=None):
self.precompile_extend(future_token_ids_map)
self.precompile_decode(future_token_ids_map)

def precompile_extend(self):
def precompile_extend(self, future_token_ids_map=None):
start_time = time.perf_counter()
logger.info(
f"[EXTEND] Begin to precompile bs_paddings={self.precompile_bs_paddings[-1:]} token_paddings={self.precompile_token_paddings}"
Expand All @@ -211,12 +212,22 @@ def precompile_extend(self):
ForwardMode.EXTEND,
self.precompile_cache_loc_paddings[-1],
)
model_worker_batch.forward_batch = ForwardBatch.init_new(
model_worker_batch, self.model_runner
)
if future_token_ids_map is not None:
model_worker_batch.forward_batch.input_ids = (
resolve_future_token_ids(
model_worker_batch.forward_batch.input_ids,
future_token_ids_map,
)
)
self.forward_batch_generation(model_worker_batch, None, True)

end_time = time.perf_counter()
logger.info("[EXTEND] Precompile finished in %.0f secs", end_time - start_time)

def precompile_decode(self):
def precompile_decode(self, future_token_ids_map=None):
start_time = time.perf_counter()
logger.info(
f"[DECODE] Begin to precompile bs_paddings={self.precompile_bs_paddings}"
Expand All @@ -236,9 +247,21 @@ def precompile_decode(self):
sampling_metadata = SamplingMetadata.from_model_worker_batch(
model_worker_batch, 0, self.mesh
)
self.forward_batch_generation(
model_worker_batch.forward_batch = ForwardBatch.init_new(
model_worker_batch, self.model_runner
)
if future_token_ids_map is not None:
model_worker_batch.forward_batch.input_ids = (
resolve_future_token_ids(
model_worker_batch.forward_batch.input_ids,
future_token_ids_map,
)
)
_, next_token_ids, _ = self.forward_batch_generation(
model_worker_batch, None, False, sampling_metadata
)
if future_token_ids_map is not None:
set_future_token_ids(future_token_ids_map, 0, next_token_ids)

end_time = time.perf_counter()
logger.info("[DECODE] Precompile finished in %.0f secs", end_time - start_time)
Expand Down Expand Up @@ -365,7 +388,7 @@ def forward_batch_generation(
if forward_metadata is None:
forward_metadata = (
self.worker.model_runner.attn_backend.get_forward_metadata(
model_worker_batch, self.mesh
model_worker_batch
)
)

Expand Down
55 changes: 3 additions & 52 deletions python/sgl_jax/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import signal
import threading
import time
from queue import Queue
from typing import Optional, Tuple

Expand All @@ -16,34 +15,14 @@

from sgl_jax.srt.managers.schedule_batch import ModelWorkerBatch
from sgl_jax.srt.managers.tp_worker import ModelWorker
from sgl_jax.srt.managers.utils import resolve_future_token_ids, set_future_token_ids
from sgl_jax.srt.sampling.sampling_batch_info import SamplingMetadata
from sgl_jax.srt.server_args import ServerArgs
from sgl_jax.srt.utils.jax_utils import device_array
from sgl_jax.utils import get_exception_traceback

logger = logging.getLogger(__name__)


@jax.jit
def resolve_future_token_ids(input_ids, future_token_ids_map):
return jnp.where(
input_ids < 0,
future_token_ids_map[jnp.clip(-input_ids, a_min=0)],
input_ids,
)


@jax.jit
def set_future_token_ids(future_token_ids_map, future_token_ids_ct, next_token_ids):
# The start index must be a tuple, one element per dimension of the array.
start_indices = (future_token_ids_ct + 1,)

# jax.lax.dynamic_update_slice is the correct tool for this job.
return jax.lax.dynamic_update_slice(
future_token_ids_map, next_token_ids, start_indices
)


class ModelWorkerClient:
"""A tensor parallel model worker."""

Expand Down Expand Up @@ -187,7 +166,7 @@ def forward_batch_generation(
)

forward_metadata = self.worker.model_runner.attn_backend.get_forward_metadata(
model_worker_batch, self.mesh
model_worker_batch
)

# Push a new batch to the queue (JAX handles synchronization automatically)
Expand Down Expand Up @@ -215,35 +194,7 @@ def forward_batch_generation(
return None, future_next_token_ids, 0

def run_precompile(self):
start_time = time.perf_counter()
logger.info(
f"[ModelWorkerClient] Begins to run resolve_future_token_ids precompile."
)
(
precompile_token_paddings,
precompile_bs_paddings,
_,
) = self.get_precompile_paddings()
max_padding_bs, _ = self.get_max_padded_size()
bs_paddings = sorted(set(precompile_bs_paddings + [max_padding_bs]))
token_paddings = sorted(set(bs_paddings + precompile_token_paddings))
for token_padding in token_paddings:
input_ids = device_array(
self.worker.mesh, jnp.arange(0, token_padding, dtype=jnp.int32)
)
resolve_future_token_ids(input_ids, self.future_token_ids_map)
for bs_padding in bs_paddings:
input_ids = device_array(
self.worker.mesh, jnp.arange(0, bs_padding, dtype=jnp.int32)
)
set_future_token_ids(
self.future_token_ids_map, self.future_token_ids_ct, input_ids
)
end_time = time.perf_counter()
logger.info(
f"[ModelWorkerClient] Completes resolve_future_token_ids precompile. Time cost: {end_time - start_time} seconds"
)
self.worker.run_precompile()
self.worker.run_precompile(self.future_token_ids_map)

def __delete__(self):
self.input_queue.put((None, None, None, None))
20 changes: 20 additions & 0 deletions python/sgl_jax/srt/managers/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import logging
from typing import Optional

import jax
from jax import numpy as jnp

from sgl_jax.srt.managers.schedule_batch import Req

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -37,3 +40,20 @@ def validate_input_length(
return error_msg

return None


@jax.jit
def resolve_future_token_ids(input_ids, future_token_ids_map):
return jnp.where(
input_ids < 0,
future_token_ids_map[jnp.clip(-input_ids, a_min=0)],
input_ids,
)


@jax.jit
def set_future_token_ids(future_token_ids_map, future_token_ids_ct, next_token_ids):
start_indices = (future_token_ids_ct + 1,)
return jax.lax.dynamic_update_slice(
future_token_ids_map, next_token_ids, start_indices
)
Loading