102102from jetstream .core .metrics .prometheus import JetstreamMetricsCollector
103103import numpy as np
104104
105- log_level = os .getenv ("LOG_LEVEL" , "WARNING" ).upper ()
105+ from jax .experimental import layout as jax_layout
106+ DLL = jax_layout .DeviceLocalLayout
107+ Layout = jax_layout .Layout
108+
109+ log_level = os .getenv ("LOG_LEVEL" , "DEBUG" ).upper ()
106110
107111logger = logging .getLogger ("JetstreamLogger" )
108112logger .propagate = False
@@ -405,6 +409,26 @@ def __init__(
405409
406410 self ._jax_padding = jax_padding
407411
412+ ##### Auto layout compile for interleaved engine
413+ self ._generate_executables = [None for _ in self ._generate_engines ]
414+ self ._cached_insert = [None for _ in self ._generate_engines ]
415+ self ._cached_prefill = [None for _ in self ._prefill_engines ]
416+ if self ._interleaved_mode :
417+ for idx in range (len (self ._generate_engines )):
418+ logger .debug ("Compiling interleaved engine {}" .format (idx ))
419+ engine = self ._generate_engines [idx ]
420+ params = self ._generate_params [idx ]
421+ engine , params , gen_fn , prefill_fn , insert_fn = self ._auto_layout_compile (engine , params )
422+
423+ self ._prefill_engines [idx ] = engine
424+ self ._generate_engines [idx ] = engine
425+ self ._prefill_params [idx ] = params
426+ self ._generate_params [idx ] = params
427+ self ._cached_prefill [idx ] = prefill_fn
428+ self ._cached_insert [idx ] = insert_fn
429+ self ._generate_executables [idx ] = gen_fn
430+
431+
408432 # Create all threads
409433 self ._prefill_threads = [
410434 JetThread (
@@ -670,6 +694,56 @@ def _do_chunked_prefill(
670694
671695 return prefill_result , first_token
672696
697+ def _auto_layout_compile (self , engine , params ):
698+ logger .debug ("Compiling generate function" )
699+ generate_executable , params , decode_state_executable = engine .aot_compile (
700+ params , pass_rng_shape = False
701+ )
702+ decode_state = decode_state_executable (None )
703+
704+ # prefill
705+ interesting_buckets = [
706+ 64 ,
707+ 128 ,
708+ 256 ,
709+ 512 ,
710+ 1024 ,
711+ ]
712+
713+ cached_prefill = {}
714+ cached_insert = {}
715+ for length in interesting_buckets :
716+ i32_scalar = jax .ShapeDtypeStruct ((), int )
717+ logger .debug ("Compiling prefill: %d" , length )
718+ input_data = jax .ShapeDtypeStruct ((length ,), jax .numpy .dtype ("int32" ))
719+
720+ cached_prefill [length ] = (
721+ jax .jit (
722+ engine .prefill_aot ,
723+ in_shardings = (engine .param_layouts , None , None ),
724+ out_shardings = (Layout (DLL .AUTO ), Layout (DLL .AUTO )),
725+ ).lower (params , input_data , i32_scalar )
726+ ).compile (compiler_options = None )
727+
728+ logger .debug ("Generate dummy prefix: %d" , length )
729+ dummy_tokens = jax .numpy .ones (shape = (length ,), dtype = jax .numpy .dtype ("int32" ))
730+ prefix_shapes = jax .eval_shape (engine .prefill_aot , params , dummy_tokens , 1 )
731+
732+ logger .debug ("Compiling insert: %d" , length )
733+ prefill_output_layout , _ = cached_prefill [length ].output_layouts
734+ logger .debug ("Prefill output layout: {}" .format (prefill_output_layout ))
735+ logger .debug ("Prefix shapes: {}" .format (prefix_shapes ))
736+ i32_scalar = jax .ShapeDtypeStruct ((), int )
737+ cached_insert [length ] = (
738+ jax .jit (
739+ engine .insert ,
740+ in_shardings = (prefill_output_layout , engine .decode_state_layouts , None ),
741+ out_shardings = (engine .decode_state_layouts ),
742+ donate_argnames = ("decode_state" ),
743+ ).lower (prefix_shapes [0 ], engine .decode_state_shapes , i32_scalar )
744+ ).compile (compiler_options = None )
745+ return engine , params , generate_executable , cached_prefill , cached_insert
746+
673747 def _prefill_thread (self , idx : int ):
674748 """Thread which runs in the background performing prefills."""
675749 logger .info ("Spinning up prefill thread %d." , idx )
@@ -683,6 +757,12 @@ def _prefill_thread(self, idx: int):
683757 thread_name = f"Prefill thread { idx } "
684758 ThreadDebugLog (thread_name , f"Prefill params { idx } loaded." )
685759
760+ if not self .interleaved :
761+ prefill_engine , prefill_params , gen_fn , prefill_fn , insert_fn = self ._auto_layout_compile (
762+ prefill_engine , prefill_params
763+ )
764+ self ._cached_prefill [idx ] = prefill_fn
765+
686766 while self .live :
687767 my_transfer_backlog = self ._transfer_backlogs [idx ]
688768 # The prefill thread can just sleep until it has work to do.
@@ -759,10 +839,11 @@ def _prefill_thread(self, idx: int):
759839 )
760840 else :
761841 # Compute new kv cache for the prefill_content.
762- prefill_result , first_token = prefill_engine .prefill (
763- params = final_prefill_params ,
764- padded_tokens = padded_tokens ,
765- true_length = true_length ,
842+ assert padded_tokens .shape [0 ] in self ._cached_prefill [idx ]
843+ prefill_result , first_token = self ._cached_prefill [idx ][padded_tokens .shape [0 ]](
844+ final_prefill_params ,
845+ padded_tokens ,
846+ true_length ,
766847 )
767848
768849 request .complete = np .zeros (
@@ -967,10 +1048,11 @@ def _insert_if_possible(
9671048 else :
9681049 break
9691050
970- decode_state = generate_engine .insert (
1051+ length = new_request .prefill_result ['cache' ]['decoder' ]['layers_0' ]['self_attention' ]['KVCache_0' ]['cache_prefill_segment_id' ].value .shape [1 ]
1052+ decode_state = self ._cached_insert [idx ][length ](
9711053 new_request .prefill_result ,
9721054 decode_state ,
973- slot = slot ,
1055+ slot ,
9741056 # request_id=new_request.request_id,
9751057 )
9761058 ThreadDebugLog (
@@ -1115,9 +1197,15 @@ def _generate_thread(self, idx: int):
11151197 # Keep track of what step tokens were generated at.
11161198 generate_timestep = 0
11171199 # State to store things like running kv cache in.
1118- decode_state = generate_engine .init_decode_state ()
1119-
1200+ decode_state = self .decode_state
11201201 generate_params = self ._generate_params [idx ]
1202+
1203+ if not self .interleaved :
1204+ generate_engine , generate_params , gen_fn , prefill_fn , insert_fn = self ._auto_layout_compile (
1205+ generate_engine , generate_params
1206+ )
1207+ self ._generate_executables [idx ] = gen_fn
1208+
11211209 thread_name = f"Generate thread { idx } "
11221210 ThreadDebugLog (thread_name , f"Generate params { idx } loaded." )
11231211 time_of_last_generate = time .time ()
@@ -1178,8 +1266,8 @@ def _generate_thread(self, idx: int):
11781266 ), "At this point we must have some requests inserted into the slots."
11791267
11801268 # Now we actually take a generate step on requests in the slots.
1181- decode_state , sampled_tokens = generate_engine . generate (
1182- generate_params , decode_state
1269+ decode_state , sampled_tokens = self . _generate_executables [ idx ] (
1270+ generate_params , decode_state , None
11831271 )
11841272 sampled_tokens .copy_to_host_async ()
11851273 # Respond to detokenization backpressure.
0 commit comments