Skip to content

Commit e07b76a

Browse files
committed
fix: validate spaces are provisioned by known providers during receipt validation
1 parent 6815e09 commit e07b76a

File tree

9 files changed

+154
-18
lines changed

9 files changed

+154
-18
lines changed

cmd/etracker/start.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,15 @@ func init() {
118118

119119
cobra.CheckErr(viper.BindEnv("consumer_table_name", "CONSUMER_TABLE_NAME"))
120120
cobra.CheckErr(viper.BindEnv("consumer_table_region", "CONSUMER_TABLE_REGION"))
121+
cobra.CheckErr(viper.BindEnv("consumer_consumer_index_name", "CONSUMER_CONSUMER_INDEX_NAME"))
121122
cobra.CheckErr(viper.BindEnv("consumer_customer_index_name", "CONSUMER_CUSTOMER_INDEX_NAME"))
123+
124+
startCmd.Flags().StringSlice(
125+
"known-providers",
126+
presets.KnownProviders,
127+
"List of known provider DIDs (defaults to presets if not specified)",
128+
)
129+
cobra.CheckErr(viper.BindPFlag("known_providers", startCmd.Flags().Lookup("known-providers")))
122130
}
123131

124132
func startService(cmd *cobra.Command, args []string) error {
@@ -164,7 +172,7 @@ func startService(cmd *cobra.Command, args []string) error {
164172

165173
consumerCfg := cfg.AWSConfig.Copy()
166174
consumerCfg.Region = cfg.ConsumerTableRegion
167-
consumerTable := consumer.NewDynamoConsumerTable(dynamodb.NewFromConfig(consumerCfg), cfg.ConsumerTableName, cfg.ConsumerCustomerIndexName)
175+
consumerTable := consumer.NewDynamoConsumerTable(dynamodb.NewFromConfig(consumerCfg), cfg.ConsumerTableName, cfg.ConsumerConsumerIndexName, cfg.ConsumerCustomerIndexName)
168176

169177
// Create service
170178
svc, err := service.New(
@@ -190,7 +198,7 @@ func startService(cmd *cobra.Command, args []string) error {
190198
interval := time.Duration(cfg.ConsolidationInterval) * time.Second
191199
batchSize := cfg.ConsolidationBatchSize
192200

193-
cons, err := consolidator.New(id, egressTable, consolidatedTable, spaceStatsTable, interval, batchSize, presolver)
201+
cons, err := consolidator.New(id, egressTable, consolidatedTable, spaceStatsTable, consumerTable, cfg.KnownProviders, interval, batchSize, presolver)
194202
if err != nil {
195203
return fmt.Errorf("creating consolidator: %w", err)
196204
}

deploy/.env.production.local.tpl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ if [ "$TF_WORKSPACE" == "prod" ]; then
88

99
CONSUMER_TABLE_NAME="prod-upload-api-consumer"
1010
CONSUMER_TABLE_REGION="us-west-2"
11+
CONSUMER_CONSUMER_INDEX_NAME="consumer"
1112
CONSUMER_CUSTOMER_INDEX_NAME="customer"
1213
else
1314
STORAGE_PROVIDER_TABLE_NAME="staging-warm-upload-api-storage-provider"
@@ -18,6 +19,7 @@ else
1819

1920
CONSUMER_TABLE_NAME="staging-warm-upload-api-consumer"
2021
CONSUMER_TABLE_REGION="us-east-2"
22+
CONSUMER_CONSUMER_INDEX_NAME="consumer"
2123
CONSUMER_CUSTOMER_INDEX_NAME="customer"
2224
fi
2325
%>
@@ -30,4 +32,5 @@ CUSTOMER_TABLE_REGION=<%= $CUSTOMER_TABLE_REGION %>
3032

3133
CONSUMER_TABLE_NAME=<%= $CONSUMER_TABLE_NAME %>
3234
CONSUMER_TABLE_REGION=<%= $CONSUMER_TABLE_REGION %>
35+
CONSUMER_CONSUMER_INDEX_NAME=<%= $CONSUMER_CONSUMER_INDEX_NAME %>
3336
CONSUMER_CUSTOMER_INDEX_NAME=<%= $CONSUMER_CUSTOMER_INDEX_NAME %>

deploy/app/external.tf

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ data "aws_iam_policy_document" "task_external_dynamodb_scan_query_document" {
4949
data.aws_dynamodb_table.storage_provider_table.arn,
5050
data.aws_dynamodb_table.customer_table.arn,
5151
data.aws_dynamodb_table.consumer_table.arn,
52+
"${data.aws_dynamodb_table.consumer_table.arn}/index/consumer",
5253
"${data.aws_dynamodb_table.consumer_table.arn}/index/customer",
5354
]
5455
}

internal/config/config.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ type Config struct {
3030
CustomerTableRegion string `mapstructure:"customer_table_region" validate:"required"`
3131
ConsumerTableName string `mapstructure:"consumer_table_name" validate:"required"`
3232
ConsumerTableRegion string `mapstructure:"consumer_table_region" validate:"required"`
33+
ConsumerConsumerIndexName string `mapstructure:"consumer_consumer_index_name" validate:"required"`
3334
ConsumerCustomerIndexName string `mapstructure:"consumer_customer_index_name" validate:"required"`
35+
KnownProviders []string `mapstructure:"known_providers"`
3436
}
3537

3638
func Load(ctx context.Context) (*Config, error) {

internal/consolidator/consolidator.go

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"iter"
77
"net/http"
88
"net/url"
9+
"slices"
910
"strings"
1011
"time"
1112

@@ -32,6 +33,7 @@ import (
3233
"go.opentelemetry.io/otel/metric"
3334

3435
"github.com/storacha/etracker/internal/db/consolidated"
36+
"github.com/storacha/etracker/internal/db/consumer"
3537
"github.com/storacha/etracker/internal/db/egress"
3638
"github.com/storacha/etracker/internal/db/spacestats"
3739
"github.com/storacha/etracker/internal/metrics"
@@ -46,6 +48,8 @@ type Consolidator struct {
4648
egressTable egress.EgressTable
4749
consolidatedTable consolidated.ConsolidatedTable
4850
spaceStatsTable spacestats.SpaceStatsTable
51+
consumerTable consumer.ConsumerTable
52+
knownProviders []string
4953
ucantoSrv ucanto.ServerView[ucanto.Service]
5054
retrieveValidationCtx validator.ValidationContext[content.RetrieveCaveats]
5155
httpClient *http.Client
@@ -59,6 +63,8 @@ func New(
5963
egressTable egress.EgressTable,
6064
consolidatedTable consolidated.ConsolidatedTable,
6165
spaceStatsTable spacestats.SpaceStatsTable,
66+
consumerTable consumer.ConsumerTable,
67+
knownProviders []string,
6268
interval time.Duration,
6369
batchSize int,
6470
presolver validator.PrincipalResolver,
@@ -84,6 +90,8 @@ func New(
8490
egressTable: egressTable,
8591
consolidatedTable: consolidatedTable,
8692
spaceStatsTable: spaceStatsTable,
93+
consumerTable: consumerTable,
94+
knownProviders: knownProviders,
8795
retrieveValidationCtx: retrieveValidationCtx,
8896
httpClient: &http.Client{Timeout: 30 * time.Second},
8997
interval: interval,
@@ -318,7 +326,7 @@ func (c *Consolidator) ucanConsolidateHandler(
318326
continue
319327
}
320328

321-
cap, err := validateRetrievalReceipt(ctx, requesterNode, rcpt, c.retrieveValidationCtx)
329+
cap, err := validateRetrievalReceipt(ctx, requesterNode, rcpt, c.retrieveValidationCtx, c.consumerTable, c.knownProviders)
322330
if err != nil {
323331
log.Warnf("Invalid receipt: %v", err)
324332
continue
@@ -415,6 +423,8 @@ func validateRetrievalReceipt(
415423
requesterNode did.DID,
416424
rcpt receipt.AnyReceipt,
417425
validationCtx validator.ValidationContext[content.RetrieveCaveats],
426+
consumerTable consumer.ConsumerTable,
427+
knownProviders []string,
418428
) (ucan.Capability[content.RetrieveCaveats], error) {
419429
// Confirm the receipt is not a failure receipt
420430
_, x := result.Unwrap(rcpt.Out())
@@ -461,6 +471,16 @@ func validateRetrievalReceipt(
461471
return nil, fmt.Errorf("original invocation is not a %s invocation, but a %s one", content.RetrieveAbility, cap.Can())
462472
}
463473

474+
// Check the space has been provisioned by the upload service
475+
space := cap.With()
476+
consumer, err := consumerTable.Get(ctx, space)
477+
if err != nil {
478+
return nil, fmt.Errorf("failed to get consumer: %w", err)
479+
}
480+
if !slices.Contains(knownProviders, consumer.Provider.String()) {
481+
return nil, fmt.Errorf("unknown space provider %s", consumer.Provider)
482+
}
483+
464484
// Verify the delegation chain
465485
auth, verr := validator.Access(ctx, inv, validationCtx)
466486
if verr != nil {

internal/consolidator/consolidator_test.go

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88

99
"github.com/ipfs/go-cid"
1010
"github.com/ipld/go-ipld-prime"
11+
"github.com/storacha/etracker/internal/db/consumer"
1112
"github.com/storacha/go-libstoracha/capabilities/space/content"
1213
"github.com/storacha/go-libstoracha/testutil"
1314
"github.com/storacha/go-ucanto/core/dag/blockstore"
@@ -21,13 +22,38 @@ import (
2122
"github.com/storacha/go-ucanto/core/receipt/ran"
2223
"github.com/storacha/go-ucanto/core/result"
2324
"github.com/storacha/go-ucanto/core/result/failure"
25+
"github.com/storacha/go-ucanto/did"
2426
"github.com/storacha/go-ucanto/principal/ed25519/verifier"
2527
"github.com/storacha/go-ucanto/ucan"
2628
"github.com/storacha/go-ucanto/validator"
2729
"github.com/stretchr/testify/assert"
2830
"github.com/stretchr/testify/require"
2931
)
3032

33+
var _ consumer.ConsumerTable = (*mockConsumerTable)(nil)
34+
35+
type mockConsumerTable struct {
36+
t *testing.T
37+
provider did.DID
38+
}
39+
40+
func (m *mockConsumerTable) Get(ctx context.Context, space string) (consumer.Consumer, error) {
41+
s, err := did.Parse(space)
42+
if err != nil {
43+
return consumer.Consumer{}, err
44+
}
45+
46+
return consumer.Consumer{
47+
ID: s,
48+
Provider: m.provider,
49+
Subscription: testutil.RandomCID(m.t).String(),
50+
}, nil
51+
}
52+
53+
func (m *mockConsumerTable) ListByCustomer(ctx context.Context, customer did.DID) ([]did.DID, error) {
54+
return []did.DID{}, nil
55+
}
56+
3157
func TestValidateRetrievalReceipt(t *testing.T) {
3258
vCtx := validator.NewValidationContext(
3359
testutil.Service.Verifier(),
@@ -45,6 +71,11 @@ func TestValidateRetrievalReceipt(t *testing.T) {
4571
},
4672
)
4773

74+
knownProvider, err := did.Parse("did:web:up.test.storacha.network")
75+
require.NoError(t, err)
76+
77+
consumerTable := &mockConsumerTable{t: t, provider: knownProvider}
78+
4879
space := testutil.RandomSigner(t)
4980
randBytes := testutil.RandomBytes(t, 256)
5081
blob := struct {
@@ -95,7 +126,7 @@ func TestValidateRetrievalReceipt(t *testing.T) {
95126
)
96127
require.NoError(t, err)
97128

98-
cap, err := validateRetrievalReceipt(context.Background(), storageNode.DID(), rcpt, vCtx)
129+
cap, err := validateRetrievalReceipt(context.Background(), storageNode.DID(), rcpt, vCtx, consumerTable, []string{knownProvider.String()})
99130
require.NoError(t, err)
100131
assert.Equal(t, content.RetrieveAbility, cap.Can())
101132
})
@@ -108,7 +139,7 @@ func TestValidateRetrievalReceipt(t *testing.T) {
108139
)
109140
require.NoError(t, err)
110141

111-
_, err = validateRetrievalReceipt(context.Background(), storageNode.DID(), rcpt, vCtx)
142+
_, err = validateRetrievalReceipt(context.Background(), storageNode.DID(), rcpt, vCtx, consumerTable, []string{knownProvider.String()})
112143
assert.ErrorContains(t, err, "receipt is a failure receipt")
113144
})
114145

@@ -121,7 +152,7 @@ func TestValidateRetrievalReceipt(t *testing.T) {
121152
)
122153
require.NoError(t, err)
123154

124-
_, err = validateRetrievalReceipt(context.Background(), storageNode.DID(), rcpt, vCtx)
155+
_, err = validateRetrievalReceipt(context.Background(), storageNode.DID(), rcpt, vCtx, consumerTable, []string{knownProvider.String()})
125156
assert.ErrorContains(t, err, "receipt is not issued by the requester node")
126157
})
127158

@@ -137,7 +168,7 @@ func TestValidateRetrievalReceipt(t *testing.T) {
137168
// Tamper with the receipt to change its result
138169
tamperReceiptResult(t, rcpt)
139170

140-
_, err = validateRetrievalReceipt(context.Background(), storageNode.DID(), rcpt, vCtx)
171+
_, err = validateRetrievalReceipt(context.Background(), storageNode.DID(), rcpt, vCtx, consumerTable, []string{knownProvider.String()})
141172
assert.ErrorContains(t, err, "receipt signature is invalid")
142173
})
143174

@@ -149,7 +180,7 @@ func TestValidateRetrievalReceipt(t *testing.T) {
149180
)
150181
require.NoError(t, err)
151182

152-
_, err = validateRetrievalReceipt(context.Background(), storageNode.DID(), rcpt, vCtx)
183+
_, err = validateRetrievalReceipt(context.Background(), storageNode.DID(), rcpt, vCtx, consumerTable, []string{knownProvider.String()})
153184
assert.ErrorContains(t, err, "original retrieve invocation must be attached to the receipt")
154185
})
155186

@@ -173,11 +204,28 @@ func TestValidateRetrievalReceipt(t *testing.T) {
173204
)
174205
require.NoError(t, err)
175206

176-
_, err = validateRetrievalReceipt(context.Background(), storageNode.DID(), rcpt, vCtx)
207+
_, err = validateRetrievalReceipt(context.Background(), storageNode.DID(), rcpt, vCtx, consumerTable, []string{knownProvider.String()})
177208
expectedErr := "original invocation is not a " + content.RetrieveAbility + " invocation, but a other/ability one"
178209
assert.ErrorContains(t, err, expectedErr)
179210
})
180211

212+
t.Run("wrong space provider", func(t *testing.T) {
213+
rcpt, err := receipt.Issue(
214+
storageNode,
215+
result.Ok[content.RetrieveOk, failure.IPLDBuilderFailure](content.RetrieveOk{}),
216+
ran.FromInvocation(inv),
217+
)
218+
require.NoError(t, err)
219+
220+
otherProvider, err := did.Parse("did:web:up.other.net")
221+
require.NoError(t, err)
222+
223+
consumerTable := &mockConsumerTable{t: t, provider: otherProvider}
224+
225+
_, err = validateRetrievalReceipt(context.Background(), storageNode.DID(), rcpt, vCtx, consumerTable, []string{knownProvider.String()})
226+
assert.ErrorContains(t, err, "unknown space provider")
227+
})
228+
181229
t.Run("invalid delegation chain", func(t *testing.T) {
182230
// Bob invokes on the space, but the proof is from the space to Alice
183231
bobInvokes, err := invocation.Invoke(
@@ -201,7 +249,7 @@ func TestValidateRetrievalReceipt(t *testing.T) {
201249
)
202250
require.NoError(t, err)
203251

204-
_, err = validateRetrievalReceipt(context.Background(), storageNode.DID(), rcpt, vCtx)
252+
_, err = validateRetrievalReceipt(context.Background(), storageNode.DID(), rcpt, vCtx, consumerTable, []string{knownProvider.String()})
205253
assert.ErrorContains(t, err, "invalid delegation chain")
206254
})
207255
}

internal/db/consumer/consumer.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@ import (
66
"github.com/storacha/go-ucanto/did"
77
)
88

9+
type Consumer struct {
10+
ID did.DID
11+
Provider did.DID
12+
Subscription string
13+
}
14+
915
type ConsumerTable interface {
16+
Get(ctx context.Context, consumerID string) (Consumer, error)
1017
ListByCustomer(ctx context.Context, customerID did.DID) ([]did.DID, error)
1118
}

internal/db/consumer/dynamodb.go

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,38 @@ var _ ConsumerTable = (*DynamoConsumerTable)(nil)
1616
type DynamoConsumerTable struct {
1717
client *dynamodb.Client
1818
tableName string
19+
consumerIndexName string
1920
customerIndexName string
2021
}
2122

22-
func NewDynamoConsumerTable(client *dynamodb.Client, tableName string, customerIndexName string) *DynamoConsumerTable {
23-
return &DynamoConsumerTable{client, tableName, customerIndexName}
23+
func NewDynamoConsumerTable(client *dynamodb.Client, tableName, consumerIndexName, customerIndexName string) *DynamoConsumerTable {
24+
return &DynamoConsumerTable{client, tableName, consumerIndexName, customerIndexName}
25+
}
26+
27+
func (d *DynamoConsumerTable) Get(ctx context.Context, consumerID string) (Consumer, error) {
28+
// Query the consumer index to get the item by consumer ID
29+
result, err := d.client.Query(ctx, &dynamodb.QueryInput{
30+
TableName: aws.String(d.tableName),
31+
IndexName: aws.String(d.consumerIndexName),
32+
KeyConditionExpression: aws.String("consumer = :consumer"),
33+
ExpressionAttributeValues: map[string]types.AttributeValue{
34+
":consumer": &types.AttributeValueMemberS{Value: consumerID},
35+
},
36+
})
37+
if err != nil {
38+
return Consumer{}, fmt.Errorf("querying consumer by ID: %w", err)
39+
}
40+
41+
if len(result.Items) == 0 {
42+
return Consumer{}, fmt.Errorf("consumer not found: %s", consumerID)
43+
}
44+
45+
consumer, err := d.unmarshalConsumer(result.Items[0])
46+
if err != nil {
47+
return Consumer{}, fmt.Errorf("unmarshaling consumer: %w", err)
48+
}
49+
50+
return *consumer, nil
2451
}
2552

2653
func (d *DynamoConsumerTable) ListByCustomer(ctx context.Context, customerID did.DID) ([]did.DID, error) {
@@ -55,7 +82,7 @@ func (d *DynamoConsumerTable) ListByCustomer(ctx context.Context, customerID did
5582
if err != nil {
5683
return nil, err
5784
}
58-
consumers = append(consumers, consumer)
85+
consumers = append(consumers, consumer.ID)
5986
}
6087

6188
// Check if there are more results to fetch
@@ -70,19 +97,33 @@ func (d *DynamoConsumerTable) ListByCustomer(ctx context.Context, customerID did
7097

7198
// consumerRecord is the internal struct for unmarshaling from DynamoDB
7299
type consumerRecord struct {
73-
Consumer string `dynamodbav:"consumer"`
100+
Consumer string `dynamodbav:"consumer"`
101+
Provider string `dynamodbav:"provider,omitempty"`
102+
Subscription string `dynamodbav:"subscription,omitempty"`
74103
}
75104

76-
func (d *DynamoConsumerTable) unmarshalConsumer(item map[string]types.AttributeValue) (did.DID, error) {
105+
func (d *DynamoConsumerTable) unmarshalConsumer(item map[string]types.AttributeValue) (*Consumer, error) {
77106
var record consumerRecord
78107
if err := attributevalue.UnmarshalMap(item, &record); err != nil {
79-
return did.DID{}, fmt.Errorf("unmarshaling consumer record: %w", err)
108+
return nil, fmt.Errorf("unmarshaling consumer record: %w", err)
80109
}
81110

82111
consumerDID, err := did.Parse(record.Consumer)
83112
if err != nil {
84-
return did.DID{}, fmt.Errorf("parsing consumer DID: %w", err)
113+
return nil, fmt.Errorf("parsing consumer DID: %w", err)
114+
}
115+
116+
providerDID := did.Undef
117+
if record.Provider != "" {
118+
providerDID, err = did.Parse(record.Provider)
119+
if err != nil {
120+
return nil, fmt.Errorf("parsing provider DID: %w", err)
121+
}
85122
}
86123

87-
return consumerDID, nil
124+
return &Consumer{
125+
ID: consumerDID,
126+
Provider: providerDID,
127+
Subscription: record.Subscription,
128+
}, nil
88129
}

0 commit comments

Comments
 (0)