Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions cmd/kontext/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ import (

var version = "dev"

var (
startLocal = run.StartLocal
startManaged = run.StartManaged
)

func main() {
root := &cobra.Command{
Use: "kontext",
Expand Down Expand Up @@ -73,12 +78,19 @@ func startCmd() *cobra.Command {
var (
agentName string
templateFile string
managed bool
verbose bool
)

cmd := &cobra.Command{
Use: "start [flags] [-- extra-agent-args...]",
Short: "Launch an agent with Kontext governance",
Short: "Launch an agent with Kontext runtime security",
Long: "Launch an agent with Kontext runtime security.\n\n" +
"By default, this starts a local-only runtime with no hosted login. " +
"Use --managed when you need hosted credentials, shared traces, and team governance.",
Example: " kontext start\n" +
" KONTEXT_MODE=enforce kontext start\n" +
" kontext start --managed",
RunE: func(cmd *cobra.Command, args []string) error {
if isInteractivePrompt() {
if latest := update.Available(version); latest != "" {
Expand All @@ -90,15 +102,24 @@ func startCmd() *cobra.Command {
} else {
update.CheckAsync(version)
}
if !managed && cmd.Flags().Changed("env-template") {
return errors.New("--env-template is only used with --managed sessions")
}
ctx := context.Background()
err := run.Start(ctx, run.Options{
opts := run.Options{
Agent: agentName,
TemplateFile: templateFile,
IssuerURL: auth.DefaultIssuerURL,
ClientID: auth.DefaultClientID,
Verbose: verbose,
Args: args,
})
}
var err error
if managed {
err = startManaged(ctx, opts)
} else {
err = startLocal(ctx, opts)
}
if exitErr, ok := err.(*run.AgentExitError); ok {
fmt.Fprintf(os.Stderr, "Error: %v\n", exitErr)
os.Exit(exitErr.ExitCode())
Expand All @@ -108,7 +129,8 @@ func startCmd() *cobra.Command {
}

cmd.Flags().StringVar(&agentName, "agent", "claude", "Agent to launch (currently: claude)")
cmd.Flags().StringVar(&templateFile, "env-template", ".env.kontext", "Path to env template file")
cmd.Flags().StringVar(&templateFile, "env-template", ".env.kontext", "Path to env template file for --managed sessions")
cmd.Flags().BoolVar(&managed, "managed", false, "Launch with hosted managed credentials and shared traces")
cmd.Flags().BoolVar(&verbose, "verbose", false, "Show redacted diagnostic output")

return cmd
Expand Down Expand Up @@ -205,7 +227,7 @@ func hookCmd() *cobra.Command {

cmd.Flags().StringVar(&agentName, "agent", "claude", "Agent type")
cmd.Flags().StringVar(&socketPath, "socket", "", "Unix socket path for local hook runtime")
cmd.Flags().StringVar(&mode, "mode", os.Getenv("KONTEXT_MODE"), "hook mode: observe or enforce")
cmd.Flags().StringVar(&mode, "mode", "", "hook mode: observe or enforce")

return cmd
}
Expand Down Expand Up @@ -275,7 +297,10 @@ func sidecarFailureResult(event hook.Event, reason, mode string) hook.Result {
if event.HookName != hook.HookPreToolUse {
return hook.Result{Decision: hook.DecisionAllow, Reason: reason}
}
if normalizedHookMode(mode) == "enforce" {
if hookMode := normalizedHookMode(mode); hookMode != "" {
if hookMode != "enforce" {
return hook.Result{Decision: hook.DecisionAllow, Reason: reason, Mode: hookMode}
}
return hook.Result{Decision: hook.DecisionDeny, Reason: reason, Mode: "enforce"}
}
if currentHostedAccessMode() == "enforce" {
Expand Down
127 changes: 127 additions & 0 deletions cmd/kontext/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"bytes"
"context"
"errors"
"fmt"
"net"
Expand All @@ -11,6 +12,7 @@ import (
"time"

"github.com/kontext-security/kontext-cli/internal/hook"
"github.com/kontext-security/kontext-cli/internal/run"
"github.com/zalando/go-keyring"
)

Expand Down Expand Up @@ -40,6 +42,98 @@ func TestStartCmdHasVerboseFlag(t *testing.T) {
}
}

func TestStartCmdHasManagedFlag(t *testing.T) {
cmd := startCmd()
flag := cmd.Flags().Lookup("managed")
if flag == nil {
t.Fatal("start command missing --managed flag")
}
if flag.DefValue != "false" {
t.Fatalf("--managed default = %q, want false", flag.DefValue)
}
}

func TestStartCmdDefaultsToLocalStart(t *testing.T) {
oldLocal := startLocal
oldManaged := startManaged
defer func() {
startLocal = oldLocal
startManaged = oldManaged
}()

called := ""
startLocal = func(_ context.Context, opts run.Options) error {
called = "local"
if opts.Agent != "claude" {
t.Fatalf("Agent = %q, want claude", opts.Agent)
}
return nil
}
startManaged = func(context.Context, run.Options) error {
t.Fatal("managed start should not be called")
return nil
}

cmd := startCmd()
if err := cmd.RunE(cmd, nil); err != nil {
t.Fatalf("RunE() error = %v", err)
}
if called != "local" {
t.Fatalf("called = %q, want local", called)
}
}

func TestStartCmdManagedFlagRoutesToHostedStart(t *testing.T) {
oldLocal := startLocal
oldManaged := startManaged
defer func() {
startLocal = oldLocal
startManaged = oldManaged
}()

called := ""
startLocal = func(context.Context, run.Options) error {
t.Fatal("local start should not be called")
return nil
}
startManaged = func(_ context.Context, opts run.Options) error {
called = "managed"
if opts.TemplateFile != "custom.env" {
t.Fatalf("TemplateFile = %q, want custom.env", opts.TemplateFile)
}
return nil
}

cmd := startCmd()
if err := cmd.Flags().Set("managed", "true"); err != nil {
t.Fatalf("Set managed error = %v", err)
}
if err := cmd.Flags().Set("env-template", "custom.env"); err != nil {
t.Fatalf("Set env-template error = %v", err)
}
if err := cmd.RunE(cmd, nil); err != nil {
t.Fatalf("RunE() error = %v", err)
}
if called != "managed" {
t.Fatalf("called = %q, want managed", called)
}
}

func TestStartCmdRejectsEnvTemplateWithoutManaged(t *testing.T) {
cmd := startCmd()
if err := cmd.Flags().Set("env-template", "custom.env"); err != nil {
t.Fatalf("Set env-template error = %v", err)
}

err := cmd.RunE(cmd, nil)
if err == nil {
t.Fatal("RunE() error = nil, want --env-template error")
}
if !strings.Contains(err.Error(), "--env-template is only used with --managed") {
t.Fatalf("error = %q, want env-template managed error", err.Error())
}
}

func TestGuardCmdRoutesToLocalGuardMode(t *testing.T) {
cmd := guardCmd()
if cmd.Use != "guard" {
Expand All @@ -50,6 +144,19 @@ func TestGuardCmdRoutesToLocalGuardMode(t *testing.T) {
}
}

func TestHookCmdModeDoesNotDefaultFromEnv(t *testing.T) {
t.Setenv("KONTEXT_MODE", "observe")

cmd := hookCmd()
flag := cmd.Flags().Lookup("mode")
if flag == nil {
t.Fatal("hook command missing --mode flag")
}
if flag.DefValue != "" {
t.Fatalf("--mode default = %q, want empty", flag.DefValue)
}
}

func TestLogoutCmdAlreadyLoggedOut(t *testing.T) {
cmd := newLogoutCmd(func() error { return keyring.ErrNotFound })

Expand Down Expand Up @@ -147,6 +254,26 @@ func TestEvaluateViaSidecarFailsClosedWhenEnforceSidecarUnavailable(t *testing.T
}
}

func TestEvaluateViaSidecarObserveModeIgnoresStaleHostedEnforce(t *testing.T) {
t.Setenv("KONTEXT_ACCESS_MODE", "enforce")

socketPath := fmt.Sprintf("/tmp/kontext-missing-%d.sock", time.Now().UnixNano())
result, err := evaluateViaSidecarForMode(socketPath, hook.Event{
Agent: "claude",
HookName: hook.HookPreToolUse,
ToolName: "Bash",
}, "observe")
if err != nil {
t.Fatalf("evaluateViaSidecarForMode() error = %v", err)
}
if result.Decision != hook.DecisionAllow {
t.Fatalf("decision = %q, want ALLOW", result.Decision)
}
if result.Mode != "observe" {
t.Fatalf("mode = %q, want observe", result.Mode)
}
}

func TestEvaluateViaSidecarFailsClosedWhenAccessModePathSet(t *testing.T) {
t.Setenv("KONTEXT_ACCESS_MODE_PATH", "/tmp/kontext-missing-mode")

Expand Down
26 changes: 20 additions & 6 deletions internal/guard/modelsnapshot/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,20 @@ func (s *Store) ActivateFromFile(path string) (Snapshot, error) {
if err != nil {
return Snapshot{}, err
}
return s.activate(path, data)
}

func (s *Store) ActivateBytes(source string, data []byte) (Snapshot, error) {
if strings.TrimSpace(source) == "" {
return Snapshot{}, fmt.Errorf("model source is required")
}
if len(data) == 0 {
return Snapshot{}, fmt.Errorf("model data is required")
}
return s.activate(source, data)
}

func (s *Store) activate(source string, data []byte) (Snapshot, error) {
model, err := markov.ReadModelJSON(bytes.NewReader(data))
if err != nil {
return Snapshot{}, fmt.Errorf("validate model snapshot: %w", err)
Expand Down Expand Up @@ -75,16 +89,16 @@ func (s *Store) ActivateFromFile(path string) (Snapshot, error) {

id := now.Format("20060102T150405.000000000Z") + "-" + hash[:12]
snapshotDir := filepath.Join(s.root, "snapshots")
if err := os.MkdirAll(snapshotDir, 0o755); err != nil {
if err := os.MkdirAll(snapshotDir, 0o700); err != nil {
return Snapshot{}, err
}
modelPath := filepath.Join(snapshotDir, id+".json")
if err := writeFileAtomic(modelPath, data, 0o644); err != nil {
if err := writeFileAtomic(modelPath, data, 0o600); err != nil {
return Snapshot{}, err
}
snapshot := Snapshot{
ID: id,
SourcePath: path,
SourcePath: source,
Path: modelPath,
SHA256: hash,
CreatedAt: now,
Expand Down Expand Up @@ -147,7 +161,7 @@ func (s *Store) writeSnapshot(snapshot Snapshot) error {
return err
}
data = append(data, '\n')
return writeFileAtomic(s.snapshotMetadataPath(snapshot.ID), data, 0o644)
return writeFileAtomic(s.snapshotMetadataPath(snapshot.ID), data, 0o600)
}

func (s *Store) readSnapshot(id string) (Snapshot, error) {
Expand All @@ -171,7 +185,7 @@ func (s *Store) writeActive(snapshot Snapshot) error {
return err
}
data = append(data, '\n')
return writeFileAtomic(s.activePath(), data, 0o644)
return writeFileAtomic(s.activePath(), data, 0o600)
}

func (s *Store) activePath() string {
Expand All @@ -183,7 +197,7 @@ func (s *Store) snapshotMetadataPath(id string) string {
}

func writeFileAtomic(path string, data []byte, perm os.FileMode) error {
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
return err
}
tmp, err := os.CreateTemp(filepath.Dir(path), "."+filepath.Base(path)+".tmp-*")
Expand Down
32 changes: 32 additions & 0 deletions internal/guard/modelsnapshot/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,38 @@ func TestActivateFromFilePersistsActiveSnapshot(t *testing.T) {
}
}

func TestActivateBytesPersistsPrivateSnapshot(t *testing.T) {
source := writeModel(t, t.TempDir(), "model.json", "v1")
data, err := os.ReadFile(source)
if err != nil {
t.Fatal(err)
}
root := t.TempDir()
store := New(root)

snapshot, err := store.ActivateBytes("embedded:test-model", data)
if err != nil {
t.Fatalf("ActivateBytes() error = %v", err)
}
if snapshot.SourcePath != "embedded:test-model" {
t.Fatalf("SourcePath = %q, want embedded source", snapshot.SourcePath)
}
for _, path := range []string{snapshot.Path, store.activePath(), store.snapshotMetadataPath(snapshot.ID)} {
info, err := os.Stat(path)
if err != nil {
t.Fatalf("Stat(%s) error = %v", path, err)
}
if got := info.Mode().Perm(); got != 0o600 {
t.Fatalf("%s mode = %o, want 600", path, got)
}
}
if info, err := os.Stat(filepath.Join(root, "snapshots")); err != nil {
t.Fatalf("snapshot dir stat error = %v", err)
} else if got := info.Mode().Perm(); got != 0o700 {
t.Fatalf("snapshot dir mode = %o, want 700", got)
}
}

func TestActivateFromFileReusesActiveSnapshotForSameModel(t *testing.T) {
source := writeModel(t, t.TempDir(), "model.json", "v1")
store := New(t.TempDir())
Expand Down
Loading
Loading