@@ -2,20 +2,27 @@ package worker
2
2
3
3
import (
4
4
"context"
5
+ "encoding/json"
5
6
"errors"
6
7
"fmt"
8
+ "io"
7
9
"log/slog"
8
10
"strings"
9
11
"sync"
10
12
"time"
11
13
12
14
"github.com/docker/cli/opts"
15
+ "github.com/docker/docker/api/types"
13
16
"github.com/docker/docker/api/types/container"
14
17
"github.com/docker/docker/api/types/filters"
18
+ "github.com/docker/docker/api/types/image"
15
19
"github.com/docker/docker/api/types/mount"
20
+ "github.com/docker/docker/api/types/network"
16
21
docker "github.com/docker/docker/client"
17
22
"github.com/docker/docker/errdefs"
23
+ "github.com/docker/docker/pkg/jsonmessage"
18
24
"github.com/docker/go-connections/nat"
25
+ ocispec "github.com/opencontainers/image-spec/specs-go/v1"
19
26
)
20
27
21
28
const containerModelDir = "/models"
@@ -27,7 +34,8 @@ const optFlagsContainerTimeout = 5 * time.Minute
27
34
const containerRemoveTimeout = 30 * time .Second
28
35
const containerCreatorLabel = "creator"
29
36
const containerCreator = "ai-worker"
30
- const containerWatchInterval = 10 * time .Second
37
+
38
+ var containerWatchInterval = 10 * time .Second
31
39
32
40
// This only works right now on a single GPU because if there is another container
33
41
// using the GPU we stop it so we don't have to worry about having enough ports
@@ -57,41 +65,76 @@ var livePipelineToImage = map[string]string{
57
65
"noop" : "livepeer/ai-runner:live-app-noop" ,
58
66
}
59
67
68
+ // DockerClient is an interface for the Docker client, allowing for mocking in tests.
69
+ // NOTE: ensure any docker.Client methods used in this package are added.
70
+ type DockerClient interface {
71
+ ContainerCreate (ctx context.Context , config * container.Config , hostConfig * container.HostConfig , networkingConfig * network.NetworkingConfig , platform * ocispec.Platform , containerName string ) (container.CreateResponse , error )
72
+ ContainerInspect (ctx context.Context , containerID string ) (types.ContainerJSON , error )
73
+ ContainerList (ctx context.Context , options container.ListOptions ) ([]types.Container , error )
74
+ ContainerRemove (ctx context.Context , containerID string , options container.RemoveOptions ) error
75
+ ContainerStart (ctx context.Context , containerID string , options container.StartOptions ) error
76
+ ContainerStop (ctx context.Context , containerID string , options container.StopOptions ) error
77
+ ImageInspectWithRaw (ctx context.Context , imageID string ) (types.ImageInspect , []byte , error )
78
+ ImagePull (ctx context.Context , ref string , options image.PullOptions ) (io.ReadCloser , error )
79
+ }
80
+
81
+ // Compile-time assertion to ensure docker.Client implements DockerClient.
82
+ var _ DockerClient = (* docker .Client )(nil )
83
+
84
+ // Create global references to functions to allow for mocking in tests.
85
+ var dockerWaitUntilRunningFunc = dockerWaitUntilRunning
86
+
60
87
type DockerManager struct {
61
88
defaultImage string
62
89
gpus []string
63
90
modelDir string
64
91
65
- dockerClient * docker. Client
92
+ dockerClient DockerClient
66
93
// gpu ID => container name
67
94
gpuContainers map [string ]string
68
95
// container name => container
69
96
containers map [string ]* RunnerContainer
70
97
mu * sync.Mutex
71
98
}
72
99
73
- func NewDockerManager (defaultImage string , gpus []string , modelDir string ) (* DockerManager , error ) {
74
- dockerClient , err := docker .NewClientWithOpts (docker .FromEnv , docker .WithAPIVersionNegotiation ())
75
- if err != nil {
76
- return nil , err
77
- }
78
-
100
+ func NewDockerManager (defaultImage string , gpus []string , modelDir string , client DockerClient ) (* DockerManager , error ) {
79
101
ctx , cancel := context .WithTimeout (context .Background (), containerTimeout )
80
- if err := removeExistingContainers (ctx , dockerClient ); err != nil {
102
+ if err := removeExistingContainers (ctx , client ); err != nil {
81
103
cancel ()
82
104
return nil , err
83
105
}
84
106
cancel ()
85
107
86
- return & DockerManager {
108
+ manager := & DockerManager {
87
109
defaultImage : defaultImage ,
88
110
gpus : gpus ,
89
111
modelDir : modelDir ,
90
- dockerClient : dockerClient ,
112
+ dockerClient : client ,
91
113
gpuContainers : make (map [string ]string ),
92
114
containers : make (map [string ]* RunnerContainer ),
93
115
mu : & sync.Mutex {},
94
- }, nil
116
+ }
117
+
118
+ return manager , nil
119
+ }
120
+
121
+ // EnsureImageAvailable ensures the container image is available locally for the given pipeline and model ID.
122
+ func (m * DockerManager ) EnsureImageAvailable (ctx context.Context , pipeline string , modelID string ) error {
123
+ imageName , err := m .getContainerImageName (pipeline , modelID )
124
+ if err != nil {
125
+ return err
126
+ }
127
+
128
+ // Pull the image if it is not available locally.
129
+ if ! m .isImageAvailable (ctx , pipeline , modelID ) {
130
+ slog .Info (fmt .Sprintf ("Pulling image for pipeline %s and modelID %s: %s" , pipeline , modelID , imageName ))
131
+ err = m .pullImage (ctx , imageName )
132
+ if err != nil {
133
+ return err
134
+ }
135
+ }
136
+
137
+ return nil
95
138
}
96
139
97
140
func (m * DockerManager ) Warm (ctx context.Context , pipeline string , modelID string , optimizationFlags OptimizationFlags ) error {
@@ -157,6 +200,24 @@ func (m *DockerManager) returnContainer(rc *RunnerContainer) {
157
200
m .containers [rc .Name ] = rc
158
201
}
159
202
203
+ // getContainerImageName returns the image name for the given pipeline and model ID.
204
+ // Returns an error if the image is not found for "live-video-to-video".
205
+ func (m * DockerManager ) getContainerImageName (pipeline , modelID string ) (string , error ) {
206
+ if pipeline == "live-video-to-video" {
207
+ // We currently use the model ID as the live pipeline name for legacy reasons.
208
+ if image , ok := livePipelineToImage [modelID ]; ok {
209
+ return image , nil
210
+ }
211
+ return "" , fmt .Errorf ("no container image found for live pipeline %s" , modelID )
212
+ }
213
+
214
+ if image , ok := pipelineToImage [pipeline ]; ok {
215
+ return image , nil
216
+ }
217
+
218
+ return m .defaultImage , nil
219
+ }
220
+
160
221
// HasCapacity checks if an unused managed container exists or if a GPU is available for a new container.
161
222
func (m * DockerManager ) HasCapacity (ctx context.Context , pipeline , modelID string ) bool {
162
223
m .mu .Lock ()
@@ -169,11 +230,57 @@ func (m *DockerManager) HasCapacity(ctx context.Context, pipeline, modelID strin
169
230
}
170
231
}
171
232
233
+ // TODO: This can be removed if we optimize the selection algorithm.
234
+ // Currently, using CreateContainer errors only can cause orchestrator reselection.
235
+ if ! m .isImageAvailable (ctx , pipeline , modelID ) {
236
+ return false
237
+ }
238
+
172
239
// Check for available GPU to allocate for a new container for the requested model.
173
240
_ , err := m .allocGPU (ctx )
174
241
return err == nil
175
242
}
176
243
244
+ // isImageAvailable checks if the specified image is available locally.
245
+ func (m * DockerManager ) isImageAvailable (ctx context.Context , pipeline string , modelID string ) bool {
246
+ imageName , err := m .getContainerImageName (pipeline , modelID )
247
+ if err != nil {
248
+ slog .Error (err .Error ())
249
+ return false
250
+ }
251
+
252
+ _ , _ , err = m .dockerClient .ImageInspectWithRaw (ctx , imageName )
253
+ if err != nil {
254
+ slog .Error (fmt .Sprintf ("Image for pipeline %s and modelID %s is not available locally: %s" , pipeline , modelID , imageName ))
255
+ }
256
+ return err == nil
257
+ }
258
+
259
+ // pullImage pulls the specified image from the registry.
260
+ func (m * DockerManager ) pullImage (ctx context.Context , imageName string ) error {
261
+ reader , err := m .dockerClient .ImagePull (ctx , imageName , image.PullOptions {})
262
+ if err != nil {
263
+ return fmt .Errorf ("failed to pull image: %w" , err )
264
+ }
265
+ defer reader .Close ()
266
+
267
+ // Display progress messages from ImagePull reader.
268
+ decoder := json .NewDecoder (reader )
269
+ for {
270
+ var progress jsonmessage.JSONMessage
271
+ if err := decoder .Decode (& progress ); err == io .EOF {
272
+ break
273
+ } else if err != nil {
274
+ return fmt .Errorf ("error decoding progress message: %w" , err )
275
+ }
276
+ if progress .Status != "" && progress .Progress != nil {
277
+ slog .Info (fmt .Sprintf ("%s: %s" , progress .Status , progress .Progress .String ()))
278
+ }
279
+ }
280
+
281
+ return nil
282
+ }
283
+
177
284
func (m * DockerManager ) createContainer (ctx context.Context , pipeline string , modelID string , keepWarm bool , optimizationFlags OptimizationFlags ) (* RunnerContainer , error ) {
178
285
gpu , err := m .allocGPU (ctx )
179
286
if err != nil {
@@ -183,15 +290,9 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo
183
290
// NOTE: We currently allow only one container per GPU for each pipeline.
184
291
containerHostPort := containerHostPorts [pipeline ][:3 ] + gpu
185
292
containerName := dockerContainerName (pipeline , modelID , containerHostPort )
186
- containerImage := m .defaultImage
187
- if pipelineSpecificImage , ok := pipelineToImage [pipeline ]; ok {
188
- containerImage = pipelineSpecificImage
189
- } else if pipeline == "live-video-to-video" {
190
- // We currently use the model ID as the live pipeline name for legacy reasons
191
- containerImage = livePipelineToImage [modelID ]
192
- if containerImage == "" {
193
- return nil , fmt .Errorf ("no container image found for live pipeline %s" , modelID )
194
- }
293
+ containerImage , err := m .getContainerImageName (pipeline , modelID )
294
+ if err != nil {
295
+ return nil , err
195
296
}
196
297
197
298
slog .Info ("Starting managed container" , slog .String ("gpu" , gpu ), slog .String ("name" , containerName ), slog .String ("modelID" , modelID ), slog .String ("containerImage" , containerImage ))
@@ -258,7 +359,7 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo
258
359
cancel ()
259
360
260
361
cctx , cancel = context .WithTimeout (ctx , containerTimeout )
261
- if err := dockerWaitUntilRunning (cctx , m .dockerClient , resp .ID , pollingInterval ); err != nil {
362
+ if err := dockerWaitUntilRunningFunc (cctx , m .dockerClient , resp .ID , pollingInterval ); err != nil {
262
363
cancel ()
263
364
dockerRemoveContainer (m .dockerClient , resp .ID )
264
365
return nil , err
@@ -390,7 +491,7 @@ func (m *DockerManager) watchContainer(rc *RunnerContainer, borrowCtx context.Co
390
491
}
391
492
}
392
493
393
- func removeExistingContainers (ctx context.Context , client * docker. Client ) error {
494
+ func removeExistingContainers (ctx context.Context , client DockerClient ) error {
394
495
filters := filters .NewArgs (filters .Arg ("label" , containerCreatorLabel + "=" + containerCreator ))
395
496
containers , err := client .ContainerList (ctx , container.ListOptions {All : true , Filters : filters })
396
497
if err != nil {
@@ -416,7 +517,7 @@ func dockerContainerName(pipeline string, modelID string, suffix ...string) stri
416
517
return fmt .Sprintf ("%s_%s" , pipeline , sanitizedModelID )
417
518
}
418
519
419
- func dockerRemoveContainer (client * docker. Client , containerID string ) error {
520
+ func dockerRemoveContainer (client DockerClient , containerID string ) error {
420
521
ctx , cancel := context .WithTimeout (context .Background (), containerRemoveTimeout )
421
522
defer cancel ()
422
523
@@ -449,7 +550,7 @@ func dockerRemoveContainer(client *docker.Client, containerID string) error {
449
550
}
450
551
}
451
552
452
- func dockerWaitUntilRunning (ctx context.Context , client * docker. Client , containerID string , pollingInterval time.Duration ) error {
553
+ func dockerWaitUntilRunning (ctx context.Context , client DockerClient , containerID string , pollingInterval time.Duration ) error {
453
554
ticker := time .NewTicker (pollingInterval )
454
555
defer ticker .Stop ()
455
556
0 commit comments