@@ -96,9 +96,9 @@ func NewMockServer() *MockServer {
96
96
// createDockerManager creates a DockerManager with a mock DockerClient.
97
97
func createDockerManager (mockDockerClient * MockDockerClient ) * DockerManager {
98
98
return & DockerManager {
99
- defaultImage : "default-image" ,
100
99
gpus : []string {"gpu0" },
101
100
modelDir : "/models" ,
101
+ overrides : ImageOverrides {Default : "default-image" },
102
102
dockerClient : mockDockerClient ,
103
103
gpuContainers : make (map [string ]string ),
104
104
containers : make (map [string ]* RunnerContainer ),
@@ -110,10 +110,10 @@ func TestNewDockerManager(t *testing.T) {
110
110
mockDockerClient := new (MockDockerClient )
111
111
112
112
createAndVerifyManager := func () * DockerManager {
113
- manager , err := NewDockerManager ("default-image" , []string {"gpu0" }, "/models" , mockDockerClient )
113
+ manager , err := NewDockerManager (ImageOverrides { Default : "default-image" } , []string {"gpu0" }, "/models" , mockDockerClient )
114
114
require .NoError (t , err )
115
115
require .NotNil (t , manager )
116
- require .Equal (t , "default-image" , manager .defaultImage )
116
+ require .Equal (t , "default-image" , manager .overrides . Default )
117
117
require .Equal (t , []string {"gpu0" }, manager .gpus )
118
118
require .Equal (t , "/models" , manager .modelDir )
119
119
require .Equal (t , mockDockerClient , manager .dockerClient )
@@ -301,47 +301,130 @@ func TestDockerManager_returnContainer(t *testing.T) {
301
301
302
302
func TestDockerManager_getContainerImageName (t * testing.T ) {
303
303
mockDockerClient := new (MockDockerClient )
304
- manager := createDockerManager (mockDockerClient )
304
+ dockerManager := createDockerManager (mockDockerClient )
305
305
306
306
tests := []struct {
307
307
name string
308
+ setup func (* DockerManager , * MockDockerClient )
308
309
pipeline string
309
310
modelID string
310
311
expectedImage string
311
312
expectError bool
312
313
}{
313
314
{
314
315
name : "live-video-to-video with valid modelID" ,
316
+ setup : func (dockerManager * DockerManager , mockDockerClient * MockDockerClient ) {},
315
317
pipeline : "live-video-to-video" ,
316
318
modelID : "streamdiffusion" ,
317
319
expectedImage : "livepeer/ai-runner:live-app-streamdiffusion" ,
318
320
expectError : false ,
319
321
},
320
322
{
321
323
name : "live-video-to-video with invalid modelID" ,
324
+ setup : func (dockerManager * DockerManager , mockDockerClient * MockDockerClient ) {},
322
325
pipeline : "live-video-to-video" ,
323
326
modelID : "invalid-model" ,
324
327
expectError : true ,
325
328
},
326
329
{
327
330
name : "valid pipeline" ,
331
+ setup : func (dockerManager * DockerManager , mockDockerClient * MockDockerClient ) {},
328
332
pipeline : "text-to-speech" ,
329
333
modelID : "" ,
330
334
expectedImage : "livepeer/ai-runner:text-to-speech" ,
331
335
expectError : false ,
332
336
},
333
337
{
334
338
name : "invalid pipeline" ,
339
+ setup : func (dockerManager * DockerManager , mockDockerClient * MockDockerClient ) {},
335
340
pipeline : "invalid-pipeline" ,
336
341
modelID : "" ,
337
342
expectedImage : "default-image" ,
338
343
expectError : false ,
339
344
},
345
+ {
346
+ name : "override default image" ,
347
+ setup : func (dockerManager * DockerManager , mockDockerClient * MockDockerClient ) {
348
+ dockerManager .overrides = ImageOverrides {
349
+ Default : "custom-image" ,
350
+ }
351
+ },
352
+ pipeline : "" ,
353
+ modelID : "" ,
354
+ expectedImage : "custom-image" ,
355
+ expectError : false ,
356
+ },
357
+ {
358
+ name : "override batch image" ,
359
+ setup : func (dockerManager * DockerManager , mockDockerClient * MockDockerClient ) {
360
+ dockerManager .overrides = ImageOverrides {
361
+ Batch : map [string ]string {
362
+ "text-to-speech" : "custom-image" ,
363
+ },
364
+ }
365
+ },
366
+ pipeline : "text-to-speech" ,
367
+ modelID : "" ,
368
+ expectedImage : "custom-image" ,
369
+ expectError : false ,
370
+ },
371
+ {
372
+ name : "override live image" ,
373
+ setup : func (dockerManager * DockerManager , mockDockerClient * MockDockerClient ) {
374
+ dockerManager .overrides = ImageOverrides {
375
+ Live : map [string ]string {
376
+ "streamdiffusion" : "custom-image" ,
377
+ },
378
+ }
379
+ },
380
+ pipeline : "live-video-to-video" ,
381
+ modelID : "streamdiffusion" ,
382
+ expectedImage : "custom-image" ,
383
+ expectError : false ,
384
+ },
385
+ {
386
+ name : "non-overridden batch image" ,
387
+ setup : func (dockerManager * DockerManager , mockDockerClient * MockDockerClient ) {
388
+ dockerManager .overrides = ImageOverrides {
389
+ Default : "default-image" ,
390
+ Batch : map [string ]string {
391
+ "text-to-speech" : "custom-batch-image" ,
392
+ },
393
+ Live : map [string ]string {
394
+ "streamdiffusion" : "custom-live-image" ,
395
+ },
396
+ }
397
+ },
398
+ pipeline : "audio-to-text" ,
399
+ modelID : "" ,
400
+ expectedImage : "livepeer/ai-runner:audio-to-text" ,
401
+ expectError : false ,
402
+ },
403
+ {
404
+ name : "non-overridden live image" ,
405
+ setup : func (dockerManager * DockerManager , mockDockerClient * MockDockerClient ) {
406
+ dockerManager .overrides = ImageOverrides {
407
+ Default : "default-image" ,
408
+ Batch : map [string ]string {
409
+ "text-to-speech" : "custom-batch-image" ,
410
+ },
411
+ Live : map [string ]string {
412
+ "streamdiffusion" : "custom-live-image" ,
413
+ },
414
+ }
415
+ },
416
+ pipeline : "live-video-to-video" ,
417
+ modelID : "comfyui" ,
418
+ expectedImage : "livepeer/ai-runner:live-app-comfyui" ,
419
+ expectError : false ,
420
+ },
340
421
}
341
422
342
423
for _ , tt := range tests {
343
424
t .Run (tt .name , func (t * testing.T ) {
344
- image , err := manager .getContainerImageName (tt .pipeline , tt .modelID )
425
+ tt .setup (dockerManager , mockDockerClient )
426
+
427
+ image , err := dockerManager .getContainerImageName (tt .pipeline , tt .modelID )
345
428
if tt .expectError {
346
429
require .Error (t , err )
347
430
require .Equal (t , fmt .Sprintf ("no container image found for live pipeline %s" , tt .modelID ), err .Error ())
@@ -500,7 +583,7 @@ func TestDockerManager_createContainer(t *testing.T) {
500
583
dockerManager .gpus = []string {gpu }
501
584
dockerManager .gpuContainers = make (map [string ]string )
502
585
dockerManager .containers = make (map [string ]* RunnerContainer )
503
- dockerManager .defaultImage = containerImage
586
+ dockerManager .overrides . Default = containerImage
504
587
505
588
mockDockerClient .On ("ContainerCreate" , mock .Anything , mock .Anything , mock .Anything , mock .Anything , mock .Anything , mock .Anything ).Return (container.CreateResponse {ID : containerID }, nil )
506
589
mockDockerClient .On ("ContainerStart" , mock .Anything , containerID , mock .Anything ).Return (nil )
0 commit comments