Skip to content

Commit

Permalink
Add support for importing models stored in the Modelcar format
Browse files Browse the repository at this point in the history
This allows dsl.import to leverage Modelcar container images in an OCI
repository. This works by having an init container prepull the image and
then adding a sidecar container when the launcher container is running.
The Modelcar container adds a symlink to its /models directory in an
emptyDir volume that is accessible by the launcher container. Once the
launcher is done running the user code, it stops the Modelcar
containers.

This approach has the benefit of leveraging image pull secrets
configured on the Kubernetes cluster rather than require separate
credentials for importing the artifact. Additionally, no data is copied
to the emptyDir volume, so the storage cost is just pulling the Modelcar
container image on the Kubernetes worker node.

Note that once Kubernetes supports OCI images as volume mounts for
several releases, consider replacing the init container with that
approach.

This also adds a new environment variable of PIPELINE_RUN_AS_USER to
set the runAsUser on all pods created by Argo Workflows.

Resolves:
#11584

Signed-off-by: mprahl <[email protected]>
  • Loading branch information
mprahl committed Feb 12, 2025
1 parent 65d1d79 commit a380cf3
Show file tree
Hide file tree
Showing 14 changed files with 460 additions and 10 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/kfp-samples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,16 @@ jobs:
with:
k8s_version: ${{ matrix.k8s_version }}

- name: Build and upload the sample Modelcar image to Kind
run: |
docker build -f samples/v2/modelcar_import/Dockerfile -t registry.domain.local/modelcar:test .
kind --name kfp load docker-image registry.domain.local/modelcar:test
- name: Forward API port
run: ./.github/resources/scripts/forward-port.sh "kubeflow" "ml-pipeline" 8888 8888

- name: Run Samples Tests
env:
PULL_NUMBER: ${{ github.event.pull_request.number }}
run: |
./backend/src/v2/test/sample-test.sh
8 changes: 8 additions & 0 deletions backend/src/v2/compiler/argocompiler/argo.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ func Compile(jobArg *pipelinespec.PipelineJob, kubernetesSpecArg *pipelinespec.S
Entrypoint: tmplEntrypoint,
},
}

runAsUser := GetPipelineRunAsUser()
if runAsUser != nil {
wf.Spec.SecurityContext = &k8score.PodSecurityContext{
RunAsUser: GetPipelineRunAsUser(),
}
}

c := &workflowCompiler{
wf: wf,
templates: make(map[string]*wfapi.Template),
Expand Down
25 changes: 25 additions & 0 deletions backend/src/v2/compiler/argocompiler/argo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func Test_argo_compiler(t *testing.T) {
jobPath string // path of input PipelineJob to compile
platformSpecPath string // path of possible input PlatformSpec to compile
argoYAMLPath string // path of expected output argo workflow YAML
envVars map[string]string
}{
{
jobPath: "../testdata/hello_world.json",
Expand Down Expand Up @@ -67,9 +68,33 @@ func Test_argo_compiler(t *testing.T) {
platformSpecPath: "",
argoYAMLPath: "testdata/exit_handler.yaml",
},
{
jobPath: "../testdata/hello_world.json",
platformSpecPath: "",
argoYAMLPath: "testdata/hello_world_run_as_user.yaml",
envVars: map[string]string{"PIPELINE_RUN_AS_USER": "1001"},
},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("%+v", tt), func(t *testing.T) {
prevEnvVars := map[string]string{}

for envVarName, envVarValue := range tt.envVars {
prevEnvVars[envVarName] = os.Getenv(envVarName)

os.Setenv(envVarName, envVarValue)
}

defer func() {
for envVarName, envVarValue := range prevEnvVars {
if envVarValue == "" {
os.Unsetenv(envVarName)
} else {
os.Setenv(envVarName, envVarValue)
}
}
}()

job, platformSpec := load(t, tt.jobPath, tt.platformSpecPath)
if *update {
wf, err := argocompiler.Compile(job, platformSpec, nil)
Expand Down
22 changes: 22 additions & 0 deletions backend/src/v2/compiler/argocompiler/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ package argocompiler
import (
"fmt"
"os"
"strconv"
"strings"

wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
"github.com/golang/glog"
"github.com/golang/protobuf/jsonpb"
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
"github.com/kubeflow/pipelines/backend/src/v2/component"
Expand All @@ -36,6 +38,7 @@ const (
DriverImageEnvVar = "V2_DRIVER_IMAGE"
DefaultDriverCommand = "driver"
DriverCommandEnvVar = "V2_DRIVER_COMMAND"
PipelineRunAsUserEnvVar = "PIPELINE_RUN_AS_USER"
gcsScratchLocation = "/gcs"
gcsScratchName = "gcs-scratch"
s3ScratchLocation = "/s3"
Expand Down Expand Up @@ -101,6 +104,25 @@ func GetDriverCommand() []string {
return strings.Split(driverCommand, " ")
}

func GetPipelineRunAsUser() *int64 {
runAsUserStr := os.Getenv(PipelineRunAsUserEnvVar)
if runAsUserStr == "" {
return nil
}

runAsUser, err := strconv.ParseInt(runAsUserStr, 10, 64)
if err != nil {
glog.Error(
"Failed to parse the %s environment variable with value %s as an int64: %v",
PipelineRunAsUserEnvVar, runAsUserStr, err,
)

return nil
}

return &runAsUser
}

func (c *workflowCompiler) containerDriverTask(name string, inputs containerDriverInputs) (*wfapi.DAGTask, *containerDriverOutputs) {
dagTask := &wfapi.DAGTask{
Name: name,
Expand Down
24 changes: 20 additions & 4 deletions backend/src/v2/component/importer_launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"fmt"
"strings"

"github.com/kubeflow/pipelines/backend/src/v2/objectstore"

pb "github.com/kubeflow/pipelines/third_party/ml-metadata/go/ml_metadata"
Expand Down Expand Up @@ -227,10 +229,6 @@ func (l *ImportLauncher) ImportSpecToMLMDArtifact(ctx context.Context) (artifact

state := pb.Artifact_LIVE

provider, err := objectstore.ParseProviderFromPath(artifactUri)
if err != nil {
return nil, fmt.Errorf("No Provider scheme found in artifact Uri: %s", artifactUri)
}
artifact = &pb.Artifact{
TypeId: &artifactTypeId,
State: &state,
Expand All @@ -248,6 +246,24 @@ func (l *ImportLauncher) ImportSpecToMLMDArtifact(ctx context.Context) (artifact
}
}

if strings.HasPrefix(artifactUri, "oci://") {
artifactType, err := metadata.SchemaToArtifactType(schema)
if err != nil {
return nil, fmt.Errorf("converting schema to artifact type failed: %w", err)
}

if *artifactType.Name != "system.Model" {
return nil, fmt.Errorf("the %s artifact type does not support OCI registries", *artifactType.Name)
}

return artifact, nil
}

provider, err := objectstore.ParseProviderFromPath(artifactUri)
if err != nil {
return nil, fmt.Errorf("no provider scheme found in artifact URI: %s", artifactUri)
}

// Assume all imported artifacts will rely on execution environment for store provider session info
storeSessionInfo := objectstore.SessionInfo{
Provider: provider,
Expand Down
100 changes: 94 additions & 6 deletions backend/src/v2/component/launcher_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,51 @@ func NewLauncherV2(ctx context.Context, executionID int64, executorInputJSON, co
}, nil
}

// stopWaitingArtifacts will create empty files to tell Modelcar sidecar containers to stop. Any errors encountered are
// logged since this is meant as a deferred function at the end of the launcher's execution.
func stopWaitingArtifacts(artifacts map[string]*pipelinespec.ArtifactList) {
for _, artifactList := range artifacts {
if len(artifactList.Artifacts) == 0 {
continue
}

// Following the convention of downloadArtifacts in the launcher to only look at the first in the list.
inputArtifact := artifactList.Artifacts[0]

// This should ideally verify that this is also a model input artifact, but this metadata doesn't seem to
// be set on inputArtifact.
if !strings.HasPrefix(inputArtifact.Uri, "oci://") {
continue
}

localPath, err := LocalPathForURI(inputArtifact.Uri)
if err != nil {
continue
}

glog.Infof("Stopping artifact %s", inputArtifact.Uri)

launcherCompleteFile := strings.TrimSuffix(localPath, "/models") + "/launcher-complete"
_, err = os.Create(launcherCompleteFile)
if err != nil {
glog.Errorf(
"Failed to stop the artifact %s by creating %s: %v", inputArtifact.Uri, launcherCompleteFile, err,
)

continue
}
}
}

func (l *LauncherV2) Execute(ctx context.Context) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("failed to execute component: %w", err)
}
}()

defer stopWaitingArtifacts(l.executorInput.GetInputs().GetArtifacts())

// publish execution regardless the task succeeds or not
var execution *metadata.Execution
var executorOutput *pipelinespec.ExecutorOutput
Expand Down Expand Up @@ -401,6 +440,7 @@ func execute(
if err := downloadArtifacts(ctx, executorInput, bucket, bucketConfig, namespace, k8sClient); err != nil {
return nil, err
}

if err := prepareOutputFolders(executorInput); err != nil {
return nil, err
}
Expand Down Expand Up @@ -441,7 +481,7 @@ func uploadOutputArtifacts(ctx context.Context, executorInput *pipelinespec.Exec
}

// Upload artifacts from local path to remote storages.
localDir, err := localPathForURI(outputArtifact.Uri)
localDir, err := LocalPathForURI(outputArtifact.Uri)
if err != nil {
glog.Warningf("Output Artifact %q does not have a recognized storage URI %q. Skipping uploading to remote storage.", name, outputArtifact.Uri)
} else {
Expand Down Expand Up @@ -477,6 +517,31 @@ func uploadOutputArtifacts(ctx context.Context, executorInput *pipelinespec.Exec
return outputArtifacts, nil
}

// waitForModelcar assumes the Modelcar has already been validated by the init container on the launcher
// pod. This waits for the Modelcar as a sidecar container to be ready.
func waitForModelcar(artifactURI string, localPath string) error {
glog.Infof("Waiting for the Modelcar %s to be available", artifactURI)

for {
_, err := os.Stat(localPath)
if err == nil {
glog.Infof("The Modelcar is now available at %s", localPath)

return nil
}

if !os.IsNotExist(err) {
return fmt.Errorf(
"failed to see if the artifact %s was ready at %s; ensure the main container and Modelcar "+
"container have the same UID (can be set with the PIPELINE_RUN_AS_USER environment variable on "+
"the API server): %v",
artifactURI, localPath, err)
}

time.Sleep(500 * time.Millisecond)
}
}

func downloadArtifacts(ctx context.Context, executorInput *pipelinespec.ExecutorInput, defaultBucket *blob.Bucket, defaultBucketConfig *objectstore.Config, namespace string, k8sClient kubernetes.Interface) error {
// Read input artifact metadata.
nonDefaultBuckets, err := fetchNonDefaultBuckets(ctx, executorInput.GetInputs().GetArtifacts(), defaultBucketConfig, namespace, k8sClient)
Expand All @@ -491,17 +556,31 @@ func downloadArtifacts(ctx context.Context, executorInput *pipelinespec.Executor
if err != nil {
return fmt.Errorf("failed to fetch non default buckets: %w", err)
}

for name, artifactList := range executorInput.GetInputs().GetArtifacts() {
// TODO(neuromage): Support concat-based placholders for arguments.
if len(artifactList.Artifacts) == 0 {
continue
}
inputArtifact := artifactList.Artifacts[0]
localPath, err := localPathForURI(inputArtifact.Uri)

localPath, err := LocalPathForURI(inputArtifact.Uri)
if err != nil {
glog.Warningf("Input Artifact %q does not have a recognized storage URI %q. Skipping downloading to local path.", name, inputArtifact.Uri)

continue
}

// OCI artifacts are handled specially
if strings.HasPrefix(inputArtifact.Uri, "oci://") {
err := waitForModelcar(inputArtifact.Uri, localPath)
if err != nil {
return err
}

continue
}

// Copy artifact to local storage.
copyErr := func(err error) error {
return fmt.Errorf("failed to download input artifact %q from remote storage URI %q: %w", name, inputArtifact.Uri, err)
Expand Down Expand Up @@ -548,6 +627,12 @@ func fetchNonDefaultBuckets(
}
// TODO: Support multiple artifacts someday, probably through the v2 engine.
artifact := artifactList.Artifacts[0]

// OCI artifacts are handled specially
if strings.HasPrefix(artifact.Uri, "oci://") {
continue
}

// The artifact does not belong under the object store path for this run. Cases:
// 1. Artifact is cached from a different run, so it may still be in the default bucket, but under a different run id subpath
// 2. Artifact is imported from the same bucket, but from a different path (re-use the same session)
Expand Down Expand Up @@ -598,7 +683,7 @@ func getPlaceholders(executorInput *pipelinespec.ExecutorInput) (placeholders ma
key := fmt.Sprintf(`{{$.inputs.artifacts['%s'].uri}}`, name)
placeholders[key] = inputArtifact.Uri

localPath, err := localPathForURI(inputArtifact.Uri)
localPath, err := LocalPathForURI(inputArtifact.Uri)
if err != nil {
// Input Artifact does not have a recognized storage URI
continue
Expand All @@ -617,7 +702,7 @@ func getPlaceholders(executorInput *pipelinespec.ExecutorInput) (placeholders ma
outputArtifact := artifactList.Artifacts[0]
placeholders[fmt.Sprintf(`{{$.outputs.artifacts['%s'].uri}}`, name)] = outputArtifact.Uri

localPath, err := localPathForURI(outputArtifact.Uri)
localPath, err := LocalPathForURI(outputArtifact.Uri)
if err != nil {
return nil, fmt.Errorf("resolve output artifact %q's local path: %w", name, err)
}
Expand Down Expand Up @@ -720,7 +805,7 @@ func getExecutorOutputFile(path string) (*pipelinespec.ExecutorOutput, error) {
return executorOutput, nil
}

func localPathForURI(uri string) (string, error) {
func LocalPathForURI(uri string) (string, error) {
if strings.HasPrefix(uri, "gs://") {
return "/gcs/" + strings.TrimPrefix(uri, "gs://"), nil
}
Expand All @@ -730,6 +815,9 @@ func localPathForURI(uri string) (string, error) {
if strings.HasPrefix(uri, "s3://") {
return "/s3/" + strings.TrimPrefix(uri, "s3://"), nil
}
if strings.HasPrefix(uri, "oci://") {
return "/oci/" + strings.ReplaceAll(strings.TrimPrefix(uri, "oci://"), "/", "\\/") + "/models", nil
}
return "", fmt.Errorf("failed to generate local path for URI %s: unsupported storage scheme", uri)
}

Expand All @@ -747,7 +835,7 @@ func prepareOutputFolders(executorInput *pipelinespec.ExecutorInput) error {
}
outputArtifact := artifactList.Artifacts[0]

localPath, err := localPathForURI(outputArtifact.Uri)
localPath, err := LocalPathForURI(outputArtifact.Uri)
if err != nil {
return fmt.Errorf("failed to generate local storage path for output artifact %q: %w", name, err)
}
Expand Down
Loading

0 comments on commit a380cf3

Please sign in to comment.