Skip to content

Commit e4fa894

Browse files
authored
fix(llama.cpp): correctly handle embeddings in batches (#4957)
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 69caccf commit e4fa894

File tree

1 file changed

+32
-4
lines changed

1 file changed

+32
-4
lines changed

backend/cpp/llama/grpc-server.cpp

+32-4
Original file line numberDiff line numberDiff line change
@@ -1350,7 +1350,7 @@ struct llama_server_context
13501350
queue_results.send(res);
13511351
}
13521352

1353-
void send_embedding(llama_client_slot &slot)
1353+
void send_embedding(llama_client_slot &slot, const llama_batch & batch)
13541354
{
13551355
task_result res;
13561356
res.id = slot.task_id;
@@ -1372,10 +1372,38 @@ struct llama_server_context
13721372
else
13731373
{
13741374
const float *data = llama_get_embeddings(ctx);
1375-
std::vector<float> embedding(data, data + n_embd);
1375+
std::vector<float> embd_res(n_embd, 0.0f);
1376+
std::vector<std::vector<float>> embedding;
1377+
for (int i = 0; i < batch.n_tokens; ++i) {
1378+
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
1379+
continue;
1380+
}
1381+
1382+
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
1383+
if (embd == NULL) {
1384+
embd = llama_get_embeddings_ith(ctx, i);
1385+
}
1386+
1387+
if (embd == NULL) {
1388+
LOG("failed to get embeddings");
1389+
1390+
continue;
1391+
}
1392+
1393+
// normalize only when there is pooling
1394+
// TODO: configurable
1395+
if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) {
1396+
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
1397+
embedding.push_back(embd_res);
1398+
} else {
1399+
embedding.push_back({ embd, embd + n_embd });
1400+
}
1401+
}
1402+
1403+
// OAI compat
13761404
res.result_json = json
13771405
{
1378-
{"embedding", embedding },
1406+
{"embedding", embedding[0] },
13791407
};
13801408
}
13811409
queue_results.send(res);
@@ -1996,7 +2024,7 @@ struct llama_server_context
19962024
// prompt evaluated for embedding
19972025
if (slot.embedding)
19982026
{
1999-
send_embedding(slot);
2027+
send_embedding(slot, batch_view);
20002028
slot.release();
20012029
slot.i_batch = -1;
20022030
continue;

0 commit comments

Comments
 (0)