Skip to content

JetStream changes for Jax based implementation of unified_lora_params for decoding batch of multiple different lora adapters. #222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
f0f295a
Extra logging for understanding the workflow
aman2930 Jan 6, 2025
610fcea
Updating checkpoint conversion script to support LoRA weights convers…
aman2930 Jan 8, 2025
50deb3e
Cleaning up of loggings and some refactoring to make the script work …
aman2930 Jan 22, 2025
7426ea7
1) Added MultiAdapterManager service proto along with the methods Lis…
aman2930 Jan 27, 2025
fb88eca
1) Implemented adapter_tensorstore module to store and manage the ada…
aman2930 Feb 18, 2025
3c6fcbd
1) Implemented a new Service API proto to align with OpenAI completio…
aman2930 Feb 24, 2025
316c490
Adding following metrics into JetStream server:
aman2930 Feb 26, 2025
e4d875a
Refactoring and cleaning of the JetStream server code.
aman2930 Mar 6, 2025
eb74d86
Refactoring part-2.
aman2930 Mar 6, 2025
a41e4cd
Refactor part-3.
aman2930 Mar 6, 2025
26b1f37
Merging main to amangu-lora.
aman2930 Mar 6, 2025
febaed1
1) Adding more comments at applying LoRA on Prefill params path.
aman2930 Mar 6, 2025
ed66fdf
Fixing TypeCheck errors.
aman2930 Mar 6, 2025
a6a5cd1
Fixing linting error.
aman2930 Mar 7, 2025
e4d22bf
JetStream changes for Jax based implementation of unified_lora_params…
aman2930 Mar 7, 2025
bd67171
Adding documentations.
aman2930 Mar 7, 2025
1059978
Adding more doc strings.
aman2930 Mar 10, 2025
5f679a9
- Created separate adapter_tensorstore for each engine.
aman2930 Mar 13, 2025
5bda29b
Refactoring and fixing lint errors.
aman2930 Mar 13, 2025
6ce324c
Merging 'main' to 'amangu-lora'
aman2930 Mar 13, 2025
fe1511c
Changes to resolve comments on the PR.
aman2930 Mar 13, 2025
f0da2b9
Fixing Unit test errors.
aman2930 Mar 13, 2025
2c2850b
Merge amangu-lora to amangu-lora-3 branch
aman2930 Mar 14, 2025
453db81
Fixing some failures due to merge conflicts.
aman2930 Mar 15, 2025
e58aa50
Merge branch 'main' into amangu-lora-3
aman2930 Apr 29, 2025
a38b686
Fixing some missed changes in last merge with main.
aman2930 May 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,10 @@ def _insert_if_possible(
# Check if there are any free my_slots. We don't want to block here since
# we can still generate if we can't insert. We do this in a while loop to
# insert as many sequences as possible.
adapter_tensorstore = None
if self._generate_adapterstore:
adapter_tensorstore = self._generate_adapterstore[idx]

while True:
my_slots_size = my_slots.qsize()

Expand Down Expand Up @@ -1100,6 +1104,11 @@ def _insert_if_possible(
slot=slot,
# request_id=new_request.request_id,
)

if adapter_tensorstore:
adapter_tensorstore.insert_adapter_in_cache(
new_request.adapter_id, slot)

ThreadDebugLog(
thread_name,
f"Generate slice {idx} filled slot {slot} at step "
Expand Down Expand Up @@ -1239,6 +1248,10 @@ def _generate_thread(self, idx: int):
my_generate_backlog = self._generate_backlogs[idx]
my_detokenize_backlog = self._detokenize_backlogs[idx]

adapter_tensorstore = None
if self._generate_adapterstore and idx < len(self._generate_adapterstore):
adapter_tensorstore = self._generate_adapterstore[idx]

# Keep track of what step tokens were generated at.
generate_timestep = 0
# State to store things like running kv cache in.
Expand Down Expand Up @@ -1304,9 +1317,14 @@ def _generate_thread(self, idx: int):
my_slots.qsize() < max_concurrent_decodes
), "At this point we must have some requests inserted into the slots."

decoding_adapters_cache = None

if adapter_tensorstore:
decoding_adapters_cache = adapter_tensorstore.decoding_adapters_cache

# Now we actually take a generate step on requests in the slots.
decode_state, sampled_tokens = generate_engine.generate(
generate_params, decode_state
generate_params, decode_state, decoding_adapters_cache
)
sampled_tokens.copy_to_host_async()
# Respond to detokenization backpressure.
Expand Down
1 change: 1 addition & 0 deletions jetstream/engine/engine_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def generate(
params: Params,
decode_state: DecodeState,
sampler: Optional[Callable[[Any], Any]] = None,
lora_params: Params = None,
) -> Tuple[DecodeState, ResultTokens]:
"""Generates tokens for each sequence being decoded in parallel.

Expand Down
Loading