diff --git a/README.md b/README.md index 10407485..77eb891f 100644 --- a/README.md +++ b/README.md @@ -129,6 +129,33 @@ root path. Services deployed to other paths on the same host will use the same TLS settings as those specified for the root path. +### On-demand TLS + +In addition to the automatic TLS functionality, Kamal Proxy can also dynamically obtain a TLS certificate +for any host allowed by an external API endpoint of your choice. This avoids hard-coding hosts in the configuration, especially when you don't know the hosts at startup. + + kamal-proxy deploy service1 --target web-1:3000 --host "" --tls --tls-on-demand-url="http://localhost:4567/check" + +The On-demand URL endpoint must return a 200 HTTP status code to allow certificate issuance. +Kamal Proxy will call the on-demand URL with a query string of `?host=` containing the host received by Kamal Proxy. + +- The HTTP request to the on-demand URL will time out after 2 seconds. If the endpoint is unreachable or slow, certificate issuance will fail for that host. +- If the endpoint returns any status other than 200, Kamal Proxy will log the status code and up to 256 bytes of the response body for debugging. +- **Security note:** The on-demand URL acts as an authorization gate for certificate issuance. It should be protected and only allow trusted hosts. If compromised, unauthorized certificates could be issued. +- If `--tls-on-demand-url` is not set, Kamal Proxy falls back to a static whitelist of hosts. + +**Best practice:** +- Ensure your on-demand endpoint is fast, reliable, and protected (e.g., behind authentication or on a private network). +- Only allow hosts you control to prevent abuse. + +Example endpoint logic (pseudo-code): + + if host in allowed_hosts: + return 200 OK + else: + return 403 Forbidden + + ### Custom TLS certificate When you obtained your TLS certificate manually, manage your own certificate authority, @@ -138,6 +165,29 @@ your certificate file and the corresponding private key: kamal-proxy deploy service1 --target web-1:3000 --host app1.example.com --tls --tls-certificate-path cert.pem --tls-private-key-path key.pem +## TLSOnDemandUrl Option + +The `TLSOnDemandUrl` option can be set to either: + +- **An external URL** (e.g., `https://my-allow-service/allow-host`): + - The service will make an HTTP request to this external URL to determine if a certificate should be issued for a given host. + +- **A local path** (e.g., `/allow-host`): + - The service will internally route a request to this path using its own load balancer and handler. You must ensure your service responds to this path appropriately. + +### Example: External URL +```yaml +TLSOnDemandUrl: "https://my-allow-service/allow-host" +``` + +### Example: Local Path +```yaml +TLSOnDemandUrl: "/allow-host" +``` + +When using a local path, your service should implement a handler for the specified path (e.g., `/allow-host`) that returns `200 OK` to allow certificate issuance, or another status code to deny it. + + ## Specifying `run` options with environment variables In some environments, like when running a Docker container, it can be convenient diff --git a/internal/cmd/deploy.go b/internal/cmd/deploy.go index edc98811..d34e0be7 100644 --- a/internal/cmd/deploy.go +++ b/internal/cmd/deploy.go @@ -34,6 +34,7 @@ func newDeployCommand() *deployCommand { deployCommand.cmd.Flags().BoolVar(&deployCommand.args.ServiceOptions.StripPrefix, "strip-path-prefix", true, "With --path-prefix, strip prefix from request before forwarding") deployCommand.cmd.Flags().BoolVar(&deployCommand.args.ServiceOptions.TLSEnabled, "tls", false, "Configure TLS for this target (requires a non-empty host)") + deployCommand.cmd.Flags().StringVar(&deployCommand.args.ServiceOptions.TLSOnDemandUrl, "tls-on-demand-url", "", "Will make an HTTP request to the given URL, asking whether a host is allowed to have a certificate issued.") deployCommand.cmd.Flags().BoolVar(&deployCommand.tlsStaging, "tls-staging", false, "Use Let's Encrypt staging environment for certificate provisioning") deployCommand.cmd.Flags().StringVar(&deployCommand.args.ServiceOptions.TLSCertificatePath, "tls-certificate-path", "", "Configure custom TLS certificate path (PEM format)") deployCommand.cmd.Flags().StringVar(&deployCommand.args.ServiceOptions.TLSPrivateKeyPath, "tls-private-key-path", "", "Configure custom TLS private key path (PEM format)") @@ -100,6 +101,11 @@ func (c *deployCommand) preRun(cmd *cobra.Command, args []string) error { } if c.args.ServiceOptions.TLSEnabled { + if c.args.ServiceOptions.TLSOnDemandUrl != "" { + c.args.ServiceOptions.Hosts = []string{""} + return nil + } + if len(c.args.ServiceOptions.Hosts) == 0 { return fmt.Errorf("host must be set when using TLS") } diff --git a/internal/cmd/deploy_test.go b/internal/cmd/deploy_test.go new file mode 100644 index 00000000..f6c4e84f --- /dev/null +++ b/internal/cmd/deploy_test.go @@ -0,0 +1,45 @@ +package cmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDeployCommand_preRun_TLSOnDemandUrl(t *testing.T) { + t.Run("TLS enabled with TLS on-demand URL should set hosts to empty string", func(t *testing.T) { + deployCmd := newDeployCommand() + + // Set flags for TLS with on-demand URL + deployCmd.cmd.Flags().Set("target", "http://localhost:8080") + deployCmd.cmd.Flags().Set("tls", "true") + deployCmd.cmd.Flags().Set("tls-on-demand-url", "http://example.com/validate") + deployCmd.cmd.Flags().Set("host", "example.com") + deployCmd.cmd.Flags().Set("path-prefix", "/") + + // Call preRun + err := deployCmd.preRun(deployCmd.cmd, []string{"test-service"}) + require.NoError(t, err) + + // Verify that hosts is set to empty string + assert.Equal(t, []string{""}, deployCmd.args.ServiceOptions.Hosts) + }) + + t.Run("TLS enabled without TLS on-demand URL should not modify hosts", func(t *testing.T) { + deployCmd := newDeployCommand() + + // Set flags for TLS without on-demand URL + deployCmd.cmd.Flags().Set("target", "http://localhost:8080") + deployCmd.cmd.Flags().Set("tls", "true") + deployCmd.cmd.Flags().Set("host", "example.com") + deployCmd.cmd.Flags().Set("path-prefix", "/") + + // Call preRun + err := deployCmd.preRun(deployCmd.cmd, []string{"test-service"}) + require.NoError(t, err) + + // Verify that hosts is not modified + assert.Equal(t, []string{"example.com"}, deployCmd.args.ServiceOptions.Hosts) + }) +} diff --git a/internal/server/service.go b/internal/server/service.go index c5132f36..acb809be 100644 --- a/internal/server/service.go +++ b/internal/server/service.go @@ -74,6 +74,7 @@ type ServiceOptions struct { TLSEnabled bool `json:"tls_enabled"` TLSCertificatePath string `json:"tls_certificate_path"` TLSPrivateKeyPath string `json:"tls_private_key_path"` + TLSOnDemandUrl string `json:"tls_on_demand_url"` TLSRedirect bool `json:"tls_redirect"` ACMEDirectory string `json:"acme_directory"` ACMECachePath string `json:"acme_cache_path"` @@ -384,10 +385,15 @@ func (s *Service) createCertManager(options ServiceOptions) (CertManager, error) } } + hostPolicy, err := NewTLSOnDemandChecker(s).HostPolicy() + if err != nil { + return nil, err + } + return &autocert.Manager{ Prompt: autocert.AcceptTOS, Cache: autocert.DirCache(options.ScopedCachePath()), - HostPolicy: autocert.HostWhitelist(options.Hosts...), + HostPolicy: hostPolicy, Client: &acme.Client{DirectoryURL: options.ACMEDirectory}, }, nil } diff --git a/internal/server/service_test.go b/internal/server/service_test.go index de859628..7cf2c366 100644 --- a/internal/server/service_test.go +++ b/internal/server/service_test.go @@ -224,6 +224,7 @@ func testCreateService(t *testing.T, options ServiceOptions, targetOptions Targe func testCreateServiceWithHandler(t *testing.T, options ServiceOptions, targetOptions TargetOptions, handler http.Handler) *Service { server := httptest.NewServer(handler) + t.Cleanup(server.Close) serverURL, err := url.Parse(server.URL) diff --git a/internal/server/tls_on_demand.go b/internal/server/tls_on_demand.go new file mode 100644 index 00000000..c0d3562e --- /dev/null +++ b/internal/server/tls_on_demand.go @@ -0,0 +1,105 @@ +package server + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "time" + + "log/slog" + + "golang.org/x/crypto/acme/autocert" +) + +type TLSOnDemandChecker struct { + service *Service + options ServiceOptions +} + +func NewTLSOnDemandChecker(service *Service) *TLSOnDemandChecker { + return &TLSOnDemandChecker{ + service: service, + options: service.options, + } +} + +func (c *TLSOnDemandChecker) HostPolicy() (autocert.HostPolicy, error) { + if c.options.TLSOnDemandUrl == "" { + return autocert.HostWhitelist(c.options.Hosts...), nil + } + + // If the URL starts with '/', treat it as a local path + if len(c.options.TLSOnDemandUrl) > 0 && c.options.TLSOnDemandUrl[0] == '/' { + return c.LocalHostPolicy(), nil + } + + // Otherwise, treat as external URL + _, err := url.ParseRequestURI(c.options.TLSOnDemandUrl) + + if err != nil { + slog.Error("Unable to parse the tls_on_demand_url URL") + return nil, err + } + + return c.ExternalHostPolicy(), nil +} + +func (c *TLSOnDemandChecker) LocalHostPolicy() autocert.HostPolicy { + return func(ctx context.Context, host string) error { + path := c.buildURLOrPath(host) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, path, http.NoBody) + if err != nil { + return err + } + + // We use httptest.NewRecorder here to route the request through the service's + // load balancer and handler, capturing the response in-memory without making + // a real network request. This ensures the request is processed as if it were + // an external client, but avoids network overhead and complexity. + recorder := httptest.NewRecorder() + c.service.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusOK { + body := recorder.Body.String() + + if len(body) > 256 { + body = body[:256] + } + + return c.handleError(host, recorder.Code, body) + } + return nil + } +} + +func (c *TLSOnDemandChecker) ExternalHostPolicy() autocert.HostPolicy { + return func(ctx context.Context, host string) error { + client := &http.Client{Timeout: 2 * time.Second} + url := c.buildURLOrPath(host) + resp, err := client.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body := make([]byte, 256) + n, _ := resp.Body.Read(body) + bodyStr := string(body[:n]) + return c.handleError(host, resp.StatusCode, bodyStr) + } + return nil + } +} + +func (c *TLSOnDemandChecker) buildURLOrPath(host string) string { + return fmt.Sprintf("%s?host=%s", c.options.TLSOnDemandUrl, url.QueryEscape(host)) +} + +func (c *TLSOnDemandChecker) handleError(host string, status int, body string) error { + slog.Warn("TLS on demand denied host", "host", host, "status", status, "body", body) + + return fmt.Errorf("%s is not allowed to get a certificate (status: %d, body: \"%s\")", host, status, body) +} diff --git a/internal/server/tls_on_demand_test.go b/internal/server/tls_on_demand_test.go new file mode 100644 index 00000000..3896925a --- /dev/null +++ b/internal/server/tls_on_demand_test.go @@ -0,0 +1,174 @@ +package server + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTLSOnDemandChecker_HostPolicy_EmptyURL(t *testing.T) { + service := &Service{options: ServiceOptions{Hosts: []string{"example.com"}}} + checker := NewTLSOnDemandChecker(service) + + policy, _ := checker.HostPolicy() + + // Should allow hosts in the whitelist + err := policy(context.Background(), "example.com") + assert.NoError(t, err) + + // Should deny hosts not in the whitelist + err = policy(context.Background(), "other.com") + assert.Error(t, err) +} + +func TestTLSOnDemandChecker_LocalHostPolicy_Success(t *testing.T) { + // Create a mock service that returns 200 for /allow-host + service := &Service{ + options: ServiceOptions{TLSOnDemandUrl: "/allow-host"}, + } + service.middleware = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/allow-host" && r.URL.Query().Get("host") == "test.example.com" { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusForbidden) + } + }) + + checker := NewTLSOnDemandChecker(service) + policy := checker.LocalHostPolicy() + + err := policy(context.Background(), "test.example.com") + assert.NoError(t, err) +} + +func TestTLSOnDemandChecker_LocalHostPolicy_Denied(t *testing.T) { + // Create a mock service that returns 403 for /allow-host + service := &Service{ + options: ServiceOptions{TLSOnDemandUrl: "/allow-host"}, + } + service.middleware = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + w.Write([]byte("Access denied")) + }) + + checker := NewTLSOnDemandChecker(service) + policy := checker.LocalHostPolicy() + + err := policy(context.Background(), "test.example.com") + assert.Error(t, err) + assert.Contains(t, err.Error(), "not allowed to get a certificate") + assert.Contains(t, err.Error(), "status: 403") +} + +func TestTLSOnDemandChecker_LocalHostPolicy_LargeResponseBody(t *testing.T) { + // Create a mock service that returns a large response body + largeBody := string(make([]byte, 500)) // 500 bytes + + service := &Service{ + options: ServiceOptions{TLSOnDemandUrl: "/allow-host"}, + } + service.middleware = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + w.Write([]byte(largeBody)) + }) + + checker := NewTLSOnDemandChecker(service) + policy := checker.LocalHostPolicy() + + err := policy(context.Background(), "test.example.com") + assert.Error(t, err) + assert.Contains(t, err.Error(), "status: 403") + + // Verify the body is truncated to 256 bytes + assert.Len(t, err.Error(), 256+len("test.example.com is not allowed to get a certificate (status: 403, body: \"")+len("\")")) +} + +func TestTLSOnDemandChecker_ExternalHostPolicy_Success(t *testing.T) { + // Create a test server that returns 200 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("host") == "test.example.com" { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusForbidden) + } + })) + defer server.Close() + + service := &Service{options: ServiceOptions{TLSOnDemandUrl: server.URL}} + checker := NewTLSOnDemandChecker(service) + policy := checker.ExternalHostPolicy() + + err := policy(context.Background(), "test.example.com") + assert.NoError(t, err) +} + +func TestTLSOnDemandChecker_ExternalHostPolicy_Denied(t *testing.T) { + // Create a test server that returns 403 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + w.Write([]byte("Access denied")) + })) + defer server.Close() + + service := &Service{options: ServiceOptions{TLSOnDemandUrl: server.URL}} + checker := NewTLSOnDemandChecker(service) + policy := checker.ExternalHostPolicy() + + err := policy(context.Background(), "test.example.com") + assert.Error(t, err) + assert.Contains(t, err.Error(), "not allowed to get a certificate") + assert.Contains(t, err.Error(), "status: 403") +} + +func TestTLSOnDemandChecker_HostPolicy_LocalPath(t *testing.T) { + service := &Service{ + options: ServiceOptions{TLSOnDemandUrl: "/allow-host"}, + } + service.middleware = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + checker := NewTLSOnDemandChecker(service) + policy, _ := checker.HostPolicy() + + err := policy(context.Background(), "test.example.com") + assert.NoError(t, err) +} + +func TestTLSOnDemandChecker_HostPolicy_ExternalURL(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + service := &Service{options: ServiceOptions{TLSOnDemandUrl: server.URL}} + checker := NewTLSOnDemandChecker(service) + policy, _ := checker.HostPolicy() + + err := policy(context.Background(), "test.example.com") + assert.NoError(t, err) +} + +func TestTLSOnDemandChecker_HostPolicy_InvalidExternalURL(t *testing.T) { + service := &Service{options: ServiceOptions{TLSOnDemandUrl: "://invalid-url"}} + checker := NewTLSOnDemandChecker(service) + _, err := checker.HostPolicy() + + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing protocol scheme") +} + +func TestTLSOnDemandChecker_buildURLOrPath(t *testing.T) { + service := &Service{options: ServiceOptions{TLSOnDemandUrl: "/allow-host"}} + checker := NewTLSOnDemandChecker(service) + + url := checker.buildURLOrPath("test.example.com") + assert.Equal(t, "/allow-host?host=test.example.com", url) + + // Test with special characters + url = checker.buildURLOrPath("test.example.com:8080") + assert.Equal(t, "/allow-host?host=test.example.com%3A8080", url) +}