diff --git a/common/chat-diff-analyzer.cpp b/common/chat-diff-analyzer.cpp index e35596b93fa..414ee892f8e 100644 --- a/common/chat-diff-analyzer.cpp +++ b/common/chat-diff-analyzer.cpp @@ -287,7 +287,7 @@ void analyze_reasoning::compare_reasoning_presence() { return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space())) + p.rest()); }); auto parser_wrapped = build_tagged_peg_parser([&](common_peg_parser_builder &p) { - return p.tag("pre", p.marker()) + p.space() + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest(); + return p.tag("pre", p.marker() + p.space()) + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest(); }); // try the more aggressive parse first, if it fails, fall back to the delimiter one auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B); @@ -297,7 +297,7 @@ void analyze_reasoning::compare_reasoning_presence() { if (result.result.success()) { if (!result.tags["pre"].empty() && !result.tags["post"].empty()) { mode = reasoning_mode::TAG_BASED; - start = trim_whitespace(result.tags["pre"]); + start = trim_leading_whitespace(result.tags["pre"]); end = trim_trailing_whitespace(result.tags["post"]); } else if (!result.tags["post"].empty()) { mode = reasoning_mode::TAG_BASED; @@ -333,7 +333,7 @@ void analyze_reasoning::compare_thinking_enabled() { if (left_trimmed.empty() && !diff.right.empty()) { if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) { if (start.empty()) { - start = right_trimmed; + start = trim_leading_whitespace(diff.right); mode = reasoning_mode::TAG_BASED; } } @@ -344,7 +344,7 @@ void analyze_reasoning::compare_thinking_enabled() { if (seg.size() >= 2 && seg[seg.size() - 1].value == left_trimmed && seg[seg.size() - 2].type == segment_type::MARKER) { start = seg[seg.size() - 2].value; } - end = left_trimmed; + end = trim_trailing_whitespace(diff.left); mode = reasoning_mode::TAG_BASED; } } @@ -363,15 +363,23 @@ void analyze_reasoning::compare_thinking_enabled() { size_t len = std::min(base.size(), anchor_len); std::string anchor = base.substr(base.size() - len); auto pos = extended.rfind(anchor); - if (pos == std::string::npos || pos + len >= extended.size()) continue; + if (pos == std::string::npos || pos + len >= extended.size()) { + continue; + } std::string extra = trim_whitespace(extended.substr(pos + len)); - if (extra.empty()) continue; + if (extra.empty()) { + continue; + } auto seg = prune_whitespace_segments(segmentize_markers(extra)); if (seg.size() == 2 && seg[0].type == segment_type::MARKER && seg[1].type == segment_type::MARKER) { - if (start.empty()) start = seg[0].value; - if (end.empty()) end = seg[1].value; + if (start.empty()) { + start = seg[0].value; + } + if (end.empty()) { + end = seg[1].value; + } mode = reasoning_mode::TAG_BASED; break; } @@ -423,7 +431,7 @@ void analyze_reasoning::compare_reasoning_scope() { LOG_DBG(ANSI_ORANGE "%s: Detected TOOLS_ONLY reasoning mode\n" ANSI_RESET, __func__); auto parser_wrapped = build_tagged_peg_parser([&](common_peg_parser_builder &p) { - return p.tag("pre", p.marker()) + p.space() + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space())); + return p.tag("pre", p.marker() + p.space()) + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space())); }); auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B); if (result.result.success()) { @@ -516,7 +524,7 @@ analyze_content::analyze_content(const common_chat_template & tmpl, const analyz // Take the more promising diff std::string pure_content = rdiff.length() > diff_tools.left.length() ? rdiff : diff_tools.left; auto parser_wrapped = build_tagged_peg_parser([&](common_peg_parser_builder &p) { - return p.tag("pre", p.marker()) + p.space() + p.literal(response) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest(); + return p.tag("pre", p.marker() + p.space()) + p.literal(response) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest(); }); auto result = parser_wrapped.parse_anywhere_and_extract(pure_content); start = result.tags["pre"]; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 221e6fa04e9..15ed5b2a79d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1112,6 +1112,16 @@ struct vk_op_glu_push_constants { uint32_t mode; // 0: default, 1: swapped, 2: split float alpha; // for swiglu_oai float limit; + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t ne01; + uint32_t ne02; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t ne11; + uint32_t ne12; }; struct vk_op_unary_push_constants { @@ -5044,7 +5054,7 @@ static vk_device ggml_vk_get_device(size_t idx) { } else { device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); } - vk::DeviceCreateInfo device_create_info; + vk::DeviceCreateInfo device_create_info{}; std::vector device_extensions; vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures(); @@ -5413,12 +5423,10 @@ static vk_device ggml_vk_get_device(size_t idx) { #endif device->name = GGML_VK_NAME + std::to_string(idx); - device_create_info = { - vk::DeviceCreateFlags(), - device_queue_create_infos, - {}, - device_extensions - }; + device_create_info + .setFlags(vk::DeviceCreateFlags()) + .setQueueCreateInfos(device_queue_create_infos) + .setPEnabledExtensionNames(device_extensions); device_create_info.setPNext(&device_features2); device->device = device->physical_device.createDevice(device_create_info); @@ -11048,8 +11056,6 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const const float alpha = op_params_f[2]; const float limit = op_params_f[3]; - GGML_ASSERT(ggml_is_contiguous(src0)); - if (!split) { GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]); } else { @@ -11067,7 +11073,17 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t)dst->ne[0], mode, alpha, - limit + limit, + (uint32_t)(src0->nb[1] / src0->nb[0]), + (uint32_t)(src0->nb[2] / src0->nb[0]), + (uint32_t)(src0->nb[3] / src0->nb[0]), + (uint32_t)src0->ne[1], + (uint32_t)src0->ne[2], + (uint32_t)(dst->nb[1] / dst->nb[0]), + (uint32_t)(dst->nb[2] / dst->nb[0]), + (uint32_t)(dst->nb[3] / dst->nb[0]), + (uint32_t)dst->ne[1], + (uint32_t)dst->ne[2] }); } @@ -15217,8 +15233,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_GLU_OP_SWIGLU_OAI: case GGML_GLU_OP_GEGLU_ERF: case GGML_GLU_OP_GEGLU_QUICK: - return ggml_is_contiguous(op->src[0]) && - (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type); default: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl index 2168989340b..95298922d83 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl @@ -16,4 +16,14 @@ layout (push_constant) uniform parameter uint mode; float alpha; float limit; + uint nb01; + uint nb02; + uint nb03; + uint ne01; + uint ne02; + uint nb11; + uint nb12; + uint nb13; + uint ne11; + uint ne12; } p; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl index 85cf65a9eca..359461306a5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl @@ -8,22 +8,32 @@ void main() { const uint row = i / p.ne20; const uint col = i - row * p.ne20; + const uint i3 = row / (p.ne01 * p.ne02); + const uint i2 = (row % (p.ne01 * p.ne02)) / p.ne01; + const uint i1 = row % p.ne01; + const uint src_idx = i3 * p.nb03 + i2 * p.nb02 + i1 * p.nb01 + col; + + const uint dst_i3 = row / (p.ne11 * p.ne12); + const uint dst_i2 = (row % (p.ne11 * p.ne12)) / p.ne11; + const uint dst_i1 = row % p.ne11; + const uint dst_idx = dst_i3 * p.nb13 + dst_i2 * p.nb12 + dst_i1 * p.nb11 + col; + if (p.mode == 0) { // Default const uint offset = p.ne00 / 2; - const uint idx = row * p.ne00 + col; + const uint idx = src_idx; - data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); + data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); } else if (p.mode == 1) { // Swapped const uint offset = p.ne00 / 2; - const uint idx = row * p.ne00 + col; + const uint idx = src_idx; - data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); + data_d[dst_idx] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); } else { // Split - const uint idx = row * p.ne00 + col; + const uint idx = src_idx; - data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); + data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); } } diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 6524cae83ad..3f1e74f7cbc 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -5,7 +5,7 @@ import sys import subprocess -HTTPLIB_VERSION = "refs/tags/v0.39.0" +HTTPLIB_VERSION = "refs/tags/v0.40.0" vendor = { "https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp", diff --git a/tests/test-chat-auto-parser.cpp b/tests/test-chat-auto-parser.cpp index 347ad94bd84..bb23b7f2aae 100644 --- a/tests/test-chat-auto-parser.cpp +++ b/tests/test-chat-auto-parser.cpp @@ -1330,7 +1330,7 @@ static void test_nemotron_reasoning_detection(testing & t) { analysis.analyze_template(tmpl); // Check reasoning markers - t.assert_equal("reasoning_start should be ''", "", analysis.reasoning.start); + t.assert_equal("reasoning_start should be '\\n'", "\n", analysis.reasoning.start); t.assert_equal("reasoning_end should be ''", "", analysis.reasoning.end); // Check reasoning mode detection diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 74f078f5edd..a2af4e3775a 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -805,7 +805,8 @@ struct peg_test_case { common_chat_templates_inputs params; std::string input; common_chat_msg expect; - bool is_partial = false; + bool is_partial = false; + bool expect_reconstruction = false; }; struct make_peg_parser { @@ -828,6 +829,12 @@ struct make_peg_parser { } }; +// Global template filter for --template flag +static std::string g_template_filter; + +// When true, run reconstruction test on every non-partial test and report results +static bool g_force_reconstruction_test = false; + static void test_peg_parser(common_chat_templates * tmpls, const std::function & init, bool detailed_debug) { @@ -1119,10 +1126,57 @@ static void test_peg_parser(common_chat_templates * tmpls, } } } -} -// Global template filter for --template flag -static std::string g_template_filter; + // Reconstruction test: verify that appending the parsed message to the original + // messages and re-rendering the template (without generation prompt) reproduces + // the original prompt + input exactly, or as a proper prefix (the template may + // append end-of-turn tokens after the assistant message). + if ((tc.expect_reconstruction || g_force_reconstruction_test) && !tc.is_partial) { + // Start from tc.expect but copy tool call arguments from the actual parser + // output, which preserves original JSON formatting (e.g. {"arg1":1} vs {"arg1": 1}). + auto reconstruction_msg = tc.expect; + auto parsed_msg = parser.parse(tc.input, false); + for (size_t i = 0; i < reconstruction_msg.tool_calls.size() && i < parsed_msg.tool_calls.size(); i++) { + reconstruction_msg.tool_calls[i].arguments = parsed_msg.tool_calls[i].arguments; + } + common_chat_templates_inputs reconstruction_inputs = tc.params; + reconstruction_inputs.messages.push_back(reconstruction_msg); + reconstruction_inputs.add_generation_prompt = false; + + auto reconstruction_params = common_chat_templates_apply(tmpls, reconstruction_inputs); + std::string expected_text = parser.params_.prompt + tc.input; + bool match = reconstruction_params.prompt == expected_text || + (reconstruction_params.prompt.size() > expected_text.size() && + reconstruction_params.prompt.compare(0, expected_text.size(), expected_text) == 0); + if (!match && g_force_reconstruction_test && !tc.expect_reconstruction) { + // In forced mode, report mismatch but don't fail + // Find the first difference position + size_t diff_pos = 0; + size_t min_len = std::min(expected_text.size(), reconstruction_params.prompt.size()); + while (diff_pos < min_len && expected_text[diff_pos] == reconstruction_params.prompt[diff_pos]) { + diff_pos++; + } + size_t ctx_start = diff_pos > 60 ? diff_pos - 60 : 0; + size_t ctx_end_e = std::min(expected_text.size(), diff_pos + 40); + size_t ctx_end_r = std::min(reconstruction_params.prompt.size(), diff_pos + 40); + LOG_ERR("\x1b[31m[RECONSTRUCTION FAIL]\x1b[0m " + "first diff at byte %zu (expected len=%zu, reconstructed len=%zu)\n" + " expected: ...%s...\n" + " reconstructed: ...%s...\n", + diff_pos, expected_text.size(), reconstruction_params.prompt.size(), + expected_text.substr(ctx_start, ctx_end_e - ctx_start).c_str(), + reconstruction_params.prompt.substr(ctx_start, ctx_end_r - ctx_start).c_str()); + } else if (!match) { + std::string error_msg = + "Reconstruction mismatch:\n\n" + ">>> Expected (prompt + input):\n" + expected_text + + "\n\n>>> Reconstructed:\n" + reconstruction_params.prompt; + throw std::runtime_error(error_msg); + } else if (g_force_reconstruction_test) { + LOG_INF("\x1b[32m[RECONSTRUCTION OK]\x1b[0m\n"); + } + } +} // Fluent builder for PEG parser tests class peg_test_builder; @@ -1182,6 +1236,11 @@ class peg_test_builder { return *this; } + peg_test_builder & expect_reconstruction(bool val = true) { + tc_.expect_reconstruction = val; + return *this; + } + // Expect setters peg_test_builder & expect(const common_chat_msg & msg) { tc_.expect = msg; @@ -1355,16 +1414,18 @@ static void test_template_output_peg_parsers(bool detailed_debug) { // Ministral-3-14B-Reasoning-2512 auto tst = peg_tester("models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja", detailed_debug); - tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run(); tst.test("[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?") .expect_content("[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?") + .expect_reconstruction() .run(); tst.test("[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?") .reasoning_format(COMMON_REASONING_FORMAT_AUTO) .enable_thinking(true) .expect(message_assist_thoughts) + .expect_reconstruction() .run(); tst.test(R"([TOOL_CALLS]special_function[ARGS]{"arg1":1})") @@ -1394,6 +1455,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { { "special_function", R"({"arg1": 1})", {} }, { "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} }, }) + .expect_reconstruction() .run(); tst.test( @@ -1418,6 +1480,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect_tool_calls({ { "special_function", R"({"arg1": 1})", {} }, }) + .expect_reconstruction() .run(); } @@ -1621,9 +1684,9 @@ static void test_template_output_peg_parsers(bool detailed_debug) { // Google Gemma 2 2B - does not support tool calling auto tst = peg_tester("models/templates/google-gemma-2-2b-it.jinja"); - tst.test("Hello, world!").expect(simple_assist_msg("Hello, world!")).run(); + tst.test("Hello, world!").expect(simple_assist_msg("Hello, world!")).expect_reconstruction().run(); - tst.test("Line 1\nLine 2\nLine 3").expect(simple_assist_msg("Line 1\nLine 2\nLine 3")).run(); + tst.test("Line 1\nLine 2\nLine 3").expect(simple_assist_msg("Line 1\nLine 2\nLine 3")).expect_reconstruction().run(); } { @@ -1666,7 +1729,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { // Test simple content-only template auto tst = peg_tester("models/templates/google-gemma-2-2b-it.jinja", detailed_debug); - tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run(); } { // IBM Granite (reasoning and tool calling model) @@ -1778,7 +1841,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { // Qwen3-Coder (tool calling with XML-style format) auto tst = peg_tester("models/templates/Qwen3-Coder.jinja", detailed_debug); - tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run(); tst.test( "\n" @@ -1790,6 +1853,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { "") .tools({ special_function_tool }) .expect(message_assist_call) + .expect_reconstruction() .run(); tst.test( @@ -1818,6 +1882,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { { "special_function", R"({"arg1": 1})", {} }, { "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} }, }) + .expect_reconstruction() .run(); // Test with code content (multiline) @@ -1838,6 +1903,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect_tool_calls({ { "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} }, }) + .expect_reconstruction() .run(); // Test with code content (asian unicode chars) @@ -1855,6 +1921,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect_tool_calls({ { "python", "{\"code\": \"格\"}", {} }, }) + .expect_reconstruction() .run(); // Test with HTML tag content @@ -1876,6 +1943,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect_tool_calls({ { "html", "{\"markup\": \"\\n \\n Hello!\\n \\n\"}", {} }, }) + .expect_reconstruction() .run(); // Test with TODO list (array of objects) @@ -1893,6 +1961,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect_tool_calls({ { "todo_list", "{\"todos\": [{\"item\": \"Check stuff\", \"selected\": false}, {\"item\": \"Prepare stuff\", \"selected\": true}]}", {} }, }) + .expect_reconstruction() .run(); // Test flexible optional argument ordering (2 required + 4 optional, reversed optional order) @@ -1909,6 +1978,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect_tool_calls({ { "tool_2req_4opt", R"({"req1": "hello", "req2": 42, "opt4": 100, "opt2": 200})", {} }, }) + .expect_reconstruction() .run(); // Test flexible optional argument ordering (2 required + 5 optional, reversed optional order) @@ -1926,6 +1996,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect_tool_calls({ { "tool_2req_5opt", R"({"req1": "world", "req2": 7, "opt5": "last", "opt3": "middle", "opt1": "first"})", {} }, }) + .expect_reconstruction() .run(); // Test flexible optional argument ordering (2 required + 5 optional, all 5 in shuffled order) @@ -1945,6 +2016,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect_tool_calls({ { "tool_2req_5opt", R"({"req1": "test", "req2": 99, "opt3": "c", "opt1": "a", "opt5": "e", "opt4": 4, "opt2": 2})", {} }, }) + .expect_reconstruction() .run(); } { @@ -2025,6 +2097,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { tst.test("Hello, world!\nWhat's up?") .enable_thinking(false) .expect(message_assist) + .expect_reconstruction() .run(); // Reasoning with content (forced-open mode - input starts after ) @@ -2032,6 +2105,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .enable_thinking(true) .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) .expect(message_assist_thoughts) + .expect_reconstruction() .run(); // Tool call without reasoning @@ -2042,6 +2116,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .enable_thinking(false) .tools({ special_function_tool }) .expect(message_assist_call) + .expect_reconstruction() .run(); // Tool call with reasoning (forced-open mode) @@ -2054,6 +2129,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) .tools({ special_function_tool }) .expect(message_assist_call_thoughts) + .expect_reconstruction() .run(); tst.test( @@ -2073,6 +2149,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { { "special_function", R"({"arg1": 1})", {} }, { "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} }, }) + .expect_reconstruction() .run(); // #20650: tool with no required args, model emits name with no arg tags. @@ -2090,6 +2167,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .tools({ no_args_tool }) .expect_reasoning("Let me read the diff content.") .expect_tool_calls({{ "read_file_diff_md", "{}", {} }}) + .expect_reconstruction() .run(); } } @@ -2348,22 +2426,24 @@ static void test_template_output_peg_parsers(bool detailed_debug) { // Kimi-K2 old template auto tst = peg_tester("models/templates/moonshotai-Kimi-K2.jinja", detailed_debug); - tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run(); tst.test( "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>" "{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>") .tools({ special_function_tool }) .expect(kimi_id_special_func_tool_call) + .expect_reconstruction() .run(); // Kimi-K2-Instruct auto tst2 = peg_tester("models/templates/Kimi-K2-Instruct.jinja", detailed_debug); - tst2.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst2.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run(); tst2.test( "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>" "{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>") .tools({ special_function_tool }) .expect(kimi_id_special_func_tool_call) + .expect_reconstruction() .run(); } @@ -2459,6 +2539,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { tst.test("<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>") .tools({ special_function_tool }) .expect(message_assist_call) + .expect_reconstruction() .run(); } @@ -2467,7 +2548,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { { auto tst = peg_tester("models/templates/MiniMax-M2.jinja", detailed_debug); tst.test( - "\n\n\n\n1\n\n") .tools({ special_function_tool }) .expect(message_assist_call) @@ -2517,37 +2598,41 @@ static void test_template_output_peg_parsers(bool detailed_debug) { // mistralai-Mistral-Nemo-Instruct-2407.jinja { auto tst = peg_tester("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", detailed_debug); - tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run(); tst.test("[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]") .tools({ special_function_tool }) .expect(message_assist_call_id) + .expect_reconstruction() .run(); } { auto tst = peg_tester("models/templates/meetkai-functionary-medium-v3.1.jinja", detailed_debug); - tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run(); tst.test("{\"arg1\": 1}") .tools({ special_function_tool }) .expect(message_assist_call) + .expect_reconstruction() .run(); } // Functionary v3.2 - recipient-based format: >>>recipient\n{content} { auto tst = peg_tester("models/templates/meetkai-functionary-medium-v3.2.jinja", detailed_debug); - tst.test("all\nHello, world!\nWhat's up?").expect(message_assist).run(); + tst.test("all\nHello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run(); tst.test("special_function\n{\"arg1\": 1}") .tools({ special_function_tool }) .expect(message_assist_call) + .expect_reconstruction() .run(); } // FireFunction { auto tst = peg_tester("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja", detailed_debug); - tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run(); tst.test(" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]") .tools({ special_function_tool }) .expect(message_assist_call) + .expect_reconstruction() .run(); } @@ -2608,10 +2693,11 @@ static void test_template_output_peg_parsers(bool detailed_debug) { { "models/templates/MiMo-VL.jinja", "models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", "models/templates/Qwen-Qwen2.5-7B-Instruct.jinja" }) { auto tst = peg_tester(path, detailed_debug); - tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run(); tst.test("\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n") .tools({ special_function_tool }) .expect(message_assist_call) + .expect_reconstruction() .run(); } @@ -2634,6 +2720,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .reasoning_format(COMMON_REASONING_FORMAT_AUTO) .enable_thinking(true) .expect(simple_assist_msg("Hello, world!\nWhat's up?", "Here are my reasoning steps:\nI'm\nthinking")) + .expect_reconstruction() .run(); // Reasoning + Tool calls @@ -2650,42 +2737,45 @@ static void test_template_output_peg_parsers(bool detailed_debug) { // Mistral Small 3.2 - FUNC_BRACKET_TAG format: [TOOL_CALLS]func_name[CALL_ID]id[ARGS]{...} { auto tst = peg_tester("models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja", detailed_debug); - tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run(); tst.test("[TOOL_CALLS]special_function[CALL_ID]123456789[ARGS]{\"arg1\": 1}") .tools({ special_function_tool }) .expect(message_assist_call_id) + .expect_reconstruction() .run(); } // Devstral { auto tst = peg_tester("models/templates/unsloth-mistral-Devstral-Small-2507.jinja", detailed_debug); - tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run(); tst.test("[TOOL_CALLS]special_function[ARGS]{\"arg1\": 1}") .tools({ special_function_tool }) .expect(message_assist_call) + .expect_reconstruction() .run(); tst.test("Hello, world!\nWhat's up?[TOOL_CALLS]special_function[ARGS]{\"arg1\": 1}") .tools({ special_function_tool }) .expect(message_assist_call_content) + .expect_reconstruction() .run(); } { // Llama 3.1 auto tst = peg_tester("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja", detailed_debug); - tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).run(); + tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).expect_reconstruction().run(); } { // Llama 3.2 auto tst = peg_tester("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", detailed_debug); - tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).run(); + tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).expect_reconstruction().run(); } { // Llama 3.3 auto tst = peg_tester("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja", detailed_debug); - tst.test("Hello, world!\nWhat's up?").tools({ python_tool }).expect(message_assist).run(); + tst.test("Hello, world!\nWhat's up?").tools({ python_tool }).expect(message_assist).expect_reconstruction().run(); } // GPT-OSS format tests @@ -2989,10 +3079,11 @@ static void test_template_output_peg_parsers(bool detailed_debug) { // GigaChat V3 { auto tst = peg_tester("models/templates/GigaChat3-10B-A1.8B.jinja", detailed_debug); - tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run(); tst.test("<|message_sep|>\n\nfunction call<|role_sep|>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}") .tools({ special_function_tool }) .expect(message_assist_call) + .expect_reconstruction() .run(); tst.test( @@ -3001,16 +3092,18 @@ static void test_template_output_peg_parsers(bool detailed_debug) { ) .tools({ special_function_tool }) .expect(message_assist_call_content) + .expect_reconstruction() .run(); } // GigaChat V3.1 { auto tst = peg_tester("models/templates/GigaChat3.1-10B-A1.8B.jinja", detailed_debug); - tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).expect_reconstruction().run(); tst.test("<|function_call|>{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}") .tools({ special_function_tool }) .expect(message_assist_call) + .expect_reconstruction() .run(); tst.test( @@ -3019,6 +3112,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) { ) .tools({ special_function_tool }) .expect(message_assist_call_content) + .expect_reconstruction() .run(); } } @@ -3155,6 +3249,10 @@ int main(int argc, char ** argv) { detailed_debug = true; common_log_set_verbosity_thold(999); } + if (arg == "--force-reconstruction-test") { + g_force_reconstruction_test = true; + only_run_filtered = true; + } } if (only_run_filtered) { diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index 1dabaeee28f..2262577e1dd 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -113,16 +113,10 @@ bool server_http_context::init(const common_params & params) { srv->set_read_timeout (params.timeout_read); srv->set_write_timeout(params.timeout_write); srv->set_socket_options([reuse_port = params.reuse_port](socket_t sock) { - int opt = 1; -#ifdef _WIN32 - const char * optval = (const char *)&opt; -#else - const void * optval = &opt; -#endif - setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, optval, sizeof(opt)); + httplib::set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 1); if (reuse_port) { #ifdef SO_REUSEPORT - setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, optval, sizeof(opt)); + httplib::set_socket_opt(sock, SOL_SOCKET, SO_REUSEPORT, 1); #else LOG_WRN("%s: SO_REUSEPORT is not supported\n", __func__); #endif diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index caa87abff64..8ff1da57bb5 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -467,10 +467,6 @@ bool set_socket_opt_impl(socket_t sock, int level, int optname, optlen) == 0; } -bool set_socket_opt(socket_t sock, int level, int optname, int optval) { - return set_socket_opt_impl(sock, level, optname, &optval, sizeof(optval)); -} - bool set_socket_opt_time(socket_t sock, int level, int optname, time_t sec, time_t usec) { #ifdef _WIN32 @@ -2218,7 +2214,7 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, #ifdef _WIN32 // Setting SO_REUSEADDR seems not to work well with AF_UNIX on windows, so // remove the option. - detail::set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 0); + set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 0); #endif bool dummy; @@ -4373,6 +4369,7 @@ make_multipart_content_provider(const UploadFormDataItems &items, struct MultipartState { std::vector owned; std::vector segs; + std::vector buf = std::vector(CPPHTTPLIB_SEND_BUFSIZ); }; auto state = std::make_shared(); state->owned = std::move(owned); @@ -4381,19 +4378,49 @@ make_multipart_content_provider(const UploadFormDataItems &items, state->segs = std::move(segs); return [state](size_t offset, size_t length, DataSink &sink) -> bool { + // Buffer multiple small segments into fewer, larger writes to avoid + // excessive TCP packets when there are many form data items (#2410) + auto &buf = state->buf; + auto buf_size = buf.size(); + size_t buf_len = 0; + size_t remaining = length; + + // Find the first segment containing 'offset' size_t pos = 0; - for (const auto &seg : state->segs) { - // Loop invariant: pos <= offset (proven by advancing pos only when - // offset - pos >= seg.size, i.e., the segment doesn't contain offset) - if (seg.size > 0 && offset - pos < seg.size) { - size_t seg_offset = offset - pos; - size_t available = seg.size - seg_offset; - size_t to_write = (std::min)(available, length); - return sink.write(seg.data + seg_offset, to_write); - } + size_t seg_idx = 0; + for (; seg_idx < state->segs.size(); seg_idx++) { + const auto &seg = state->segs[seg_idx]; + if (seg.size > 0 && offset - pos < seg.size) { break; } pos += seg.size; } - return true; // past end (shouldn't be reached when content_length is exact) + + size_t seg_offset = (seg_idx < state->segs.size()) ? offset - pos : 0; + + for (; seg_idx < state->segs.size() && remaining > 0; seg_idx++) { + const auto &seg = state->segs[seg_idx]; + size_t available = seg.size - seg_offset; + size_t to_copy = (std::min)(available, remaining); + const char *src = seg.data + seg_offset; + seg_offset = 0; // only the first segment has a non-zero offset + + while (to_copy > 0) { + size_t space = buf_size - buf_len; + size_t chunk = (std::min)(to_copy, space); + std::memcpy(buf.data() + buf_len, src, chunk); + buf_len += chunk; + src += chunk; + to_copy -= chunk; + remaining -= chunk; + + if (buf_len == buf_size) { + if (!sink.write(buf.data(), buf_len)) { return false; } + buf_len = 0; + } + } + } + + if (buf_len > 0) { return sink.write(buf.data(), buf_len); } + return true; }; } @@ -5264,13 +5291,18 @@ bool setup_client_tls_session(const std::string &host, tls::ctx_t &ctx, */ void default_socket_options(socket_t sock) { - detail::set_socket_opt(sock, SOL_SOCKET, + set_socket_opt(sock, SOL_SOCKET, #ifdef SO_REUSEPORT - SO_REUSEPORT, + SO_REUSEPORT, #else - SO_REUSEADDR, + SO_REUSEADDR, #endif - 1); + 1); +} + +bool set_socket_opt(socket_t sock, int level, int optname, int optval) { + return detail::set_socket_opt_impl(sock, level, optname, &optval, + sizeof(optval)); } std::string get_bearer_token_auth(const Request &req) { @@ -7418,6 +7450,8 @@ bool Server::read_content_core( return false; } + req.body_consumed_ = true; + if (req.is_multipart_form_data()) { if (!multipart_form_data_parser.is_valid()) { res.status = StatusCode::BadRequest_400; @@ -7688,9 +7722,7 @@ bool Server::listen_internal() { detail::set_socket_opt_time(sock, SOL_SOCKET, SO_SNDTIMEO, write_timeout_sec_, write_timeout_usec_); - if (tcp_nodelay_) { - detail::set_socket_opt(sock, IPPROTO_TCP, TCP_NODELAY, 1); - } + if (tcp_nodelay_) { set_socket_opt(sock, IPPROTO_TCP, TCP_NODELAY, 1); } if (!task_queue->enqueue( [this, sock]() { process_and_close_socket(sock); })) { @@ -8036,8 +8068,19 @@ Server::process_request(Stream &strm, const std::string &remote_addr, return write_response(strm, close_connection, req, res); } + // RFC 9112 §6.3: Reject requests with both a non-zero Content-Length and + // any Transfer-Encoding to prevent request smuggling. Content-Length: 0 is + // tolerated for compatibility with existing clients. + if (req.get_header_value_u64("Content-Length") > 0 && + req.has_header("Transfer-Encoding")) { + connection_closed = true; + res.status = StatusCode::BadRequest_400; + return write_response(strm, close_connection, req, res); + } + // Check if the request URI doesn't exceed the limit if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + connection_closed = true; res.status = StatusCode::UriTooLong_414; output_error_log(Error::ExceedUriMaxLength, &req); return write_response(strm, close_connection, req, res); @@ -8066,6 +8109,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr, if (req.has_header("Accept")) { const auto &accept_header = req.get_header_value("Accept"); if (!detail::parse_accept_header(accept_header, req.accept_content_types)) { + connection_closed = true; res.status = StatusCode::BadRequest_400; output_error_log(Error::HTTPParsing, &req); return write_response(strm, close_connection, req, res); @@ -8075,6 +8119,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr, if (req.has_header("Range")) { const auto &range_header_value = req.get_header_value("Range"); if (!detail::parse_range_header(range_header_value, req.ranges)) { + connection_closed = true; res.status = StatusCode::RangeNotSatisfiable_416; output_error_log(Error::InvalidRangeHeader, &req); return write_response(strm, close_connection, req, res); @@ -8202,6 +8247,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr, } } #endif + auto ret = false; if (routed) { if (res.status == -1) { res.status = req.ranges.empty() ? StatusCode::OK_200 @@ -8209,6 +8255,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr, } // Serve file content by using a content provider + auto file_open_error = false; if (!res.file_content_path_.empty()) { const auto &path = res.file_content_path_; auto mm = std::make_shared(path.c_str()); @@ -8218,37 +8265,53 @@ Server::process_request(Stream &strm, const std::string &remote_addr, res.content_provider_ = nullptr; res.status = StatusCode::NotFound_404; output_error_log(Error::OpenFile, &req); - return write_response(strm, close_connection, req, res); - } + file_open_error = true; + } else { + auto content_type = res.file_content_content_type_; + if (content_type.empty()) { + content_type = detail::find_content_type( + path, file_extension_and_mimetype_map_, default_file_mimetype_); + } - auto content_type = res.file_content_content_type_; - if (content_type.empty()) { - content_type = detail::find_content_type( - path, file_extension_and_mimetype_map_, default_file_mimetype_); + res.set_content_provider( + mm->size(), content_type, + [mm](size_t offset, size_t length, DataSink &sink) -> bool { + sink.write(mm->data() + offset, length); + return true; + }); } - - res.set_content_provider( - mm->size(), content_type, - [mm](size_t offset, size_t length, DataSink &sink) -> bool { - sink.write(mm->data() + offset, length); - return true; - }); } - if (detail::range_error(req, res)) { + if (file_open_error) { + ret = write_response(strm, close_connection, req, res); + } else if (detail::range_error(req, res)) { res.body.clear(); res.content_length_ = 0; res.content_provider_ = nullptr; res.status = StatusCode::RangeNotSatisfiable_416; - return write_response(strm, close_connection, req, res); + ret = write_response(strm, close_connection, req, res); + } else { + ret = write_response_with_content(strm, close_connection, req, res); } - - return write_response_with_content(strm, close_connection, req, res); } else { if (res.status == -1) { res.status = StatusCode::NotFound_404; } - - return write_response(strm, close_connection, req, res); + ret = write_response(strm, close_connection, req, res); + } + + // Drain any unconsumed request body to prevent request smuggling on + // keep-alive connections. + if (!req.body_consumed_ && detail::expect_content(req)) { + int drain_status = 200; // required by read_content signature + if (!detail::read_content( + strm, req, payload_max_length_, drain_status, nullptr, + [](const char *, size_t, size_t, size_t) { return true; }, false)) { + // Body exceeds payload limit or read error — close the connection + // to prevent leftover bytes from being misinterpreted. + connection_closed = true; + } } + + return ret; } bool Server::is_valid() const { return true; } diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index ce1681fcbee..2967ddf5e50 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -8,8 +8,8 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.39.0" -#define CPPHTTPLIB_VERSION_NUM "0x002700" +#define CPPHTTPLIB_VERSION "0.40.0" +#define CPPHTTPLIB_VERSION_NUM "0x002800" #ifdef _WIN32 #if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0A00 @@ -1266,6 +1266,7 @@ struct Request { bool is_multipart_form_data() const; // private members... + bool body_consumed_ = false; size_t redirect_count_ = CPPHTTPLIB_REDIRECT_MAX_COUNT; size_t content_length_ = 0; ContentProvider content_provider_; @@ -1475,6 +1476,8 @@ using SocketOptions = std::function; void default_socket_options(socket_t sock); +bool set_socket_opt(socket_t sock, int level, int optname, int optval); + const char *status_message(int status); std::string to_string(Error error); @@ -1564,6 +1567,13 @@ ssize_t write_headers(Stream &strm, const Headers &headers); bool set_socket_opt_time(socket_t sock, int level, int optname, time_t sec, time_t usec); +size_t get_multipart_content_length(const UploadFormDataItems &items, + const std::string &boundary); + +ContentProvider +make_multipart_content_provider(const UploadFormDataItems &items, + const std::string &boundary); + } // namespace detail class Server {