Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions cpp/serve/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class ModelImpl : public ModelObj {
// Begin forward with the sequence ids and new lengths.
IntTuple seq_ids_tuple(seq_ids);
IntTuple lengths_tuple(lengths.begin(), lengths.end());
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple);
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_);

ObjectRef embeddings_dref_or_nd;
if (!embeddings->IsInstance<DRefObj>()) {
Expand Down Expand Up @@ -372,7 +372,7 @@ class ModelImpl : public ModelObj {
// Begin forward with the sequence ids and new lengths.
IntTuple seq_ids_tuple(seq_ids);
IntTuple lengths_tuple(lengths.begin(), lengths.end());
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple);
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_);

// args: embeddings, logit_pos, kv_cache, params
ObjectRef result{nullptr};
Expand Down Expand Up @@ -422,7 +422,7 @@ class ModelImpl : public ModelObj {
// Begin forward with the sequence ids and new lengths.
IntTuple seq_ids_tuple(seq_ids);
IntTuple lengths_tuple(std::vector<int64_t>(/*n=*/seq_ids.size(), /*v=*/1));
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple);
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_);

ObjectRef embeddings_dref_or_nd;
if (!embeddings->IsInstance<DRefObj>()) {
Expand Down Expand Up @@ -501,7 +501,7 @@ class ModelImpl : public ModelObj {
IntTuple seq_ids_tuple(seq_ids);
IntTuple lengths_tuple(lengths.begin(), lengths.end());
IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr);
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple,
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_,
token_tree_parent_ptr_tuple);

ObjectRef embeddings_dref_or_nd;
Expand Down Expand Up @@ -564,7 +564,7 @@ class ModelImpl : public ModelObj {
// Begin forward with the sequence ids and new lengths.
IntTuple seq_ids_tuple(seq_ids);
IntTuple lengths_tuple(std::vector<int64_t>(/*n=*/seq_ids.size(), /*v=*/1));
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple);
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_);

// args: embeddings, kv_cache, params
ObjectRef result{nullptr};
Expand Down Expand Up @@ -624,7 +624,7 @@ class ModelImpl : public ModelObj {
IntTuple seq_ids_tuple(seq_ids);
IntTuple lengths_tuple(lengths.begin(), lengths.end());
IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr);
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple,
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_,
token_tree_parent_ptr_tuple);

ObjectRef embeddings_dref_or_nd;
Expand Down Expand Up @@ -712,7 +712,7 @@ class ModelImpl : public ModelObj {
IntTuple seq_ids_tuple(seq_ids);
IntTuple lengths_tuple(lengths.begin(), lengths.end());
IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr);
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple,
ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple, seqlen_padding_factor_,
token_tree_parent_ptr_tuple);

// args: embeddings, logit_pos, kv_cache, params
Expand Down Expand Up @@ -827,7 +827,7 @@ class ModelImpl : public ModelObj {

// Run KV receive preparation.
ObjectRef ret;
ret = ft_.kv_cache_disagg_prepare_recv_func_(kv_cache_, seq_id, length).cast<ObjectRef>();
ret = ft_.kv_cache_disagg_prepare_recv_func_(kv_cache_, seq_id, length, seqlen_padding_factor_).cast<ObjectRef>();
IntTuple compressed_kv_append_metadata;
if (ft_.use_disco) {
compressed_kv_append_metadata = Downcast<DRef>(ret)->DebugGetFromRemote(0).cast<IntTuple>();
Expand Down
Loading