@@ -1350,7 +1350,7 @@ struct llama_server_context
1350
1350
queue_results.send (res);
1351
1351
}
1352
1352
1353
- void send_embedding (llama_client_slot &slot)
1353
+ void send_embedding (llama_client_slot &slot, const llama_batch & batch )
1354
1354
{
1355
1355
task_result res;
1356
1356
res.id = slot.task_id ;
@@ -1372,10 +1372,38 @@ struct llama_server_context
1372
1372
else
1373
1373
{
1374
1374
const float *data = llama_get_embeddings (ctx);
1375
- std::vector<float > embedding (data, data + n_embd);
1375
+ std::vector<float > embd_res (n_embd, 0 .0f );
1376
+ std::vector<std::vector<float >> embedding;
1377
+ for (int i = 0 ; i < batch.n_tokens ; ++i) {
1378
+ if (!batch.logits [i] || batch.seq_id [i][0 ] != slot.id ) {
1379
+ continue ;
1380
+ }
1381
+
1382
+ const float * embd = llama_get_embeddings_seq (ctx, batch.seq_id [i][0 ]);
1383
+ if (embd == NULL ) {
1384
+ embd = llama_get_embeddings_ith (ctx, i);
1385
+ }
1386
+
1387
+ if (embd == NULL ) {
1388
+ LOG (" failed to get embeddings" );
1389
+
1390
+ continue ;
1391
+ }
1392
+
1393
+ // normalize only when there is pooling
1394
+ // TODO: configurable
1395
+ if (llama_pooling_type (ctx) != LLAMA_POOLING_TYPE_NONE) {
1396
+ common_embd_normalize (embd, embd_res.data (), n_embd, 2 );
1397
+ embedding.push_back (embd_res);
1398
+ } else {
1399
+ embedding.push_back ({ embd, embd + n_embd });
1400
+ }
1401
+ }
1402
+
1403
+ // OAI compat
1376
1404
res.result_json = json
1377
1405
{
1378
- {" embedding" , embedding },
1406
+ {" embedding" , embedding[ 0 ] },
1379
1407
};
1380
1408
}
1381
1409
queue_results.send (res);
@@ -1996,7 +2024,7 @@ struct llama_server_context
1996
2024
// prompt evaluated for embedding
1997
2025
if (slot.embedding )
1998
2026
{
1999
- send_embedding (slot);
2027
+ send_embedding (slot, batch_view );
2000
2028
slot.release ();
2001
2029
slot.i_batch = -1 ;
2002
2030
continue ;
0 commit comments