Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Stephanie You committed Nov 30, 2023
1 parent 3e5c6b6 commit 5d24edd
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 89 deletions.
31 changes: 11 additions & 20 deletions go/cmd/dolt/commands/reflog.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,27 +107,17 @@ func (cmd ReflogCmd) Exec(ctx context.Context, commandStr string, args []string,
// Also interpolates this query to prevent sql injection
func constructInterpolatedDoltReflogQuery(apr *argparser.ArgParseResults) (string, error) {
var params []interface{}
refPlaceholder := ""
allFlag := ""
var args []string

if apr.NArg() == 1 {
params = append(params, apr.Arg(0))
refPlaceholder = "?"
args = append(args, "?")
}
if apr.Contains(cli.AllFlag) {
allFlag = "'--all'"
args = append(args, "'--all'")
}

args := ""
if refPlaceholder == "" && allFlag != "" {
args = allFlag
} else if refPlaceholder != "" && allFlag == "" {
args = refPlaceholder
} else if refPlaceholder != "" && allFlag != "" {
args = strings.Join([]string{refPlaceholder, allFlag}, ", ")
}

query := strings.Join([]string{"SELECT ref, commit_hash, commit_message FROM DOLT_REFLOG(", args, ")"}, "")
query := fmt.Sprintf("SELECT ref, commit_hash, commit_message FROM DOLT_REFLOG(%s)", strings.Join(args, ", "))
interpolatedQuery, err := dbr.InterpolateForDialect(query, params, dialect.MySQL)
if err != nil {
return "", err
Expand All @@ -147,11 +137,12 @@ func printReflog(rows []sql.Row, queryist cli.Queryist, sqlCtx *sql.Context) int
var reflogInfo []ReflogInfo

// Get the current branch
curBranch := ""
res, err := GetRowsForSql(queryist, sqlCtx, "SELECT active_branch()")
if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), nil)
if err == nil {
// still print the reflog even if we can't get the current branch
curBranch = res[0][0].(string)
}
curBranch := res[0][0].(string)

for _, row := range rows {
ref := row[0].(string)
Expand All @@ -174,13 +165,13 @@ func reflogToStdOut(reflogInfo []ReflogInfo, curBranch string) {
pager := outputpager.Start()
defer pager.Stop()

for pos, info := range reflogInfo {
for _, info := range reflogInfo {
// TODO: use short hash instead
line := []string{fmt.Sprintf("\033[33m%s\033[0m", info.commitHash)} // commit hash in yellow (33m)

processedRef := processRefForReflog(info.ref, curBranch)
line = append(line, fmt.Sprintf("\033[33m(%s\033[33m)\033[0m", processedRef)) // () in yellow (33m)
line = append(line, fmt.Sprintf("HEAD@{%d}: %s\n", pos, info.commitMessage))
line = append(line, fmt.Sprintf("%s\n", info.commitMessage))
pager.Writer.Write([]byte(strings.Join(line, " ")))
}
})
Expand All @@ -190,7 +181,7 @@ func reflogToStdOut(reflogInfo []ReflogInfo, curBranch string) {
func processRefForReflog(fullRef string, curBranch string) string {
if strings.HasPrefix(fullRef, "refs/heads/") {
branch := strings.TrimPrefix(fullRef, "refs/heads/")
if branch == curBranch {
if curBranch != "" && branch == curBranch {
return fmt.Sprintf("\033[36;1mHEAD -> \033[32;1m%s\033[0m", branch) // HEAD in cyan (36;1), branch in green (32;1m)
}
return fmt.Sprintf("\033[32;1m%s\033[0m", branch) // branch in green (32;1m)
Expand Down
6 changes: 3 additions & 3 deletions go/libraries/doltcore/sqle/enginetest/dolt_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -4317,15 +4317,15 @@ var DoltReflogTestScripts = []queries.ScriptTest{
Assertions: []queries.ScriptTestAssertion{
{
Query: "select * from dolt_reflog('foo', 'bar');",
ExpectedErrStr: "error: reflog has too many positional arguments. Expected at most 1, found 2: foo, bar",
ExpectedErrStr: "error: dolt_reflog has too many positional arguments. Expected at most 1, found 2: ['foo' 'bar']",
},
{
Query: "select * from dolt_reflog(NULL);",
ExpectedErrStr: "Invalid argument to dolt_reflog: NULL",
ExpectedErrStr: "argument (<nil>) is not a string value, but a <nil>",
},
{
Query: "select * from dolt_reflog(-100);",
ExpectedErrStr: "Invalid argument to dolt_reflog: -100",
ExpectedErrStr: "argument (-100) is not a string value, but a int8",
},
},
},
Expand Down
67 changes: 48 additions & 19 deletions go/libraries/doltcore/sqle/reflog_table_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"

"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/store/hash"
Expand All @@ -32,7 +31,7 @@ import (
type ReflogTableFunction struct {
ctx *sql.Context
database sql.Database
refName string
refExpr sql.Expression
showAll bool
}

Expand Down Expand Up @@ -66,6 +65,20 @@ func (rltf *ReflogTableFunction) RowIter(ctx *sql.Context, row sql.Row) (sql.Row
return nil, fmt.Errorf("unexpected database type: %T", rltf.database)
}

var refName string
if rltf.refExpr != nil {
target, err := rltf.refExpr.Eval(ctx, row)
if err != nil {
return nil, fmt.Errorf("error evaluating expression (%s): %s",
rltf.refExpr.String(), err.Error())
}

refName, ok = target.(string)
if !ok {
return nil, fmt.Errorf("argument (%v) is not a string value, but a %T", target, target)
}
}

ddb := sqlDb.DbData().Ddb
journal := ddb.ChunkJournal()
if journal == nil {
Expand Down Expand Up @@ -105,15 +118,15 @@ func (rltf *ReflogTableFunction) RowIter(ctx *sql.Context, row sql.Row) (sql.Row
}

// If a ref expression to filter on was specified, see if we match the current ref
if rltf.refName != "" {
if refName != "" {
// If the caller has supplied a branch or tag name, without the fully qualified ref path,
// take the first match and use that as the canonical ref to filter on
if strings.HasSuffix(strings.ToLower(id), "/"+strings.ToLower(rltf.refName)) {
rltf.refName = id
if strings.HasSuffix(strings.ToLower(id), "/"+strings.ToLower(refName)) {
refName = id
}

// Skip refs that don't match the target we're looking for
if strings.ToLower(id) != strings.ToLower(rltf.refName) {
if strings.ToLower(id) != strings.ToLower(refName) {
return nil
}
}
Expand Down Expand Up @@ -166,11 +179,21 @@ func (rltf *ReflogTableFunction) Schema() sql.Schema {
}

func (rltf *ReflogTableFunction) Resolved() bool {
if rltf.refExpr != nil {
return rltf.refExpr.Resolved()
}
return true
}

func (rltf *ReflogTableFunction) String() string {
return fmt.Sprintf("DOLT_REFLOG(%s)", rltf.refName)
var args []string
if rltf.showAll {
args = append(args, "'--all'")
}
if rltf.refExpr != nil {
args = append(args, rltf.refExpr.String())
}
return fmt.Sprintf("DOLT_REFLOG(%s)", strings.Join(args, ", "))
}

func (rltf *ReflogTableFunction) Children() []sql.Node {
Expand All @@ -195,6 +218,9 @@ func (rltf *ReflogTableFunction) IsReadOnly() bool {
}

func (rltf *ReflogTableFunction) Expressions() []sql.Expression {
if rltf.refExpr != nil {
return []sql.Expression{rltf.refExpr}
}
return []sql.Expression{}
}

Expand All @@ -204,20 +230,23 @@ func (rltf *ReflogTableFunction) WithExpressions(expression ...sql.Expression) (
}

new := *rltf
args, err := getDoltArgs(rltf.ctx, expression, rltf.Name())
if err != nil {
return nil, err
}
apr, err := cli.CreateReflogArgParser().Parse(args)
if err != nil {
return nil, err
}
if apr.NArg() > 0 {
new.refName = apr.Arg(0)

if len(expression) == 2 {
if expression[0].String() == "'--all'" && expression[1].String() == "'--all'" {
return nil, fmt.Errorf("error: multiple values provided for `all`")
}
if expression[0].String() != "'--all'" && expression[1].String() != "'--all'" {
return nil, fmt.Errorf("error: %s has too many positional arguments. Expected at most %d, found %d: %s", rltf.Name(), 1, 2, expression)
}
}
if apr.Contains(cli.AllFlag) {
new.showAll = true
for _, expr := range expression {
if expr.String() != "'--all'" {
new.refExpr = expr
} else {
new.showAll = true
}
}

return &new, nil
}

Expand Down
1 change: 0 additions & 1 deletion integration-tests/bats/helper/local-remote.bash
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ SKIP_SERVER_TESTS=$(cat <<-EOM
~profile.bats~
~ls.bats~
~reflog.bats~
~sql-reflog.bats~
EOM
)

Expand Down
Loading

0 comments on commit 5d24edd

Please sign in to comment.