Skip to content

Commit 97c3d66

Browse files
committed
Implement resumable downloads
So if a download gets interrupted, we do not have to start again from scratch Signed-off-by: Eric Curtin <[email protected]>
1 parent 6d1c75e commit 97c3d66

File tree

9 files changed

+378
-23
lines changed

9 files changed

+378
-23
lines changed

Dockerfile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ WORKDIR /app
1919
# Copy go mod/sum first for better caching
2020
COPY --link go.mod go.sum ./
2121

22+
# Copy pkg/go-containerregistry for the replace directive in go.mod
23+
COPY --link pkg/go-containerregistry ./pkg/go-containerregistry
24+
2225
# Download dependencies (with cache mounts)
2326
RUN --mount=type=cache,target=/go/pkg/mod \
2427
--mount=type=cache,target=/root/.cache/go-build \

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ integration-tests:
7777
@echo "Integration tests completed!"
7878

7979
validate:
80-
find . -type f -name "*.sh" | xargs shellcheck
80+
find . -type f -name "*.sh" | grep -v pkg/go-containerregistry | xargs shellcheck
8181

8282
# Build Docker image
8383
docker-build:

pkg/distribution/distribution/client.go

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"github.com/docker/model-runner/pkg/distribution/registry"
1818
"github.com/docker/model-runner/pkg/distribution/tarball"
1919
"github.com/docker/model-runner/pkg/distribution/types"
20+
"github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1/remote"
2021
"github.com/docker/model-runner/pkg/inference/platform"
2122
)
2223

@@ -140,24 +141,90 @@ func NewClient(opts ...Option) (*Client, error) {
140141
func (c *Client) PullModel(ctx context.Context, reference string, progressWriter io.Writer) error {
141142
c.log.Infoln("Starting model pull:", utils.SanitizeForLog(reference))
142143

144+
// First, fetch the remote model to get the manifest
143145
remoteModel, err := c.registry.Model(ctx, reference)
144146
if err != nil {
145147
return fmt.Errorf("reading model from registry: %w", err)
146148
}
147149

148-
// Check for supported type
149-
if err := checkCompat(remoteModel, c.log, reference); err != nil {
150-
return err
151-
}
152-
153-
// Get the remote image digest
150+
// Get the remote image digest immediately to ensure we work with a consistent manifest
151+
// This prevents race conditions if the tag is updated during the pull
154152
remoteDigest, err := remoteModel.Digest()
155153
if err != nil {
156154
c.log.Errorln("Failed to get remote image digest:", err)
157155
return fmt.Errorf("getting remote image digest: %w", err)
158156
}
159157
c.log.Infoln("Remote model digest:", remoteDigest.String())
160158

159+
// Check for incomplete downloads and prepare resume offsets
160+
layers, err := remoteModel.Layers()
161+
if err != nil {
162+
return fmt.Errorf("getting layers: %w", err)
163+
}
164+
165+
// Build a map of digest -> resume offset for layers with incomplete downloads
166+
resumeOffsets := make(map[string]int64)
167+
for _, layer := range layers {
168+
digest, err := layer.Digest()
169+
if err != nil {
170+
c.log.Warnf("Failed to get layer digest: %v", err)
171+
continue
172+
}
173+
174+
// Check if there's an incomplete download for this layer (use DiffID for uncompressed models)
175+
diffID, err := layer.DiffID()
176+
if err != nil {
177+
c.log.Warnf("Failed to get layer diffID: %v", err)
178+
continue
179+
}
180+
181+
incompleteSize, err := c.store.GetIncompleteSize(diffID)
182+
if err != nil {
183+
c.log.Warnf("Failed to check incomplete size for layer %s: %v", digest, err)
184+
continue
185+
}
186+
187+
if incompleteSize > 0 {
188+
c.log.Infof("Found incomplete download for layer %s: %d bytes", digest, incompleteSize)
189+
resumeOffsets[digest.String()] = incompleteSize
190+
}
191+
}
192+
193+
// If we have any incomplete downloads, create a new context with resume offsets
194+
// and re-fetch using the digest to ensure we're resuming the same manifest
195+
if len(resumeOffsets) > 0 {
196+
c.log.Infof("Resuming %d interrupted layer download(s)", len(resumeOffsets))
197+
ctx = remote.WithResumeOffsets(ctx, resumeOffsets)
198+
// Re-fetch the model using the digest reference to prevent race conditions
199+
// Extract repository name from the original reference and construct digest reference
200+
repository := reference
201+
// Find the last occurrence of : or @ (tag or digest separator)
202+
// We need to search after the last / to avoid matching port separators
203+
if lastSlash := strings.LastIndex(reference, "/"); lastSlash != -1 {
204+
// Search for : or @ after the last slash
205+
suffix := reference[lastSlash:]
206+
if idx := strings.IndexAny(suffix, ":@"); idx != -1 {
207+
repository = reference[:lastSlash+idx]
208+
}
209+
} else {
210+
// No slash found, search from beginning (e.g., "library/image:tag" or "image:tag")
211+
if idx := strings.IndexAny(reference, ":@"); idx != -1 {
212+
repository = reference[:idx]
213+
}
214+
}
215+
digestReference := repository + "@" + remoteDigest.String()
216+
c.log.Infof("Re-fetching model with digest reference: %s", utils.SanitizeForLog(digestReference))
217+
remoteModel, err = c.registry.Model(ctx, digestReference)
218+
if err != nil {
219+
return fmt.Errorf("reading model from registry with resume context: %w", err)
220+
}
221+
}
222+
223+
// Check for supported type
224+
if err := checkCompat(remoteModel, c.log, reference); err != nil {
225+
return err
226+
}
227+
161228
// Check if model exists in local store
162229
localModel, err := c.store.Read(remoteDigest.String())
163230
if err == nil {

pkg/distribution/internal/progress/reader.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@ func NewReader(r io.Reader, updates chan<- v1.Update) io.Reader {
2424
}
2525
}
2626

27+
// NewReaderWithOffset returns a reader that reports progress starting from an initial offset.
28+
// This is useful for resuming interrupted downloads.
29+
func NewReaderWithOffset(r io.Reader, updates chan<- v1.Update, initialOffset int64) io.Reader {
30+
if updates == nil {
31+
return r
32+
}
33+
return &Reader{
34+
Reader: r,
35+
ProgressChan: updates,
36+
Total: initialOffset,
37+
}
38+
}
39+
2740
func (pr *Reader) Read(p []byte) (int, error) {
2841
n, err := pr.Reader.Read(p)
2942
pr.Total += int64(n)

pkg/distribution/internal/store/blobs.go

Lines changed: 166 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"os"
77
"path/filepath"
88
"strings"
9+
"time"
910

1011
"github.com/docker/model-runner/pkg/distribution/internal/progress"
1112

@@ -78,6 +79,12 @@ type blob interface {
7879
Uncompressed() (io.ReadCloser, error)
7980
}
8081

82+
// layerWithDigest extends blob to include the Digest method
83+
type layerWithDigest interface {
84+
blob
85+
Digest() (v1.Hash, error)
86+
}
87+
8188
// writeLayer writes the layer blob to the store.
8289
// It returns true when a new blob was created and the blob's DiffID.
8390
func (s *LocalStore) writeLayer(layer blob, updates chan<- v1.Update) (bool, v1.Hash, error) {
@@ -94,13 +101,28 @@ func (s *LocalStore) writeLayer(layer blob, updates chan<- v1.Update) (bool, v1.
94101
return false, hash, nil
95102
}
96103

104+
// Check if we're resuming an incomplete download
105+
incompleteSize, err := s.GetIncompleteSize(hash)
106+
if err != nil {
107+
return false, v1.Hash{}, fmt.Errorf("check incomplete size: %w", err)
108+
}
109+
97110
lr, err := layer.Uncompressed()
98111
if err != nil {
99112
return false, v1.Hash{}, fmt.Errorf("get blob contents: %w", err)
100113
}
101114
defer lr.Close()
102-
r := progress.NewReader(lr, updates)
103115

116+
// Wrap the reader with progress reporting, accounting for already downloaded bytes
117+
var r io.Reader
118+
if incompleteSize > 0 {
119+
r = progress.NewReaderWithOffset(lr, updates, incompleteSize)
120+
} else {
121+
r = progress.NewReader(lr, updates)
122+
}
123+
124+
// WriteBlob will handle appending to incomplete files
125+
// The HTTP layer will handle resuming via Range headers
104126
if err := s.WriteBlob(hash, r); err != nil {
105127
return false, hash, err
106128
}
@@ -109,6 +131,7 @@ func (s *LocalStore) writeLayer(layer blob, updates chan<- v1.Update) (bool, v1.
109131

110132
// WriteBlob writes the blob to the store, reporting progress to the given channel.
111133
// If the blob is already in the store, it is a no-op and the blob is not consumed from the reader.
134+
// If an incomplete download exists, it will be resumed by appending to the existing file.
112135
func (s *LocalStore) WriteBlob(diffID v1.Hash, r io.Reader) error {
113136
hasBlob, err := s.hasBlob(diffID)
114137
if err != nil {
@@ -122,21 +145,83 @@ func (s *LocalStore) WriteBlob(diffID v1.Hash, r io.Reader) error {
122145
if err != nil {
123146
return fmt.Errorf("get blob path: %w", err)
124147
}
125-
f, err := createFile(incompletePath(path))
126-
if err != nil {
127-
return fmt.Errorf("create blob file: %w", err)
148+
149+
incompletePath := incompletePath(path)
150+
151+
// Check if we're resuming a partial download
152+
var f *os.File
153+
var isResume bool
154+
if _, err := os.Stat(incompletePath); err == nil {
155+
// Before resuming, verify that the incomplete file isn't already complete
156+
existingFile, err := os.Open(incompletePath)
157+
if err != nil {
158+
return fmt.Errorf("open incomplete file for verification: %w", err)
159+
}
160+
161+
computedHash, _, err := v1.SHA256(existingFile)
162+
existingFile.Close()
163+
164+
if err == nil && computedHash.String() == diffID.String() {
165+
// File is already complete, just rename it
166+
if err := os.Rename(incompletePath, path); err != nil {
167+
return fmt.Errorf("rename completed blob file: %w", err)
168+
}
169+
return nil
170+
}
171+
172+
// File is incomplete or corrupt, try to resume
173+
isResume = true
174+
f, err = os.OpenFile(incompletePath, os.O_WRONLY|os.O_APPEND, 0644)
175+
if err != nil {
176+
return fmt.Errorf("open incomplete blob file for resume: %w", err)
177+
}
178+
} else {
179+
// New download: create file
180+
f, err = createFile(incompletePath)
181+
if err != nil {
182+
return fmt.Errorf("create blob file: %w", err)
183+
}
128184
}
129-
defer os.Remove(incompletePath(path))
130185
defer f.Close()
131186

132187
if _, err := io.Copy(f, r); err != nil {
188+
// If we were resuming and copy failed, the incomplete file might be corrupt
189+
if isResume {
190+
_ = os.Remove(incompletePath)
191+
}
133192
return fmt.Errorf("copy blob %q to store: %w", diffID.String(), err)
134193
}
135194

136195
f.Close() // Rename will fail on Windows if the file is still open.
137-
if err := os.Rename(incompletePath(path), path); err != nil {
196+
197+
// For resumed downloads, verify the complete file's hash before finalizing
198+
// (For new downloads, the stream was already verified during download)
199+
if isResume {
200+
completeFile, err := os.Open(incompletePath)
201+
if err != nil {
202+
return fmt.Errorf("open completed file for verification: %w", err)
203+
}
204+
defer completeFile.Close()
205+
206+
computedHash, _, err := v1.SHA256(completeFile)
207+
if err != nil {
208+
return fmt.Errorf("compute hash of completed file: %w", err)
209+
}
210+
211+
if computedHash.String() != diffID.String() {
212+
// The resumed download is corrupt, remove it so we can start fresh next time
213+
_ = os.Remove(incompletePath)
214+
return fmt.Errorf("hash mismatch after download: got %s, want %s", computedHash, diffID)
215+
}
216+
}
217+
218+
if err := os.Rename(incompletePath, path); err != nil {
138219
return fmt.Errorf("rename blob file: %w", err)
139220
}
221+
222+
// Only remove incomplete file if rename succeeded (though rename should have moved it)
223+
// This is a safety cleanup in case rename didn't remove the source
224+
os.Remove(incompletePath)
140225
return nil
141226
}
142227

@@ -160,6 +245,25 @@ func (s *LocalStore) hasBlob(hash v1.Hash) (bool, error) {
160245
return false, nil
161246
}
162247

248+
// GetIncompleteSize returns the size of an incomplete blob if it exists, or 0 if it doesn't.
249+
func (s *LocalStore) GetIncompleteSize(hash v1.Hash) (int64, error) {
250+
path, err := s.blobPath(hash)
251+
if err != nil {
252+
return 0, fmt.Errorf("get blob path: %w", err)
253+
}
254+
255+
incompletePath := incompletePath(path)
256+
stat, err := os.Stat(incompletePath)
257+
if err != nil {
258+
if os.IsNotExist(err) {
259+
return 0, nil
260+
}
261+
return 0, fmt.Errorf("stat incomplete file: %w", err)
262+
}
263+
264+
return stat.Size(), nil
265+
}
266+
163267
// createFile is a wrapper around os.Create that creates any parent directories as needed.
164268
func createFile(path string) (*os.File, error) {
165269
if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil {
@@ -201,3 +305,59 @@ func (s *LocalStore) writeConfigFile(mdl v1.Image) (bool, error) {
201305
}
202306
return true, nil
203307
}
308+
309+
// CleanupStaleIncompleteFiles removes incomplete download files that haven't been modified
310+
// for more than the specified duration. This prevents disk space leaks from abandoned downloads.
311+
func (s *LocalStore) CleanupStaleIncompleteFiles(maxAge time.Duration) error {
312+
blobsPath := s.blobsDir()
313+
if _, err := os.Stat(blobsPath); os.IsNotExist(err) {
314+
// Blobs directory doesn't exist yet, nothing to clean up
315+
return nil
316+
}
317+
318+
var cleanedCount int
319+
var cleanupErrors []error
320+
321+
// Walk through the blobs directory looking for .incomplete files
322+
err := filepath.Walk(blobsPath, func(path string, info os.FileInfo, err error) error {
323+
if err != nil {
324+
// Continue walking even if we encounter errors on individual files
325+
return nil
326+
}
327+
328+
// Skip directories
329+
if info.IsDir() {
330+
return nil
331+
}
332+
333+
// Only process .incomplete files
334+
if !strings.HasSuffix(path, ".incomplete") {
335+
return nil
336+
}
337+
338+
// Check if file is older than maxAge
339+
if time.Since(info.ModTime()) > maxAge {
340+
if removeErr := os.Remove(path); removeErr != nil {
341+
cleanupErrors = append(cleanupErrors, fmt.Errorf("failed to remove stale incomplete file %s: %w", path, removeErr))
342+
} else {
343+
cleanedCount++
344+
}
345+
}
346+
347+
return nil
348+
})
349+
350+
if err != nil {
351+
return fmt.Errorf("walking blobs directory: %w", err)
352+
}
353+
354+
if len(cleanupErrors) > 0 {
355+
return fmt.Errorf("encountered %d errors during cleanup (cleaned %d files): %v", len(cleanupErrors), cleanedCount, cleanupErrors[0])
356+
}
357+
358+
if cleanedCount > 0 {
359+
fmt.Printf("Cleaned up %d stale incomplete download file(s)\n", cleanedCount)
360+
}
361+
362+
return nil
363+
}

pkg/distribution/internal/store/store.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"io"
77
"os"
88
"path/filepath"
9+
"time"
910

1011
v1 "github.com/docker/model-runner/pkg/go-containerregistry/pkg/v1"
1112

@@ -81,6 +82,13 @@ func (s *LocalStore) initialize() error {
8182
}
8283
}
8384

85+
// Clean up stale incomplete files (older than 7 days)
86+
// This prevents disk space leaks from abandoned downloads
87+
if err := s.CleanupStaleIncompleteFiles(7 * 24 * time.Hour); err != nil {
88+
// Log the error but don't fail initialization
89+
fmt.Printf("Warning: failed to clean up stale incomplete files: %v\n", err)
90+
}
91+
8492
return nil
8593
}
8694

0 commit comments

Comments
 (0)