From cc7b78c35dee01c483386100186f248bb19a1b35 Mon Sep 17 00:00:00 2001 From: Joshua Jiahua Hong Date: Mon, 25 Aug 2025 04:00:25 -0400 Subject: [PATCH] Add sequence padding to BeginForward --- cpp/serve/model.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cpp/serve/model.cc b/cpp/serve/model.cc index 11c9f03995..22686588bf 100644 --- a/cpp/serve/model.cc +++ b/cpp/serve/model.cc @@ -262,7 +262,7 @@ class ModelImpl : public ModelObj { // Begin forward with the sequence ids and new lengths. IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(lengths.begin(), lengths.end()); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_); ObjectRef embeddings_dref_or_nd; if (!embeddings->IsInstance()) { @@ -372,7 +372,7 @@ class ModelImpl : public ModelObj { // Begin forward with the sequence ids and new lengths. IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(lengths.begin(), lengths.end()); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_); // args: embeddings, logit_pos, kv_cache, params ObjectRef result{nullptr}; @@ -422,7 +422,7 @@ class ModelImpl : public ModelObj { // Begin forward with the sequence ids and new lengths. IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(std::vector(/*n=*/seq_ids.size(), /*v=*/1)); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_); ObjectRef embeddings_dref_or_nd; if (!embeddings->IsInstance()) { @@ -501,7 +501,7 @@ class ModelImpl : public ModelObj { IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(lengths.begin(), lengths.end()); IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_, token_tree_parent_ptr_tuple); ObjectRef embeddings_dref_or_nd; @@ -564,7 +564,7 @@ class ModelImpl : public ModelObj { // Begin forward with the sequence ids and new lengths. IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(std::vector(/*n=*/seq_ids.size(), /*v=*/1)); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple); + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_); // args: embeddings, kv_cache, params ObjectRef result{nullptr}; @@ -624,7 +624,7 @@ class ModelImpl : public ModelObj { IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(lengths.begin(), lengths.end()); IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_, token_tree_parent_ptr_tuple); ObjectRef embeddings_dref_or_nd; @@ -712,7 +712,7 @@ class ModelImpl : public ModelObj { IntTuple seq_ids_tuple(seq_ids); IntTuple lengths_tuple(lengths.begin(), lengths.end()); IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr); - ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, + ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_, token_tree_parent_ptr_tuple); // args: embeddings, logit_pos, kv_cache, params @@ -827,7 +827,7 @@ class ModelImpl : public ModelObj { // Run KV receive preparation. ObjectRef ret; - ret = ft_.kv_cache_disagg_prepare_recv_func_(kv_cache_, seq_id, length).cast(); + ret = ft_.kv_cache_disagg_prepare_recv_func_(kv_cache_, seq_id, length, seqlen_padding_factor_).cast(); IntTuple compressed_kv_append_metadata; if (ft_.use_disco) { compressed_kv_append_metadata = Downcast(ret)->DebugGetFromRemote(0).cast();