Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3591,6 +3591,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.draft.p_min = std::stof(value);
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_P_MIN"));
add_opt(common_arg(
{"--spec-draft-backend-sampling"},
{"--no-spec-draft-backend-sampling"},
string_format("offload draft sampling to the backend (default: %s)",
params.speculative.draft.backend_sampling ? "enabled" : "disabled"),
[](common_params & params, bool value) {
params.speculative.draft.backend_sampling = value;
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_BACKEND_SAMPLING"));
add_opt(common_arg(
{"--spec-draft-device", "-devd", "--device-draft"}, "<dev1,dev2,..>",
"comma-separated list of devices to use for offloading the draft model (none = don't offload)\n"
Expand Down
2 changes: 2 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,8 @@ struct common_params_speculative_draft {
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.0f; // minimum speculative decoding probability (greedy)

bool backend_sampling = true; // offload draft sampling to the backend (default: on)

common_params_model mparams;

llama_context * ctx_tgt = nullptr;
Expand Down
44 changes: 37 additions & 7 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,15 @@ const std::map<std::string, common_speculative_type> common_speculative_type_fro
};

static std::string common_speculative_get_devices_str(const std::vector<ggml_backend_dev_t> & devices) {
if (devices.empty()) {
return "default";
}

std::string result;
for (size_t i = 0; i < devices.size(); i++) {
if (i > 0) result += ", ";
if (devices[i] == nullptr) {
continue;
}
if (!result.empty()) result += ", ";
result += ggml_backend_dev_name(devices[i]);
}
return result;
return result.empty() ? "default" : result;
}

struct common_speculative_config {
Expand Down Expand Up @@ -414,6 +413,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {

std::vector<common_sampler_ptr> smpls;

// backend sampler chain per seq, attached to ctx_dft
std::vector<llama_sampler *> backend_chains;

int32_t n_embd = 0;

// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
Expand Down Expand Up @@ -445,7 +447,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
n_embd = llama_model_n_embd(llama_get_model(ctx_dft));

LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
this->params.n_gpu_layers,
ggml_type_name(this->params.cache_type_k),
Expand All @@ -469,6 +471,22 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
}

// offload draft sampling to the backend
backend_chains.assign(n_seq, nullptr);
if (this->params.backend_sampling) {
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
llama_sampler * chain = llama_sampler_chain_init(llama_sampler_chain_default_params());
llama_sampler_chain_add(chain, llama_sampler_init_top_k(10));

if (!llama_set_sampler(ctx_dft, seq_id, chain)) {
LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id);
llama_sampler_free(chain);
chain = nullptr;
}
backend_chains[seq_id] = chain;
}
}

llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false);
llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true);

Expand All @@ -484,6 +502,18 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
}

~common_speculative_impl_draft_mtp() override {
auto * ctx_dft = this->params.ctx_dft;
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) backend_chains.size(); ++seq_id) {
if (backend_chains[seq_id] == nullptr) {
continue;
}
if (ctx_dft) {
llama_set_sampler(ctx_dft, seq_id, nullptr);
}
llama_sampler_free(backend_chains[seq_id]);
}
backend_chains.clear();

if (batch.token != nullptr) {
free(batch.token);
batch.token = nullptr;
Expand Down
Loading
Loading