diff --git a/collections/indexes/multi.go b/collections/indexes/multi.go index 3554346906cb..b890cd0e698b 100644 --- a/collections/indexes/multi.go +++ b/collections/indexes/multi.go @@ -119,6 +119,14 @@ func (m *Multi[ReferenceKey, PrimaryKey, Value]) Iterate(ctx context.Context, ra return (MultiIterator[ReferenceKey, PrimaryKey])(iter), err } +func (m *Multi[ReferenceKey, PrimaryKey, Value]) IterateRaw( + ctx context.Context, start, end []byte, order collections.Order, +) ( + iter collections.Iterator[collections.Pair[ReferenceKey, PrimaryKey], collections.NoValue], err error, +) { + return m.refKeys.IterateRaw(ctx, start, end, order) +} + func (m *Multi[ReferenceKey, PrimaryKey, Value]) Walk( ctx context.Context, ranger collections.Ranger[collections.Pair[ReferenceKey, PrimaryKey]], diff --git a/collections/indexes/multi_test.go b/collections/indexes/multi_test.go index a0b1313b8568..a64ef725b8d7 100644 --- a/collections/indexes/multi_test.go +++ b/collections/indexes/multi_test.go @@ -110,3 +110,110 @@ func TestMultiUnchecked(t *testing.T) { require.NoError(t, err) require.Equal(t, []byte{}, rawValue) } + +func TestMulti_PairPrimaryKey(t *testing.T) { + sk, ctx := deps() + schema := collections.NewSchemaBuilder(sk) + + mi := NewMulti(schema, + collections.NewPrefix(1), "multi_index", + collections.Uint64Key, + collections.PairKeyCodec(collections.Uint64Key, collections.Uint64Key), + func(pk collections.Pair[uint64, uint64], _ collections.NoValue) (uint64, error) { + return pk.K2(), nil + }, + ) + + // we create two reference keys for primary key 1 and 2 associated with "milan" + require.NoError(t, mi.Reference(ctx, collections.Join(uint64(1), uint64(1)), collections.NoValue{}, func() (collections.NoValue, error) { return collections.NoValue{}, collections.ErrNotFound })) + require.NoError(t, mi.Reference(ctx, collections.Join(uint64(2), uint64(1)), collections.NoValue{}, func() (collections.NoValue, error) { return collections.NoValue{}, collections.ErrNotFound })) + + iter, err := mi.MatchExact(ctx, uint64(1)) + require.NoError(t, err) + pks, err := iter.PrimaryKeys() + require.NoError(t, err) + expectedPks := []collections.Pair[uint64, uint64]{ + collections.Join(uint64(1), uint64(1)), + collections.Join(uint64(2), uint64(1)), + } + require.Equal(t, expectedPks, pks) + + rawIter, err := mi.IterateRaw(ctx, nil, nil, collections.OrderAscending) + require.NoError(t, err) + defer iter.Close() + + count := 0 + for ; rawIter.Valid(); rawIter.Next() { + key, err := rawIter.Key() + require.NoError(t, err) + require.Equal(t, uint64(1), key.K1()) + expectedKey := collections.Join(uint64(count+1), uint64(1)) + require.Equal(t, expectedKey, key.K2()) + count++ + } + require.Equal(t, 2, count) +} +func TestMulti_IterateRaw(t *testing.T) { + sk, ctx := deps() + schema := collections.NewSchemaBuilder(sk) + + mi := NewMulti(schema, collections.NewPrefix(1), "multi_index", collections.StringKey, collections.Uint64Key, func(_ uint64, value company) (string, error) { + return value.City, nil + }) + + // Insert some test data + company1 := company{City: "milan"} + company2 := company{City: "new york"} + company3 := company{City: "milan"} + companies := []company{company1, company2, company3} + + for i, c := range companies { + ref := uint64(i) + 1 + err := mi.Reference(ctx, ref, c, func() (company, error) { return c, nil }) + require.NoError(t, err) + } + + // Test IterateRaw with ascending order + iter, err := mi.IterateRaw(ctx, nil, nil, collections.OrderAscending) + require.NoError(t, err) + defer iter.Close() + + var count int + for ; iter.Valid(); iter.Next() { + key, err := iter.Key() + require.NoError(t, err) + require.NotEmpty(t, key.K1()) + require.NotEmpty(t, key.K2()) + count++ + } + require.Equal(t, 3, count) + + // Test IterateRaw with descending order + iter, err = mi.IterateRaw(ctx, nil, nil, collections.OrderDescending) + require.NoError(t, err) + defer iter.Close() + + count = 0 + for ; iter.Valid(); iter.Next() { + key, err := iter.Key() + require.NoError(t, err) + require.NotEmpty(t, key.K1()) + require.NotEmpty(t, key.K2()) + count++ + } + require.Equal(t, 3, count) + + // Test with specific range - use MatchExact to get the correct keys + matchIter, err := mi.MatchExact(ctx, "milan") + require.NoError(t, err) + defer matchIter.Close() + + count = 0 + for ; matchIter.Valid(); matchIter.Next() { + fullKey, err := matchIter.FullKey() + require.NoError(t, err) + require.Equal(t, "milan", fullKey.K1()) + count++ + } + require.Equal(t, 2, count) +}