diff --git a/cli/module_build.go b/cli/module_build.go index 23e7097d2c2..b1f496fbc80 100644 --- a/cli/module_build.go +++ b/cli/module_build.go @@ -784,17 +784,35 @@ func (c *viamClient) ensureModuleRegisteredInCloud( return nil } -func (c *viamClient) inferOrgIDFromManifest(manifest ModuleManifest) (string, error) { - moduleID, err := parseModuleID(manifest.ModuleID) +func (c *viamClient) getOrgIDForPart(part *apppb.RobotPart) (string, error) { + robot, err := c.client.GetRobot(c.c.Context, &apppb.GetRobotRequest{ + Id: part.GetRobot(), + }) if err != nil { return "", err } - org, err := getOrgByModuleIDPrefix(c, moduleID.prefix) + + location, err := c.client.GetLocation(c.c.Context, &apppb.GetLocationRequest{ + LocationId: robot.Robot.GetLocation(), + }) if err != nil { return "", err } - return org.GetId(), nil + // use the primary org id for the machine as the reload + // module org + var orgID string + for _, org := range location.Location.Organizations { + if org.Primary { + orgID = org.GetOrganizationId() + break + } + } + if orgID == "" { + orgID = location.Location.Organizations[0].GetOrganizationId() + } + + return orgID, nil } func (c *viamClient) triggerCloudReloadBuild( @@ -814,16 +832,10 @@ func (c *viamClient) triggerCloudReloadBuild( return "", err } - orgID, err := c.inferOrgIDFromManifest(manifest) - if err != nil { - return "", err - } - part, err := c.getRobotPart(partID) if err != nil { return "", err } - if part.Part == nil { return "", fmt.Errorf("part with id=%s not found", partID) } @@ -832,6 +844,11 @@ func (c *viamClient) triggerCloudReloadBuild( return "", errors.New("unable to determine platform for part") } + orgID, err := c.getOrgIDForPart(part.Part) + if err != nil { + return "", err + } + // App expects `BuildInfo` as the first request platform := part.Part.UserSuppliedInfo.Fields["platform"].GetStringValue() req := &buildpb.StartReloadBuildRequest{ @@ -901,10 +918,11 @@ func getNextReloadBuildUploadRequest(file *os.File) (*buildpb.StartReloadBuildRe // moduleCloudBuildInfo contains information needed to download a cloud build artifact. type moduleCloudBuildInfo struct { - ID string + ModuleID string Version string Platform string ArchivePath string // Path to the temporary archive that should be deleted after download + OrgID string } // moduleCloudReload triggers a cloud build and returns info needed to download the artifact. @@ -921,6 +939,18 @@ func (c *viamClient) moduleCloudReload( return nil, err } + part, err := c.getRobotPart(partID) + if err != nil { + return nil, err + } + if part.Part == nil { + return nil, fmt.Errorf("part with id=%s not found", partID) + } + orgID, err := c.getOrgIDForPart(part.Part) + if err != nil { + return nil, err + } + // ensure that the module has been registered in the cloud moduleID, err := parseModuleID(manifest.ModuleID) if err != nil { @@ -940,11 +970,6 @@ func (c *viamClient) moduleCloudReload( return nil, err } - id := ctx.String(generalFlagID) - if id == "" { - id = manifest.ModuleID - } - if err := pm.Start("archive"); err != nil { return nil, err } @@ -1011,13 +1036,19 @@ func (c *viamClient) moduleCloudReload( // Return build info so the caller can download the artifact with a spinner return &moduleCloudBuildInfo{ - ID: id, + ModuleID: manifest.ModuleID, + OrgID: orgID, Version: getReloadVersion(reloadVersionPrefix, partID), Platform: platform, ArchivePath: archivePath, }, nil } +// IsReloadVersion checks if the version is a reload version. +func IsReloadVersion(version string) bool { + return strings.HasPrefix(version, reloadVersionPrefix) +} + // ReloadModuleLocalAction builds a module locally, configures it on a robot, and starts or restarts it. func ReloadModuleLocalAction(c *cli.Context, args reloadModuleArgs) error { return reloadModuleAction(c, args, false) @@ -1174,7 +1205,8 @@ func reloadModuleActionInner( return err } downloadArgs := downloadModuleFlags{ - ID: buildInfo.ID, + ModuleID: buildInfo.ModuleID, + OrgID: buildInfo.OrgID, Version: buildInfo.Version, Platform: buildInfo.Platform, Destination: ".", diff --git a/cli/module_build_test.go b/cli/module_build_test.go index 5a8cfef414c..f9fc5e5027d 100644 --- a/cli/module_build_test.go +++ b/cli/module_build_test.go @@ -11,6 +11,7 @@ import ( "time" v1 "go.viam.com/api/app/build/v1" + apppb "go.viam.com/api/app/v1" "go.viam.com/test" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -480,3 +481,192 @@ func TestRetryableCopyToPart(t *testing.T) { test.That(t, errMsg, test.ShouldContainSubstring, "run the RDK as root") }) } + +func TestIsReloadVersion(t *testing.T) { + tests := []struct { + name string + version string + expected bool + }{ + { + name: "reload version with part ID", + version: "reload-abc123", + expected: true, + }, + { + name: "reload version simple", + version: "reload", + expected: true, + }, + { + name: "reload-source version", + version: "reload-source-abc123", + expected: true, + }, + { + name: "normal semver version", + version: "1.2.3", + expected: false, + }, + { + name: "latest version", + version: "latest", + expected: false, + }, + { + name: "empty version", + version: "", + expected: false, + }, + { + name: "version containing reload but not prefix", + version: "v1.0.0-reload", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsReloadVersion(tt.version) + test.That(t, result, test.ShouldEqual, tt.expected) + }) + } +} + +func TestGetOrgIDForPart(t *testing.T) { + t.Run("returns primary org ID", func(t *testing.T) { + expectedOrgID := "primary-org-123" + secondaryOrgID := "secondary-org-456" + robotID := "robot-abc" + locationID := "location-xyz" + + mockClient := &inject.AppServiceClient{ + GetRobotFunc: func(ctx context.Context, req *apppb.GetRobotRequest, + opts ...grpc.CallOption, + ) (*apppb.GetRobotResponse, error) { + return &apppb.GetRobotResponse{ + Robot: &apppb.Robot{ + Id: robotID, + Location: locationID, + }, + }, nil + }, + GetLocationFunc: func(ctx context.Context, req *apppb.GetLocationRequest, + opts ...grpc.CallOption, + ) (*apppb.GetLocationResponse, error) { + test.That(t, req.LocationId, test.ShouldEqual, locationID) + return &apppb.GetLocationResponse{ + Location: &apppb.Location{ + Id: locationID, + Organizations: []*apppb.LocationOrganization{ + {OrganizationId: secondaryOrgID, Primary: false}, + {OrganizationId: expectedOrgID, Primary: true}, + }, + }, + }, nil + }, + } + + _, vc, _, _ := setup(mockClient, nil, &inject.BuildServiceClient{}, map[string]any{}, "token") + + part := &apppb.RobotPart{ + Robot: robotID, + } + orgID, err := vc.getOrgIDForPart(part) + test.That(t, err, test.ShouldBeNil) + test.That(t, orgID, test.ShouldEqual, expectedOrgID) + }) + + t.Run("falls back to first org when no primary", func(t *testing.T) { + firstOrgID := "first-org-123" + secondOrgID := "second-org-456" + robotID := "robot-abc" + locationID := "location-xyz" + + mockClient := &inject.AppServiceClient{ + GetRobotFunc: func(ctx context.Context, req *apppb.GetRobotRequest, + opts ...grpc.CallOption, + ) (*apppb.GetRobotResponse, error) { + return &apppb.GetRobotResponse{ + Robot: &apppb.Robot{ + Id: robotID, + Location: locationID, + }, + }, nil + }, + GetLocationFunc: func(ctx context.Context, req *apppb.GetLocationRequest, + opts ...grpc.CallOption, + ) (*apppb.GetLocationResponse, error) { + return &apppb.GetLocationResponse{ + Location: &apppb.Location{ + Id: locationID, + Organizations: []*apppb.LocationOrganization{ + {OrganizationId: firstOrgID, Primary: false}, + {OrganizationId: secondOrgID, Primary: false}, + }, + }, + }, nil + }, + } + + _, vc, _, _ := setup(mockClient, nil, &inject.BuildServiceClient{}, map[string]any{}, "token") + + part := &apppb.RobotPart{ + Robot: robotID, + } + orgID, err := vc.getOrgIDForPart(part) + test.That(t, err, test.ShouldBeNil) + test.That(t, orgID, test.ShouldEqual, firstOrgID) + }) + + t.Run("returns error when GetRobot fails", func(t *testing.T) { + mockClient := &inject.AppServiceClient{ + GetRobotFunc: func(ctx context.Context, req *apppb.GetRobotRequest, + opts ...grpc.CallOption, + ) (*apppb.GetRobotResponse, error) { + return nil, errors.New("robot not found") + }, + } + + _, vc, _, _ := setup(mockClient, nil, &inject.BuildServiceClient{}, map[string]any{}, "token") + + part := &apppb.RobotPart{ + Robot: "robot-abc", + } + _, err := vc.getOrgIDForPart(part) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err.Error(), test.ShouldContainSubstring, "robot not found") + }) + + t.Run("returns error when GetLocation fails", func(t *testing.T) { + robotID := "robot-abc" + locationID := "location-xyz" + + mockClient := &inject.AppServiceClient{ + GetRobotFunc: func(ctx context.Context, req *apppb.GetRobotRequest, + opts ...grpc.CallOption, + ) (*apppb.GetRobotResponse, error) { + return &apppb.GetRobotResponse{ + Robot: &apppb.Robot{ + Id: robotID, + Location: locationID, + }, + }, nil + }, + GetLocationFunc: func(ctx context.Context, req *apppb.GetLocationRequest, + opts ...grpc.CallOption, + ) (*apppb.GetLocationResponse, error) { + return nil, errors.New("location not found") + }, + } + + _, vc, _, _ := setup(mockClient, nil, &inject.BuildServiceClient{}, map[string]any{}, "token") + + part := &apppb.RobotPart{ + Robot: robotID, + } + _, err := vc.getOrgIDForPart(part) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err.Error(), test.ShouldContainSubstring, "location not found") + }) +} diff --git a/cli/module_registry.go b/cli/module_registry.go index 9ab373b2575..970b66872a2 100644 --- a/cli/module_registry.go +++ b/cli/module_registry.go @@ -1060,13 +1060,14 @@ func getNextModuleUploadRequest(file *os.File) (*apppb.UploadModuleFileRequest, type downloadModuleFlags struct { Destination string - ID string + ModuleID string + OrgID string Version string Platform string } func (c *viamClient) downloadModuleAction(ctx *cli.Context, flags downloadModuleFlags) (string, error) { - moduleID := flags.ID + moduleID := flags.ModuleID if moduleID == "" { manifest, err := loadManifest(defaultManifestFilename) if err != nil { @@ -1074,46 +1075,59 @@ func (c *viamClient) downloadModuleAction(ctx *cli.Context, flags downloadModule } moduleID = manifest.ModuleID } + req := &apppb.GetModuleRequest{ModuleId: moduleID} res, err := c.client.GetModule(ctx.Context, req) if err != nil { return "", err } - if len(res.Module.Versions) == 0 { - return "", errors.New("module has 0 uploaded versions, nothing to download") - } requestedVersion := flags.Version - var ver *apppb.VersionHistory - if requestedVersion == "latest" { - ver = res.Module.Versions[len(res.Module.Versions)-1] - } else { - for _, iVer := range res.Module.Versions { - if iVer.Version == requestedVersion { - ver = iVer - break + platform := flags.Platform + + // if not reload version, validate module versions + var fullVersion string + var packageID string + if !IsReloadVersion(requestedVersion) { + if len(res.Module.Versions) == 0 { + return "", errors.New("module has 0 uploaded versions, nothing to download") + } + + var ver *apppb.VersionHistory + if requestedVersion == "latest" { + ver = res.Module.Versions[len(res.Module.Versions)-1] + } else { + for _, iVer := range res.Module.Versions { + if iVer.Version == requestedVersion { + ver = iVer + break + } + } + if ver == nil { + return "", fmt.Errorf("version %s not found in versions for module", requestedVersion) } } - if ver == nil { - return "", fmt.Errorf("version %s not found in versions for module", requestedVersion) + if len(ver.Files) == 0 { + return "", fmt.Errorf("version %s has 0 files uploaded", ver.Version) } - } - if len(ver.Files) == 0 { - return "", fmt.Errorf("version %s has 0 files uploaded", ver.Version) - } - platform := flags.Platform - if platform == "" { - platform = fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH) - infof(ctx.App.ErrWriter, "using default platform %s", platform) - } - if !slices.ContainsFunc(ver.Files, func(file *apppb.Uploads) bool { return file.Platform == platform }) { - return "", fmt.Errorf("platform %s not present for version %s", platform, ver.Version) + + if platform == "" { + platform = fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH) + infof(ctx.App.ErrWriter, "using default platform %s", platform) + } + if !slices.ContainsFunc(ver.Files, func(file *apppb.Uploads) bool { return file.Platform == platform }) { + return "", fmt.Errorf("platform %s not present for version %s", platform, ver.Version) + } + fullVersion = fmt.Sprintf("%s-%s", ver.Version, strings.ReplaceAll(platform, "/", "-")) + packageID = strings.ReplaceAll(moduleID, ":", "/") + } else { + fullVersion = fmt.Sprintf("%s-%s", requestedVersion, strings.ReplaceAll(platform, "/", "-")) + packageID = fmt.Sprintf("%s/%s", flags.OrgID, res.Module.Name) } include := true packageType := packagespb.PackageType_PACKAGE_TYPE_MODULE // note: this is working around a GetPackage quirk where platform messes with version - fullVersion := fmt.Sprintf("%s-%s", ver.Version, strings.ReplaceAll(platform, "/", "-")) pkg, err := c.packageClient.GetPackage(ctx.Context, &packagespb.GetPackageRequest{ - Id: strings.ReplaceAll(moduleID, ":", "/"), + Id: packageID, Version: fullVersion, IncludeUrl: &include, Type: &packageType,