diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 2c1ae67098ca4..7eb9f14c25a5b 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -768,6 +768,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) + if (r_l[il] == nullptr) continue; // Write key type const int32_t r_type_i = (int32_t)r_l[il]->type; @@ -787,6 +789,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: if (!s_trans) { for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) + if (s_l[il] == nullptr) continue; // Write value type const int32_t s_type_i = (int32_t)s_l[il]->type; @@ -807,6 +811,9 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t mem_size = size; for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) + if (s_l[il] == nullptr) continue; + const uint32_t n_embd_s = hparams.n_embd_s(); // Write value type @@ -951,6 +958,8 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers + if (r_l[il] == nullptr) continue; // Read type of key int32_t r_type_i_ref; @@ -978,11 +987,14 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell if (!s_trans) { for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers + if (s_l[il] == nullptr) continue; // Read type of value int32_t s_type_i_ref; io.read_to(&s_type_i_ref, sizeof(s_type_i_ref)); const int32_t s_type_i = (int32_t)s_l[il]->type; + if (s_type_i != s_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il); return false; @@ -1005,6 +1017,9 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell } else { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers + if(s_l[il] == nullptr) continue; + const uint32_t n_embd_s = hparams.n_embd_s(); // Read type of value