diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index fb5dd3430389..22b8576e73e1 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -217,6 +217,7 @@ struct llama_client_slot bool infill = false; bool embedding = false; + bool reranker = false; bool has_next_token = true; bool truncated = false; bool stopped_eos = false; @@ -535,6 +536,12 @@ struct llama_server_context return false; } + // Enable reranking if embeddings are enabled - moved after context initialization + if (params.embedding) { + params.reranking = true; + LOG_INFO("Reranking enabled (embeddings are enabled)", {}); + } + if (multimodal) { const int n_embd_clip = clip_n_mmproj_embd(clp_ctx); const int n_embd_llm = llama_model_n_embd(model); @@ -1413,7 +1420,59 @@ struct llama_server_context queue_results.send(res); } - void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id) + void send_rerank(llama_client_slot &slot, const llama_batch & batch) + { + task_result res; + res.id = slot.task_id; + res.multitask_id = slot.multitask_id; + res.error = false; + res.stop = true; + + float score = -1e6f; // Default score if we fail to get embeddings + + if (!params.reranking) + { + LOG_WARNING("reranking disabled", { + {"params.reranking", params.reranking}, + }); + } + else if (ctx == nullptr) + { + LOG_ERR("context is null, cannot perform reranking"); + res.error = true; + } + else + { + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + LOG("failed to get embeddings"); + continue; + } + + score = embd[0]; + } + } + + // Format result as JSON similar to the embedding function + res.result_json = json + { + {"score", score}, + {"tokens", slot.num_prompt_tokens} + }; + + queue_results.send(res); + } + + void request_completion(int task_id, json data, bool infill, bool embedding, bool rerank, int multitask_id) { task_server task; task.id = task_id; @@ -1421,6 +1480,7 @@ struct llama_server_context task.data = std::move(data); task.infill_mode = infill; task.embedding_mode = embedding; + task.reranking_mode = rerank; task.type = TASK_TYPE_COMPLETION; task.multitask_id = multitask_id; @@ -1552,7 +1612,7 @@ struct llama_server_context subtask_data["prompt"] = subtask_data["prompt"][i]; // subtasks inherit everything else (infill mode, embedding mode, etc.) - request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id); + request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multiprompt_task.reranking_mode, multitask_id); } } @@ -1591,6 +1651,7 @@ struct llama_server_context slot->infill = task.infill_mode; slot->embedding = task.embedding_mode; + slot->reranker = task.reranking_mode; slot->task_id = task.id; slot->multitask_id = task.multitask_id; @@ -2034,6 +2095,14 @@ struct llama_server_context continue; } + if (slot.reranker) + { + send_rerank(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; + } + completion_token_output result; const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, slot.i_batch - i); @@ -2489,7 +2558,7 @@ class BackendServiceImpl final : public backend::Backend::Service { json data = parse_options(true, request, llama); const int task_id = llama.queue_tasks.get_new_id(); llama.queue_results.add_waiting_task_id(task_id); - llama.request_completion(task_id, data, false, false, -1); + llama.request_completion(task_id, data, false, false, false, -1); while (true) { task_result result = llama.queue_results.recv(task_id); @@ -2543,7 +2612,7 @@ class BackendServiceImpl final : public backend::Backend::Service { json data = parse_options(false, request, llama); const int task_id = llama.queue_tasks.get_new_id(); llama.queue_results.add_waiting_task_id(task_id); - llama.request_completion(task_id, data, false, false, -1); + llama.request_completion(task_id, data, false, false, false, -1); std::string completion_text; task_result result = llama.queue_results.recv(task_id); if (!result.error && result.stop) { @@ -2580,7 +2649,7 @@ class BackendServiceImpl final : public backend::Backend::Service { json data = parse_options(false, request, llama); const int task_id = llama.queue_tasks.get_new_id(); llama.queue_results.add_waiting_task_id(task_id); - llama.request_completion(task_id, { {"prompt", data["embeddings"]}, { "n_predict", 0}, {"image_data", ""} }, false, true, -1); + llama.request_completion(task_id, { {"prompt", data["embeddings"]}, { "n_predict", 0}, {"image_data", ""} }, false, true, false, -1); // get the result task_result result = llama.queue_results.recv(task_id); //std::cout << "Embedding result JSON" << result.result_json.dump() << std::endl; @@ -2612,6 +2681,46 @@ class BackendServiceImpl final : public backend::Backend::Service { return grpc::Status::OK; } + grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) { + // Create a JSON object with the query and documents + json data = { + {"prompt", request->query()}, + {"documents", request->documents()}, + {"top_n", request->top_n()} + }; + + // Generate a new task ID + const int task_id = llama.queue_tasks.get_new_id(); + llama.queue_results.add_waiting_task_id(task_id); + + // Queue the task with reranking mode enabled + llama.request_completion(task_id, data, false, false, true, -1); + + // Get the result + task_result result = llama.queue_results.recv(task_id); + llama.queue_results.remove_waiting_task_id(task_id); + + if (!result.error && result.stop) { + // Set usage information + backend::Usage* usage = rerankResult->mutable_usage(); + usage->set_total_tokens(result.result_json.value("tokens", 0)); + usage->set_prompt_tokens(result.result_json.value("tokens", 0)); + + // Get the score from the result + float score = result.result_json.value("score", 0.0f); + + // Create document results for each input document + for (int i = 0; i < request->documents_size(); i++) { + backend::DocumentResult* doc_result = rerankResult->add_results(); + doc_result->set_index(i); + doc_result->set_text(request->documents(i)); + doc_result->set_relevance_score(score); + } + } + + return grpc::Status::OK; + } + grpc::Status GetMetrics(ServerContext* context, const backend::MetricsRequest* request, backend::MetricsResponse* response) { llama_client_slot* active_slot = llama.get_active_slot(); diff --git a/backend/cpp/llama/utils.hpp b/backend/cpp/llama/utils.hpp index 198b6f265957..d79b63daa170 100644 --- a/backend/cpp/llama/utils.hpp +++ b/backend/cpp/llama/utils.hpp @@ -61,6 +61,7 @@ struct task_server { json data; bool infill_mode = false; bool embedding_mode = false; + bool reranking_mode = false; int multitask_id = -1; };