Skip to content

Commit 5fc1474

Browse files
RUFFY-369rickstaavictorges
authored
feat(worker): replace defaultImage with overrides struct (#293)
This commit introduces an `overrides` struct to replace the `defaultImage`. The new struct allows overriding both the default image and pipeline-specific images. This enhancement enables orchestrators and developers to specify custom images for specific pipelines, providing greater flexibility and configurability. --------- Co-authored-by: Rick Staa <[email protected]> Co-authored-by: Victor Elias <[email protected]>
1 parent a11302b commit 5fc1474

File tree

3 files changed

+114
-18
lines changed

3 files changed

+114
-18
lines changed

worker/docker.go

+23-10
Original file line numberDiff line numberDiff line change
@@ -56,21 +56,27 @@ var containerHostPorts = map[string]string{
5656
"live-video-to-video": "8900",
5757
}
5858

59-
// Mapping for per pipeline container images.
59+
// Default pipeline container image mapping to use if no overrides are provided.
60+
var defaultBaseImage = "livepeer/ai-runner:latest"
6061
var pipelineToImage = map[string]string{
6162
"segment-anything-2": "livepeer/ai-runner:segment-anything-2",
6263
"text-to-speech": "livepeer/ai-runner:text-to-speech",
6364
"audio-to-text": "livepeer/ai-runner:audio-to-text",
6465
"llm": "livepeer/ai-runner:llm",
6566
}
66-
6767
var livePipelineToImage = map[string]string{
6868
"streamdiffusion": "livepeer/ai-runner:live-app-streamdiffusion",
6969
"comfyui": "livepeer/ai-runner:live-app-comfyui",
7070
"segment_anything_2": "livepeer/ai-runner:live-app-segment_anything_2",
7171
"noop": "livepeer/ai-runner:live-app-noop",
7272
}
7373

74+
type ImageOverrides struct {
75+
Default string `json:"default"`
76+
Batch map[string]string `json:"batch"`
77+
Live map[string]string `json:"live"`
78+
}
79+
7480
// DockerClient is an interface for the Docker client, allowing for mocking in tests.
7581
// NOTE: ensure any docker.Client methods used in this package are added.
7682
type DockerClient interface {
@@ -91,9 +97,9 @@ var _ DockerClient = (*docker.Client)(nil)
9197
var dockerWaitUntilRunningFunc = dockerWaitUntilRunning
9298

9399
type DockerManager struct {
94-
defaultImage string
95-
gpus []string
96-
modelDir string
100+
gpus []string
101+
modelDir string
102+
overrides ImageOverrides
97103

98104
dockerClient DockerClient
99105
// gpu ID => container name
@@ -103,7 +109,7 @@ type DockerManager struct {
103109
mu *sync.Mutex
104110
}
105111

106-
func NewDockerManager(defaultImage string, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) {
112+
func NewDockerManager(overrides ImageOverrides, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) {
107113
ctx, cancel := context.WithTimeout(context.Background(), containerTimeout)
108114
if err := removeExistingContainers(ctx, client); err != nil {
109115
cancel()
@@ -112,9 +118,9 @@ func NewDockerManager(defaultImage string, gpus []string, modelDir string, clien
112118
cancel()
113119

114120
manager := &DockerManager{
115-
defaultImage: defaultImage,
116121
gpus: gpus,
117122
modelDir: modelDir,
123+
overrides: overrides,
118124
dockerClient: client,
119125
gpuContainers: make(map[string]string),
120126
containers: make(map[string]*RunnerContainer),
@@ -215,17 +221,24 @@ func (m *DockerManager) returnContainer(rc *RunnerContainer) {
215221
func (m *DockerManager) getContainerImageName(pipeline, modelID string) (string, error) {
216222
if pipeline == "live-video-to-video" {
217223
// We currently use the model ID as the live pipeline name for legacy reasons.
218-
if image, ok := livePipelineToImage[modelID]; ok {
224+
if image, ok := m.overrides.Live[modelID]; ok {
225+
return image, nil
226+
} else if image, ok := livePipelineToImage[modelID]; ok {
219227
return image, nil
220228
}
221229
return "", fmt.Errorf("no container image found for live pipeline %s", modelID)
222230
}
223231

224-
if image, ok := pipelineToImage[pipeline]; ok {
232+
if image, ok := m.overrides.Batch[pipeline]; ok {
233+
return image, nil
234+
} else if image, ok := pipelineToImage[pipeline]; ok {
225235
return image, nil
226236
}
227237

228-
return m.defaultImage, nil
238+
if m.overrides.Default != "" {
239+
return m.overrides.Default, nil
240+
}
241+
return defaultBaseImage, nil
229242
}
230243

231244
// HasCapacity checks if an unused managed container exists or if a GPU is available for a new container.

worker/docker_test.go

+89-6
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ func NewMockServer() *MockServer {
9696
// createDockerManager creates a DockerManager with a mock DockerClient.
9797
func createDockerManager(mockDockerClient *MockDockerClient) *DockerManager {
9898
return &DockerManager{
99-
defaultImage: "default-image",
10099
gpus: []string{"gpu0"},
101100
modelDir: "/models",
101+
overrides: ImageOverrides{Default: "default-image"},
102102
dockerClient: mockDockerClient,
103103
gpuContainers: make(map[string]string),
104104
containers: make(map[string]*RunnerContainer),
@@ -110,10 +110,10 @@ func TestNewDockerManager(t *testing.T) {
110110
mockDockerClient := new(MockDockerClient)
111111

112112
createAndVerifyManager := func() *DockerManager {
113-
manager, err := NewDockerManager("default-image", []string{"gpu0"}, "/models", mockDockerClient)
113+
manager, err := NewDockerManager(ImageOverrides{Default: "default-image"}, []string{"gpu0"}, "/models", mockDockerClient)
114114
require.NoError(t, err)
115115
require.NotNil(t, manager)
116-
require.Equal(t, "default-image", manager.defaultImage)
116+
require.Equal(t, "default-image", manager.overrides.Default)
117117
require.Equal(t, []string{"gpu0"}, manager.gpus)
118118
require.Equal(t, "/models", manager.modelDir)
119119
require.Equal(t, mockDockerClient, manager.dockerClient)
@@ -301,47 +301,130 @@ func TestDockerManager_returnContainer(t *testing.T) {
301301

302302
func TestDockerManager_getContainerImageName(t *testing.T) {
303303
mockDockerClient := new(MockDockerClient)
304-
manager := createDockerManager(mockDockerClient)
304+
dockerManager := createDockerManager(mockDockerClient)
305305

306306
tests := []struct {
307307
name string
308+
setup func(*DockerManager, *MockDockerClient)
308309
pipeline string
309310
modelID string
310311
expectedImage string
311312
expectError bool
312313
}{
313314
{
314315
name: "live-video-to-video with valid modelID",
316+
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {},
315317
pipeline: "live-video-to-video",
316318
modelID: "streamdiffusion",
317319
expectedImage: "livepeer/ai-runner:live-app-streamdiffusion",
318320
expectError: false,
319321
},
320322
{
321323
name: "live-video-to-video with invalid modelID",
324+
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {},
322325
pipeline: "live-video-to-video",
323326
modelID: "invalid-model",
324327
expectError: true,
325328
},
326329
{
327330
name: "valid pipeline",
331+
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {},
328332
pipeline: "text-to-speech",
329333
modelID: "",
330334
expectedImage: "livepeer/ai-runner:text-to-speech",
331335
expectError: false,
332336
},
333337
{
334338
name: "invalid pipeline",
339+
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {},
335340
pipeline: "invalid-pipeline",
336341
modelID: "",
337342
expectedImage: "default-image",
338343
expectError: false,
339344
},
345+
{
346+
name: "override default image",
347+
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {
348+
dockerManager.overrides = ImageOverrides{
349+
Default: "custom-image",
350+
}
351+
},
352+
pipeline: "",
353+
modelID: "",
354+
expectedImage: "custom-image",
355+
expectError: false,
356+
},
357+
{
358+
name: "override batch image",
359+
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {
360+
dockerManager.overrides = ImageOverrides{
361+
Batch: map[string]string{
362+
"text-to-speech": "custom-image",
363+
},
364+
}
365+
},
366+
pipeline: "text-to-speech",
367+
modelID: "",
368+
expectedImage: "custom-image",
369+
expectError: false,
370+
},
371+
{
372+
name: "override live image",
373+
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {
374+
dockerManager.overrides = ImageOverrides{
375+
Live: map[string]string{
376+
"streamdiffusion": "custom-image",
377+
},
378+
}
379+
},
380+
pipeline: "live-video-to-video",
381+
modelID: "streamdiffusion",
382+
expectedImage: "custom-image",
383+
expectError: false,
384+
},
385+
{
386+
name: "non-overridden batch image",
387+
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {
388+
dockerManager.overrides = ImageOverrides{
389+
Default: "default-image",
390+
Batch: map[string]string{
391+
"text-to-speech": "custom-batch-image",
392+
},
393+
Live: map[string]string{
394+
"streamdiffusion": "custom-live-image",
395+
},
396+
}
397+
},
398+
pipeline: "audio-to-text",
399+
modelID: "",
400+
expectedImage: "livepeer/ai-runner:audio-to-text",
401+
expectError: false,
402+
},
403+
{
404+
name: "non-overridden live image",
405+
setup: func(dockerManager *DockerManager, mockDockerClient *MockDockerClient) {
406+
dockerManager.overrides = ImageOverrides{
407+
Default: "default-image",
408+
Batch: map[string]string{
409+
"text-to-speech": "custom-batch-image",
410+
},
411+
Live: map[string]string{
412+
"streamdiffusion": "custom-live-image",
413+
},
414+
}
415+
},
416+
pipeline: "live-video-to-video",
417+
modelID: "comfyui",
418+
expectedImage: "livepeer/ai-runner:live-app-comfyui",
419+
expectError: false,
420+
},
340421
}
341422

342423
for _, tt := range tests {
343424
t.Run(tt.name, func(t *testing.T) {
344-
image, err := manager.getContainerImageName(tt.pipeline, tt.modelID)
425+
tt.setup(dockerManager, mockDockerClient)
426+
427+
image, err := dockerManager.getContainerImageName(tt.pipeline, tt.modelID)
345428
if tt.expectError {
346429
require.Error(t, err)
347430
require.Equal(t, fmt.Sprintf("no container image found for live pipeline %s", tt.modelID), err.Error())
@@ -500,7 +583,7 @@ func TestDockerManager_createContainer(t *testing.T) {
500583
dockerManager.gpus = []string{gpu}
501584
dockerManager.gpuContainers = make(map[string]string)
502585
dockerManager.containers = make(map[string]*RunnerContainer)
503-
dockerManager.defaultImage = containerImage
586+
dockerManager.overrides.Default = containerImage
504587

505588
mockDockerClient.On("ContainerCreate", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(container.CreateResponse{ID: containerID}, nil)
506589
mockDockerClient.On("ContainerStart", mock.Anything, containerID, mock.Anything).Return(nil)

worker/worker.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,13 @@ type Worker struct {
5151
mu *sync.Mutex
5252
}
5353

54-
func NewWorker(defaultImage string, gpus []string, modelDir string) (*Worker, error) {
54+
func NewWorker(imageOverrides ImageOverrides, gpus []string, modelDir string) (*Worker, error) {
5555
dockerClient, err := docker.NewClientWithOpts(docker.FromEnv, docker.WithAPIVersionNegotiation())
5656
if err != nil {
5757
return nil, err
5858
}
5959

60-
manager, err := NewDockerManager(defaultImage, gpus, modelDir, dockerClient)
60+
manager, err := NewDockerManager(imageOverrides, gpus, modelDir, dockerClient)
6161
if err != nil {
6262
return nil, err
6363
}

0 commit comments

Comments
 (0)