diff --git a/cmd/kontext/main.go b/cmd/kontext/main.go index 3ed48dd..abc3e36 100644 --- a/cmd/kontext/main.go +++ b/cmd/kontext/main.go @@ -27,6 +27,11 @@ import ( var version = "dev" +var ( + startLocal = run.StartLocal + startManaged = run.StartManaged +) + func main() { root := &cobra.Command{ Use: "kontext", @@ -73,12 +78,19 @@ func startCmd() *cobra.Command { var ( agentName string templateFile string + managed bool verbose bool ) cmd := &cobra.Command{ Use: "start [flags] [-- extra-agent-args...]", - Short: "Launch an agent with Kontext governance", + Short: "Launch an agent with Kontext runtime security", + Long: "Launch an agent with Kontext runtime security.\n\n" + + "By default, this starts a local-only runtime with no hosted login. " + + "Use --managed when you need hosted credentials, shared traces, and team governance.", + Example: " kontext start\n" + + " KONTEXT_MODE=enforce kontext start\n" + + " kontext start --managed", RunE: func(cmd *cobra.Command, args []string) error { if isInteractivePrompt() { if latest := update.Available(version); latest != "" { @@ -90,15 +102,24 @@ func startCmd() *cobra.Command { } else { update.CheckAsync(version) } + if !managed && cmd.Flags().Changed("env-template") { + return errors.New("--env-template is only used with --managed sessions") + } ctx := context.Background() - err := run.Start(ctx, run.Options{ + opts := run.Options{ Agent: agentName, TemplateFile: templateFile, IssuerURL: auth.DefaultIssuerURL, ClientID: auth.DefaultClientID, Verbose: verbose, Args: args, - }) + } + var err error + if managed { + err = startManaged(ctx, opts) + } else { + err = startLocal(ctx, opts) + } if exitErr, ok := err.(*run.AgentExitError); ok { fmt.Fprintf(os.Stderr, "Error: %v\n", exitErr) os.Exit(exitErr.ExitCode()) @@ -108,7 +129,8 @@ func startCmd() *cobra.Command { } cmd.Flags().StringVar(&agentName, "agent", "claude", "Agent to launch (currently: claude)") - cmd.Flags().StringVar(&templateFile, "env-template", ".env.kontext", "Path to env template file") + cmd.Flags().StringVar(&templateFile, "env-template", ".env.kontext", "Path to env template file for --managed sessions") + cmd.Flags().BoolVar(&managed, "managed", false, "Launch with hosted managed credentials and shared traces") cmd.Flags().BoolVar(&verbose, "verbose", false, "Show redacted diagnostic output") return cmd @@ -205,7 +227,7 @@ func hookCmd() *cobra.Command { cmd.Flags().StringVar(&agentName, "agent", "claude", "Agent type") cmd.Flags().StringVar(&socketPath, "socket", "", "Unix socket path for local hook runtime") - cmd.Flags().StringVar(&mode, "mode", os.Getenv("KONTEXT_MODE"), "hook mode: observe or enforce") + cmd.Flags().StringVar(&mode, "mode", "", "hook mode: observe or enforce") return cmd } @@ -275,7 +297,10 @@ func sidecarFailureResult(event hook.Event, reason, mode string) hook.Result { if event.HookName != hook.HookPreToolUse { return hook.Result{Decision: hook.DecisionAllow, Reason: reason} } - if normalizedHookMode(mode) == "enforce" { + if hookMode := normalizedHookMode(mode); hookMode != "" { + if hookMode != "enforce" { + return hook.Result{Decision: hook.DecisionAllow, Reason: reason, Mode: hookMode} + } return hook.Result{Decision: hook.DecisionDeny, Reason: reason, Mode: "enforce"} } if currentHostedAccessMode() == "enforce" { diff --git a/cmd/kontext/main_test.go b/cmd/kontext/main_test.go index 3f36970..e3195af 100644 --- a/cmd/kontext/main_test.go +++ b/cmd/kontext/main_test.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "context" "errors" "fmt" "net" @@ -11,6 +12,7 @@ import ( "time" "github.com/kontext-security/kontext-cli/internal/hook" + "github.com/kontext-security/kontext-cli/internal/run" "github.com/zalando/go-keyring" ) @@ -40,6 +42,98 @@ func TestStartCmdHasVerboseFlag(t *testing.T) { } } +func TestStartCmdHasManagedFlag(t *testing.T) { + cmd := startCmd() + flag := cmd.Flags().Lookup("managed") + if flag == nil { + t.Fatal("start command missing --managed flag") + } + if flag.DefValue != "false" { + t.Fatalf("--managed default = %q, want false", flag.DefValue) + } +} + +func TestStartCmdDefaultsToLocalStart(t *testing.T) { + oldLocal := startLocal + oldManaged := startManaged + defer func() { + startLocal = oldLocal + startManaged = oldManaged + }() + + called := "" + startLocal = func(_ context.Context, opts run.Options) error { + called = "local" + if opts.Agent != "claude" { + t.Fatalf("Agent = %q, want claude", opts.Agent) + } + return nil + } + startManaged = func(context.Context, run.Options) error { + t.Fatal("managed start should not be called") + return nil + } + + cmd := startCmd() + if err := cmd.RunE(cmd, nil); err != nil { + t.Fatalf("RunE() error = %v", err) + } + if called != "local" { + t.Fatalf("called = %q, want local", called) + } +} + +func TestStartCmdManagedFlagRoutesToHostedStart(t *testing.T) { + oldLocal := startLocal + oldManaged := startManaged + defer func() { + startLocal = oldLocal + startManaged = oldManaged + }() + + called := "" + startLocal = func(context.Context, run.Options) error { + t.Fatal("local start should not be called") + return nil + } + startManaged = func(_ context.Context, opts run.Options) error { + called = "managed" + if opts.TemplateFile != "custom.env" { + t.Fatalf("TemplateFile = %q, want custom.env", opts.TemplateFile) + } + return nil + } + + cmd := startCmd() + if err := cmd.Flags().Set("managed", "true"); err != nil { + t.Fatalf("Set managed error = %v", err) + } + if err := cmd.Flags().Set("env-template", "custom.env"); err != nil { + t.Fatalf("Set env-template error = %v", err) + } + if err := cmd.RunE(cmd, nil); err != nil { + t.Fatalf("RunE() error = %v", err) + } + if called != "managed" { + t.Fatalf("called = %q, want managed", called) + } +} + +func TestStartCmdRejectsEnvTemplateWithoutManaged(t *testing.T) { + cmd := startCmd() + if err := cmd.Flags().Set("env-template", "custom.env"); err != nil { + t.Fatalf("Set env-template error = %v", err) + } + + err := cmd.RunE(cmd, nil) + if err == nil { + t.Fatal("RunE() error = nil, want --env-template error") + } + if !strings.Contains(err.Error(), "--env-template is only used with --managed") { + t.Fatalf("error = %q, want env-template managed error", err.Error()) + } +} + func TestGuardCmdRoutesToLocalGuardMode(t *testing.T) { cmd := guardCmd() if cmd.Use != "guard" { @@ -50,6 +144,19 @@ func TestGuardCmdRoutesToLocalGuardMode(t *testing.T) { } } +func TestHookCmdModeDoesNotDefaultFromEnv(t *testing.T) { + t.Setenv("KONTEXT_MODE", "observe") + + cmd := hookCmd() + flag := cmd.Flags().Lookup("mode") + if flag == nil { + t.Fatal("hook command missing --mode flag") + } + if flag.DefValue != "" { + t.Fatalf("--mode default = %q, want empty", flag.DefValue) + } +} + func TestLogoutCmdAlreadyLoggedOut(t *testing.T) { cmd := newLogoutCmd(func() error { return keyring.ErrNotFound }) @@ -147,6 +254,26 @@ func TestEvaluateViaSidecarFailsClosedWhenEnforceSidecarUnavailable(t *testing.T } } +func TestEvaluateViaSidecarObserveModeIgnoresStaleHostedEnforce(t *testing.T) { + t.Setenv("KONTEXT_ACCESS_MODE", "enforce") + + socketPath := fmt.Sprintf("/tmp/kontext-missing-%d.sock", time.Now().UnixNano()) + result, err := evaluateViaSidecarForMode(socketPath, hook.Event{ + Agent: "claude", + HookName: hook.HookPreToolUse, + ToolName: "Bash", + }, "observe") + if err != nil { + t.Fatalf("evaluateViaSidecarForMode() error = %v", err) + } + if result.Decision != hook.DecisionAllow { + t.Fatalf("decision = %q, want ALLOW", result.Decision) + } + if result.Mode != "observe" { + t.Fatalf("mode = %q, want observe", result.Mode) + } +} + func TestEvaluateViaSidecarFailsClosedWhenAccessModePathSet(t *testing.T) { t.Setenv("KONTEXT_ACCESS_MODE_PATH", "/tmp/kontext-missing-mode") diff --git a/internal/guard/modelsnapshot/store.go b/internal/guard/modelsnapshot/store.go index b3816da..85dc783 100644 --- a/internal/guard/modelsnapshot/store.go +++ b/internal/guard/modelsnapshot/store.go @@ -48,6 +48,20 @@ func (s *Store) ActivateFromFile(path string) (Snapshot, error) { if err != nil { return Snapshot{}, err } + return s.activate(path, data) +} + +func (s *Store) ActivateBytes(source string, data []byte) (Snapshot, error) { + if strings.TrimSpace(source) == "" { + return Snapshot{}, fmt.Errorf("model source is required") + } + if len(data) == 0 { + return Snapshot{}, fmt.Errorf("model data is required") + } + return s.activate(source, data) +} + +func (s *Store) activate(source string, data []byte) (Snapshot, error) { model, err := markov.ReadModelJSON(bytes.NewReader(data)) if err != nil { return Snapshot{}, fmt.Errorf("validate model snapshot: %w", err) @@ -75,16 +89,16 @@ func (s *Store) ActivateFromFile(path string) (Snapshot, error) { id := now.Format("20060102T150405.000000000Z") + "-" + hash[:12] snapshotDir := filepath.Join(s.root, "snapshots") - if err := os.MkdirAll(snapshotDir, 0o755); err != nil { + if err := os.MkdirAll(snapshotDir, 0o700); err != nil { return Snapshot{}, err } modelPath := filepath.Join(snapshotDir, id+".json") - if err := writeFileAtomic(modelPath, data, 0o644); err != nil { + if err := writeFileAtomic(modelPath, data, 0o600); err != nil { return Snapshot{}, err } snapshot := Snapshot{ ID: id, - SourcePath: path, + SourcePath: source, Path: modelPath, SHA256: hash, CreatedAt: now, @@ -147,7 +161,7 @@ func (s *Store) writeSnapshot(snapshot Snapshot) error { return err } data = append(data, '\n') - return writeFileAtomic(s.snapshotMetadataPath(snapshot.ID), data, 0o644) + return writeFileAtomic(s.snapshotMetadataPath(snapshot.ID), data, 0o600) } func (s *Store) readSnapshot(id string) (Snapshot, error) { @@ -171,7 +185,7 @@ func (s *Store) writeActive(snapshot Snapshot) error { return err } data = append(data, '\n') - return writeFileAtomic(s.activePath(), data, 0o644) + return writeFileAtomic(s.activePath(), data, 0o600) } func (s *Store) activePath() string { @@ -183,7 +197,7 @@ func (s *Store) snapshotMetadataPath(id string) string { } func writeFileAtomic(path string, data []byte, perm os.FileMode) error { - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { return err } tmp, err := os.CreateTemp(filepath.Dir(path), "."+filepath.Base(path)+".tmp-*") diff --git a/internal/guard/modelsnapshot/store_test.go b/internal/guard/modelsnapshot/store_test.go index 77cb68c..80eb110 100644 --- a/internal/guard/modelsnapshot/store_test.go +++ b/internal/guard/modelsnapshot/store_test.go @@ -37,6 +37,38 @@ func TestActivateFromFilePersistsActiveSnapshot(t *testing.T) { } } +func TestActivateBytesPersistsPrivateSnapshot(t *testing.T) { + source := writeModel(t, t.TempDir(), "model.json", "v1") + data, err := os.ReadFile(source) + if err != nil { + t.Fatal(err) + } + root := t.TempDir() + store := New(root) + + snapshot, err := store.ActivateBytes("embedded:test-model", data) + if err != nil { + t.Fatalf("ActivateBytes() error = %v", err) + } + if snapshot.SourcePath != "embedded:test-model" { + t.Fatalf("SourcePath = %q, want embedded source", snapshot.SourcePath) + } + for _, path := range []string{snapshot.Path, store.activePath(), store.snapshotMetadataPath(snapshot.ID)} { + info, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat(%s) error = %v", path, err) + } + if got := info.Mode().Perm(); got != 0o600 { + t.Fatalf("%s mode = %o, want 600", path, got) + } + } + if info, err := os.Stat(filepath.Join(root, "snapshots")); err != nil { + t.Fatalf("snapshot dir stat error = %v", err) + } else if got := info.Mode().Perm(); got != 0o700 { + t.Fatalf("snapshot dir mode = %o, want 700", got) + } +} + func TestActivateFromFileReusesActiveSnapshotForSameModel(t *testing.T) { source := writeModel(t, t.TempDir(), "model.json", "v1") store := New(t.TempDir()) diff --git a/internal/localruntime/service.go b/internal/localruntime/service.go index 705dfb4..d379820 100644 --- a/internal/localruntime/service.go +++ b/internal/localruntime/service.go @@ -5,6 +5,7 @@ import ( "errors" "net" "os" + "sync" "time" "github.com/kontext-security/kontext-cli/internal/diagnostic" @@ -24,6 +25,8 @@ type Service struct { onFailure func(hook.Event, error) hook.Result diagnostic diagnostic.Logger cancel context.CancelFunc + serveDone chan struct{} + wg sync.WaitGroup } type Options struct { @@ -65,24 +68,62 @@ func (s *Service) Start(ctx context.Context) error { s.listener = ln ctx, s.cancel = context.WithCancel(ctx) - go s.acceptLoop(ctx) + s.serveDone = make(chan struct{}) + go s.acceptLoop(ctx, ln, s.serveDone) return nil } func (s *Service) Stop() { - if s.cancel != nil { - s.cancel() + _ = s.Shutdown(context.Background()) +} + +func (s *Service) Shutdown(ctx context.Context) error { + if ctx == nil { + ctx = context.Background() } if s.listener != nil { - s.listener.Close() + _ = s.listener.Close() + } + _ = os.Remove(s.socketPath) + + if s.serveDone != nil { + select { + case <-s.serveDone: + case <-ctx.Done(): + if s.cancel != nil { + s.cancel() + } + return ctx.Err() + } + } + + done := make(chan struct{}) + go func() { + s.wg.Wait() + close(done) + }() + select { + case <-done: + if s.cancel != nil { + s.cancel() + } + return nil + case <-ctx.Done(): + if s.cancel != nil { + s.cancel() + } + return ctx.Err() } - os.Remove(s.socketPath) } -func (s *Service) acceptLoop(ctx context.Context) { +func (s *Service) acceptLoop(ctx context.Context, listener net.Listener, done chan<- struct{}) { + defer close(done) for { - conn, err := s.listener.Accept() + conn, err := listener.Accept() if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } select { case <-ctx.Done(): return @@ -91,7 +132,11 @@ func (s *Service) acceptLoop(ctx context.Context) { return } } - go s.handleConn(ctx, conn) + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.handleConn(ctx, conn) + }() } } @@ -115,7 +160,11 @@ func (s *Service) handleConn(ctx context.Context, conn net.Conn) { } if s.asyncIngest && req.HookEvent != hook.HookPreToolUse.String() { - go s.ingestEvent(ctx, &req) + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.ingestEvent(ctx, &req) + }() } } diff --git a/internal/localruntime/service_test.go b/internal/localruntime/service_test.go index 00fa316..aab6c86 100644 --- a/internal/localruntime/service_test.go +++ b/internal/localruntime/service_test.go @@ -73,6 +73,55 @@ func TestServiceCanAckTelemetryBeforeAsyncIngest(t *testing.T) { } } +func TestServiceShutdownDrainsAsyncIngest(t *testing.T) { + t.Parallel() + + runtime := &stubRuntime{ + ingestStarted: make(chan struct{}), + releaseIngest: make(chan struct{}), + } + service := newTestService(t, runtime, true) + client := NewClient(service.SocketPath()) + + result, err := client.Process(context.Background(), hook.Event{ + SessionID: "agent-session", + HookName: hook.HookPostToolUse, + ToolName: "Bash", + }) + if err != nil { + t.Fatalf("Process() error = %v", err) + } + if result.Decision != hook.DecisionAllow { + t.Fatalf("Process().Decision = %q, want allow", result.Decision) + } + + select { + case <-runtime.ingestStarted: + case <-time.After(2 * time.Second): + t.Fatal("async ingest did not start") + } + + done := make(chan error, 1) + go func() { + done <- service.Shutdown(context.Background()) + }() + select { + case err := <-done: + t.Fatalf("Shutdown() returned before async ingest drained: %v", err) + case <-time.After(50 * time.Millisecond): + } + + close(runtime.releaseIngest) + select { + case err := <-done: + if err != nil { + t.Fatalf("Shutdown() error = %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Shutdown() did not return after async ingest completed") + } +} + func TestServiceUsesCustomFailureResult(t *testing.T) { t.Parallel() @@ -191,8 +240,11 @@ type stubRuntime struct { evaluateResult hook.Result evaluateErr error ingested chan hook.Event + ingestStarted chan struct{} + releaseIngest chan struct{} evaluateCalls atomic.Int32 ingestCalls atomic.Int32 + startedIngest atomic.Bool } func (s *stubRuntime) EvaluateHook(_ context.Context, _ hook.Event) (hook.Result, error) { @@ -200,8 +252,18 @@ func (s *stubRuntime) EvaluateHook(_ context.Context, _ hook.Event) (hook.Result return s.evaluateResult, s.evaluateErr } -func (s *stubRuntime) IngestEvent(_ context.Context, event hook.Event) (hook.Result, error) { +func (s *stubRuntime) IngestEvent(ctx context.Context, event hook.Event) (hook.Result, error) { s.ingestCalls.Add(1) + if s.ingestStarted != nil && s.startedIngest.CompareAndSwap(false, true) { + close(s.ingestStarted) + } + if s.releaseIngest != nil { + select { + case <-s.releaseIngest: + case <-ctx.Done(): + return hook.Result{}, ctx.Err() + } + } if s.ingested != nil { s.ingested <- event } diff --git a/internal/run/hooks.go b/internal/run/hooks.go index a52bc7a..c9d7d54 100644 --- a/internal/run/hooks.go +++ b/internal/run/hooks.go @@ -38,7 +38,23 @@ func commandHookGroups(command string) []hookGroup { // and returns the path to the generated file. func GenerateSettings(sessionDir, kontextBinary, agentName string) (string, error) { hookCmd := fmt.Sprintf("%s hook --agent %s", kontextBinary, agentName) + return writeClaudeSettings(sessionDir, hookCmd) +} + +// GenerateLocalSettings creates a Claude Code settings.json that connects +// hooks to the wrapper-owned local runtime socket. +func GenerateLocalSettings(sessionDir, kontextBinary, agentName, socketPath, mode string) (string, error) { + hookCmd := fmt.Sprintf( + "%s hook --agent %s --mode %s --socket %s", + shellQuote(kontextBinary), + shellQuote(agentName), + shellQuote(mode), + shellQuote(socketPath), + ) + return writeClaudeSettings(sessionDir, hookCmd) +} +func writeClaudeSettings(sessionDir, hookCmd string) (string, error) { settings := claudeSettings{ Hooks: map[string][]hookGroup{ "PreToolUse": commandHookGroups(hookCmd), @@ -59,6 +75,10 @@ func GenerateSettings(sessionDir, kontextBinary, agentName string) (string, erro return settingsPath, nil } +func shellQuote(value string) string { + return "'" + strings.ReplaceAll(value, "'", "'\\''") + "'" +} + func VerifyBlockingHookSettings(settingsPath, kontextBinary, agentName string) error { data, err := os.ReadFile(settingsPath) if err != nil { diff --git a/internal/run/local.go b/internal/run/local.go new file mode 100644 index 0000000..b1d6417 --- /dev/null +++ b/internal/run/local.go @@ -0,0 +1,77 @@ +package run + +import ( + "context" + "fmt" + "os" + + "github.com/kontext-security/kontext-cli/internal/diagnostic" + "github.com/kontext-security/kontext-cli/internal/runtimehost" +) + +// StartLocal launches an agent with a wrapper-owned local runtime. +func StartLocal(ctx context.Context, opts Options) error { + diagnostics := diagnostic.New(os.Stderr, opts.Verbose || diagnostic.EnabledFromEnv()) + diagnostics.Printf("start local: agent=%s\n", opts.Agent) + + agentPath, err := preflightAgent(opts.Agent) + if err != nil { + return err + } + diagnostics.Printf("agent preflight: %s -> %s\n", opts.Agent, agentPath) + + mode, err := runtimehost.ResolveMode(os.Getenv("KONTEXT_MODE")) + if err != nil { + return err + } + cwd, _ := os.Getwd() + host, err := runtimehost.Start(ctx, runtimehost.Options{ + AgentName: opts.Agent, + CWD: cwd, + DBPath: os.Getenv("KONTEXT_DB"), + ModelPath: os.Getenv("KONTEXT_MODEL"), + DashboardAddr: os.Getenv("KONTEXT_ADDR"), + StartDashboard: true, + Mode: string(mode), + Diagnostic: diagnostics, + Out: os.Stderr, + }) + if err != nil { + return err + } + defer func() { + _ = host.Close(context.Background()) + }() + + kontextBin, err := os.Executable() + if err != nil { + return fmt.Errorf("resolve kontext executable: %w", err) + } + settingsPath, err := GenerateLocalSettings(host.SessionDir, kontextBin, opts.Agent, host.SocketPath, string(host.Mode)) + if err != nil { + return fmt.Errorf("generate settings: %w", err) + } + + env := append(os.Environ(), "KONTEXT_RUN=1") + env = append(env, "KONTEXT_SOCKET="+host.SocketPath) + env = append(env, "KONTEXT_SESSION_ID="+host.SessionID) + env = append(env, "KONTEXT_MODE="+string(host.Mode)) + + fmt.Fprintf(os.Stderr, "✓ Local session: %s\n", truncateID(host.SessionID)) + if host.DashboardURL != "" { + fmt.Fprintf(os.Stderr, "✓ Dashboard: %s\n", host.DashboardURL) + } + fmt.Fprintf(os.Stderr, "✓ Risk model: %s\n", host.ActiveModelPath) + printLocalMode(os.Stderr, string(host.Mode)) + fmt.Fprintf(os.Stderr, "\nLaunching %s...\n\n", opts.Agent) + + return launchAgentWithSettings(ctx, opts.Agent, agentPath, env, opts.Args, settingsPath) +} + +func printLocalMode(out *os.File, mode string) { + if mode == "enforce" { + fmt.Fprintln(out, "Mode: enforce (ask and deny decisions block supported pre-tool actions).") + return + } + fmt.Fprintln(out, "Mode: observe (agent actions run normally; decisions are recorded as would allow / would ask / would deny).") +} diff --git a/internal/run/run.go b/internal/run/run.go index 870845d..da2333f 100644 --- a/internal/run/run.go +++ b/internal/run/run.go @@ -46,10 +46,15 @@ type Options struct { Args []string } -// Start is the main entry point for `kontext start`. +// Start preserves the historical managed-session entry point. func Start(ctx context.Context, opts Options) error { + return StartManaged(ctx, opts) +} + +// StartManaged launches an agent with a hosted managed Kontext session. +func StartManaged(ctx context.Context, opts Options) error { diagnostics := diagnostic.New(os.Stderr, opts.Verbose || diagnostic.EnabledFromEnv()) - diagnostics.Printf("start: agent=%s env_template=%s\n", opts.Agent, opts.TemplateFile) + diagnostics.Printf("start managed: agent=%s env_template=%s\n", opts.Agent, opts.TemplateFile) agentPath, err := preflightAgent(opts.Agent) if err != nil { @@ -246,7 +251,10 @@ func Start(ctx context.Context, opts Options) error { defer runtimeService.Stop() // 7. Generate hook settings - kontextBin, _ := os.Executable() + kontextBin, err := os.Executable() + if err != nil { + return fmt.Errorf("resolve kontext executable: %w", err) + } settingsPath, err := GenerateSettings(sessionDir, kontextBin, opts.Agent) if err != nil { return fmt.Errorf("generate settings: %w", err) diff --git a/internal/run/run_test.go b/internal/run/run_test.go index 16b16a3..31530fe 100644 --- a/internal/run/run_test.go +++ b/internal/run/run_test.go @@ -1189,6 +1189,48 @@ func TestGenerateSettingsWritesClaudeHooks(t *testing.T) { } } +func TestGenerateLocalSettingsWritesSocketHooks(t *testing.T) { + t.Parallel() + + sessionDir := t.TempDir() + settingsPath, err := GenerateLocalSettings(sessionDir, "/usr/local/bin/kontext", "claude", "/tmp/kontext.sock", "observe") + if err != nil { + t.Fatalf("GenerateLocalSettings() error = %v", err) + } + + data, err := os.ReadFile(settingsPath) + if err != nil { + t.Fatalf("ReadFile() error = %v", err) + } + + var settings claudeSettings + if err := json.Unmarshal(data, &settings); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + wantEvents := []string{"PreToolUse", "PostToolUse"} + wantCommand := `'/usr/local/bin/kontext' hook --agent 'claude' --mode 'observe' --socket '/tmp/kontext.sock'` + for _, event := range wantEvents { + groups, ok := settings.Hooks[event] + if !ok { + t.Fatalf("settings.Hooks missing %q", event) + } + if len(groups) != 1 { + t.Fatalf("%s groups len = %d, want 1", event, len(groups)) + } + if len(groups[0].Hooks) != 1 { + t.Fatalf("%s hooks len = %d, want 1", event, len(groups[0].Hooks)) + } + hook := groups[0].Hooks[0] + if hook.Command != wantCommand { + t.Fatalf("%s hook command = %q, want %q", event, hook.Command, wantCommand) + } + if hook.Timeout != 10 { + t.Fatalf("%s hook timeout = %d, want 10", event, hook.Timeout) + } + } +} + func TestVerifyBlockingHookSettingsRequiresPreToolUseCommand(t *testing.T) { t.Parallel() diff --git a/internal/runtimehost/host.go b/internal/runtimehost/host.go new file mode 100644 index 0000000..64ccef5 --- /dev/null +++ b/internal/runtimehost/host.go @@ -0,0 +1,420 @@ +package runtimehost + +import ( + "context" + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "io" + "net" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/kontext-security/kontext-cli/internal/diagnostic" + "github.com/kontext-security/kontext-cli/internal/guard/app/server" + guardhookruntime "github.com/kontext-security/kontext-cli/internal/guard/hookruntime" + "github.com/kontext-security/kontext-cli/internal/guard/modelsnapshot" + "github.com/kontext-security/kontext-cli/internal/guard/risk" + "github.com/kontext-security/kontext-cli/internal/localruntime" + "github.com/kontext-security/kontext-cli/internal/runtimecore" + guardmodels "github.com/kontext-security/kontext-cli/models/guard" +) + +const ( + defaultModelSource = "embedded:models/guard/coding-agent-v0.json" + defaultThreshold = 0.5 + defaultHorizon = 5 +) + +type Options struct { + AgentName string + SessionID string + CWD string + DBPath string + ModelPath string + ModelSnapshotDir string + SocketPath string + DashboardAddr string + StartDashboard bool + AllowNonLoopbackDashboard bool + Mode string + Threshold float64 + Horizon int + Diagnostic diagnostic.Logger + Out io.Writer +} + +type Host struct { + SessionID string + SessionDir string + SocketPath string + DBPath string + DashboardURL string + DashboardErr error + ActiveModelPath string + Mode guardhookruntime.Mode + + server *server.Server + closeStore func() error + runtimeService *localruntime.Service + dashboardServer *http.Server + sessionOpened bool + sessionCloseOnce bool +} + +func Start(ctx context.Context, opts Options) (*Host, error) { + if strings.TrimSpace(opts.AgentName) == "" { + return nil, errors.New("runtime host requires agent name") + } + mode, err := ResolveMode(opts.Mode) + if err != nil { + return nil, err + } + dbPath := strings.TrimSpace(opts.DBPath) + usingDefaultDBPath := false + if dbPath == "" { + dbPath = DefaultDBPath() + usingDefaultDBPath = true + } + if err := ensureRuntimeDataDir(filepath.Dir(dbPath), usingDefaultDBPath); err != nil { + return nil, fmt.Errorf("create runtime data dir: %w", err) + } + + scorer, activeModelPath, err := loadScorer(loadScorerOptions{ + DBPath: dbPath, + ModelPath: opts.ModelPath, + ModelSnapshotDir: opts.ModelSnapshotDir, + Threshold: opts.Threshold, + Horizon: opts.Horizon, + }) + if err != nil { + return nil, err + } + + localServer, closeStore, err := server.OpenDefaultServer(dbPath, scorer) + if err != nil { + return nil, err + } + if err := os.Chmod(dbPath, 0o600); err != nil && !errors.Is(err, os.ErrNotExist) { + _ = closeStore() + return nil, fmt.Errorf("secure runtime database: %w", err) + } + + sessionID := strings.TrimSpace(opts.SessionID) + if sessionID == "" { + sessionID = NewSessionID() + } + if err := validateSessionID(sessionID); err != nil { + _ = closeStore() + return nil, err + } + sessionDir := filepath.Join("/tmp", "kontext", sessionID) + if err := createSessionDir(sessionDir); err != nil { + _ = closeStore() + return nil, err + } + socketPath := strings.TrimSpace(opts.SocketPath) + if socketPath == "" { + socketPath = filepath.Join(sessionDir, "kontext.sock") + } + + host := &Host{ + SessionID: sessionID, + SessionDir: sessionDir, + SocketPath: socketPath, + DBPath: dbPath, + ActiveModelPath: activeModelPath, + Mode: mode, + server: localServer, + closeStore: closeStore, + } + + runtimeService, err := localruntime.NewService(localruntime.Options{ + SocketPath: socketPath, + Core: localServer.RuntimeCore(), + SessionID: sessionID, + AgentName: opts.AgentName, + AsyncIngest: true, + Diagnostic: opts.Diagnostic, + }) + if err != nil { + _ = host.Close(context.Background()) + return nil, fmt.Errorf("local runtime: %w", err) + } + if err := runtimeService.Start(ctx); err != nil { + _ = host.Close(context.Background()) + return nil, fmt.Errorf("local runtime start: %w", err) + } + host.runtimeService = runtimeService + + cwd := opts.CWD + if cwd == "" { + cwd, _ = os.Getwd() + } + if _, err := localServer.RuntimeCore().OpenSession(ctx, runtimecore.Session{ + ID: sessionID, + Agent: opts.AgentName, + CWD: cwd, + Source: runtimecore.SessionSourceWrapperOwned, + ExternalID: sessionID, + }); err != nil { + _ = host.Close(context.Background()) + return nil, fmt.Errorf("open runtime session: %w", err) + } + host.sessionOpened = true + + if opts.StartDashboard { + addr, err := DashboardAddr(opts.DashboardAddr, opts.AllowNonLoopbackDashboard) + if err != nil { + _ = host.Close(context.Background()) + return nil, err + } + dashboardServer, dashboardURL, err := startDashboard(addr, localServer.Handler()) + if err != nil { + host.DashboardErr = err + if opts.Out != nil { + fmt.Fprintf(opts.Out, "Local dashboard unavailable: %v\n", err) + } + } else { + host.dashboardServer = dashboardServer + host.DashboardURL = dashboardURL + } + } + + return host, nil +} + +func (h *Host) Close(ctx context.Context) error { + var errs []error + if h == nil { + return nil + } + if h.dashboardServer != nil { + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + err := h.dashboardServer.Shutdown(shutdownCtx) + cancel() + if err != nil { + errs = append(errs, err) + } + h.dashboardServer = nil + } + if h.runtimeService != nil { + if err := h.runtimeService.Shutdown(ctx); err != nil { + errs = append(errs, err) + } + h.runtimeService = nil + } + if h.sessionOpened && !h.sessionCloseOnce && h.server != nil { + if err := h.server.RuntimeCore().CloseSession(context.Background(), h.SessionID); err != nil { + errs = append(errs, err) + } + h.sessionCloseOnce = true + } + if h.closeStore != nil { + if err := h.closeStore(); err != nil { + errs = append(errs, err) + } + h.closeStore = nil + } + if h.SessionDir != "" { + if err := os.RemoveAll(h.SessionDir); err != nil { + errs = append(errs, err) + } + h.SessionDir = "" + } + return errors.Join(errs...) +} + +func ResolveMode(value string) (guardhookruntime.Mode, error) { + return guardhookruntime.ParseMode(value) +} + +func NewSessionID() string { + var bytes [16]byte + if _, err := rand.Read(bytes[:]); err == nil { + return hex.EncodeToString(bytes[:]) + } + return strconv.FormatInt(time.Now().UnixNano(), 16) +} + +func DefaultDBPath() string { + if dir, err := os.UserConfigDir(); err == nil && dir != "" { + return filepath.Join(dir, "kontext", "guard.db") + } + if home, err := os.UserHomeDir(); err == nil && home != "" { + return filepath.Join(home, ".kontext", "guard.db") + } + return "kontext-guard.db" +} + +func DashboardAddr(value string, allowNonLoopback bool) (string, error) { + addr := strings.TrimSpace(value) + if addr == "" { + addr = server.DefaultAddr + } + host, _, err := net.SplitHostPort(addr) + if err != nil { + return "", fmt.Errorf("dashboard address %q must be host:port: %w", addr, err) + } + if allowNonLoopback || isLoopbackHost(host) { + return addr, nil + } + return "", fmt.Errorf("dashboard address %q is not loopback; use 127.0.0.1 or localhost", addr) +} + +type loadScorerOptions struct { + DBPath string + ModelPath string + ModelSnapshotDir string + Threshold float64 + Horizon int +} + +func loadScorer(opts loadScorerOptions) (risk.Scorer, string, error) { + threshold, err := envFloat("KONTEXT_THRESHOLD", opts.Threshold, defaultThreshold) + if err != nil { + return nil, "", err + } + horizon, err := envInt("KONTEXT_HORIZON", opts.Horizon, defaultHorizon) + if err != nil { + return nil, "", err + } + snapshotDir := strings.TrimSpace(opts.ModelSnapshotDir) + if snapshotDir == "" { + snapshotDir = defaultModelSnapshotDir(opts.DBPath) + } + store := modelsnapshot.NewWithValidator(snapshotDir, risk.ValidateMarkovModel) + + modelPath := strings.TrimSpace(opts.ModelPath) + if modelPath == "" { + modelPath = strings.TrimSpace(os.Getenv("KONTEXT_MODEL")) + } + var snapshot modelsnapshot.Snapshot + if modelPath == "" { + snapshot, err = store.ActivateBytes(defaultModelSource, guardmodels.CodingAgentV0) + } else { + snapshot, err = store.ActivateFromFile(modelPath) + } + if err != nil { + return nil, "", fmt.Errorf("activate risk model: %w", err) + } + scorer, err := risk.LoadMarkovScorer(snapshot.Path, threshold, horizon) + if err != nil { + return nil, "", fmt.Errorf("load risk model: %w", err) + } + return scorer, snapshot.Path, nil +} + +func startDashboard(addr string, handler http.Handler) (*http.Server, string, error) { + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, "", fmt.Errorf("local dashboard listen on %s: %w", addr, err) + } + httpServer := &http.Server{ + Handler: handler, + ReadHeaderTimeout: 5 * time.Second, + } + go func() { + if err := httpServer.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { + fmt.Fprintf(os.Stderr, "Local dashboard stopped: %v\n", err) + } + }() + return httpServer, "http://" + ln.Addr().String(), nil +} + +func createSessionDir(sessionDir string) error { + if err := os.MkdirAll(filepath.Dir(sessionDir), 0o700); err != nil { + return fmt.Errorf("create session parent dir: %w", err) + } + if err := os.Mkdir(sessionDir, 0o700); err != nil { + if errors.Is(err, os.ErrExist) { + return fmt.Errorf("runtime session directory already exists: %s", sessionDir) + } + return fmt.Errorf("create session dir: %w", err) + } + return nil +} + +func ensureRuntimeDataDir(path string, private bool) error { + if path == "." || path == "" { + return nil + } + if err := os.MkdirAll(path, 0o700); err != nil { + return err + } + if !private { + return nil + } + return os.Chmod(path, 0o700) +} + +func defaultModelSnapshotDir(dbPath string) string { + if dbPath != "" { + return filepath.Join(filepath.Dir(dbPath), "models") + } + return filepath.Join(".", "models") +} + +func envFloat(key string, explicit, fallback float64) (float64, error) { + if value := strings.TrimSpace(os.Getenv(key)); value != "" { + parsed, err := strconv.ParseFloat(value, 64) + if err != nil { + return 0, fmt.Errorf("%s must be a number: %w", key, err) + } + return parsed, nil + } + if explicit != 0 { + return explicit, nil + } + return fallback, nil +} + +func envInt(key string, explicit, fallback int) (int, error) { + if value := strings.TrimSpace(os.Getenv(key)); value != "" { + parsed, err := strconv.Atoi(value) + if err != nil { + return 0, fmt.Errorf("%s must be an integer: %w", key, err) + } + return parsed, nil + } + if explicit != 0 { + return explicit, nil + } + return fallback, nil +} + +func isLoopbackHost(host string) bool { + if strings.EqualFold(host, "localhost") { + return true + } + ip := net.ParseIP(strings.Trim(host, "[]")) + return ip != nil && ip.IsLoopback() +} + +func validateSessionID(sessionID string) error { + if sessionID == "" { + return errors.New("runtime session ID is required") + } + if sessionID != filepath.Base(sessionID) || strings.ContainsAny(sessionID, `/\`) { + return fmt.Errorf("runtime session ID %q is not a safe path segment", sessionID) + } + if sessionID == "." || sessionID == ".." { + return fmt.Errorf("runtime session ID %q is not a safe path segment", sessionID) + } + for i := 0; i < len(sessionID); i++ { + c := sessionID[i] + if (c >= 'a' && c <= 'z') || + (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') || + c == '-' || c == '_' || c == '.' { + continue + } + return fmt.Errorf("runtime session ID %q is not a safe path segment", sessionID) + } + return nil +} diff --git a/internal/runtimehost/host_test.go b/internal/runtimehost/host_test.go new file mode 100644 index 0000000..9c942b0 --- /dev/null +++ b/internal/runtimehost/host_test.go @@ -0,0 +1,260 @@ +package runtimehost + +import ( + "context" + "net/http" + "os" + "path/filepath" + "testing" + + "github.com/kontext-security/kontext-cli/internal/diagnostic" + "github.com/kontext-security/kontext-cli/internal/guard/store/sqlite" + "github.com/kontext-security/kontext-cli/internal/hook" + "github.com/kontext-security/kontext-cli/internal/localruntime" +) + +func TestStartLoadsEmbeddedModelOutsideRepoCWD(t *testing.T) { + t.Setenv("KONTEXT_MODEL", "") + t.Setenv("KONTEXT_THRESHOLD", "") + t.Setenv("KONTEXT_HORIZON", "") + t.Chdir(t.TempDir()) + + host, err := Start(context.Background(), Options{ + AgentName: "claude", + CWD: t.TempDir(), + DBPath: filepath.Join(t.TempDir(), "guard.db"), + }) + if err != nil { + t.Fatalf("Start() error = %v", err) + } + defer host.Close(context.Background()) + + if host.ActiveModelPath == "" { + t.Fatal("ActiveModelPath is empty") + } + if _, err := os.Stat(host.ActiveModelPath); err != nil { + t.Fatalf("active model stat error = %v", err) + } +} + +func TestStartUsesFullSessionIDForSessionDir(t *testing.T) { + sessionID := "1234567890abcdef1234567890abcdef" + host, err := Start(context.Background(), Options{ + AgentName: "claude", + SessionID: sessionID, + CWD: t.TempDir(), + DBPath: filepath.Join(t.TempDir(), "guard.db"), + }) + if err != nil { + t.Fatalf("Start() error = %v", err) + } + defer host.Close(context.Background()) + + if got, want := filepath.Base(host.SessionDir), sessionID; got != want { + t.Fatalf("session dir base = %q, want %q", got, want) + } +} + +func TestStartRejectsUnsafeSessionID(t *testing.T) { + for _, sessionID := range []string{"../escape", "nested/path", `nested\path`, ".", "..", "bad id"} { + t.Run(sessionID, func(t *testing.T) { + host, err := Start(context.Background(), Options{ + AgentName: "claude", + SessionID: sessionID, + CWD: t.TempDir(), + DBPath: filepath.Join(t.TempDir(), "guard.db"), + }) + if err == nil { + _ = host.Close(context.Background()) + t.Fatal("Start() error = nil, want unsafe session ID error") + } + }) + } +} + +func TestStartDoesNotChmodCustomDBParent(t *testing.T) { + parent := filepath.Join(t.TempDir(), "project") + if err := os.Mkdir(parent, 0o755); err != nil { + t.Fatalf("Mkdir() error = %v", err) + } + if err := os.Chmod(parent, 0o755); err != nil { + t.Fatalf("Chmod() error = %v", err) + } + host, err := Start(context.Background(), Options{ + AgentName: "claude", + CWD: t.TempDir(), + DBPath: filepath.Join(parent, "guard.db"), + }) + if err != nil { + t.Fatalf("Start() error = %v", err) + } + defer host.Close(context.Background()) + + info, err := os.Stat(parent) + if err != nil { + t.Fatalf("Stat() error = %v", err) + } + if got := info.Mode().Perm(); got != 0o755 { + t.Fatalf("custom DB parent mode = %o, want 755", got) + } +} + +func TestStartPersistsLocalDecisions(t *testing.T) { + ctx := context.Background() + dbPath := filepath.Join(t.TempDir(), "guard.db") + host, err := Start(ctx, Options{ + AgentName: "claude", + CWD: t.TempDir(), + DBPath: dbPath, + Diagnostic: diagnostic.New(nil, false), + }) + if err != nil { + t.Fatalf("Start() error = %v", err) + } + + client := localruntime.NewClient(host.SocketPath) + result, err := client.Process(ctx, hook.Event{ + Agent: "claude", + HookName: hook.HookPreToolUse, + ToolName: "Bash", + ToolInput: map[string]any{ + "command": "rm -rf /tmp/kontext-test", + }, + }) + if err != nil { + t.Fatalf("Process() error = %v", err) + } + if result.Decision == "" { + t.Fatal("result decision is empty") + } + sessionID := host.SessionID + if err := host.Close(ctx); err != nil { + t.Fatalf("Close() error = %v", err) + } + + store, err := sqlite.OpenStore(dbPath) + if err != nil { + t.Fatalf("OpenStore() error = %v", err) + } + defer store.Close() + events, err := store.Events(ctx, sessionID) + if err != nil { + t.Fatalf("Events() error = %v", err) + } + if len(events) != 1 { + t.Fatalf("events len = %d, want 1", len(events)) + } + if string(events[0].Decision) != string(result.Decision) { + t.Fatalf("stored decision = %q, want %q", events[0].Decision, result.Decision) + } +} + +func TestCloseDrainsAsyncTelemetryBeforeClosingSession(t *testing.T) { + ctx := context.Background() + dbPath := filepath.Join(t.TempDir(), "guard.db") + host, err := Start(ctx, Options{ + AgentName: "claude", + CWD: t.TempDir(), + DBPath: dbPath, + Diagnostic: diagnostic.New(nil, false), + }) + if err != nil { + t.Fatalf("Start() error = %v", err) + } + + client := localruntime.NewClient(host.SocketPath) + result, err := client.Process(ctx, hook.Event{ + Agent: "claude", + HookName: hook.HookPostToolUse, + ToolName: "Bash", + ToolInput: map[string]any{ + "command": "echo done", + }, + }) + if err != nil { + t.Fatalf("Process() error = %v", err) + } + if result.Decision != hook.DecisionAllow { + t.Fatalf("post-tool decision = %q, want allow ack", result.Decision) + } + sessionID := host.SessionID + if err := host.Close(ctx); err != nil { + t.Fatalf("Close() error = %v", err) + } + + store, err := sqlite.OpenStore(dbPath) + if err != nil { + t.Fatalf("OpenStore() error = %v", err) + } + defer store.Close() + events, err := store.Events(ctx, sessionID) + if err != nil { + t.Fatalf("Events() error = %v", err) + } + if len(events) != 1 { + t.Fatalf("events len = %d, want 1", len(events)) + } + session, err := store.Session(ctx, sessionID) + if err != nil { + t.Fatalf("Session() error = %v", err) + } + if session.Status != "closed" { + t.Fatalf("session status = %q, want closed", session.Status) + } +} + +func TestDashboardAddrRejectsNonLoopback(t *testing.T) { + tests := []string{ + "0.0.0.0:4765", + "[::]:4765", + ":4765", + } + for _, tt := range tests { + t.Run(tt, func(t *testing.T) { + if _, err := DashboardAddr(tt, false); err == nil { + t.Fatal("DashboardAddr() error = nil, want non-loopback error") + } + }) + } +} + +func TestDashboardAddrAllowsLoopback(t *testing.T) { + tests := []string{ + "127.0.0.1:0", + "localhost:0", + "[::1]:0", + } + for _, tt := range tests { + t.Run(tt, func(t *testing.T) { + if _, err := DashboardAddr(tt, false); err != nil { + t.Fatalf("DashboardAddr() error = %v", err) + } + }) + } +} + +func TestStartDashboardUsesLoopbackEphemeralPort(t *testing.T) { + host, err := Start(context.Background(), Options{ + AgentName: "claude", + CWD: t.TempDir(), + DBPath: filepath.Join(t.TempDir(), "guard.db"), + DashboardAddr: "127.0.0.1:0", + StartDashboard: true, + }) + if err != nil { + t.Fatalf("Start() error = %v", err) + } + defer host.Close(context.Background()) + + if host.DashboardURL == "" { + t.Fatal("DashboardURL is empty") + } + resp, err := http.Get(host.DashboardURL + "/healthz") + if err != nil { + t.Fatalf("GET healthz error = %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("healthz status = %d, want 200", resp.StatusCode) + } +} diff --git a/models/guard/embed.go b/models/guard/embed.go new file mode 100644 index 0000000..a4b98d2 --- /dev/null +++ b/models/guard/embed.go @@ -0,0 +1,8 @@ +package guardmodels + +import _ "embed" + +// CodingAgentV0 is the default local risk model shipped inside release builds. +// +//go:embed coding-agent-v0.json +var CodingAgentV0 []byte