diff --git a/pkg/preparation/shards/car.go b/pkg/preparation/shards/car.go index 27b9e8fd..037628f2 100644 --- a/pkg/preparation/shards/car.go +++ b/pkg/preparation/shards/car.go @@ -39,6 +39,7 @@ func init() { noRootsHeader = buf.Bytes() hasher := newShardHashState() + defer hasher.reset() _, err = hasher.Write(noRootsHeader) if err != nil { panic(fmt.Sprintf("failed to hash CAR header: %v", err)) diff --git a/pkg/preparation/shards/shardhashstate.go b/pkg/preparation/shards/shardhashstate.go index cc22453b..8c216c63 100644 --- a/pkg/preparation/shards/shardhashstate.go +++ b/pkg/preparation/shards/shardhashstate.go @@ -67,32 +67,40 @@ func (s *shardHashState) marshal() ([]byte, []byte, error) { if err != nil { return nil, nil, fmt.Errorf("marshaling digest hash state: %w", err) } - s.digestHash.Reset() pieceCIDState, err := s.commpCalc.MarshalBinary() if err != nil { return nil, nil, fmt.Errorf("marshaling piece CID state: %w", err) } - s.commpCalc.Reset() return digestState, pieceCIDState, nil } -func (s *shardHashState) finalize(shardSize uint64) (multihash.Multihash, cid.Cid, error) { +func (s *shardHashState) reset() { + s.digestHash.Reset() + s.commpCalc.Reset() +} + +type shardHashes struct { + shardDigest multihash.Multihash + pieceCID cid.Cid +} + +func (s *shardHashState) finalize(shardSize uint64) (shardHashes, error) { shardDigest, err := multihash.Encode(s.digestHash.Sum(nil), multihash.SHA2_256) if err != nil { - return nil, cid.Undef, fmt.Errorf("encoding shard digest: %w", err) + return shardHashes{}, fmt.Errorf("encoding shard digest: %w", err) } // If the shard is too small to form a piece CID, return undefined. if shardSize < types.MinPiecePayload { - return shardDigest, cid.Undef, nil + return shardHashes{shardDigest: shardDigest, pieceCID: cid.Undef}, nil } pieceDigest := s.commpCalc.Sum(nil) pieceCID, err := commcid.DataCommitmentToPieceCidv2(pieceDigest, shardSize) if err != nil { - return nil, cid.Undef, fmt.Errorf("computing piece CID: %w", err) + return shardHashes{}, fmt.Errorf("computing piece CID: %w", err) } - return shardDigest, pieceCID, nil + return shardHashes{shardDigest: shardDigest, pieceCID: pieceCID}, nil } diff --git a/pkg/preparation/shards/shards.go b/pkg/preparation/shards/shards.go index bb5af71a..d90e523f 100644 --- a/pkg/preparation/shards/shards.go +++ b/pkg/preparation/shards/shards.go @@ -141,54 +141,52 @@ func (a API) addNodeToDigestState(ctx context.Context, shard *model.Shard, node return digestStateUpdate{}, fmt.Errorf("expected %d bytes for node %s, got %d", node.Size(), node.CID(), len(data)) } - hasher, err := a.updatedShardHashState(ctx, shard) - if err != nil { - return digestStateUpdate{}, fmt.Errorf("getting updated shard %s hasher: %w", shard.ID(), err) - } - - err = a.ShardEncoder.WriteNode(ctx, node, data, hasher) - if err != nil { - return digestStateUpdate{}, fmt.Errorf("writing node %s to shard %s digest state: %w", node.CID(), shard.ID(), err) - } + return withUpdatedShardHashState(ctx, a, shard, func(hasher *shardHashState) (digestStateUpdate, error) { + err := a.ShardEncoder.WriteNode(ctx, node, data, hasher) + if err != nil { + return digestStateUpdate{}, fmt.Errorf("writing node %s to shard %s digest state: %w", node.CID(), shard.ID(), err) + } - digestState, pieceCIDState, err := hasher.marshal() - if err != nil { - return digestStateUpdate{}, fmt.Errorf("marshaling shard %s digest state: %w", shard.ID(), err) - } + digestState, pieceCIDState, err := hasher.marshal() + if err != nil { + return digestStateUpdate{}, fmt.Errorf("marshaling shard %s digest state: %w", shard.ID(), err) + } - return digestStateUpdate{ - digestStateUpTo: shard.Size() + a.ShardEncoder.NodeEncodingLength(node), - digestState: digestState, - pieceCIDState: pieceCIDState, - }, nil + return digestStateUpdate{ + digestStateUpTo: shard.Size() + a.ShardEncoder.NodeEncodingLength(node), + digestState: digestState, + pieceCIDState: pieceCIDState, + }, nil + }) } -func (a API) updatedShardHashState(ctx context.Context, shard *model.Shard) (*shardHashState, error) { +func withUpdatedShardHashState[T any](ctx context.Context, a API, shard *model.Shard, action func(*shardHashState) (T, error)) (T, error) { h, err := fromShard(shard) + defer h.reset() if err != nil { - return nil, fmt.Errorf("getting shard %s hasher: %w", shard.ID(), err) + var zero T + return zero, fmt.Errorf("getting shard %s hasher: %w", shard.ID(), err) } if shard.DigestStateUpTo() < shard.Size() { err := a.fastWriteShard(ctx, shard.ID(), shard.DigestStateUpTo(), h) if err != nil { - return nil, fmt.Errorf("hashing remaining data for shard %s: %w", shard.ID(), err) + var zero T + return zero, fmt.Errorf("hashing remaining data for shard %s: %w", shard.ID(), err) } } - return h, nil + return action(h) } func (a API) finalizeShardDigests(ctx context.Context, shard *model.Shard) error { - h, err := a.updatedShardHashState(ctx, shard) - if err != nil { - return fmt.Errorf("getting updated shard %s hasher: %w", shard.ID(), err) - } - shardDigest, pieceCID, err := h.finalize(shard.Size()) + sh, err := withUpdatedShardHashState(ctx, a, shard, func(hasher *shardHashState) (shardHashes, error) { + return hasher.finalize(shard.Size()) + }) if err != nil { return fmt.Errorf("finalizing digests for shard %s: %w", shard.ID(), err) } - if err := shard.Close(shardDigest, pieceCID); err != nil { + if err := shard.Close(sh.shardDigest, sh.pieceCID); err != nil { return err } return a.Repo.UpdateShard(ctx, shard)