@@ -536,6 +536,12 @@ struct llama_server_context
536
536
return false ;
537
537
}
538
538
539
+ // Enable reranking if embeddings are enabled - moved after context initialization
540
+ if (params.embedding ) {
541
+ params.reranking = true ;
542
+ LOG_INFO (" Reranking enabled (embeddings are enabled)" , {});
543
+ }
544
+
539
545
if (multimodal) {
540
546
const int n_embd_clip = clip_n_mmproj_embd (clp_ctx);
541
547
const int n_embd_llm = llama_model_n_embd (model);
@@ -1424,11 +1430,16 @@ struct llama_server_context
1424
1430
1425
1431
float score = -1e6f; // Default score if we fail to get embeddings
1426
1432
1427
- if (!params.rerank )
1433
+ if (!params.reranking )
1428
1434
{
1429
1435
LOG_WARNING (" reranking disabled" , {
1430
- {" params.rerank" , params.rerank },
1431
- });
1436
+ {" params.reranking" , params.reranking },
1437
+ });
1438
+ }
1439
+ else if (ctx == nullptr )
1440
+ {
1441
+ LOG_ERR (" context is null, cannot perform reranking" );
1442
+ res.error = true ;
1432
1443
}
1433
1444
else
1434
1445
{
@@ -1455,7 +1466,7 @@ struct llama_server_context
1455
1466
res.result_json = json
1456
1467
{
1457
1468
{" score" , score},
1458
- {" tokens" , slot.n_prompt_tokens }
1469
+ {" tokens" , slot.num_prompt_tokens }
1459
1470
};
1460
1471
1461
1472
queue_results.send (res);
@@ -2547,7 +2558,7 @@ class BackendServiceImpl final : public backend::Backend::Service {
2547
2558
json data = parse_options (true , request, llama);
2548
2559
const int task_id = llama.queue_tasks .get_new_id ();
2549
2560
llama.queue_results .add_waiting_task_id (task_id);
2550
- llama.request_completion (task_id, data, false , false , -1 );
2561
+ llama.request_completion (task_id, data, false , false , false , -1 );
2551
2562
while (true )
2552
2563
{
2553
2564
task_result result = llama.queue_results .recv (task_id);
@@ -2601,7 +2612,7 @@ class BackendServiceImpl final : public backend::Backend::Service {
2601
2612
json data = parse_options (false , request, llama);
2602
2613
const int task_id = llama.queue_tasks .get_new_id ();
2603
2614
llama.queue_results .add_waiting_task_id (task_id);
2604
- llama.request_completion (task_id, data, false , false , -1 );
2615
+ llama.request_completion (task_id, data, false , false , false , -1 );
2605
2616
std::string completion_text;
2606
2617
task_result result = llama.queue_results .recv (task_id);
2607
2618
if (!result.error && result.stop ) {
@@ -2638,7 +2649,7 @@ class BackendServiceImpl final : public backend::Backend::Service {
2638
2649
json data = parse_options (false , request, llama);
2639
2650
const int task_id = llama.queue_tasks .get_new_id ();
2640
2651
llama.queue_results .add_waiting_task_id (task_id);
2641
- llama.request_completion (task_id, { {" prompt" , data[" embeddings" ]}, { " n_predict" , 0 }, {" image_data" , " " } }, false , true , -1 );
2652
+ llama.request_completion (task_id, { {" prompt" , data[" embeddings" ]}, { " n_predict" , 0 }, {" image_data" , " " } }, false , true , false , -1 );
2642
2653
// get the result
2643
2654
task_result result = llama.queue_results .recv (task_id);
2644
2655
// std::cout << "Embedding result JSON" << result.result_json.dump() << std::endl;
@@ -2670,6 +2681,46 @@ class BackendServiceImpl final : public backend::Backend::Service {
2670
2681
return grpc::Status::OK;
2671
2682
}
2672
2683
2684
+ grpc::Status Rerank (ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) {
2685
+ // Create a JSON object with the query and documents
2686
+ json data = {
2687
+ {" prompt" , request->query ()},
2688
+ {" documents" , request->documents ()},
2689
+ {" top_n" , request->top_n ()}
2690
+ };
2691
+
2692
+ // Generate a new task ID
2693
+ const int task_id = llama.queue_tasks .get_new_id ();
2694
+ llama.queue_results .add_waiting_task_id (task_id);
2695
+
2696
+ // Queue the task with reranking mode enabled
2697
+ llama.request_completion (task_id, data, false , false , true , -1 );
2698
+
2699
+ // Get the result
2700
+ task_result result = llama.queue_results .recv (task_id);
2701
+ llama.queue_results .remove_waiting_task_id (task_id);
2702
+
2703
+ if (!result.error && result.stop ) {
2704
+ // Set usage information
2705
+ backend::Usage* usage = rerankResult->mutable_usage ();
2706
+ usage->set_total_tokens (result.result_json .value (" tokens" , 0 ));
2707
+ usage->set_prompt_tokens (result.result_json .value (" tokens" , 0 ));
2708
+
2709
+ // Get the score from the result
2710
+ float score = result.result_json .value (" score" , 0 .0f );
2711
+
2712
+ // Create document results for each input document
2713
+ for (int i = 0 ; i < request->documents_size (); i++) {
2714
+ backend::DocumentResult* doc_result = rerankResult->add_results ();
2715
+ doc_result->set_index (i);
2716
+ doc_result->set_text (request->documents (i));
2717
+ doc_result->set_relevance_score (score);
2718
+ }
2719
+ }
2720
+
2721
+ return grpc::Status::OK;
2722
+ }
2723
+
2673
2724
grpc::Status GetMetrics (ServerContext* context, const backend::MetricsRequest* request, backend::MetricsResponse* response) {
2674
2725
llama_client_slot* active_slot = llama.get_active_slot ();
2675
2726
0 commit comments