diff --git a/cmd/litestream/main.go b/cmd/litestream/main.go index 3bf00b19..d229fcb7 100644 --- a/cmd/litestream/main.go +++ b/cmd/litestream/main.go @@ -185,6 +185,9 @@ The commands are: // Config represents a configuration file for the litestream daemon. type Config struct { + // Global replica settings that serve as defaults for all replicas + ReplicaSettings `yaml:",inline"` + // Bind address for serving metrics. Addr string `yaml:"addr"` @@ -202,10 +205,6 @@ type Config struct { // Litestream will shutdown when subcommand exits. Exec string `yaml:"exec"` - // Global S3 settings - AccessKeyID string `yaml:"access-key-id"` - SecretAccessKey string `yaml:"secret-access-key"` - // Logging Logging LoggingConfig `yaml:"logging"` @@ -230,16 +229,15 @@ type LoggingConfig struct { Stderr bool `yaml:"stderr"` } -// propagateGlobalSettings copies global S3 settings to replica configs. +// propagateGlobalSettings copies global replica settings to individual replica configs. func (c *Config) propagateGlobalSettings() { for _, dbc := range c.DBs { + // Handle both old-style 'replicas' and new-style 'replica' + if dbc.Replica != nil { + dbc.Replica.SetDefaults(&c.ReplicaSettings) + } for _, rc := range dbc.Replicas { - if rc.AccessKeyID == "" { - rc.AccessKeyID = c.AccessKeyID - } - if rc.SecretAccessKey == "" { - rc.SecretAccessKey = c.SecretAccessKey - } + rc.SetDefaults(&c.ReplicaSettings) } } } @@ -533,12 +531,9 @@ func NewDBFromConfig(dbc *DBConfig) (*litestream.DB, error) { return db, nil } -// ReplicaConfig represents the configuration for a single replica in a database. -type ReplicaConfig struct { - Type string `yaml:"type"` // "file", "s3" - Name string `yaml:"name"` // Deprecated - Path string `yaml:"path"` - URL string `yaml:"url"` +// ReplicaSettings contains settings shared across replica configurations. +// These can be set globally in Config or per-replica in ReplicaConfig. +type ReplicaSettings struct { SyncInterval *time.Duration `yaml:"sync-interval"` ValidationInterval *time.Duration `yaml:"validation-interval"` @@ -584,6 +579,129 @@ type ReplicaConfig struct { } `yaml:"age"` } +// SetDefaults merges default settings from src into the current ReplicaSettings. +// Individual settings override defaults when already set. +func (rs *ReplicaSettings) SetDefaults(src *ReplicaSettings) { + if src == nil { + return + } + + // Timing settings + if rs.SyncInterval == nil && src.SyncInterval != nil { + rs.SyncInterval = src.SyncInterval + } + if rs.ValidationInterval == nil && src.ValidationInterval != nil { + rs.ValidationInterval = src.ValidationInterval + } + + // S3 settings + if rs.AccessKeyID == "" { + rs.AccessKeyID = src.AccessKeyID + } + if rs.SecretAccessKey == "" { + rs.SecretAccessKey = src.SecretAccessKey + } + if rs.Region == "" { + rs.Region = src.Region + } + if rs.Bucket == "" { + rs.Bucket = src.Bucket + } + if rs.Endpoint == "" { + rs.Endpoint = src.Endpoint + } + if rs.ForcePathStyle == nil { + rs.ForcePathStyle = src.ForcePathStyle + } + if src.SkipVerify { + rs.SkipVerify = true + } + + // ABS settings + if rs.AccountName == "" { + rs.AccountName = src.AccountName + } + if rs.AccountKey == "" { + rs.AccountKey = src.AccountKey + } + + // SFTP settings + if rs.Host == "" { + rs.Host = src.Host + } + if rs.User == "" { + rs.User = src.User + } + if rs.Password == "" { + rs.Password = src.Password + } + if rs.KeyPath == "" { + rs.KeyPath = src.KeyPath + } + if rs.ConcurrentWrites == nil { + rs.ConcurrentWrites = src.ConcurrentWrites + } + + // NATS settings + if rs.JWT == "" { + rs.JWT = src.JWT + } + if rs.Seed == "" { + rs.Seed = src.Seed + } + if rs.Creds == "" { + rs.Creds = src.Creds + } + if rs.NKey == "" { + rs.NKey = src.NKey + } + if rs.Username == "" { + rs.Username = src.Username + } + if rs.Token == "" { + rs.Token = src.Token + } + if !rs.TLS { + rs.TLS = src.TLS + } + if len(rs.RootCAs) == 0 { + rs.RootCAs = src.RootCAs + } + if rs.ClientCert == "" { + rs.ClientCert = src.ClientCert + } + if rs.ClientKey == "" { + rs.ClientKey = src.ClientKey + } + if rs.MaxReconnects == nil { + rs.MaxReconnects = src.MaxReconnects + } + if rs.ReconnectWait == nil { + rs.ReconnectWait = src.ReconnectWait + } + if rs.Timeout == nil { + rs.Timeout = src.Timeout + } + + // Age encryption settings + if len(rs.Age.Identities) == 0 { + rs.Age.Identities = src.Age.Identities + } + if len(rs.Age.Recipients) == 0 { + rs.Age.Recipients = src.Age.Recipients + } +} + +// ReplicaConfig represents the configuration for a single replica in a database. +type ReplicaConfig struct { + ReplicaSettings `yaml:",inline"` + + Type string `yaml:"type"` // "file", "s3" + Name string `yaml:"name"` // Deprecated + Path string `yaml:"path"` + URL string `yaml:"url"` +} + // NewReplicaFromConfig instantiates a replica for a DB based on a config. func NewReplicaFromConfig(c *ReplicaConfig, db *litestream.DB) (_ *litestream.Replica, err error) { // Ensure user did not specify URL in path. diff --git a/cmd/litestream/main_test.go b/cmd/litestream/main_test.go index dadbc323..f9c02150 100644 --- a/cmd/litestream/main_test.go +++ b/cmd/litestream/main_test.go @@ -622,3 +622,246 @@ func TestConfig_DefaultValues(t *testing.T) { t.Errorf("expected default snapshot retention of 24h, got %v", *config.Snapshot.Retention) } } + +func TestGlobalDefaults(t *testing.T) { + // Test comprehensive global defaults functionality + t.Run("GlobalReplicaDefaults", func(t *testing.T) { + filename := filepath.Join(t.TempDir(), "litestream.yml") + syncInterval := "30s" + validationInterval := "1h" + + if err := os.WriteFile(filename, []byte(` +# Global defaults for all replicas +access-key-id: GLOBAL_ACCESS_KEY +secret-access-key: GLOBAL_SECRET_KEY +region: us-west-2 +endpoint: custom.s3.endpoint.com +sync-interval: `+syncInterval+` +validation-interval: `+validationInterval+` + +dbs: + # Database 1: Uses all global defaults + - path: /tmp/db1.sqlite + replica: + type: s3 + bucket: my-bucket-1 + + # Database 2: Overrides some defaults + - path: /tmp/db2.sqlite + replica: + type: s3 + bucket: my-bucket-2 + region: us-east-1 # Override global region + access-key-id: CUSTOM_KEY # Override global access key + + # Database 3: Uses legacy replicas format + - path: /tmp/db3.sqlite + replicas: + - type: s3 + bucket: my-bucket-3 + # Should inherit all other global settings +`[1:]), 0666); err != nil { + t.Fatal(err) + } + + config, err := main.ReadConfigFile(filename, true) + if err != nil { + t.Fatal(err) + } + + // Test global settings were parsed correctly + if got, want := config.AccessKeyID, "GLOBAL_ACCESS_KEY"; got != want { + t.Errorf("config.AccessKeyID=%v, want %v", got, want) + } + if got, want := config.SecretAccessKey, "GLOBAL_SECRET_KEY"; got != want { + t.Errorf("config.SecretAccessKey=%v, want %v", got, want) + } + if got, want := config.Region, "us-west-2"; got != want { + t.Errorf("config.Region=%v, want %v", got, want) + } + if got, want := config.Endpoint, "custom.s3.endpoint.com"; got != want { + t.Errorf("config.Endpoint=%v, want %v", got, want) + } + + // Parse expected intervals + expectedSyncInterval, err := time.ParseDuration(syncInterval) + if err != nil { + t.Fatal(err) + } + expectedValidationInterval, err := time.ParseDuration(validationInterval) + if err != nil { + t.Fatal(err) + } + + if config.SyncInterval == nil || *config.SyncInterval != expectedSyncInterval { + t.Errorf("config.SyncInterval=%v, want %v", config.SyncInterval, expectedSyncInterval) + } + if config.ValidationInterval == nil || *config.ValidationInterval != expectedValidationInterval { + t.Errorf("config.ValidationInterval=%v, want %v", config.ValidationInterval, expectedValidationInterval) + } + + // Test Database 1: Should inherit all global defaults + db1 := config.DBs[0] + if db1.Replica == nil { + t.Fatal("db1.Replica is nil") + } + replica1 := db1.Replica + + if got, want := replica1.AccessKeyID, "GLOBAL_ACCESS_KEY"; got != want { + t.Errorf("replica1.AccessKeyID=%v, want %v", got, want) + } + if got, want := replica1.SecretAccessKey, "GLOBAL_SECRET_KEY"; got != want { + t.Errorf("replica1.SecretAccessKey=%v, want %v", got, want) + } + if got, want := replica1.Region, "us-west-2"; got != want { + t.Errorf("replica1.Region=%v, want %v", got, want) + } + if got, want := replica1.Endpoint, "custom.s3.endpoint.com"; got != want { + t.Errorf("replica1.Endpoint=%v, want %v", got, want) + } + if got, want := replica1.Bucket, "my-bucket-1"; got != want { + t.Errorf("replica1.Bucket=%v, want %v", got, want) + } + if replica1.SyncInterval == nil || *replica1.SyncInterval != expectedSyncInterval { + t.Errorf("replica1.SyncInterval=%v, want %v", replica1.SyncInterval, expectedSyncInterval) + } + if replica1.ValidationInterval == nil || *replica1.ValidationInterval != expectedValidationInterval { + t.Errorf("replica1.ValidationInterval=%v, want %v", replica1.ValidationInterval, expectedValidationInterval) + } + + // Test Database 2: Should override some defaults + db2 := config.DBs[1] + if db2.Replica == nil { + t.Fatal("db2.Replica is nil") + } + replica2 := db2.Replica + + if got, want := replica2.AccessKeyID, "CUSTOM_KEY"; got != want { + t.Errorf("replica2.AccessKeyID=%v, want %v", got, want) + } + if got, want := replica2.SecretAccessKey, "GLOBAL_SECRET_KEY"; got != want { + t.Errorf("replica2.SecretAccessKey=%v, want %v", got, want) + } + if got, want := replica2.Region, "us-east-1"; got != want { + t.Errorf("replica2.Region=%v, want %v", got, want) + } + if got, want := replica2.Endpoint, "custom.s3.endpoint.com"; got != want { + t.Errorf("replica2.Endpoint=%v, want %v", got, want) + } + if got, want := replica2.Bucket, "my-bucket-2"; got != want { + t.Errorf("replica2.Bucket=%v, want %v", got, want) + } + + // Test Database 3: Legacy replicas format should work + db3 := config.DBs[2] + if len(db3.Replicas) != 1 { + t.Fatalf("db3.Replicas length=%v, want 1", len(db3.Replicas)) + } + replica3 := db3.Replicas[0] + + if got, want := replica3.AccessKeyID, "GLOBAL_ACCESS_KEY"; got != want { + t.Errorf("replica3.AccessKeyID=%v, want %v", got, want) + } + if got, want := replica3.SecretAccessKey, "GLOBAL_SECRET_KEY"; got != want { + t.Errorf("replica3.SecretAccessKey=%v, want %v", got, want) + } + if got, want := replica3.Region, "us-west-2"; got != want { + t.Errorf("replica3.Region=%v, want %v", got, want) + } + if got, want := replica3.Endpoint, "custom.s3.endpoint.com"; got != want { + t.Errorf("replica3.Endpoint=%v, want %v", got, want) + } + if got, want := replica3.Bucket, "my-bucket-3"; got != want { + t.Errorf("replica3.Bucket=%v, want %v", got, want) + } + }) + + // Test different replica types inherit appropriate defaults + t.Run("MultipleReplicaTypes", func(t *testing.T) { + filename := filepath.Join(t.TempDir(), "litestream.yml") + + if err := os.WriteFile(filename, []byte(` +# Global defaults that apply to all supported replica types +access-key-id: GLOBAL_S3_KEY +secret-access-key: GLOBAL_S3_SECRET +region: global-region +endpoint: global.endpoint.com +account-name: global-abs-account +account-key: global-abs-key +host: global.sftp.host +user: global-sftp-user +password: global-sftp-pass +sync-interval: 45s + +dbs: + - path: /tmp/s3.sqlite + replica: + type: s3 + bucket: s3-bucket + + - path: /tmp/abs.sqlite + replica: + type: abs + bucket: abs-container + + - path: /tmp/sftp.sqlite + replica: + type: sftp + path: /backup/path +`[1:]), 0666); err != nil { + t.Fatal(err) + } + + config, err := main.ReadConfigFile(filename, true) + if err != nil { + t.Fatal(err) + } + + expectedSyncInterval, _ := time.ParseDuration("45s") + + // Test S3 replica inherits S3-specific defaults + s3Replica := config.DBs[0].Replica + if got, want := s3Replica.AccessKeyID, "GLOBAL_S3_KEY"; got != want { + t.Errorf("s3Replica.AccessKeyID=%v, want %v", got, want) + } + if got, want := s3Replica.SecretAccessKey, "GLOBAL_S3_SECRET"; got != want { + t.Errorf("s3Replica.SecretAccessKey=%v, want %v", got, want) + } + if got, want := s3Replica.Region, "global-region"; got != want { + t.Errorf("s3Replica.Region=%v, want %v", got, want) + } + if got, want := s3Replica.Endpoint, "global.endpoint.com"; got != want { + t.Errorf("s3Replica.Endpoint=%v, want %v", got, want) + } + if s3Replica.SyncInterval == nil || *s3Replica.SyncInterval != expectedSyncInterval { + t.Errorf("s3Replica.SyncInterval=%v, want %v", s3Replica.SyncInterval, expectedSyncInterval) + } + + // Test ABS replica inherits ABS-specific defaults + absReplica := config.DBs[1].Replica + if got, want := absReplica.AccountName, "global-abs-account"; got != want { + t.Errorf("absReplica.AccountName=%v, want %v", got, want) + } + if got, want := absReplica.AccountKey, "global-abs-key"; got != want { + t.Errorf("absReplica.AccountKey=%v, want %v", got, want) + } + if absReplica.SyncInterval == nil || *absReplica.SyncInterval != expectedSyncInterval { + t.Errorf("absReplica.SyncInterval=%v, want %v", absReplica.SyncInterval, expectedSyncInterval) + } + + // Test SFTP replica inherits SFTP-specific defaults + sftpReplica := config.DBs[2].Replica + if got, want := sftpReplica.Host, "global.sftp.host"; got != want { + t.Errorf("sftpReplica.Host=%v, want %v", got, want) + } + if got, want := sftpReplica.User, "global-sftp-user"; got != want { + t.Errorf("sftpReplica.User=%v, want %v", got, want) + } + if got, want := sftpReplica.Password, "global-sftp-pass"; got != want { + t.Errorf("sftpReplica.Password=%v, want %v", got, want) + } + if sftpReplica.SyncInterval == nil || *sftpReplica.SyncInterval != expectedSyncInterval { + t.Errorf("sftpReplica.SyncInterval=%v, want %v", sftpReplica.SyncInterval, expectedSyncInterval) + } + }) +} diff --git a/cmd/litestream/replicate.go b/cmd/litestream/replicate.go index 2cf2eb74..d37d8562 100644 --- a/cmd/litestream/replicate.go +++ b/cmd/litestream/replicate.go @@ -86,8 +86,10 @@ func (c *ReplicateCommand) ParseFlags(_ context.Context, args []string) (err err } syncInterval := litestream.DefaultSyncInterval dbConfig.Replicas = append(dbConfig.Replicas, &ReplicaConfig{ - URL: u, - SyncInterval: &syncInterval, + URL: u, + ReplicaSettings: ReplicaSettings{ + SyncInterval: &syncInterval, + }, }) } c.Config.DBs = []*DBConfig{dbConfig} diff --git a/cmd/litestream/restore.go b/cmd/litestream/restore.go index fe22b863..78e5ae1b 100644 --- a/cmd/litestream/restore.go +++ b/cmd/litestream/restore.go @@ -95,8 +95,10 @@ func (c *RestoreCommand) loadFromURL(ctx context.Context, replicaURL string, ifD syncInterval := litestream.DefaultSyncInterval r, err := NewReplicaFromConfig(&ReplicaConfig{ - URL: replicaURL, - SyncInterval: &syncInterval, + URL: replicaURL, + ReplicaSettings: ReplicaSettings{ + SyncInterval: &syncInterval, + }, }, nil) if err != nil { return nil, err