diff --git a/cmd/scdbserver/config.yml b/cmd/scdbserver/config.yml index 61000cd8..db7421d1 100644 --- a/cmd/scdbserver/config.yml +++ b/cmd/scdbserver/config.yml @@ -21,3 +21,6 @@ engine: "protocol": "SEMI2K", "field": "FM64" } + infoschema_cache: + enabled: true + ttl: 10m diff --git a/cmd/scdbserver/main.go b/cmd/scdbserver/main.go index 2d1d651c..47f6d0db 100644 --- a/cmd/scdbserver/main.go +++ b/cmd/scdbserver/main.go @@ -87,6 +87,7 @@ func main() { log.Info("Starting to connect to database and do bootstrap if necessary...") storage.InitPasswordValidation(cfg.PasswordCheck) + storage.InitInfoSchemaCache(cfg.InfoSchemaCache.Enabled, cfg.InfoSchemaCache.TTL) store, err := server.NewDbConnWithBootstrap(&cfg.Storage) if err != nil { log.Fatalf("Failed to connect to database and bootstrap it: %v", err) diff --git a/pkg/scdb/config/config.go b/pkg/scdb/config/config.go index 251788fb..1f144e82 100644 --- a/pkg/scdb/config/config.go +++ b/pkg/scdb/config/config.go @@ -38,6 +38,7 @@ const ( DefaultProtocol = "https" DefaultLogLevel = "info" DefaultEngineClientMode = "HTTP" + DefaultInfoSchemaCacheTTL = 10 * time.Minute // 10 minutes ) type EngineConfig struct { @@ -62,6 +63,11 @@ type SecurityCompromiseConf struct { RevealGroupCount bool `yaml:"reveal_group_count"` } +type InfoSchemaCacheConf struct { + Enabled bool `yaml:"enabled"` + TTL time.Duration `yaml:"ttl"` +} + // Config contains bootstrap configuration for SCDB type Config struct { // SCDBHost is used as callback url for engine worked in async mode @@ -79,6 +85,7 @@ type Config struct { Engine EngineConfig `yaml:"engine"` SecurityCompromise SecurityCompromiseConf `yaml:"security_compromise"` PartyAuth PartyAuthConf `yaml:"party_auth"` + InfoSchemaCache InfoSchemaCacheConf `yaml:"infoschema_cache"` } const ( @@ -164,6 +171,10 @@ func NewDefaultConfig() *Config { EnableTimestampCheck: true, ValidityPeriod: 30 * time.Second, // 30s } + config.InfoSchemaCache = InfoSchemaCacheConf{ + Enabled: true, + TTL: DefaultInfoSchemaCacheTTL, + } return &config } diff --git a/pkg/scdb/config/config_test.go b/pkg/scdb/config/config_test.go index 32b89065..69a6f7a7 100644 --- a/pkg/scdb/config/config_test.go +++ b/pkg/scdb/config/config_test.go @@ -114,6 +114,10 @@ security_compromise: ValidityPeriod: time.Minute * 3, }, SecurityCompromise: SecurityCompromiseConf{GroupByThreshold: 3}, + InfoSchemaCache: InfoSchemaCacheConf{ + Enabled: true, + TTL: DefaultInfoSchemaCacheTTL, + }, } expectedCfg.Engine.SpuRuntimeCfg = strings.ReplaceAll(expectedCfg.Engine.SpuRuntimeCfg, "\t", " ") r.Equal(expectedCfg, cfg) @@ -206,6 +210,10 @@ party_auth: EnableTimestampCheck: true, ValidityPeriod: time.Minute * 3, }, + InfoSchemaCache: InfoSchemaCacheConf{ + Enabled: true, + TTL: DefaultInfoSchemaCacheTTL, + }, } expectedCfg.Engine.SpuRuntimeCfg = strings.ReplaceAll(expectedCfg.Engine.SpuRuntimeCfg, "\t", " ") r.Equal(expectedCfg, cfg) diff --git a/pkg/scdb/executor/ddl.go b/pkg/scdb/executor/ddl.go index 8d8c9246..a41b41c1 100644 --- a/pkg/scdb/executor/ddl.go +++ b/pkg/scdb/executor/ddl.go @@ -53,14 +53,41 @@ func (e *DDLExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { switch x := e.stmt.(type) { case *ast.CreateDatabaseStmt: err = e.executeCreateDatabase(x) + if err == nil { + storage.InvalidateInfoSchemaCache(strings.ToLower(x.Name)) + } case *ast.CreateTableStmt: err = e.executeCreateTable(x) + if err == nil { + dbName := x.Table.Schema.O + if dbName == "" { + dbName = e.ctx.GetSessionVars().CurrentDB + } + storage.InvalidateInfoSchemaCache(dbName) + } case *ast.DropDatabaseStmt: err = e.executeDropDatabase(x) + if err == nil { + storage.InvalidateInfoSchemaCache(strings.ToLower(x.Name)) + } case *ast.DropTableStmt: err = e.executeDropTableOrView(x) + if err == nil && len(x.Tables) > 0 { + dbName := x.Tables[0].Schema.L + if dbName == "" { + dbName = e.ctx.GetSessionVars().CurrentDB + } + storage.InvalidateInfoSchemaCache(dbName) + } case *ast.CreateViewStmt: err = e.executeCreateView(x) + if err == nil { + dbName := x.ViewName.Schema.L + if dbName == "" { + dbName = e.ctx.GetSessionVars().CurrentDB + } + storage.InvalidateInfoSchemaCache(dbName) + } default: err = fmt.Errorf("ddl.Next: Unsupported statement %v", x) diff --git a/pkg/scdb/storage/storage.go b/pkg/scdb/storage/storage.go index 0179aae2..77de02e5 100644 --- a/pkg/scdb/storage/storage.go +++ b/pkg/scdb/storage/storage.go @@ -21,6 +21,9 @@ import ( "reflect" "sort" "strings" + "sync" + "sync/atomic" + "time" "github.com/sethvargo/go-password/password" "gorm.io/gorm" @@ -53,6 +56,30 @@ var EnablePasswordCheck = false // keep creating order var allTables = []interface{}{&User{}, &Database{}, &Table{}, &Column{}, &DatabasePriv{}, &TablePriv{}, &ColumnPriv{}} +// InfoSchema cache entry with TTL support +type infoCacheEntry struct { + schema infoschema.InfoSchema + createTime time.Time +} + +// InfoSchema cache to avoid frequent database queries +var ( + infoSchemaCache sync.Map // map[string]*infoCacheEntry, key is database name + infoCacheHitCount int64 // accessed via atomic operations + infoCacheMissCount int64 // accessed via atomic operations + cacheEnabled bool // whether cache is enabled + cacheTTL time.Duration +) + +// InitInfoSchemaCache initializes the InfoSchema cache configuration +func InitInfoSchemaCache(enabled bool, ttl time.Duration) { + cacheEnabled = enabled + cacheTTL = ttl + if ttl <= 0 { + ttl = 10 * time.Minute // default to 10 minutes + } +} + // NeedBootstrap checks if the store is empty func NeedBootstrap(store *gorm.DB) bool { for _, tn := range allTables { @@ -173,6 +200,32 @@ func FindUserByParty(store *gorm.DB, partyCode string) (*User, error) { return &user, nil } +// InvalidateInfoSchemaCache invalidates the InfoSchema cache for a specific database +// If dbName is empty, it invalidates all cached InfoSchemas +func InvalidateInfoSchemaCache(dbName string) { + if dbName == "" { + // Clear all cache by iterating and deleting each key to avoid race conditions + infoSchemaCache.Range(func(key, value interface{}) bool { + infoSchemaCache.Delete(key) + return true + }) + } else { + // Clear specific database cache + infoSchemaCache.Delete(dbName) + } +} + +// ResetInfoSchemaCacheStats resets cache statistics (for testing) +func ResetInfoSchemaCacheStats() { + atomic.StoreInt64(&infoCacheHitCount, 0) + atomic.StoreInt64(&infoCacheMissCount, 0) +} + +// GetInfoSchemaCacheStats returns cache hit and miss statistics +func GetInfoSchemaCacheStats() (hits, misses int64) { + return atomic.LoadInt64(&infoCacheHitCount), atomic.LoadInt64(&infoCacheMissCount) +} + func QueryInfoSchema(store *gorm.DB) (result infoschema.InfoSchema, err error) { callFc := func(tx *gorm.DB) error { result, err = queryInfoSchema(tx) @@ -244,6 +297,23 @@ func queryInfoSchema(store *gorm.DB) (infoschema.InfoSchema, error) { } func QueryDBInfoSchema(store *gorm.DB, dbName string) (result infoschema.InfoSchema, err error) { + // Try to get from cache first if enabled + if cacheEnabled { + if cached, ok := infoSchemaCache.Load(dbName); ok { + entry := cached.(*infoCacheEntry) + // Check if cache entry is still valid (not expired) + if time.Since(entry.createTime) < cacheTTL { + atomic.AddInt64(&infoCacheHitCount, 1) + return entry.schema, nil + } + // Cache expired, remove it + infoSchemaCache.Delete(dbName) + } + } + + // Cache miss or disabled, query from database + atomic.AddInt64(&infoCacheMissCount, 1) + callFc := func(tx *gorm.DB) error { result, err = queryDBInfoSchema(tx, dbName) return err @@ -251,6 +321,15 @@ func QueryDBInfoSchema(store *gorm.DB, dbName string) (result infoschema.InfoSch if err := store.Transaction(callFc, &sql.TxOptions{ReadOnly: true}); err != nil { return nil, fmt.Errorf("queryDBInfoSchema: %v", err) } + + // Store in cache if enabled + if cacheEnabled && result != nil { + infoSchemaCache.Store(dbName, &infoCacheEntry{ + schema: result, + createTime: time.Now(), + }) + } + return result, nil } diff --git a/pkg/scdb/storage/storage_test.go b/pkg/scdb/storage/storage_test.go index c330bec5..708b17c2 100644 --- a/pkg/scdb/storage/storage_test.go +++ b/pkg/scdb/storage/storage_test.go @@ -16,6 +16,7 @@ package storage import ( "testing" + "time" "github.com/stretchr/testify/require" "gorm.io/driver/sqlite" @@ -107,3 +108,125 @@ func batchInsert(db *gorm.DB, records []interface{}) error { } return nil } + +func createTestDB(t *testing.T, db *gorm.DB, dbName string) { + r := require.New(t) + records := []interface{}{ + &Database{Db: dbName}, + &Table{Db: dbName, Table: "t1", Owner: "root", Host: "%", RefDb: "ref", RefTable: "ref", DBType: 0}, + &Column{Db: dbName, TableName: "t1", ColumnName: "c1", Type: "int"}, + } + r.NoError(batchInsert(db, records)) +} + +func TestInfoSchemaCache(t *testing.T) { + r := require.New(t) + db, err := newDbStore() + r.NoError(err) + + // Initialize cache with enabled=true and a reasonable TTL + InitInfoSchemaCache(true, 5*time.Minute) + InvalidateInfoSchemaCache("") + ResetInfoSchemaCacheStats() + + createTestDB(t, db, "db1") + createTestDB(t, db, "db2") + + // Test cache miss and hit + _, err = QueryDBInfoSchema(db, "db1") + r.NoError(err) + hits, misses := GetInfoSchemaCacheStats() + r.Equal(int64(0), hits) + r.Equal(int64(1), misses) + + _, err = QueryDBInfoSchema(db, "db1") + r.NoError(err) + hits, misses = GetInfoSchemaCacheStats() + r.Equal(int64(1), hits) + r.Equal(int64(1), misses) + + // Test cache invalidation + InvalidateInfoSchemaCache("db1") + _, err = QueryDBInfoSchema(db, "db1") + r.NoError(err) + hits, misses = GetInfoSchemaCacheStats() + r.Equal(int64(1), hits) + r.Equal(int64(2), misses) + + // Test multiple databases cache isolation + _, err = QueryDBInfoSchema(db, "db2") + r.NoError(err) + hits, misses = GetInfoSchemaCacheStats() + r.Equal(int64(1), hits) + r.Equal(int64(3), misses) + + InvalidateInfoSchemaCache("db2") + _, err = QueryDBInfoSchema(db, "db1") + r.NoError(err) + hits, misses = GetInfoSchemaCacheStats() + r.Equal(int64(2), hits) // db1 still cached + r.Equal(int64(3), misses) +} + +func TestInfoSchemaCacheTTL(t *testing.T) { + r := require.New(t) + db, err := newDbStore() + r.NoError(err) + + // Initialize cache with a short TTL for testing + InitInfoSchemaCache(true, 100*time.Millisecond) + InvalidateInfoSchemaCache("") + ResetInfoSchemaCacheStats() + + createTestDB(t, db, "db1") + + // First query - cache miss + _, err = QueryDBInfoSchema(db, "db1") + r.NoError(err) + hits, misses := GetInfoSchemaCacheStats() + r.Equal(int64(0), hits) + r.Equal(int64(1), misses) + + // Second query immediately - cache hit + _, err = QueryDBInfoSchema(db, "db1") + r.NoError(err) + hits, misses = GetInfoSchemaCacheStats() + r.Equal(int64(1), hits) + r.Equal(int64(1), misses) + + // Wait for cache to expire + time.Sleep(150 * time.Millisecond) + + // Query after TTL expired - should be cache miss + _, err = QueryDBInfoSchema(db, "db1") + r.NoError(err) + hits, misses = GetInfoSchemaCacheStats() + r.Equal(int64(1), hits) + r.Equal(int64(2), misses) +} + +func TestInfoSchemaCacheDisabled(t *testing.T) { + r := require.New(t) + db, err := newDbStore() + r.NoError(err) + + // Disable cache + InitInfoSchemaCache(false, 5*time.Minute) + InvalidateInfoSchemaCache("") + ResetInfoSchemaCacheStats() + + createTestDB(t, db, "db1") + + // All queries should be cache misses when disabled + _, err = QueryDBInfoSchema(db, "db1") + r.NoError(err) + hits, misses := GetInfoSchemaCacheStats() + r.Equal(int64(0), hits) + r.Equal(int64(1), misses) + + _, err = QueryDBInfoSchema(db, "db1") + r.NoError(err) + hits, misses = GetInfoSchemaCacheStats() + r.Equal(int64(0), hits) + r.Equal(int64(2), misses) +}