Skip to content

Commit

Permalink
add rsync flag option to copy files using rsync
Browse files Browse the repository at this point in the history
Signed-off-by: olalekan odukoya <[email protected]>
  • Loading branch information
olamilekan000 committed Jan 26, 2025
1 parent ccd3c0c commit a9bade5
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 22 deletions.
146 changes: 124 additions & 22 deletions cmd/limactl/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ Prefix guest filenames with the instance name and a colon.
Example: limactl copy default:/etc/os-release .
`

type copyTool string

const (
rsync copyTool = "rsync"
scp copyTool = "scp"
)

func newCopyCommand() *cobra.Command {
copyCommand := &cobra.Command{
Use: "copy SOURCE ... TARGET",
Expand Down Expand Up @@ -49,13 +56,6 @@ func copyAction(cmd *cobra.Command, args []string) error {
return err
}

arg0, err := exec.LookPath("scp")
if err != nil {
return err
}
instances := make(map[string]*store.Instance)
scpFlags := []string{}
scpArgs := []string{}
debug, err := cmd.Flags().GetBool("debug")
if err != nil {
return err
Expand All @@ -65,6 +65,48 @@ func copyAction(cmd *cobra.Command, args []string) error {
verbose = true
}

cpTool := rsync
arg0, err := exec.LookPath(string(cpTool))
if err != nil {
arg0, err = exec.LookPath(string(cpTool))
if err != nil {
return err
}
}
logrus.Infof("using copy tool %q", arg0)

var sshArgs, toolArgs []string

switch cpTool {
case scp:
sshArgs, toolArgs, err = useScp(args, verbose, recursive)
if err != nil {
return err
}
case rsync:
toolArgs, err = useRsync(args, verbose, recursive)
if err != nil {
return err
}
default:
return fmt.Errorf("invalid copy tool %q", cpTool)
}

sshCmd := exec.Command(arg0, append(sshArgs, toolArgs...)...)
sshCmd.Stdin = cmd.InOrStdin()
sshCmd.Stdout = cmd.OutOrStdout()
sshCmd.Stderr = cmd.ErrOrStderr()
logrus.Debugf("executing scp (may take a long time): %+v", sshCmd.Args)

// TODO: use syscall.Exec directly (results in losing tty?)
return sshCmd.Run()
}

func useScp(args []string, verbose, recursive bool) (sshArgs, scpArgs []string, err error) {
instances := make(map[string]*store.Instance)

scpFlags := []string{}

if verbose {
scpFlags = append(scpFlags, "-v")
} else {
Expand All @@ -74,6 +116,7 @@ func copyAction(cmd *cobra.Command, args []string) error {
if recursive {
scpFlags = append(scpFlags, "-r")
}

// this assumes that ssh and scp come from the same place, but scp has no -V
legacySSH := sshutil.DetectOpenSSHVersion("ssh").LessThan(*semver.New("8.0.0"))
for _, arg := range args {
Expand All @@ -86,12 +129,12 @@ func copyAction(cmd *cobra.Command, args []string) error {
inst, err := store.Inspect(instName)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("instance %q does not exist, run `limactl create %s` to create a new instance", instName, instName)
return nil, nil, fmt.Errorf("instance %q does not exist, run `limactl create %s` to create a new instance", instName, instName)
}
return err
return nil, nil, err
}
if inst.Status == store.StatusStopped {
return fmt.Errorf("instance %q is stopped, run `limactl start %s` to start the instance", instName, instName)
return nil, nil, fmt.Errorf("instance %q is stopped, run `limactl start %s` to start the instance", instName, instName)
}
if legacySSH {
scpFlags = append(scpFlags, "-P", fmt.Sprintf("%d", inst.SSHLocalPort))
Expand All @@ -101,11 +144,11 @@ func copyAction(cmd *cobra.Command, args []string) error {
}
instances[instName] = inst
default:
return fmt.Errorf("path %q contains multiple colons", arg)
return nil, nil, fmt.Errorf("path %q contains multiple colons", arg)
}
}
if legacySSH && len(instances) > 1 {
return errors.New("more than one (instance) host is involved in this command, this is only supported for openSSH v8.0 or higher")
return nil, nil, errors.New("more than one (instance) host is involved in this command, this is only supported for openSSH v8.0 or higher")
}
scpFlags = append(scpFlags, "-3", "--")
scpArgs = append(scpFlags, scpArgs...)
Expand All @@ -118,24 +161,83 @@ func copyAction(cmd *cobra.Command, args []string) error {
for _, inst := range instances {
sshOpts, err = sshutil.SSHOpts("ssh", inst.Dir, *inst.Config.User.Name, false, false, false, false)
if err != nil {
return err
return nil, nil, err
}
}
} else {
// Copying among multiple hosts; we can't pass in host-specific options.
sshOpts, err = sshutil.CommonOpts("ssh", false)
if err != nil {
return err
return nil, nil, err
}
}
sshArgs := sshutil.SSHArgsFromOpts(sshOpts)
sshArgs = sshutil.SSHArgsFromOpts(sshOpts)

sshCmd := exec.Command(arg0, append(sshArgs, scpArgs...)...)
sshCmd.Stdin = cmd.InOrStdin()
sshCmd.Stdout = cmd.OutOrStdout()
sshCmd.Stderr = cmd.ErrOrStderr()
logrus.Debugf("executing scp (may take a long time): %+v", sshCmd.Args)
return sshArgs, scpArgs, nil
}

// TODO: use syscall.Exec directly (results in losing tty?)
return sshCmd.Run()
func useRsync(args []string, verbose, recursive bool) ([]string, error) {
instances := make(map[string]*store.Instance)

var instName string

rsyncFlags := []string{}
rsyncArgs := []string{}

if verbose {
rsyncFlags = append(rsyncFlags, "-v", "--progress")
} else {
rsyncFlags = append(rsyncFlags, "-q")
}

if recursive {
rsyncFlags = append(rsyncFlags, "-r")
}

for _, arg := range args {
path := strings.Split(arg, ":")
switch len(path) {
case 1:
inst, ok := instances[instName]
if !ok {
return nil, fmt.Errorf("instance %q does not exist, run `limactl create %s` to create a new instance", instName, instName)
}
guestVM := fmt.Sprintf("%[email protected]:%s", *inst.Config.User.Name, path[0])
rsyncArgs = append(rsyncArgs, guestVM)
case 2:
instName = path[0]
inst, err := store.Inspect(instName)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("instance %q does not exist, run `limactl create %s` to create a new instance", instName, instName)
}
return nil, err
}
sshOpts, err := sshutil.SSHOpts("ssh", inst.Dir, *inst.Config.User.Name, false, false, false, false)
if err != nil {
return nil, err
}

sshStr := fmt.Sprintf("ssh -p %s -i %s", fmt.Sprintf("%d", inst.SSHLocalPort), extractSSHOptionField(sshOpts, "IdentityFile"))
rsyncArgs = append(rsyncArgs, "-avz", "-e", sshStr, path[1])
instances[instName] = inst
default:
return nil, fmt.Errorf("path %q contains multiple colons", arg)
}
}

rsyncArgs = append(rsyncFlags, rsyncArgs...)

return rsyncArgs, nil
}

func extractSSHOptionField(sshOpts []string, optName string) string {
for _, opt := range sshOpts {
optField := fmt.Sprintf("%s=", optName)
if strings.HasPrefix(opt, optField) {
identityFile := strings.TrimPrefix(opt, optField)
return strings.Trim(identityFile, `"`)
}
}
return ""
}
40 changes: 40 additions & 0 deletions pkg/hostagent/hostagent.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,41 @@ func (a *HostAgent) Info(_ context.Context) (*hostagentapi.Info, error) {
return info, nil
}

func (a *HostAgent) installPackage() error {
logrus.Debugf("installing packages")

faScript := `#!/bin/bash
if ! output=$(type rsync 2>&1); then
echo "rsync is not installed. Attempting to install..."
# Try to install rsync based on the OS
if [ -f /etc/debian_version ]; then
sudo apt-get update && sudo apt-get install -y rsync
elif [ -f /etc/alpine-release ]; then
sudo apk add rsync
elif [ -f /etc/redhat-release ]; then
sudo yum install -y rsync
elif [ -f /etc/arch-release ]; then
sudo pacman -S --noconfirm rsync
else
echo "Unsupported Linux distribution. Please install rsync manually."
fi
echo "rsync installation complete."
else
echo "rsync is already installed."
fi`
faDesc := "installing rsync"
stdout, stderr, err := ssh.ExecuteScript(a.instSSHAddress, a.sshLocalPort, a.sshConfig, faScript, faDesc)
logrus.Debugf("stdout=%q, stderr=%q, err=%v", stdout, stderr, err)
if err != nil {
err = fmt.Errorf("stdout=%q, stderr=%q: %w", stdout, stderr, err)
return err
}

return nil
}

func (a *HostAgent) startHostAgentRoutines(ctx context.Context) error {
if *a.instConfig.Plain {
logrus.Info("Running in plain mode. Mounts, port forwarding, containerd, etc. will be ignored. Guest agent will not be running.")
Expand All @@ -439,6 +474,11 @@ func (a *HostAgent) startHostAgentRoutines(ctx context.Context) error {
if err := a.waitForRequirements("essential", a.essentialRequirements()); err != nil {
errs = append(errs, err)
}

if err := a.installPackage(); err != nil {
errs = append(errs, err)
}

if *a.instConfig.SSH.ForwardAgent {
faScript := `#!/bin/bash
set -eux -o pipefail
Expand Down

0 comments on commit a9bade5

Please sign in to comment.