Skip to content

Commit 8fea82e

Browse files
committed
wire to grpc
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 01e2e3d commit 8fea82e

File tree

1 file changed

+58
-7
lines changed

1 file changed

+58
-7
lines changed

backend/cpp/llama/grpc-server.cpp

+58-7
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,12 @@ struct llama_server_context
536536
return false;
537537
}
538538

539+
// Enable reranking if embeddings are enabled - moved after context initialization
540+
if (params.embedding) {
541+
params.reranking = true;
542+
LOG_INFO("Reranking enabled (embeddings are enabled)", {});
543+
}
544+
539545
if (multimodal) {
540546
const int n_embd_clip = clip_n_mmproj_embd(clp_ctx);
541547
const int n_embd_llm = llama_model_n_embd(model);
@@ -1424,11 +1430,16 @@ struct llama_server_context
14241430

14251431
float score = -1e6f; // Default score if we fail to get embeddings
14261432

1427-
if (!params.rerank)
1433+
if (!params.reranking)
14281434
{
14291435
LOG_WARNING("reranking disabled", {
1430-
{"params.rerank", params.rerank},
1431-
});
1436+
{"params.reranking", params.reranking},
1437+
});
1438+
}
1439+
else if (ctx == nullptr)
1440+
{
1441+
LOG_ERR("context is null, cannot perform reranking");
1442+
res.error = true;
14321443
}
14331444
else
14341445
{
@@ -1455,7 +1466,7 @@ struct llama_server_context
14551466
res.result_json = json
14561467
{
14571468
{"score", score},
1458-
{"tokens", slot.n_prompt_tokens}
1469+
{"tokens", slot.num_prompt_tokens}
14591470
};
14601471

14611472
queue_results.send(res);
@@ -2547,7 +2558,7 @@ class BackendServiceImpl final : public backend::Backend::Service {
25472558
json data = parse_options(true, request, llama);
25482559
const int task_id = llama.queue_tasks.get_new_id();
25492560
llama.queue_results.add_waiting_task_id(task_id);
2550-
llama.request_completion(task_id, data, false, false, -1);
2561+
llama.request_completion(task_id, data, false, false, false, -1);
25512562
while (true)
25522563
{
25532564
task_result result = llama.queue_results.recv(task_id);
@@ -2601,7 +2612,7 @@ class BackendServiceImpl final : public backend::Backend::Service {
26012612
json data = parse_options(false, request, llama);
26022613
const int task_id = llama.queue_tasks.get_new_id();
26032614
llama.queue_results.add_waiting_task_id(task_id);
2604-
llama.request_completion(task_id, data, false, false, -1);
2615+
llama.request_completion(task_id, data, false, false, false, -1);
26052616
std::string completion_text;
26062617
task_result result = llama.queue_results.recv(task_id);
26072618
if (!result.error && result.stop) {
@@ -2638,7 +2649,7 @@ class BackendServiceImpl final : public backend::Backend::Service {
26382649
json data = parse_options(false, request, llama);
26392650
const int task_id = llama.queue_tasks.get_new_id();
26402651
llama.queue_results.add_waiting_task_id(task_id);
2641-
llama.request_completion(task_id, { {"prompt", data["embeddings"]}, { "n_predict", 0}, {"image_data", ""} }, false, true, -1);
2652+
llama.request_completion(task_id, { {"prompt", data["embeddings"]}, { "n_predict", 0}, {"image_data", ""} }, false, true, false, -1);
26422653
// get the result
26432654
task_result result = llama.queue_results.recv(task_id);
26442655
//std::cout << "Embedding result JSON" << result.result_json.dump() << std::endl;
@@ -2670,6 +2681,46 @@ class BackendServiceImpl final : public backend::Backend::Service {
26702681
return grpc::Status::OK;
26712682
}
26722683

2684+
grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) {
2685+
// Create a JSON object with the query and documents
2686+
json data = {
2687+
{"prompt", request->query()},
2688+
{"documents", request->documents()},
2689+
{"top_n", request->top_n()}
2690+
};
2691+
2692+
// Generate a new task ID
2693+
const int task_id = llama.queue_tasks.get_new_id();
2694+
llama.queue_results.add_waiting_task_id(task_id);
2695+
2696+
// Queue the task with reranking mode enabled
2697+
llama.request_completion(task_id, data, false, false, true, -1);
2698+
2699+
// Get the result
2700+
task_result result = llama.queue_results.recv(task_id);
2701+
llama.queue_results.remove_waiting_task_id(task_id);
2702+
2703+
if (!result.error && result.stop) {
2704+
// Set usage information
2705+
backend::Usage* usage = rerankResult->mutable_usage();
2706+
usage->set_total_tokens(result.result_json.value("tokens", 0));
2707+
usage->set_prompt_tokens(result.result_json.value("tokens", 0));
2708+
2709+
// Get the score from the result
2710+
float score = result.result_json.value("score", 0.0f);
2711+
2712+
// Create document results for each input document
2713+
for (int i = 0; i < request->documents_size(); i++) {
2714+
backend::DocumentResult* doc_result = rerankResult->add_results();
2715+
doc_result->set_index(i);
2716+
doc_result->set_text(request->documents(i));
2717+
doc_result->set_relevance_score(score);
2718+
}
2719+
}
2720+
2721+
return grpc::Status::OK;
2722+
}
2723+
26732724
grpc::Status GetMetrics(ServerContext* context, const backend::MetricsRequest* request, backend::MetricsResponse* response) {
26742725
llama_client_slot* active_slot = llama.get_active_slot();
26752726

0 commit comments

Comments
 (0)