Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions cmd/kontext/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,10 @@ func startCmd() *cobra.Command {
RunE: func(cmd *cobra.Command, args []string) error {
if isInteractivePrompt() {
if latest := update.Available(version); latest != "" {
upgraded, _ := update.PromptAndUpgrade(os.Stdin, os.Stderr, version, latest)
if upgraded {
upgraded, err := update.PromptAndUpgrade(os.Stdin, os.Stderr, version, latest)
if err != nil {
fmt.Fprintf(os.Stderr, "Update prompt failed: %v\n", err)
} else if upgraded {
return nil
}
}
Expand Down Expand Up @@ -269,24 +271,29 @@ func evaluateViaSidecar(socketPath string, event hook.Event) (hook.Result, error
func evaluateViaSidecarForMode(socketPath string, event hook.Event, mode string) (hook.Result, error) {
conn, err := net.DialTimeout("unix", socketPath, 5*time.Second)
if err != nil {
warnHookError(fmt.Sprintf("sidecar dial %q", socketPath), err)
return sidecarFailureResult(event, "sidecar unreachable", mode), nil
}
defer conn.Close()
if err := conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
warnHookError("sidecar set deadline", err)
return sidecarFailureResult(event, "sidecar deadline error", mode), nil
}

req, err := localruntime.EvaluateRequestFromEvent(event)
if err != nil {
warnHookError("sidecar marshal request", err)
return sidecarFailureResult(event, "sidecar marshal error", mode), nil
}

if err := localruntime.WriteMessage(conn, req); err != nil {
warnHookError("sidecar write request", err)
return sidecarFailureResult(event, "sidecar write error", mode), nil
}

var result localruntime.EvaluateResult
if err := localruntime.ReadMessage(conn, &result); err != nil {
warnHookError("sidecar read response", err)
return sidecarFailureResult(event, "sidecar read error", mode), nil
}

Expand Down Expand Up @@ -322,6 +329,7 @@ func currentHostedAccessMode() string {
if modePath := os.Getenv("KONTEXT_ACCESS_MODE_PATH"); modePath != "" {
data, err := os.ReadFile(modePath)
if err != nil {
warnHookError(fmt.Sprintf("read KONTEXT_ACCESS_MODE_PATH %q", modePath), err)
return "enforce"
}
if mode := normalizedHostedAccessMode(string(data)); mode != "" {
Expand Down Expand Up @@ -358,3 +366,10 @@ func isCharDevice(f *os.File) bool {
}
return info.Mode()&os.ModeCharDevice != 0
}

func warnHookError(operation string, err error) {
if err == nil {
return
}
fmt.Fprintf(os.Stderr, "kontext: %s: %v\n", operation, err)
}
35 changes: 17 additions & 18 deletions internal/guard/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.
}
}

func newFlagSet(name string) *flag.FlagSet {
fs := flag.NewFlagSet(name, flag.ContinueOnError)
fs.SetOutput(io.Discard)
return fs
}

func usage(out io.Writer) {
fmt.Fprintln(out, "Kontext Guard")
fmt.Fprintln(out, "")
Expand Down Expand Up @@ -96,13 +102,12 @@ func runDaemon(ctx context.Context, args []string, out io.Writer) error {
if err != nil {
return err
}
fs := flag.NewFlagSet("start", flag.ContinueOnError)
fs.SetOutput(io.Discard)
fs := newFlagSet("start")
addr := fs.String("addr", envString("KONTEXT_ADDR", server.DefaultAddr), "listen address")
dbPath := fs.String("db", envString("KONTEXT_DB", defaultDBPath()), "SQLite database path")
skipHookInstall := fs.Bool("skip-hook-install", false, "skip Claude Code hook install")
noOpen := fs.Bool("no-open", false, "do not open the local dashboard")
socketPath := fs.String("socket", defaultGuardSocketPath(), "Unix socket path for local hook runtime")
socketPath := fs.String("socket", localruntime.DefaultSocketPath(), "Unix socket path for local hook runtime")
judgeURL := fs.String("judge-url", envString("KONTEXT_JUDGE_URL", ""), "OpenAI-compatible local judge base URL, for example http://127.0.0.1:8080")
judgeModel := fs.String("judge-model", envString("KONTEXT_JUDGE_MODEL", ""), "local judge model name; with --judge-managed this may be a local GGUF path")
judgeTimeout := fs.Duration("judge-timeout", defaultJudgeTimeout, "local judge timeout")
Expand Down Expand Up @@ -153,7 +158,7 @@ func runDaemon(ctx context.Context, args []string, out io.Writer) error {
defer func() {
_ = closeStore()
}()
if err := ensureGuardSocketDir(*socketPath); err != nil {
if err := localruntime.EnsureSocketDir(*socketPath); err != nil {
return err
}
runtimeService, err := localruntime.NewService(localruntime.Options{
Expand Down Expand Up @@ -356,8 +361,7 @@ func verifyClaudeCode() error {
}

func runStatus(ctx context.Context, args []string, out io.Writer) error {
fs := flag.NewFlagSet("status", flag.ContinueOnError)
fs.SetOutput(io.Discard)
fs := newFlagSet("status")
baseURL := fs.String("daemon-url", envString("KONTEXT_DAEMON_URL", defaultBaseURL), "local daemon URL")
if err := fs.Parse(args); err != nil {
return err
Expand All @@ -375,8 +379,7 @@ func runStatus(ctx context.Context, args []string, out io.Writer) error {
}

func runDashboard(args []string, out io.Writer) error {
fs := flag.NewFlagSet("dashboard", flag.ContinueOnError)
fs.SetOutput(io.Discard)
fs := newFlagSet("dashboard")
baseURL := fs.String("daemon-url", envString("KONTEXT_DAEMON_URL", defaultBaseURL), "local daemon URL")
noOpen := fs.Bool("no-open", false, "print URL without opening a browser")
if err := fs.Parse(args); err != nil {
Expand Down Expand Up @@ -410,8 +413,7 @@ func fetchSummary(ctx context.Context, baseURL string) (sqlite.Summary, error) {
}

func runDoctor(ctx context.Context, args []string, out io.Writer) error {
fs := flag.NewFlagSet("doctor", flag.ContinueOnError)
fs.SetOutput(io.Discard)
fs := newFlagSet("doctor")
baseURL := fs.String("daemon-url", envString("KONTEXT_DAEMON_URL", defaultBaseURL), "local daemon URL")
if err := fs.Parse(args); err != nil {
return err
Expand Down Expand Up @@ -508,9 +510,8 @@ func runHooks(args []string, out io.Writer) error {
}
switch args[0] {
case "install":
fs := flag.NewFlagSet("hooks install claude-code", flag.ContinueOnError)
fs.SetOutput(io.Discard)
socketPath := fs.String("socket", defaultGuardSocketPath(), "Unix socket path for local hook runtime")
fs := newFlagSet("hooks install claude-code")
socketPath := fs.String("socket", localruntime.DefaultSocketPath(), "Unix socket path for local hook runtime")
if err := fs.Parse(args[2:]); err != nil {
return err
}
Expand Down Expand Up @@ -672,8 +673,7 @@ func isGuardHookCommand(command string) bool {
}

func runSmokeTest(ctx context.Context, args []string, out io.Writer) error {
fs := flag.NewFlagSet("smoke-test", flag.ContinueOnError)
fs.SetOutput(io.Discard)
fs := newFlagSet("smoke-test")
if err := fs.Parse(args); err != nil {
return err
}
Expand Down Expand Up @@ -768,8 +768,7 @@ func runJudgeEval(ctx context.Context, args []string, out io.Writer) error {
if err != nil {
return err
}
fs := flag.NewFlagSet("judge eval", flag.ContinueOnError)
fs.SetOutput(io.Discard)
fs := newFlagSet("judge eval")
judgeURL := fs.String("judge-url", envString("KONTEXT_JUDGE_URL", ""), "OpenAI-compatible local judge base URL")
judgeModel := fs.String("judge-model", envString("KONTEXT_JUDGE_MODEL", ""), "local judge model name")
judgeTimeout := fs.Duration("judge-timeout", defaultJudgeTimeout, "local judge timeout")
Expand Down Expand Up @@ -922,7 +921,7 @@ func installedHookCommand(socketPath string) string {
return command
}
if strings.TrimSpace(socketPath) == "" {
socketPath = defaultGuardSocketPath()
socketPath = localruntime.DefaultSocketPath()
}
path := selfPath()
if strings.Contains(path, "go-build") {
Expand Down
13 changes: 0 additions & 13 deletions internal/guard/cli/localruntime.go

This file was deleted.

10 changes: 6 additions & 4 deletions internal/guard/judge/llama_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func TestStartLlamaServerHealthCheckAndStop(t *testing.T) {
BinaryPath: binaryPath,
ModelPath: modelPath,
Port: port,
StartupTimeout: 2 * time.Second,
StartupTimeout: 5 * time.Second,
})
if err != nil {
t.Fatal(err)
Expand All @@ -108,18 +108,20 @@ func TestStartLlamaServerHealthCheckAndStop(t *testing.T) {
func TestStartLlamaServerEarlyExitDoesNotWaitForStopTimeout(t *testing.T) {
modelPath := writeTestModel(t)
binaryPath := writeFakeExitingLlamaServer(t)
startupTimeout := 5 * time.Second
maxWait := 3 * time.Second
start := time.Now()
_, err := StartLlamaServer(context.Background(), LlamaServerOptions{
BinaryPath: binaryPath,
ModelPath: modelPath,
Port: freeTCPPort(t),
StartupTimeout: 2 * time.Second,
StartupTimeout: startupTimeout,
})
if err == nil {
t.Fatal("StartLlamaServer() error = nil, want early exit error")
}
if elapsed := time.Since(start); elapsed > time.Second {
t.Fatalf("early exit took %s, want less than 1s", elapsed)
if elapsed := time.Since(start); elapsed > maxWait {
t.Fatalf("early exit took %s, want less than %s", elapsed, maxWait)
}
}

Expand Down

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion internal/guard/web/assets/dist/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Kontext Guard</title>
<script type="module" crossorigin src="/assets/index-BpP_w0XF.js"></script>
<script type="module" crossorigin src="/assets/index-CyKZwYkD.js"></script>
<link rel="stylesheet" crossorigin href="/assets/index-BRMO20IK.css">
</head>
<body>
Expand Down
24 changes: 6 additions & 18 deletions web/guard-dashboard/src/dashboard/api.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import { API } from "./config";
import type { Decision, Event, PolicyProfile, PolicyProfileID, RiskEvent, Session } from "./types";
import { isDecision, isPolicyProfileID, type Decision, type Event, type PolicyProfile, type PolicyProfileID, type RiskEvent, type Session } from "./types";

export function errorMessage(error: unknown): string {
return error instanceof Error ? error.message : String(error);
}

async function responseJSON(r: Response): Promise<unknown> {
return r.json();
// `Response.json()` is typed as `any` in lib.dom. Force it back to `unknown` at the boundary.
const body: unknown = await r.json();
return body;
}

async function ok(r: Response): Promise<unknown> {
Expand Down Expand Up @@ -51,25 +53,11 @@ function stringList(value: unknown): string[] | undefined {
}

function decision(value: unknown): Decision | undefined {
switch (value) {
case "allow":
case "ask":
case "deny":
return value;
default:
return undefined;
}
return isDecision(value) ? value : undefined;
}

function policyProfileID(value: unknown): PolicyProfileID | undefined {
switch (value) {
case "relaxed":
case "balanced":
case "strict":
return value;
default:
return undefined;
}
return isPolicyProfileID(value) ? value : undefined;
}

function parseRiskEvent(value: unknown): RiskEvent | undefined {
Expand Down
20 changes: 17 additions & 3 deletions web/guard-dashboard/src/dashboard/types.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
export type Decision = "allow" | "ask" | "deny";
export const DECISION_VALUES = ["allow", "ask", "deny"] as const;
export type Decision = (typeof DECISION_VALUES)[number];

export type Tab = "all" | "deny" | "ask" | "allow";
export const TAB_VALUES = ["all", ...DECISION_VALUES] as const;
export type Tab = (typeof TAB_VALUES)[number];

export type PolicyProfileID = "relaxed" | "balanced" | "strict";
export const POLICY_PROFILE_ID_VALUES = ["relaxed", "balanced", "strict"] as const;
export type PolicyProfileID = (typeof POLICY_PROFILE_ID_VALUES)[number];

const decisionSet: ReadonlySet<string> = new Set(DECISION_VALUES);
const policyProfileIDSet: ReadonlySet<string> = new Set(POLICY_PROFILE_ID_VALUES);

export function isDecision(value: unknown): value is Decision {
return typeof value === "string" && decisionSet.has(value);
}

export function isPolicyProfileID(value: unknown): value is PolicyProfileID {
return typeof value === "string" && policyProfileIDSet.has(value);
}

export type RiskEvent = {
type?: string;
Expand Down
Loading