@@ -73,6 +73,11 @@ GgmlOvDecoder::GgmlOvDecoder(struct ggml_cgraph* cgraph,
7373}
7474
7575GgmlOvDecoder::GgmlOvDecoder (struct  ggml_cgraph * cgraph) {
76+     if  (getenv (" GGML_OPENVINO_DUMP_CGRAPH"  )) {
77+         std::string filename = " cgraph.txt"  ;
78+         dump_cgraph (cgraph, filename);
79+     }
80+ 
7681    m_cgraph = cgraph;
7782    for  (int  node_n = 0 ; node_n < cgraph->n_nodes ; node_n++) {
7883        auto * cur_node = cgraph->nodes [node_n];
@@ -173,49 +178,46 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
173178            break ;
174179        }
175180        case  GGML_OP_CONT: {
176-             if  (ggml_nelements (node->src [0 ]) == ggml_nelements (node->src [0 ]->view_src )) {
177-                 //  The input comes from a PERMUTE
178-                 m_op_case = 1 ;
179-             } else  {
180-                 //  The input comes from a VIEW which is subtensor
181-                 m_op_case = 2 ;
182-             }
183-             break ;
184-         }
185-         case  GGML_OP_SET_ROWS: {
186-             if  (std::string (node->name ).find (" cache_k"  ) == 0 ) {
181+             if  (node->src [0 ]->op  == GGML_OP_PERMUTE) {
187182                m_op_case = 1 ;
188-             } else  {
183+             } else  if  (node-> src [ 0 ]-> op  == GGML_OP_TRANSPOSE)  {
189184                m_op_case = 2 ;
185+             } else  if  (node->src [0 ]->op  == GGML_OP_VIEW) {
186+                 //  The input comes from a VIEW which is subtensor
187+                 m_op_case = 3 ;
190188            }
191189            break ;
192190        }
193191        case  GGML_OP_PERMUTE: {
194-             if  (node->src [0 ]->view_src  == nullptr ) {
195-                 //  Permute Qcur
192+             if  (node->src [0 ]->op  != GGML_OP_VIEW) {
196193                m_op_case = 1 ;
197194            } else  if  (ggml_is_contiguous (node->src [0 ])) {
198195                //  Permute cache_k (view)
199196                m_op_case = 2 ;
200197            } else  {
201-                 //  Permute cache_v (view)
198+                 //  Permute cache_v (view), deprecated, cache_v will also fall to case 2
199+                 m_op_case = 3 ;
200+             }
201+             break ;
202+         }
203+         case  GGML_OP_MUL_MAT: {
204+             if  (node->src [0 ]->op  == GGML_OP_CONT && node->src [0 ]->src [0 ]->op  == GGML_OP_TRANSPOSE) {
205+                 m_op_case = 2 ;
206+             } else  if  (node->src [0 ]->op  == GGML_OP_VIEW && node->src [1 ]->op  == GGML_OP_VIEW) {
207+                 //  test-backend-ops case
202208                m_op_case = 3 ;
203209            }
204210            break ;
205211        }
206212        case  GGML_OP_GET_ROWS: {
207213            if  (node->src [1 ]->op  == GGML_OP_VIEW) {
208214                m_op_case = 2 ;
209-             } else  {
210-                 m_op_case = 1 ;
211215            }
212216            break ;
213217        }
214218        case  GGML_OP_ROPE: {
215219            if  (node->src [0 ]->op  == GGML_OP_VIEW) {
216220                m_op_case = 2 ;
217-             } else  {
218-                 m_op_case = 1 ;
219221            }
220222            break ;
221223        }
@@ -270,19 +272,9 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
270272    } else  if  (name.find (" cache_k"  ) == 0 ) {
271273        input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size};
272274    } else  if  (name.find (" cache_v"  ) == 0 ) {
273-         input_shape = ov::PartialShape{m_num_heads_kv, m_head_size, m_context_size };
275+         input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size };
274276    } else  if  (const  auto * op = get_tensor_used_op (src); op && op->op  == GGML_OP_SET_ROWS) {
275-         input_shape = ov::PartialShape{1 , 1 , -1 };
276-         if  (m_is_static) {
277-             if  (m_is_first_token) {
278-                 //  Dummy static shape, since the indices are not used in this case
279-                 input_shape = ov::PartialShape{1 };
280-             } else  if  (std::string (op->name ).find (" cache_k"  ) == 0 ) {
281-                 input_shape = ov::PartialShape{1 , 1 , 1 };
282-             } else  {
283-                 input_shape = ov::PartialShape{1 , 1 , m_num_heads_kv * m_head_size};
284-             }
285-         }
277+         input_shape = ov::PartialShape{1 , 1 , m_is_static ? 1  : -1 };
286278    } else  if  (src->op  == GGML_OP_VIEW) {
287279        //  This case is added to make test-backend-ops work
288280        input_shape = ov::PartialShape{get_shape (src->view_src )};
0 commit comments