Skip to content

Commit

Permalink
Refactor: Pass context around to help with cancellations. (Velocidex#…
Browse files Browse the repository at this point in the history
…3629)

Close LRU caches during shutdown process
  • Loading branch information
scudette authored Jul 18, 2024
1 parent 7c32605 commit 6b83126
Show file tree
Hide file tree
Showing 36 changed files with 149 additions and 49 deletions.
1 change: 1 addition & 0 deletions accessors/process/process_address_space_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ func (self ProcessAccessor) New(scope vfilter.Scope) (
reader.mu.Unlock()
}
}
result.lru.Close()
})
return result, nil
}
Expand Down
2 changes: 2 additions & 0 deletions accessors/raw_registry/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ func getRegFileSystemAccessorCache(scope vfilter.Scope) *RawRegFileSystemAccesso

root_scope.AddDestructor(func() {
cache.Close()
cache.lru.Close()
cache.readdir_lru.Close()
})
vql_subsystem.CacheSet(root_scope, RAW_CACHE_TAG, cache)

Expand Down
2 changes: 2 additions & 0 deletions accessors/registry/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ func getRegFileSystemAccessorCache(scope vfilter.Scope) *RegFileSystemAccessorCa

root_scope.AddDestructor(func() {
cache.Close()
cache.lru.Close()
cache.readdir_lru.Close()
})
vql_subsystem.CacheSet(root_scope, CACHE_TAG, cache)

Expand Down
5 changes: 5 additions & 0 deletions actions/events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"www.velocidex.com/golang/velociraptor/actions"
actions_proto "www.velocidex.com/golang/velociraptor/actions/proto"
crypto_proto "www.velocidex.com/golang/velociraptor/crypto/proto"
"www.velocidex.com/golang/velociraptor/datastore"
"www.velocidex.com/golang/velociraptor/file_store/test_utils"
flows_proto "www.velocidex.com/golang/velociraptor/flows/proto"
"www.velocidex.com/golang/velociraptor/responder"
Expand Down Expand Up @@ -69,6 +70,10 @@ func (self *EventsTestSuite) SetupTest() {
self.ConfigObj.Client.WritebackDarwin = self.writeback
self.ConfigObj.Services.ClientMonitoring = true
self.ConfigObj.Services.IndexServer = true

datastore.SetGlobalDatastore(context.Background(),
self.ConfigObj.Datastore.Implementation, self.ConfigObj)

self.TestSuite.SetupTest()

writeback_service := writeback.GetWritebackService()
Expand Down
7 changes: 5 additions & 2 deletions api/assets.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

"github.com/gorilla/csrf"
"github.com/lpar/gzipped"
context "golang.org/x/net/context"
"www.velocidex.com/golang/velociraptor/api/proto"
utils "www.velocidex.com/golang/velociraptor/api/utils"
config_proto "www.velocidex.com/golang/velociraptor/config/proto"
Expand All @@ -34,11 +35,13 @@ import (
"www.velocidex.com/golang/velociraptor/services"
)

func install_static_assets(config_obj *config_proto.Config, mux *http.ServeMux) {
func install_static_assets(
ctx context.Context,
config_obj *config_proto.Config, mux *http.ServeMux) {
base := utils.GetBasePath(config_obj)
dir := utils.Join(base, "/app/")
mux.Handle(dir, ipFilter(config_obj, http.StripPrefix(
dir, gzipped.FileServer(NewCachedFilesystem(gui_assets.HTTP)))))
dir, gzipped.FileServer(NewCachedFilesystem(ctx, gui_assets.HTTP)))))

mux.Handle("/favicon.png",
http.RedirectHandler(utils.Join(base, "/favicon.ico"),
Expand Down
2 changes: 1 addition & 1 deletion api/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ func PrepareGUIMux(
downloadFileStore([]string{"clients"}))))))

// Assets etc do not need auth.
install_static_assets(config_obj, mux)
install_static_assets(ctx, config_obj, mux)

// Add reverse proxy support.
err = AddProxyMux(config_obj, mux)
Expand Down
10 changes: 9 additions & 1 deletion api/static.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
errors "github.com/go-errors/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
context "golang.org/x/net/context"
"www.velocidex.com/golang/velociraptor/services"
"www.velocidex.com/golang/velociraptor/utils"
)
Expand Down Expand Up @@ -145,14 +146,21 @@ func (self *CachedFilesystem) Exists(path string) bool {
return true
}

func NewCachedFilesystem(fs http.FileSystem) *CachedFilesystem {
func NewCachedFilesystem(
ctx context.Context, fs http.FileSystem) *CachedFilesystem {
result := &CachedFilesystem{
FileSystem: fs,
lru: ttlcache.NewCache(),
}

result.lru.SetTTL(10 * time.Minute)
result.lru.SkipTTLExtensionOnHit(true)

go func() {
<-ctx.Done()
result.lru.Close()
}()

return result
}

Expand Down
16 changes: 13 additions & 3 deletions crypto/client/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"crypto/rsa"
"crypto/x509"

Expand Down Expand Up @@ -54,7 +55,9 @@ func (self *ClientCryptoManager) AddCertificate(
return server_name, nil
}

func NewClientCryptoManager(config_obj *config_proto.Config, client_private_key_pem []byte) (
func NewClientCryptoManager(
ctx context.Context,
config_obj *config_proto.Config, client_private_key_pem []byte) (
*ClientCryptoManager, error) {
private_key, err := crypto_utils.ParseRsaPrivateKeyFromPemStr(client_private_key_pem)
if err != nil {
Expand All @@ -77,13 +80,20 @@ func NewClientCryptoManager(config_obj *config_proto.Config, client_private_key_
lru_size = config_obj.Frontend.Resources.ExpectedClients
}

return &ClientCryptoManager{CryptoManager{
result := &ClientCryptoManager{CryptoManager{
client_id: client_id,
private_key: private_key,
Resolver: NewInMemoryPublicKeyResolver(),
cipher_lru: NewCipherLRU(lru_size),
unauthenticated_lru: ttlcache.NewCache(),
caPool: roots,
logger: logger,
}}, nil
}}

go func() {
<-ctx.Done()
result.unauthenticated_lru.Close()
}()

return result, nil
}
10 changes: 9 additions & 1 deletion crypto/client/manager.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"crypto"
"crypto/aes"
"crypto/cipher"
Expand Down Expand Up @@ -85,7 +86,9 @@ func (self *CryptoManager) GetCSR() ([]byte, error) {
Bytes: csrBytes}), nil
}

func NewCryptoManager(config_obj *config_proto.Config,
func NewCryptoManager(
ctx context.Context,
config_obj *config_proto.Config,
client_id string,
private_key_pem []byte,
public_key_resolver PublicKeyResolver,
Expand All @@ -110,6 +113,11 @@ func NewCryptoManager(config_obj *config_proto.Config,
result.unauthenticated_lru.SetTTL(time.Second * 60)
result.unauthenticated_lru.SkipTTLExtensionOnHit(true)

go func() {
<-ctx.Done()
result.unauthenticated_lru.Close()
}()

return result, nil
}

Expand Down
2 changes: 1 addition & 1 deletion crypto/crypto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (self *TestSuite) SetupTest() {
self.ConfigObj.Writeback.PrivateKey = string(key)

// Configure the client manager.
self.client_manager, err = crypto_client.NewClientCryptoManager(
self.client_manager, err = crypto_client.NewClientCryptoManager(self.Ctx,
self.ConfigObj, []byte(self.ConfigObj.Writeback.PrivateKey))
require.NoError(self.T(), err)

Expand Down
3 changes: 2 additions & 1 deletion crypto/server/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ func NewServerCryptoManager(
return nil, err
}

base, err := client.NewCryptoManager(config_obj, crypto_utils.GetSubjectName(cert),
base, err := client.NewCryptoManager(ctx, config_obj,
crypto_utils.GetSubjectName(cert),
[]byte(config_obj.Frontend.PrivateKey), resolver,
logging.GetLogger(config_obj, &logging.FrontendComponent))
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion crypto/storage/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (self *CrytpoStoreTestSuite) testWriting() {
// Initial state no server connection.
SetCurrentServerPem(nil)

fd, err := NewCryptoFileWriter(self.ConfigObj, 10000, output)
fd, err := NewCryptoFileWriter(self.Ctx, self.ConfigObj, 10000, output)
assert.NoError(self.T(), err)
defer fd.Close()

Expand Down
12 changes: 8 additions & 4 deletions crypto/storage/writer.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package storage

import (
"context"
"crypto/rsa"
"errors"
"os"
Expand Down Expand Up @@ -29,6 +30,7 @@ var (
type CryptoFileWriter struct {
mu sync.Mutex

ctx context.Context
config_obj *config_proto.Config
fd *os.File
header *Header
Expand Down Expand Up @@ -78,7 +80,7 @@ func (self *CryptoFileWriter) serverPem() ([]byte, error) {
// the crypto manager until we have contacted the server and fetched
// its certificate. This code delays use of the crypto manager until
// it becomes available.
func (self *CryptoFileWriter) cryptoManager() (
func (self *CryptoFileWriter) cryptoManager(ctx context.Context) (
*crypto_client.ClientCryptoManager, error) {

server_pem, err := self.serverPem()
Expand All @@ -98,7 +100,7 @@ func (self *CryptoFileWriter) cryptoManager() (
return nil, err
}

crypto_manager, err := crypto_client.NewClientCryptoManager(
crypto_manager, err := crypto_client.NewClientCryptoManager(ctx,
self.config_obj, []byte(writeback.PrivateKey))
if err != nil {
return nil, err
Expand Down Expand Up @@ -168,7 +170,7 @@ func (self *CryptoFileWriter) Flush(keep_on_error KeepPolicy) error {

nonce := self.config_obj.Client.Nonce

manager, err := self.cryptoManager()
manager, err := self.cryptoManager(self.ctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -222,7 +224,7 @@ func (self *CryptoFileWriter) writeCerts() error {

self.header.Next = pub_key.Next

manager, err := self.cryptoManager()
manager, err := self.cryptoManager(self.ctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -250,6 +252,7 @@ func (self *CryptoFileWriter) writeCerts() error {
}

func NewCryptoFileWriter(
ctx context.Context,
config_obj *config_proto.Config,
max_size uint64,
filename string) (*CryptoFileWriter, error) {
Expand Down Expand Up @@ -279,6 +282,7 @@ func NewCryptoFileWriter(
result := &CryptoFileWriter{
config_obj: config_obj,
fd: fd,
ctx: ctx,
header: &Header{},
max_size: max_size,
}
Expand Down
19 changes: 12 additions & 7 deletions datastore/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package datastore

import (
"context"
"errors"
"sync"
"time"
Expand Down Expand Up @@ -125,17 +126,20 @@ func GetDB(config_obj *config_proto.Config) (DataStore, error) {
return nil, err
}

return getImpl(config_obj, implementation)
ctx := context.Background()
return getImpl(ctx, config_obj, implementation)
}

func getImpl(config_obj *config_proto.Config, implementation string) (DataStore, error) {
func getImpl(
ctx context.Context, config_obj *config_proto.Config,
implementation string) (DataStore, error) {
switch implementation {
case "FileBaseDataStore":
return file_based_imp, nil

case "ReadOnlyDataStore":
if read_only_imp == nil {
read_only_imp = NewReadOnlyDataStore(config_obj)
read_only_imp = NewReadOnlyDataStore(ctx, config_obj)
}
return read_only_imp, nil

Expand All @@ -144,23 +148,23 @@ func getImpl(config_obj *config_proto.Config, implementation string) (DataStore,

case "Memcache":
if memcache_imp == nil {
memcache_imp_ := NewMemcacheDataStore(config_obj)
memcache_imp_ := NewMemcacheDataStore(ctx, config_obj)
memcache_imp = memcache_imp_
RegisterMemcacheDatastoreMetrics(memcache_imp_)
}
return memcache_imp, nil

case "MemcacheFileDataStore":
if memcache_file_imp == nil {
memcache_imp_ := NewMemcacheFileDataStore(config_obj)
memcache_imp_ := NewMemcacheFileDataStore(ctx, config_obj)
memcache_file_imp = memcache_imp_
RegisterMemcacheDatastoreMetrics(memcache_imp_)
}
return memcache_file_imp, nil

case "Test":
if memcache_imp == nil {
memcache_imp = NewMemcacheDataStore(config_obj)
memcache_imp = NewMemcacheDataStore(ctx, config_obj)
}
return memcache_imp, nil

Expand All @@ -171,12 +175,13 @@ func getImpl(config_obj *config_proto.Config, implementation string) (DataStore,
}

func SetGlobalDatastore(
ctx context.Context,
implementation string,
config_obj *config_proto.Config) (err error) {
ds_mu.Lock()
defer ds_mu.Unlock()

g_impl, err = getImpl(config_obj, implementation)
g_impl, err = getImpl(ctx, config_obj, implementation)
return err
}

Expand Down
15 changes: 11 additions & 4 deletions datastore/memcache.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package datastore

import (
"context"
"errors"
"fmt"
"os"
Expand Down Expand Up @@ -265,7 +266,7 @@ func (self *DirectoryLRUCache) Count() int {
}

func NewDirectoryLRUCache(
config_obj *config_proto.Config,
ctx context.Context, config_obj *config_proto.Config,
max_size, max_item_size int) *DirectoryLRUCache {

result := &DirectoryLRUCache{
Expand All @@ -276,6 +277,11 @@ func NewDirectoryLRUCache(
max_item_size: max_item_size,
}

go func() {
<-ctx.Done()
result.Cache.Close()
}()

result.Cache.SetCacheSizeLimit(max_size)
return result
}
Expand Down Expand Up @@ -668,12 +674,13 @@ func (self *MemcacheDatastore) Stats() *MemcacheStats {
}
}

func NewMemcacheDataStore(config_obj *config_proto.Config) *MemcacheDatastore {
func NewMemcacheDataStore(
ctx context.Context, config_obj *config_proto.Config) *MemcacheDatastore {
// This data store is used for testing so we really do not want to
// expire anything.
result := &MemcacheDatastore{
data_cache: NewDataLRUCache(config_obj, 100000, 1000000),
dir_cache: NewDirectoryLRUCache(config_obj, 100000, 100000),
data_cache: NewDataLRUCache(ctx, config_obj, 100000, 1000000),
dir_cache: NewDirectoryLRUCache(ctx, config_obj, 100000, 100000),
get_dir_metadata: get_dir_metadata,
}

Expand Down
Loading

0 comments on commit 6b83126

Please sign in to comment.