Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/preparation/shards/car.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ func init() {

noRootsHeader = buf.Bytes()
hasher := newShardHashState()
defer hasher.reset()
_, err = hasher.Write(noRootsHeader)
if err != nil {
panic(fmt.Sprintf("failed to hash CAR header: %v", err))
Expand Down
22 changes: 15 additions & 7 deletions pkg/preparation/shards/shardhashstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,32 +67,40 @@ func (s *shardHashState) marshal() ([]byte, []byte, error) {
if err != nil {
return nil, nil, fmt.Errorf("marshaling digest hash state: %w", err)
}
s.digestHash.Reset()
pieceCIDState, err := s.commpCalc.MarshalBinary()
if err != nil {
return nil, nil, fmt.Errorf("marshaling piece CID state: %w", err)
}
s.commpCalc.Reset()
return digestState, pieceCIDState, nil
}

func (s *shardHashState) finalize(shardSize uint64) (multihash.Multihash, cid.Cid, error) {
func (s *shardHashState) reset() {
s.digestHash.Reset()
s.commpCalc.Reset()
}

type shardHashes struct {
shardDigest multihash.Multihash
pieceCID cid.Cid
}

func (s *shardHashState) finalize(shardSize uint64) (shardHashes, error) {
shardDigest, err := multihash.Encode(s.digestHash.Sum(nil), multihash.SHA2_256)
if err != nil {
return nil, cid.Undef, fmt.Errorf("encoding shard digest: %w", err)
return shardHashes{}, fmt.Errorf("encoding shard digest: %w", err)
}

// If the shard is too small to form a piece CID, return undefined.
if shardSize < types.MinPiecePayload {
return shardDigest, cid.Undef, nil
return shardHashes{shardDigest: shardDigest, pieceCID: cid.Undef}, nil
}

pieceDigest := s.commpCalc.Sum(nil)

pieceCID, err := commcid.DataCommitmentToPieceCidv2(pieceDigest, shardSize)
if err != nil {
return nil, cid.Undef, fmt.Errorf("computing piece CID: %w", err)
return shardHashes{}, fmt.Errorf("computing piece CID: %w", err)
}

return shardDigest, pieceCID, nil
return shardHashes{shardDigest: shardDigest, pieceCID: pieceCID}, nil
}
54 changes: 26 additions & 28 deletions pkg/preparation/shards/shards.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,54 +141,52 @@ func (a API) addNodeToDigestState(ctx context.Context, shard *model.Shard, node
return digestStateUpdate{}, fmt.Errorf("expected %d bytes for node %s, got %d", node.Size(), node.CID(), len(data))
}

hasher, err := a.updatedShardHashState(ctx, shard)
if err != nil {
return digestStateUpdate{}, fmt.Errorf("getting updated shard %s hasher: %w", shard.ID(), err)
}

err = a.ShardEncoder.WriteNode(ctx, node, data, hasher)
if err != nil {
return digestStateUpdate{}, fmt.Errorf("writing node %s to shard %s digest state: %w", node.CID(), shard.ID(), err)
}
return withUpdatedShardHashState(ctx, a, shard, func(hasher *shardHashState) (digestStateUpdate, error) {
err := a.ShardEncoder.WriteNode(ctx, node, data, hasher)
if err != nil {
return digestStateUpdate{}, fmt.Errorf("writing node %s to shard %s digest state: %w", node.CID(), shard.ID(), err)
}

digestState, pieceCIDState, err := hasher.marshal()
if err != nil {
return digestStateUpdate{}, fmt.Errorf("marshaling shard %s digest state: %w", shard.ID(), err)
}
digestState, pieceCIDState, err := hasher.marshal()
if err != nil {
return digestStateUpdate{}, fmt.Errorf("marshaling shard %s digest state: %w", shard.ID(), err)
}

return digestStateUpdate{
digestStateUpTo: shard.Size() + a.ShardEncoder.NodeEncodingLength(node),
digestState: digestState,
pieceCIDState: pieceCIDState,
}, nil
return digestStateUpdate{
digestStateUpTo: shard.Size() + a.ShardEncoder.NodeEncodingLength(node),
digestState: digestState,
pieceCIDState: pieceCIDState,
}, nil
})
}

func (a API) updatedShardHashState(ctx context.Context, shard *model.Shard) (*shardHashState, error) {
func withUpdatedShardHashState[T any](ctx context.Context, a API, shard *model.Shard, action func(*shardHashState) (T, error)) (T, error) {
h, err := fromShard(shard)
defer h.reset()
if err != nil {
return nil, fmt.Errorf("getting shard %s hasher: %w", shard.ID(), err)
var zero T
return zero, fmt.Errorf("getting shard %s hasher: %w", shard.ID(), err)
}

if shard.DigestStateUpTo() < shard.Size() {
err := a.fastWriteShard(ctx, shard.ID(), shard.DigestStateUpTo(), h)
if err != nil {
return nil, fmt.Errorf("hashing remaining data for shard %s: %w", shard.ID(), err)
var zero T
return zero, fmt.Errorf("hashing remaining data for shard %s: %w", shard.ID(), err)
}
}

return h, nil
return action(h)
}

func (a API) finalizeShardDigests(ctx context.Context, shard *model.Shard) error {
h, err := a.updatedShardHashState(ctx, shard)
if err != nil {
return fmt.Errorf("getting updated shard %s hasher: %w", shard.ID(), err)
}
shardDigest, pieceCID, err := h.finalize(shard.Size())
sh, err := withUpdatedShardHashState(ctx, a, shard, func(hasher *shardHashState) (shardHashes, error) {
return hasher.finalize(shard.Size())
})
if err != nil {
return fmt.Errorf("finalizing digests for shard %s: %w", shard.ID(), err)
}
if err := shard.Close(shardDigest, pieceCID); err != nil {
if err := shard.Close(sh.shardDigest, sh.pieceCID); err != nil {
return err
}
return a.Repo.UpdateShard(ctx, shard)
Expand Down