diff --git a/cmd/av/sync.go b/cmd/av/sync.go index 4a9df2c0..9c217933 100644 --- a/cmd/av/sync.go +++ b/cmd/av/sync.go @@ -8,6 +8,7 @@ import ( "emperror.dev/errors" "github.com/aviator-co/av/internal/actions" + "github.com/aviator-co/av/internal/config" "github.com/aviator-co/av/internal/gh" "github.com/aviator-co/av/internal/gh/ghui" "github.com/aviator-co/av/internal/git" @@ -25,14 +26,15 @@ import ( ) var syncFlags struct { - All bool - RebaseToTrunk bool - Current bool - Abort bool - Continue bool - Skip bool - Push string - Prune string + All bool + RebaseToTrunk bool + Current bool + Abort bool + Continue bool + Skip bool + Push string + Prune string + FastForwardTrunk bool } var syncCmd = &cobra.Command{ @@ -84,6 +86,9 @@ but can still be synced explicitly. if cmd.Flags().Changed("parent") { return actions.ErrExitSilently{ExitCode: 1} } + if !cmd.Flags().Changed("ff-trunk") { + syncFlags.FastForwardTrunk = config.Av.Sync.FastForwardTrunk + } repo, err := getRepo(ctx) if err != nil { return err @@ -111,9 +116,10 @@ type savedSyncState struct { } type syncState struct { - TargetBranches []plumbing.ReferenceName - Prune string - Push string + TargetBranches []plumbing.ReferenceName + Prune string + Push string + FastForwardTrunk bool } type syncViewModel struct { @@ -289,6 +295,16 @@ func (vm *syncViewModel) initPruneBranches() tea.Cmd { vm.state.Prune, vm.state.TargetBranches, vm.restackState.InitialBranch, + vm.initFastForwardTrunk, + )) +} + +func (vm *syncViewModel) initFastForwardTrunk() tea.Cmd { + if !vm.state.FastForwardTrunk { + return tea.Quit + } + return vm.AddView(gitui.NewFastForwardTrunkModel( + vm.repo, func() tea.Cmd { return tea.Quit }, @@ -321,8 +337,9 @@ func (vm *syncViewModel) createState() (*savedSyncState, error) { state := savedSyncState{ RestackState: &sequencerui.RestackState{}, SyncState: &syncState{ - Push: syncFlags.Push, - Prune: syncFlags.Prune, + Push: syncFlags.Push, + Prune: syncFlags.Prune, + FastForwardTrunk: syncFlags.FastForwardTrunk, }, } status, err := vm.repo.Status(ctx) @@ -487,6 +504,10 @@ func init() { &syncFlags.RebaseToTrunk, "rebase-to-trunk", false, "rebase the branches to the latest trunk always", ) + syncCmd.Flags().BoolVar( + &syncFlags.FastForwardTrunk, "ff-trunk", false, + "fast-forward the local trunk branch to match the remote", + ) syncCmd.Flags().BoolVar( &syncFlags.Continue, "continue", false, diff --git a/internal/config/config.go b/internal/config/config.go index 47e77669..385ffe86 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -44,6 +44,12 @@ type PullRequest struct { WriteStack bool } +type Sync struct { + // If true, fast-forward the local trunk branch to match the remote + // tracking branch after syncing. + FastForwardTrunk bool +} + type Aviator struct { // The base URL of the Aviator API to use. // By default, this is https://aviator.co, but for on-prem installations @@ -58,6 +64,7 @@ var Av = struct { PullRequest PullRequest GitHub GitHub Aviator Aviator + Sync Sync AdditionalTrunkBranches []string Remote string }{ diff --git a/internal/git/gitui/ff_trunk.go b/internal/git/gitui/ff_trunk.go new file mode 100644 index 00000000..764456ee --- /dev/null +++ b/internal/git/gitui/ff_trunk.go @@ -0,0 +1,125 @@ +package gitui + +import ( + "context" + "fmt" + + "github.com/aviator-co/av/internal/git" + "github.com/aviator-co/av/internal/utils/colors" + tea "github.com/charmbracelet/bubbletea" +) + +// FastForwardTrunkModel is a Bubbletea model that fast-forward merges the local +// trunk branch (e.g., main/master) to match its remote tracking branch. +type FastForwardTrunkModel struct { + repo *git.Repo + onDone func() tea.Cmd + + done bool + skipped bool + diverged bool + trunk string +} + +type ffTrunkDone struct{} + +func NewFastForwardTrunkModel( + repo *git.Repo, + onDone func() tea.Cmd, +) *FastForwardTrunkModel { + return &FastForwardTrunkModel{ + repo: repo, + onDone: onDone, + } +} + +func (m *FastForwardTrunkModel) Init() tea.Cmd { + return m.run +} + +func (m *FastForwardTrunkModel) run() tea.Msg { + ctx := context.Background() + m.trunk = m.repo.DefaultBranch() + remote := m.repo.GetRemoteName() + + // Check if the local trunk branch exists. + exists, err := m.repo.DoesBranchExist(ctx, m.trunk) + if err != nil || !exists { + m.skipped = true + return ffTrunkDone{} + } + + // Check if the remote tracking branch exists. + remoteExists, err := m.repo.DoesRemoteBranchExist(ctx, m.trunk) + if err != nil || !remoteExists { + m.skipped = true + return ffTrunkDone{} + } + + // Check if we are currently on the trunk branch. If so, we need to use + // "git merge --ff-only" directly. Otherwise, we can use update-ref to + // update the local ref without checking it out. + currentBranch, err := m.repo.CurrentBranchName() + if err != nil { + // Detached HEAD or other issue; just try the update-ref approach. + currentBranch = "" + } + + remoteRef := fmt.Sprintf("%s/%s", remote, m.trunk) + + if currentBranch == m.trunk { + // We're on the trunk branch, use merge --ff-only. + _, err := m.repo.Run(ctx, &git.RunOpts{ + Args: []string{"merge", "--ff-only", remoteRef}, + ExitError: true, + }) + if err != nil { + m.diverged = true + return ffTrunkDone{} + } + } else { + // Not on the trunk branch. Verify that the remote is a fast-forward + // of the local branch using merge-base --is-ancestor, then update the + // ref. + _, err := m.repo.Run(ctx, &git.RunOpts{ + Args: []string{"merge-base", "--is-ancestor", fmt.Sprintf("refs/heads/%s", m.trunk), fmt.Sprintf("refs/remotes/%s/%s", remote, m.trunk)}, + ExitError: true, + }) + if err != nil { + m.diverged = true + return ffTrunkDone{} + } + _, err = m.repo.Run(ctx, &git.RunOpts{ + Args: []string{"update-ref", fmt.Sprintf("refs/heads/%s", m.trunk), fmt.Sprintf("refs/remotes/%s/%s", remote, m.trunk)}, + ExitError: true, + }) + if err != nil { + m.diverged = true + return ffTrunkDone{} + } + } + + m.done = true + return ffTrunkDone{} +} + +func (m *FastForwardTrunkModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg.(type) { + case ffTrunkDone: + return m, m.onDone() + } + return m, nil +} + +func (m *FastForwardTrunkModel) View() string { + if m.skipped { + return "" + } + if m.diverged { + return colors.ProgressStyle.Render(fmt.Sprintf(" Could not fast-forward %s (local branch has diverged from remote)", m.trunk)) + "\n" + } + if m.done { + return colors.SuccessStyle.Render(fmt.Sprintf("✓ Fast-forwarded %s to match remote", m.trunk)) + "\n" + } + return "" +}