diff --git a/go.mod b/go.mod index 13718eb..cc27737 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ toolchain go1.25.1 require ( github.com/charmbracelet/glamour v0.10.0 github.com/joho/godotenv v1.5.1 + github.com/sergi/go-diff v1.4.0 github.com/stretchr/testify v1.11.1 github.com/xeipuuv/gojsonschema v1.2.0 ) diff --git a/go.sum b/go.sum index 7233122..474c1e0 100644 --- a/go.sum +++ b/go.sum @@ -37,6 +37,11 @@ github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUq github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= @@ -56,8 +61,11 @@ github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw= +github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c= @@ -84,7 +92,10 @@ golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/core/runtime/apply_patch_command.go b/internal/core/runtime/apply_patch_command.go new file mode 100644 index 0000000..6c23b26 --- /dev/null +++ b/internal/core/runtime/apply_patch_command.go @@ -0,0 +1,665 @@ +package runtime + +import ( + "context" + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "sort" + "strings" + "unicode" + + "github.com/sergi/go-diff/diffmatchpatch" +) + +const applyPatchUsage = "Usage: apply_patch [--respect-whitespace]\n\n" + + "Reads a *** Begin Patch block from the command string and applies it to the workspace.\n" + + " --ignore-whitespace, -w Match hunks without considering whitespace differences (default).\n" + + " --respect-whitespace, -W Require whitespace to match before applying hunks.\n" + +type applyPatchOptions struct { + ignoreWhitespace bool +} + +type patchOperation struct { + typeLabel string + path string + hunks []patchHunk +} + +type patchHunk struct { + header string + before []string + after []string + rawPatchLines []string +} + +type hunkStatus struct { + number int + status string +} + +type applyResult struct { + status string + path string +} + +type fileState struct { + path string + relativePath string + lines []string + normalizedLines []string + originalContent string + originalEndsWithNew *bool + touched bool + cursor int + hunkStatuses []hunkStatus + isNew bool + options applyPatchOptions +} + +type hunkApplyError struct { + err error + relativePath string + original string + hunkStatuses []hunkStatus + failedHunk patchHunk +} + +func (e *hunkApplyError) Error() string { + if e == nil { + return "" + } + return e.err.Error() +} + +func (e *hunkApplyError) Unwrap() error { + if e == nil { + return nil + } + return e.err +} + +func newApplyPatchCommand() InternalCommandHandler { + return func(ctx context.Context, req InternalCommandRequest) (PlanObservationPayload, error) { + opts, help, err := parseApplyPatchInvocation(req.Raw) + if err != nil { + return PlanObservationPayload{}, err + } + if help { + return PlanObservationPayload{Stdout: applyPatchUsage}, nil + } + + operations, err := parsePatchOperations(req.Raw) + if err != nil { + return PlanObservationPayload{}, err + } + if len(operations) == 0 { + return PlanObservationPayload{}, errors.New("apply_patch: no patch operations detected") + } + + baseDir, err := resolveCommandDirectory(req.Step.Command.Cwd) + if err != nil { + return PlanObservationPayload{}, err + } + + results, applyErr := applyPatchOperations(ctx, baseDir, operations, opts) + if applyErr != nil { + formatted := formatApplyPatchError(applyErr) + payload := PlanObservationPayload{Stderr: formatted, Details: formatted} + code := 1 + payload.ExitCode = &code + return payload, applyErr + } + + if len(results) == 0 { + zero := 0 + return PlanObservationPayload{Stdout: "No changes applied.", ExitCode: &zero}, nil + } + + sort.Slice(results, func(i, j int) bool { + return results[i].path < results[j].path + }) + lines := []string{"Success. Updated the following files:"} + for _, res := range results { + lines = append(lines, fmt.Sprintf("%s %s", res.status, res.path)) + } + zero := 0 + return PlanObservationPayload{Stdout: strings.Join(lines, "\n"), ExitCode: &zero}, nil + } +} + +func parseApplyPatchInvocation(raw string) (applyPatchOptions, bool, error) { + trimmed := strings.TrimSpace(raw) + line := trimmed + if idx := strings.Index(trimmed, "\n"); idx >= 0 { + line = trimmed[:idx] + } + tokens := strings.Fields(line) + opts := applyPatchOptions{ignoreWhitespace: true} + if len(tokens) == 0 { + return opts, false, errors.New("apply_patch: command invocation missing") + } + for _, token := range tokens[1:] { + if strings.HasPrefix(token, "<<") { + break + } + switch token { + case "--ignore-whitespace", "-w": + opts.ignoreWhitespace = true + case "--respect-whitespace", "-W", "--no-ignore-whitespace": + opts.ignoreWhitespace = false + case "--help", "-h": + return opts, true, nil + default: + return opts, false, fmt.Errorf("apply_patch: unknown option: %s", token) + } + } + return opts, false, nil +} + +func parsePatchOperations(raw string) ([]patchOperation, error) { + if !strings.Contains(raw, "*** Begin Patch") { + return nil, errors.New("apply_patch: no patch provided via stdin") + } + lines := strings.Split(strings.ReplaceAll(raw, "\r\n", "\n"), "\n") + var operations []patchOperation + var currentOp *patchOperation + var currentHunk *patchHunk + inside := false + + flushHunk := func() error { + if currentHunk == nil { + return nil + } + if currentOp == nil { + return errors.New("apply_patch: encountered hunk without active operation") + } + parsed, err := parseHunk(currentHunk, currentOp.path) + if err != nil { + return err + } + currentOp.hunks = append(currentOp.hunks, parsed) + currentHunk = nil + return nil + } + + flushOp := func() error { + if currentOp == nil { + return nil + } + if err := flushHunk(); err != nil { + return err + } + if len(currentOp.hunks) == 0 { + return fmt.Errorf("apply_patch: no hunks provided for %s", currentOp.path) + } + operations = append(operations, *currentOp) + currentOp = nil + return nil + } + + for _, line := range lines { + switch { + case line == "*** Begin Patch": + inside = true + continue + case line == "*** End Patch": + if inside { + if err := flushOp(); err != nil { + return nil, err + } + } + inside = false + continue + case !inside: + continue + case strings.HasPrefix(line, "*** "): + if err := flushOp(); err != nil { + return nil, err + } + switch { + case strings.HasPrefix(line, "*** Update File:"): + path := strings.TrimSpace(strings.TrimPrefix(line, "*** Update File:")) + currentOp = &patchOperation{typeLabel: "update", path: path} + case strings.HasPrefix(line, "*** Add File:"): + path := strings.TrimSpace(strings.TrimPrefix(line, "*** Add File:")) + currentOp = &patchOperation{typeLabel: "add", path: path} + default: + return nil, fmt.Errorf("apply_patch: unsupported patch directive: %s", line) + } + continue + case strings.HasPrefix(line, "@@"): + if err := flushHunk(); err != nil { + return nil, err + } + currentHunk = &patchHunk{header: line} + continue + } + + if currentOp == nil { + if strings.TrimSpace(line) == "" { + continue + } + return nil, fmt.Errorf("apply_patch: diff content appeared before a file directive: %q", line) + } + + if currentHunk == nil { + currentHunk = &patchHunk{} + } + currentHunk.rawPatchLines = append(currentHunk.rawPatchLines, line) + } + + if inside { + return nil, errors.New("apply_patch: missing *** End Patch terminator") + } + if err := flushOp(); err != nil { + return nil, err + } + return operations, nil +} + +func parseHunk(raw *patchHunk, path string) (patchHunk, error) { + if raw == nil { + return patchHunk{}, errors.New("apply_patch: missing hunk data") + } + var before []string + var after []string + var rawLines []string + if raw.header != "" { + rawLines = append(rawLines, raw.header) + } + for _, line := range raw.rawPatchLines { + switch { + case strings.HasPrefix(line, "+"): + after = append(after, strings.TrimPrefix(line, "+")) + case strings.HasPrefix(line, "-"): + before = append(before, strings.TrimPrefix(line, "-")) + case strings.HasPrefix(line, " "): + value := strings.TrimPrefix(line, " ") + before = append(before, value) + after = append(after, value) + case line == "\\ No newline at end of file": + // ignore marker + continue + default: + return patchHunk{}, fmt.Errorf("apply_patch: unsupported hunk line in %s: %q", path, line) + } + rawLines = append(rawLines, line) + } + return patchHunk{header: raw.header, before: before, after: after, rawPatchLines: rawLines}, nil +} + +func resolveCommandDirectory(cwd string) (string, error) { + if strings.TrimSpace(cwd) == "" { + return os.Getwd() + } + if filepath.IsAbs(cwd) { + return cwd, nil + } + base, err := os.Getwd() + if err != nil { + return "", err + } + return filepath.Join(base, cwd), nil +} + +func applyPatchOperations(ctx context.Context, baseDir string, operations []patchOperation, opts applyPatchOptions) ([]applyResult, error) { + fileStates := make(map[string]*fileState) + for _, op := range operations { + if err := ctx.Err(); err != nil { + return nil, err + } + state, err := ensureFileState(baseDir, op, opts, fileStates) + if err != nil { + return nil, err + } + state.cursor = 0 + state.touched = false + state.hunkStatuses = state.hunkStatuses[:0] + for idx, hunk := range op.hunks { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + if err := applyHunk(state, hunk); err != nil { + return nil, enhanceHunkError(err, state, hunk, idx+1) + } + state.hunkStatuses = append(state.hunkStatuses, hunkStatus{number: idx + 1, status: "applied"}) + state.touched = true + } + } + + var results []applyResult + for _, state := range fileStates { + if !state.touched { + continue + } + if err := writeFileState(state); err != nil { + return nil, err + } + status := "M" + if state.isNew { + status = "A" + } + results = append(results, applyResult{status: status, path: filepath.ToSlash(state.relativePath)}) + } + return results, nil +} + +func ensureFileState(baseDir string, op patchOperation, opts applyPatchOptions, cache map[string]*fileState) (*fileState, error) { + if strings.TrimSpace(op.path) == "" { + return nil, errors.New("apply_patch: file path missing in patch operation") + } + absPath := op.path + if !filepath.IsAbs(absPath) { + absPath = filepath.Join(baseDir, op.path) + } + absPath = filepath.Clean(absPath) + if state, ok := cache[absPath]; ok { + state.options = opts + if opts.ignoreWhitespace { + state.normalizedLines = normalizeLines(state.lines, opts) + } else { + state.normalizedLines = nil + } + return state, nil + } + info, err := os.Stat(absPath) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + if op.typeLabel != "add" { + return nil, fmt.Errorf("apply_patch: failed to read %s: %v", op.path, err) + } + state := &fileState{ + path: absPath, + relativePath: op.path, + lines: []string{}, + normalizedLines: normalizeLines(nil, opts), + originalContent: "", + originalEndsWithNew: nil, + cursor: 0, + hunkStatuses: nil, + isNew: true, + options: opts, + } + cache[absPath] = state + return state, nil + } + return nil, fmt.Errorf("apply_patch: failed to stat %s: %v", op.path, err) + } + if info.IsDir() { + return nil, fmt.Errorf("apply_patch: %s is a directory", op.path) + } + if op.typeLabel == "add" { + return nil, fmt.Errorf("apply_patch: cannot add %s because it already exists", op.path) + } + content, err := os.ReadFile(absPath) + if err != nil { + return nil, fmt.Errorf("apply_patch: failed to read %s: %v", op.path, err) + } + normalized := strings.ReplaceAll(string(content), "\r\n", "\n") + lines := strings.Split(normalized, "\n") + state := &fileState{ + path: absPath, + relativePath: op.path, + lines: lines, + normalizedLines: nil, + originalContent: string(content), + originalEndsWithNew: ptrBool(strings.HasSuffix(normalized, "\n")), + cursor: 0, + hunkStatuses: nil, + options: opts, + } + if opts.ignoreWhitespace { + state.normalizedLines = normalizeLines(lines, opts) + } + cache[absPath] = state + return state, nil +} + +func writeFileState(state *fileState) error { + if state == nil { + return nil + } + content := strings.Join(state.lines, "\n") + if state.originalEndsWithNew != nil { + if *state.originalEndsWithNew && !strings.HasSuffix(content, "\n") { + content += "\n" + } else if !*state.originalEndsWithNew && strings.HasSuffix(content, "\n") { + content = strings.TrimSuffix(content, "\n") + } + } + if err := os.MkdirAll(filepath.Dir(state.path), 0o755); err != nil { + return fmt.Errorf("apply_patch: failed to create parent directory for %s: %v", state.relativePath, err) + } + if err := os.WriteFile(state.path, []byte(content), 0o666); err != nil { + return fmt.Errorf("apply_patch: failed to write %s: %v", state.relativePath, err) + } + return nil +} + +func applyHunk(state *fileState, hunk patchHunk) error { + if len(hunk.before) == 0 { + insertion := len(state.lines) + if insertion > 0 && state.lines[insertion-1] == "" { + insertion-- + } + state.lines = splice(state.lines, insertion, 0, hunk.after) + updateNormalizedLines(state, insertion, 0, hunk.after) + state.cursor = insertion + len(hunk.after) + return nil + } + + matchIndex := findSubsequence(state.lines, hunk.before, state.cursor) + if matchIndex == -1 { + matchIndex = findSubsequence(state.lines, hunk.before, 0) + } + + if matchIndex == -1 && state.options.ignoreWhitespace { + normalizedBefore := normalizeLines(hunk.before, state.options) + if state.normalizedLines != nil { + matchIndex = findSubsequence(state.normalizedLines, normalizedBefore, state.cursor) + if matchIndex == -1 { + matchIndex = findSubsequence(state.normalizedLines, normalizedBefore, 0) + } + } + if matchIndex == -1 { + matchIndex = findByDiffMatchPatch(state, normalizedBefore) + } + } + if matchIndex == -1 { + return fmt.Errorf("Hunk not found in %s.", state.relativePath) + } + + state.lines = splice(state.lines, matchIndex, len(hunk.before), hunk.after) + updateNormalizedLines(state, matchIndex, len(hunk.before), hunk.after) + state.cursor = matchIndex + len(hunk.after) + return nil +} + +func enhanceHunkError(err error, state *fileState, hunk patchHunk, number int) error { + if err == nil { + return nil + } + statuses := append([]hunkStatus{}, state.hunkStatuses...) + statuses = append(statuses, hunkStatus{number: number, status: "no-match"}) + original := state.originalContent + if original == "" { + original = strings.Join(state.lines, "\n") + } + return &hunkApplyError{ + err: err, + relativePath: state.relativePath, + original: original, + hunkStatuses: statuses, + failedHunk: hunk, + } +} + +func formatApplyPatchError(err error) string { + if err == nil { + return "Unknown error occurred." + } + var hunkErr *hunkApplyError + if errors.As(err, &hunkErr) { + lines := []string{hunkErr.err.Error()} + desc := describeHunkStatuses(hunkErr.hunkStatuses) + if desc != "" { + lines = append(lines, "", desc) + } + if len(hunkErr.failedHunk.rawPatchLines) > 0 { + lines = append(lines, "", "Offending hunk:") + lines = append(lines, strings.Join(hunkErr.failedHunk.rawPatchLines, "\n")) + } + relative := filepath.ToSlash(hunkErr.relativePath) + if !strings.HasPrefix(relative, "./") { + relative = "./" + relative + } + lines = append(lines, "", fmt.Sprintf("Full content of file: %s::::", relative), hunkErr.original) + return strings.Join(lines, "\n") + } + return err.Error() +} + +func describeHunkStatuses(statuses []hunkStatus) string { + if len(statuses) == 0 { + return "" + } + var applied []string + var failed *hunkStatus + for _, st := range statuses { + if st.status == "applied" { + applied = append(applied, fmt.Sprintf("%d", st.number)) + } else if failed == nil { + failed = &hunkStatus{number: st.number, status: st.status} + } + } + var lines []string + if len(applied) > 0 { + lines = append(lines, fmt.Sprintf("Hunks applied: %s.", strings.Join(applied, ", "))) + } + if failed != nil { + lines = append(lines, fmt.Sprintf("No match for hunk %d.", failed.number)) + } + return strings.Join(lines, "\n") +} + +func normalizeLines(lines []string, opts applyPatchOptions) []string { + if !opts.ignoreWhitespace { + return nil + } + normalized := make([]string, len(lines)) + for i, line := range lines { + normalized[i] = normalizeLine(line) + } + return normalized +} + +func normalizeLine(input string) string { + if input == "" { + return "" + } + runes := make([]rune, 0, len(input)) + for _, r := range input { + if unicode.IsSpace(r) { + continue + } + runes = append(runes, r) + } + return string(runes) +} + +func updateNormalizedLines(state *fileState, index, deleteCount int, replacement []string) { + if !state.options.ignoreWhitespace { + return + } + if state.normalizedLines == nil { + state.normalizedLines = normalizeLines(state.lines, state.options) + return + } + repl := normalizeLines(replacement, state.options) + state.normalizedLines = splice(state.normalizedLines, index, deleteCount, repl) +} + +func splice[T any](input []T, index, deleteCount int, replacement []T) []T { + if index < 0 { + index = 0 + } + if index > len(input) { + index = len(input) + } + if deleteCount < 0 { + deleteCount = 0 + } + end := index + deleteCount + if end > len(input) { + end = len(input) + } + result := append([]T{}, input[:index]...) + result = append(result, replacement...) + result = append(result, input[end:]...) + return result +} + +func findSubsequence(haystack, needle []string, start int) int { + if len(needle) == 0 { + return -1 + } + if start < 0 { + start = 0 + } + for i := start; i+len(needle) <= len(haystack); i++ { + match := true + for j := range needle { + if haystack[i+j] != needle[j] { + match = false + break + } + } + if match { + return i + } + } + return -1 +} + +func ptrBool(value bool) *bool { + return &value +} + +func findByDiffMatchPatch(state *fileState, normalizedBefore []string) int { + if state == nil || !state.options.ignoreWhitespace { + return -1 + } + if len(normalizedBefore) == 0 { + return -1 + } + normalized := state.normalizedLines + if normalized == nil { + normalized = normalizeLines(state.lines, state.options) + } + if len(normalized) == 0 { + return -1 + } + text := strings.Join(normalized, "\n") + pattern := strings.Join(normalizedBefore, "\n") + if text == "" || pattern == "" { + return -1 + } + matcher := diffmatchpatch.New() + idx := matcher.MatchMain(text, pattern, 0) + if idx == -1 { + return -1 + } + line := 0 + for i := 0; i < len(text) && i < idx; i++ { + if text[i] == '\n' { + line++ + } + } + return line +} diff --git a/internal/core/runtime/apply_patch_command_test.go b/internal/core/runtime/apply_patch_command_test.go new file mode 100644 index 0000000..5bff07c --- /dev/null +++ b/internal/core/runtime/apply_patch_command_test.go @@ -0,0 +1,147 @@ +package runtime + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestApplyPatchCommandIgnoreWhitespace(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + filePath := filepath.Join(dir, "sample.txt") + if err := os.WriteFile(filePath, []byte("value = 42\n"), 0o644); err != nil { + t.Fatalf("failed to seed file: %v", err) + } + + executor := NewCommandExecutor() + if err := registerDefaultInternalCommands(executor); err != nil { + t.Fatalf("failed to register default internal commands: %v", err) + } + + run := "apply_patch <<'PATCH'\n*** Begin Patch\n*** Update File: sample.txt\n@@\n-value=42\n+value=43\n*** End Patch\nPATCH" + step := PlanStep{ID: "apply", Command: CommandDraft{Shell: "agent", Run: run, Cwd: dir}} + payload, err := executor.Execute(context.Background(), step) + if err != nil { + t.Fatalf("apply_patch failed: %v", err) + } + if payload.ExitCode == nil || *payload.ExitCode != 0 { + t.Fatalf("expected exit code 0, got %v", payload.ExitCode) + } + if !strings.Contains(payload.Stdout, "Success. Updated the following files:") { + t.Fatalf("unexpected stdout: %q", payload.Stdout) + } + content, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("failed to read patched file: %v", err) + } + if string(content) != "value=43\n" { + t.Fatalf("unexpected file content: %q", string(content)) + } +} + +func TestApplyPatchCommandRespectWhitespaceFailure(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + filePath := filepath.Join(dir, "format.txt") + if err := os.WriteFile(filePath, []byte("value = 42\n"), 0o644); err != nil { + t.Fatalf("failed to seed file: %v", err) + } + + executor := NewCommandExecutor() + if err := registerDefaultInternalCommands(executor); err != nil { + t.Fatalf("failed to register default internal commands: %v", err) + } + + run := "apply_patch --respect-whitespace <<'PATCH'\n*** Begin Patch\n*** Update File: format.txt\n@@\n-value=42\n+value=43\n*** End Patch\nPATCH" + step := PlanStep{ID: "apply", Command: CommandDraft{Shell: "agent", Run: run, Cwd: dir}} + payload, err := executor.Execute(context.Background(), step) + if err == nil { + t.Fatalf("expected failure, got success with payload %+v", payload) + } + if payload.ExitCode == nil || *payload.ExitCode == 0 { + t.Fatalf("expected non-zero exit code, got %v", payload.ExitCode) + } + if !strings.Contains(payload.Stderr, "Hunk not found") { + t.Fatalf("expected stderr to reference missing hunk, got: %q", payload.Stderr) + } + content, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("failed to read file after failure: %v", err) + } + if string(content) != "value = 42\n" { + t.Fatalf("file should remain untouched, got %q", string(content)) + } +} + +func TestApplyPatchCommandAddFile(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + executor := NewCommandExecutor() + if err := registerDefaultInternalCommands(executor); err != nil { + t.Fatalf("failed to register default internal commands: %v", err) + } + + run := "apply_patch <<'PATCH'\n*** Begin Patch\n*** Add File: new.txt\n@@\n+hello world\n*** End Patch\nPATCH" + step := PlanStep{ID: "add", Command: CommandDraft{Shell: "agent", Run: run, Cwd: dir}} + payload, err := executor.Execute(context.Background(), step) + if err != nil { + t.Fatalf("apply_patch add failed: %v", err) + } + if payload.ExitCode == nil || *payload.ExitCode != 0 { + t.Fatalf("expected exit code 0, got %v", payload.ExitCode) + } + content, err := os.ReadFile(filepath.Join(dir, "new.txt")) + if err != nil { + t.Fatalf("failed to read added file: %v", err) + } + if string(content) != "hello world" { + t.Fatalf("unexpected file content: %q", string(content)) + } +} + +func TestApplyPatchCommandMissingPatch(t *testing.T) { + t.Parallel() + + executor := NewCommandExecutor() + if err := registerDefaultInternalCommands(executor); err != nil { + t.Fatalf("failed to register default internal commands: %v", err) + } + + run := "apply_patch" + step := PlanStep{ID: "noop", Command: CommandDraft{Shell: "agent", Run: run}} + _, err := executor.Execute(context.Background(), step) + if err == nil { + t.Fatalf("expected error for missing patch input") + } + if !strings.Contains(err.Error(), "no patch provided") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestApplyPatchCommandHelp(t *testing.T) { + t.Parallel() + + executor := NewCommandExecutor() + if err := registerDefaultInternalCommands(executor); err != nil { + t.Fatalf("failed to register default internal commands: %v", err) + } + + run := "apply_patch --help" + step := PlanStep{ID: "help", Command: CommandDraft{Shell: "agent", Run: run}} + payload, err := executor.Execute(context.Background(), step) + if err != nil { + t.Fatalf("help invocation should succeed: %v", err) + } + if payload.ExitCode == nil || *payload.ExitCode != 0 { + t.Fatalf("expected zero exit code, got %v", payload.ExitCode) + } + if !strings.Contains(payload.Stdout, "Usage: apply_patch") { + t.Fatalf("unexpected usage output: %q", payload.Stdout) + } +} diff --git a/internal/core/runtime/internal_commands.go b/internal/core/runtime/internal_commands.go new file mode 100644 index 0000000..d94d4d7 --- /dev/null +++ b/internal/core/runtime/internal_commands.go @@ -0,0 +1,10 @@ +package runtime + +// registerDefaultInternalCommands wires the built-in internal commands into a new executor. +// Additional commands supplied via RuntimeOptions can extend or override these defaults. +func registerDefaultInternalCommands(executor *CommandExecutor) error { + if executor == nil { + return nil + } + return executor.RegisterInternalCommand("apply_patch", newApplyPatchCommand()) +} diff --git a/internal/core/runtime/runtime.go b/internal/core/runtime/runtime.go index c62a91d..d079e20 100644 --- a/internal/core/runtime/runtime.go +++ b/internal/core/runtime/runtime.go @@ -71,6 +71,10 @@ func NewRuntime(options RuntimeOptions) (*Runtime, error) { contextBudget: ContextBudget{MaxTokens: options.MaxContextTokens, CompactWhenPercent: options.CompactWhenPercent}, } + if err := registerDefaultInternalCommands(rt.executor); err != nil { + return nil, err + } + for name, handler := range options.InternalCommands { if err := rt.executor.RegisterInternalCommand(name, handler); err != nil { return nil, err