From 12ecec013ea96cfefcf802985c2bf4f562c1f439 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 14 Jul 2025 18:52:00 +0800 Subject: [PATCH 1/3] Update llama-memory-recurrent.cpp handle saving/loading null layers in recurrent memory --- src/llama-memory-recurrent.cpp | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 2c1ae67098ca4..bab5bfe45ef35 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -769,6 +769,11 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { + if (r_l[il] == nullptr) { + // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) + continue; + } + // Write key type const int32_t r_type_i = (int32_t)r_l[il]->type; io.write(&r_type_i, sizeof(r_type_i)); @@ -788,6 +793,12 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: if (!s_trans) { for (uint32_t il = 0; il < n_layer; ++il) { + // special key to handle null layers + if (s_l[il] == nullptr) { + // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) + continue; + } + // Write value type const int32_t s_type_i = (int32_t)s_l[il]->type; io.write(&s_type_i, sizeof(s_type_i)); @@ -807,6 +818,12 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t mem_size = size; for (uint32_t il = 0; il < n_layer; ++il) { + // special key to handle null layers + if (s_l[il] == nullptr) { + // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) + continue; + } + const uint32_t n_embd_s = hparams.n_embd_s(); // Write value type @@ -951,6 +968,8 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers + if(r_l[il] == nullptr) continue; // Read type of key int32_t r_type_i_ref; @@ -978,11 +997,14 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell if (!s_trans) { for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers + if(s_l[il] == nullptr) continue; // Read type of value int32_t s_type_i_ref; io.read_to(&s_type_i_ref, sizeof(s_type_i_ref)); const int32_t s_type_i = (int32_t)s_l[il]->type; + if (s_type_i != s_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il); return false; @@ -1005,6 +1027,9 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell } else { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers + if(s_l[il] == nullptr) continue; + const uint32_t n_embd_s = hparams.n_embd_s(); // Read type of value From 8974f228367610257a33113c99e7278110b8f18a Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Mon, 14 Jul 2025 19:07:40 +0800 Subject: [PATCH 2/3] fixed styling issues and updated comments --- src/llama-memory-recurrent.cpp | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index bab5bfe45ef35..7eb9f14c25a5b 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -768,11 +768,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { - - if (r_l[il] == nullptr) { - // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) - continue; - } + // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) + if (r_l[il] == nullptr) continue; // Write key type const int32_t r_type_i = (int32_t)r_l[il]->type; @@ -792,12 +789,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: if (!s_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - - // special key to handle null layers - if (s_l[il] == nullptr) { - // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) - continue; - } + // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) + if (s_l[il] == nullptr) continue; // Write value type const int32_t s_type_i = (int32_t)s_l[il]->type; @@ -818,11 +811,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t mem_size = size; for (uint32_t il = 0; il < n_layer; ++il) { - // special key to handle null layers - if (s_l[il] == nullptr) { - // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) - continue; - } + // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) + if (s_l[il] == nullptr) continue; const uint32_t n_embd_s = hparams.n_embd_s(); @@ -969,7 +959,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { // skip null layers - if(r_l[il] == nullptr) continue; + if (r_l[il] == nullptr) continue; // Read type of key int32_t r_type_i_ref; @@ -998,7 +988,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell if (!s_trans) { for (uint32_t il = 0; il < n_layer; ++il) { // skip null layers - if(s_l[il] == nullptr) continue; + if (s_l[il] == nullptr) continue; // Read type of value int32_t s_type_i_ref; From 14de5a5bab2fe6f698a0ac76248c6e4b1c9ffcf9 Mon Sep 17 00:00:00 2001 From: l3utterfly Date: Wed, 23 Jul 2025 12:08:22 +0800 Subject: [PATCH 3/3] fix styling issue MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- src/llama-memory-recurrent.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 7eb9f14c25a5b..bc76a57178d9a 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -1018,7 +1018,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { // skip null layers - if(s_l[il] == nullptr) continue; + if (s_l[il] == nullptr) continue; const uint32_t n_embd_s = hparams.n_embd_s();