Skip to content

bug fix: handle saving/loading null layers in recurrent memory #14675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/llama-memory-recurrent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
// Iterate and write all the keys first, each row is a cell
// Get whole range at a time
for (uint32_t il = 0; il < n_layer; ++il) {
// skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
if (r_l[il] == nullptr) continue;

// Write key type
const int32_t r_type_i = (int32_t)r_l[il]->type;
Expand All @@ -787,6 +789,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::

if (!s_trans) {
for (uint32_t il = 0; il < n_layer; ++il) {
// skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
if (s_l[il] == nullptr) continue;

// Write value type
const int32_t s_type_i = (int32_t)s_l[il]->type;
Expand All @@ -807,6 +811,9 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
// When v is transposed, we also need the element size and get the element ranges from each row
const uint32_t mem_size = size;
for (uint32_t il = 0; il < n_layer; ++il) {
// skip null layers (read_data will handle this by checking "r_l" and "s_l" for null)
if (s_l[il] == nullptr) continue;

const uint32_t n_embd_s = hparams.n_embd_s();

// Write value type
Expand Down Expand Up @@ -951,6 +958,8 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell

// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
for (uint32_t il = 0; il < n_layer; ++il) {
// skip null layers
if (r_l[il] == nullptr) continue;

// Read type of key
int32_t r_type_i_ref;
Expand Down Expand Up @@ -978,11 +987,14 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell

if (!s_trans) {
for (uint32_t il = 0; il < n_layer; ++il) {
// skip null layers
if (s_l[il] == nullptr) continue;

// Read type of value
int32_t s_type_i_ref;
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
const int32_t s_type_i = (int32_t)s_l[il]->type;

if (s_type_i != s_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
return false;
Expand All @@ -1005,6 +1017,9 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
} else {
// For each layer, read the values for each cell (transposed)
for (uint32_t il = 0; il < n_layer; ++il) {
// skip null layers
if(s_l[il] == nullptr) continue;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if(s_l[il] == nullptr) continue;
if (s_l[il] == nullptr) continue;


const uint32_t n_embd_s = hparams.n_embd_s();

// Read type of value
Expand Down
Loading