@@ -423,17 +423,17 @@ void StatefulOVInferRequest::PreProcessInferRequest() {
423423 // TODO(ankit): Address this issue and implement the fix at the appropriate layer.
424424 FillTensor (" beam_idx" , ov::element::i32 , {1 }, 0 );
425425
426- if (src_idx_val .size () > 0 ) {
427- ov::Tensor src_idx_tensor = ov::Tensor (ov::element::i32 , {src_idx_val .size ()});
428- for (int i = 0 ; i < src_idx_val .size (); ++i) {
429- src_idx_tensor.data <int32_t >()[i] = int32_t (src_idx_val [i]);
426+ if (kv_src_indices .size () > 0 ) {
427+ ov::Tensor src_idx_tensor = ov::Tensor (ov::element::i32 , {kv_src_indices .size ()});
428+ for (int i = 0 ; i < kv_src_indices .size (); ++i) {
429+ src_idx_tensor.data <int32_t >()[i] = int32_t (kv_src_indices [i]);
430430 }
431431 ovInfReq.set_tensor (" src_idx" , src_idx_tensor);
432- ov::Tensor dst_idx_tensor = ov::Tensor (ov::element::i32 , {1 , 32 , dst_idx_val .size (), 96 });
433- for (int i = 0 ; i < dst_idx_val .size (); ++i) {
432+ ov::Tensor dst_idx_tensor = ov::Tensor (ov::element::i32 , {1 , 32 , kv_dst_indices .size (), 96 });
433+ for (int i = 0 ; i < kv_dst_indices .size (); ++i) {
434434 for (int j = 0 ; j < 32 ; ++j) {
435435 for (int k = 0 ; k < 96 ; ++k) {
436- dst_idx_tensor.data <int32_t >()[(j * dst_idx_val .size () + i) * 96 + k] = int32_t (dst_idx_val [i]);
436+ dst_idx_tensor.data <int32_t >()[(j * kv_dst_indices .size () + i) * 96 + k] = int32_t (kv_dst_indices [i]);
437437 }
438438 }
439439 }
@@ -492,25 +492,12 @@ void StatefulOVInferRequest::ReorderKVCache(const std::vector<size_t>& src_indic
492492 LOGS_DEFAULT (INFO) << log_tag << " ReorderKVCache: Reordering OpenVINO-internal KVCache state with "
493493 << src_indices.size () << " index pairs" ;
494494
495- // set beam_idx and dst_idx based on provided values
496- src_idx_val.clear ();
497- dst_idx_val.clear ();
495+ kv_src_indices.clear ();
496+ kv_dst_indices.clear ();
498497 for (int i = 0 ; i < src_indices.size (); ++i) {
499- src_idx_val .emplace_back (src_indices[i]);
500- dst_idx_val .emplace_back (dst_indices[i]);
498+ kv_src_indices .emplace_back (src_indices[i]);
499+ kv_dst_indices .emplace_back (dst_indices[i]);
501500 }
502- /*
503- // Retrieve KVCache states and reorder them based on the provided indices
504- auto states = ovInfReq.query_state();
505-
506- for (auto& state : states) {
507- auto start_time = std::chrono::high_resolution_clock::now();
508- state.gather_by_axis(src_indices, dst_indices);
509- auto end_time = std::chrono::high_resolution_clock::now();
510- auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
511- LOGS_DEFAULT(INFO) << log_tag << "gather_by_axis: " << duration << " microseconds";
512- }
513- */
514501}
515502
516503void StatefulOVInferRequest::RewindKVCache (size_t index) {
0 commit comments