Skip to content

Commit

Permalink
Add new relic transactions for Registry Service.
Browse files Browse the repository at this point in the history
  • Loading branch information
robinjhuang committed Jan 18, 2025
1 parent d140d34 commit 80f1cc5
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 21 deletions.
4 changes: 3 additions & 1 deletion integration-tests/ci_cd_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/google/uuid"
"github.com/newrelic/go-agent/v3/newrelic"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
Expand All @@ -28,6 +29,7 @@ func TestCICD(t *testing.T) {
mockSlackService := new(gateways.MockSlackService)
mockDiscordService := new(gateways.MockDiscordService)
mockPubsubService := new(gateways.MockPubSubService)
newRelicApp := new(newrelic.Application)
mockSlackService.
On("SendRegistryMessageToSlack", mock.Anything).
Return(nil) // Do nothing for all slack messsage calls.
Expand All @@ -36,7 +38,7 @@ func TestCICD(t *testing.T) {
On("IndexNodes", mock.Anything, mock.Anything).
Return(nil)
impl := implementation.NewStrictServerImplementation(
client, &config.Config{}, mockStorageService, mockPubsubService, mockSlackService, mockDiscordService, mockAlgolia)
client, &config.Config{}, mockStorageService, mockPubsubService, mockSlackService, mockDiscordService, mockAlgolia, newRelicApp)

ctx := context.Background()
now := time.Now()
Expand Down
4 changes: 3 additions & 1 deletion integration-tests/test_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/testcontainers/testcontainers-go/wait"

_ "github.com/lib/pq"
"github.com/newrelic/go-agent/v3/newrelic"
)

type MockedServerImplementation struct {
Expand All @@ -55,6 +56,7 @@ func NewStrictServerImplementationWithMocks(
mockSlackService := new(gateways.MockSlackService)
mockDiscordService := new(gateways.MockDiscordService)
mockAlgolia := new(gateways.MockAlgoliaService)
newRelicApp := new(newrelic.Application)

// Set up mock service expectations.
mockDiscordService.On("SendSecurityCouncilMessage", mock.Anything, mock.Anything).
Expand All @@ -73,7 +75,7 @@ func NewStrictServerImplementationWithMocks(
// Initialize the mocked implementation with mocked services.
return &MockedServerImplementation{
DripStrictServerImplementation: implementation.NewStrictServerImplementation(
client, config, mockStorageService, mockPubsubService, mockSlackService, mockDiscordService, mockAlgolia),
client, config, mockStorageService, mockPubsubService, mockSlackService, mockDiscordService, mockAlgolia, newRelicApp),
mockStorageService: mockStorageService,
mockSlackService: mockSlackService,
mockDiscordService: mockDiscordService,
Expand Down
8 changes: 6 additions & 2 deletions server/implementation/api.implementation.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ import (
dripservices "registry-backend/services/registry"

"github.com/mixpanel/mixpanel-go"
"github.com/newrelic/go-agent/v3/newrelic"
)

type DripStrictServerImplementation struct {
Client *ent.Client
ComfyCIService *dripservices_comfyci.ComfyCIService
RegistryService *dripservices.RegistryService
MixpanelService *mixpanel.ApiClient
NewRelicApp *newrelic.Application
}

func NewStrictServerImplementation(
Expand All @@ -28,11 +30,13 @@ func NewStrictServerImplementation(
pubsubService pubsub.PubSubService,
slackService gateway.SlackService,
discordService discord.DiscordService,
algolia algolia.AlgoliaService) *DripStrictServerImplementation {
algolia algolia.AlgoliaService,
newRelicApp *newrelic.Application) *DripStrictServerImplementation {
return &DripStrictServerImplementation{
Client: client,
ComfyCIService: dripservices_comfyci.NewComfyCIService(config),
RegistryService: dripservices.NewRegistryService(storageService, pubsubService, slackService, discordService, algolia, config),
RegistryService: dripservices.NewRegistryService(storageService, pubsubService, slackService, discordService, algolia, config, newRelicApp),
MixpanelService: mixpanel.NewApiClient("f919d1b9da9a57482453c72ef7b16d88"),
NewRelicApp: newRelicApp,
}
}
32 changes: 18 additions & 14 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,33 @@ type Server struct {
Client *ent.Client
Config *config.Config
Dependencies *ServerDependencies
NewRelicApp *newrelic.Application
}

func NewServer(client *ent.Client, config *config.Config) (*Server, error) {
deps, err := initializeDependencies(config)
if err != nil {
return nil, err
}

app, err := newrelic.NewApplication(
newrelic.ConfigAppName(fmt.Sprintf("registry-%s", config.DripEnv)),
newrelic.ConfigLicense(config.NewRelicLicenseKey),
newrelic.ConfigAppLogForwardingEnabled(true),
newrelic.ConfigDebugLogger(log.Logger),
newrelic.ConfigDistributedTracerEnabled(true),
newrelic.ConfigEnabled(true),
)

if err != nil {
log.Error().Err(err).Msg("Failed to initialize NewRelic application")
}

return &Server{
Client: client,
Config: config,
Dependencies: deps,
NewRelicApp: app,
}, nil
}

Expand Down Expand Up @@ -93,23 +109,11 @@ func initializeDependencies(config *config.Config) (*ServerDependencies, error)
}

func (s *Server) Start() error {
app, err := newrelic.NewApplication(
newrelic.ConfigAppName(fmt.Sprintf("registry-%s", s.Config.DripEnv)),
newrelic.ConfigLicense(s.Config.NewRelicLicenseKey),
newrelic.ConfigAppLogForwardingEnabled(true),
newrelic.ConfigDebugLogger(log.Logger),
newrelic.ConfigDistributedTracerEnabled(true),
newrelic.ConfigEnabled(true),
)
if err != nil {
log.Error().Err(err).Msg("Failed to initialize NewRelic application")
}

e := echo.New()
e.HideBanner = true

// Apply middleware
e.Use(nrecho.Middleware(app))
e.Use(nrecho.Middleware(s.NewRelicApp))
e.Use(middleware.TracingMiddleware)
e.Use(labstack_middleware.CORSWithConfig(labstack_middleware.CORSConfig{
AllowOrigins: []string{"*"},
Expand All @@ -131,7 +135,7 @@ func (s *Server) Start() error {
impl := implementation.NewStrictServerImplementation(
s.Client, s.Config, s.Dependencies.StorageService, s.Dependencies.PubSubService,
s.Dependencies.SlackService,
s.Dependencies.DiscordService, s.Dependencies.AlgoliaService)
s.Dependencies.DiscordService, s.Dependencies.AlgoliaService, s.NewRelicApp)

// Define middleware for authorization
authorizationManager := drip_authorization.NewAuthorizationManager(s.Client, impl.RegistryService)
Expand Down
52 changes: 49 additions & 3 deletions services/registry/registry_svc.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ import (

"entgo.io/ent/dialect/sql"
"github.com/Masterminds/semver/v3"
"google.golang.org/protobuf/proto"

"github.com/google/uuid"
"github.com/newrelic/go-agent/v3/newrelic"
"github.com/rs/zerolog/log"
"google.golang.org/protobuf/proto"
)

type RegistryService struct {
Expand All @@ -49,16 +49,18 @@ type RegistryService struct {
algolia algolia.AlgoliaService
discordService discord.DiscordService
config *config.Config
newRelicApp *newrelic.Application
}

func NewRegistryService(storageSvc storage.StorageService, pubsubService pubsub.PubSubService, slackSvc gateway.SlackService, discordSvc discord.DiscordService, algoliaSvc algolia.AlgoliaService, config *config.Config) *RegistryService {
func NewRegistryService(storageSvc storage.StorageService, pubsubService pubsub.PubSubService, slackSvc gateway.SlackService, discordSvc discord.DiscordService, algoliaSvc algolia.AlgoliaService, config *config.Config, newRelicApp *newrelic.Application) *RegistryService {
return &RegistryService{
storageService: storageSvc,
pubsubService: pubsubService,
slackService: slackSvc,
discordService: discordSvc,
algolia: algoliaSvc,
config: config,
newRelicApp: newRelicApp,
}
}

Expand Down Expand Up @@ -308,6 +310,8 @@ func (s *RegistryService) UpdateNode(
}

func (s *RegistryService) GetNode(ctx context.Context, client *ent.Client, nodeID string) (*ent.Node, error) {
txn := s.newRelicApp.StartTransaction("RegistryService.GetNode")
defer txn.End()
log.Ctx(ctx).Info().Msgf("getting node: %v", nodeID)
node, err := client.Node.Get(ctx, nodeID)
if err != nil {
Expand Down Expand Up @@ -388,6 +392,8 @@ type NodeVersionCreation struct {

func (s *RegistryService) ListNodeVersions(
ctx context.Context, client *ent.Client, filter *entity.NodeVersionFilter) (*entity.ListNodeVersionsResult, error) {
txn := s.newRelicApp.StartTransaction("RegistryService.ListNodeVersions")
defer txn.End()
query := client.NodeVersion.Query().
WithStorageFile().
Order(ent.Desc(nodeversion.FieldVersion))
Expand Down Expand Up @@ -489,6 +495,8 @@ func (s *RegistryService) AddNodeReview(ctx context.Context, client *ent.Client,
}

func (s *RegistryService) GetNodeVersionByVersion(ctx context.Context, client *ent.Client, nodeId, nodeVersion string) (*ent.NodeVersion, error) {
txn := s.newRelicApp.StartTransaction("RegistryService.GetNodeVersionByVersion")
defer txn.End()
log.Ctx(ctx).Info().Msgf("getting node version %v@%v", nodeId, nodeVersion)
return client.NodeVersion.
Query().
Expand All @@ -499,12 +507,16 @@ func (s *RegistryService) GetNodeVersionByVersion(ctx context.Context, client *e
}

func (s *RegistryService) GetNodeVersion(ctx context.Context, client *ent.Client, nodeVersionId string) (*ent.NodeVersion, error) {
txn := s.newRelicApp.StartTransaction("RegistryService.GetNodeVersion")
defer txn.End()
log.Ctx(ctx).Info().Msgf("getting node version %v", nodeVersionId)
return client.NodeVersion.
Get(ctx, uuid.MustParse(nodeVersionId))
}

func (s *RegistryService) UpdateNodeVersion(ctx context.Context, client *ent.Client, update *ent.NodeVersionUpdateOne) (*ent.NodeVersion, error) {
txn := s.newRelicApp.StartTransaction("RegistryService.UpdateNodeVersion")
defer txn.End()
log.Ctx(ctx).Info().Msgf("updating node version fields: %v", update.Mutation().Fields())
return db.WithTxResult(ctx, client, func(tx *ent.Tx) (*ent.NodeVersion, error) {
node, err := update.Save(ctx)
Expand All @@ -522,6 +534,8 @@ func (s *RegistryService) UpdateNodeVersion(ctx context.Context, client *ent.Cli
}

func (s *RegistryService) RecordNodeInstallation(ctx context.Context, client *ent.Client, node *ent.Node) (*ent.Node, error) {
txn := s.newRelicApp.StartTransaction("RegistryService.RecordNodeInstallation")
defer txn.End()
var n *ent.Node
err := db.WithTx(ctx, client, func(tx *ent.Tx) (err error) {
n, err = tx.Node.UpdateOne(node).AddTotalInstall(1).Save(ctx)
Expand All @@ -539,6 +553,8 @@ func (s *RegistryService) RecordNodeInstallation(ctx context.Context, client *en
}

func (s *RegistryService) GetLatestNodeVersion(ctx context.Context, client *ent.Client, nodeId string) (*ent.NodeVersion, error) {
txn := s.newRelicApp.StartTransaction("RegistryService.GetLatestNodeVersion")
defer txn.End()
log.Ctx(ctx).Info().Msgf("Getting latest version of node: %v", nodeId)
nodeVersion, err := client.NodeVersion.
Query().
Expand Down Expand Up @@ -595,6 +611,8 @@ func (s *RegistryService) CreateComfyNodes(
comfyNodes map[string]drip.ComfyNode,
info *schema.ComfyNodeCloudBuildInfo,
) error {
txn := s.newRelicApp.StartTransaction("RegistryService.CreateComfyNodes")
defer txn.End()
return db.WithTx(ctx, client, func(tx *ent.Tx) error {
// Query the NodeVersion with the given nodeID and nodeVersion, lock it for updates
nv, err := tx.NodeVersion.Query().
Expand Down Expand Up @@ -682,6 +700,8 @@ func (s *RegistryService) GetComfyNode(
nodeID, nodeVersion, comfyNodeName string,
) (*ent.ComfyNode, error) {
// Query the NodeVersion with the given nodeID and nodeVersion, ensuring extraction status is success
txn := s.newRelicApp.StartTransaction("RegistryService.GetComfyNode")
defer txn.End()
nv, err := client.NodeVersion.Query().
Where(nodeversion.VersionEQ(nodeVersion)).
Where(nodeversion.NodeIDEQ(nodeID)).
Expand All @@ -707,6 +727,8 @@ func (s *RegistryService) GetComfyNode(
func (s *RegistryService) TriggerComfyNodesBackfill(
ctx context.Context, client *ent.Client, max *int) error {
// Query all NodeVersions with pending comfy node extraction status
txn := s.newRelicApp.StartTransaction("RegistryService.TriggerComfyNodesBackfill")
defer txn.End()
q := client.NodeVersion.
Query().
WithStorageFile().
Expand Down Expand Up @@ -748,6 +770,8 @@ func (s *RegistryService) AssertPublisherPermissions(ctx context.Context,
userID string,
permissions []schema.PublisherPermissionType,
) (err error) {
txn := s.newRelicApp.StartTransaction("RegistryService.AssertPublisherPermissions")
defer txn.End()
w, err := client.Publisher.Get(ctx, publisherID)
if err != nil {
return fmt.Errorf("fail to query publisher by id: %s %w", publisherID, err)
Expand All @@ -772,6 +796,8 @@ func (s *RegistryService) IsPersonalAccessTokenValidForPublisher(ctx context.Con
publisherID string,
accessToken string,
) (bool, error) {
txn := s.newRelicApp.StartTransaction("RegistryService.IsPersonalAccessTokenValidForPublisher")
defer txn.End()
w, err := client.Publisher.Get(ctx, publisherID)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msgf("fail to find publisher by id: %s", publisherID)
Expand All @@ -792,6 +818,8 @@ func (s *RegistryService) IsPersonalAccessTokenValidForPublisher(ctx context.Con
}

func (s *RegistryService) AssertNodeBelongsToPublisher(ctx context.Context, client *ent.Client, publisherID string, nodeID string) error {
txn := s.newRelicApp.StartTransaction("RegistryService.AssertNodeBelongsToPublisher")
defer txn.End()
node, err := client.Node.Get(ctx, nodeID)
if err != nil {
return fmt.Errorf("failed to get node: %w", err)
Expand All @@ -803,6 +831,8 @@ func (s *RegistryService) AssertNodeBelongsToPublisher(ctx context.Context, clie
}

func (s *RegistryService) AssertAccessTokenBelongsToPublisher(ctx context.Context, client *ent.Client, publisherID string, tokenId uuid.UUID) error {
txn := s.newRelicApp.StartTransaction("RegistryService.AssertAccessTokenBelongsToPublisher")
defer txn.End()
pat, err := client.PersonalAccessToken.Query().Where(
personalaccesstoken.IDEQ(tokenId),
personalaccesstoken.PublisherIDEQ(publisherID),
Expand All @@ -817,6 +847,8 @@ func (s *RegistryService) AssertAccessTokenBelongsToPublisher(ctx context.Contex
}

func (s *RegistryService) DeletePublisher(ctx context.Context, client *ent.Client, publisherID string) error {
txn := s.newRelicApp.StartTransaction("RegistryService.DeletePublisher")
defer txn.End()
log.Ctx(ctx).Info().Msgf("deleting publisher: %v", publisherID)
return db.WithTx(ctx, client, func(tx *ent.Tx) error {
client = tx.Client()
Expand Down Expand Up @@ -850,6 +882,8 @@ func (s *RegistryService) DeletePublisher(ctx context.Context, client *ent.Clien
}

func (s *RegistryService) DeleteNode(ctx context.Context, client *ent.Client, nodeID string) error {
txn := s.newRelicApp.StartTransaction("RegistryService.DeleteNode")
defer txn.End()
log.Ctx(ctx).Info().Msgf("deleting node: %v", nodeID)
db.WithTx(ctx, client, func(tx *ent.Tx) error {
nv, err := tx.Client().NodeVersion.Query().Where(nodeversion.NodeID(nodeID)).All(ctx)
Expand All @@ -876,6 +910,8 @@ func (s *RegistryService) DeleteNode(ctx context.Context, client *ent.Client, no
}

func (s *RegistryService) DeleteNodeVersion(ctx context.Context, client *ent.Client, nodeIDVersion string) error {
txn := s.newRelicApp.StartTransaction("RegistryService.DeleteNodeVersion")
defer txn.End()
log.Ctx(ctx).Info().Msgf("deleting node version: %v", nodeIDVersion)
db.WithTx(ctx, client, func(tx *ent.Tx) error {
nv, err := tx.Client().NodeVersion.Get(ctx, uuid.MustParse(nodeIDVersion))
Expand Down Expand Up @@ -921,6 +957,8 @@ func IsPermissionError(err error) bool {
}

func (s *RegistryService) BanPublisher(ctx context.Context, client *ent.Client, id string) error {
txn := s.newRelicApp.StartTransaction("RegistryService.BanPublisher")
defer txn.End()
log.Ctx(ctx).Info().Msgf("banning publisher: %v", id)
pub, err := client.Publisher.Get(ctx, id)
if err != nil {
Expand Down Expand Up @@ -972,6 +1010,8 @@ func (s *RegistryService) BanPublisher(ctx context.Context, client *ent.Client,
}

func (s *RegistryService) BanNode(ctx context.Context, client *ent.Client, publisherid, id string) error {
txn := s.newRelicApp.StartTransaction("RegistryService.BanNode")
defer txn.End()
log.Ctx(ctx).Info().Msgf("banning publisher node: %v %v", publisherid, id)

return db.WithTx(ctx, client, func(tx *ent.Tx) error {
Expand Down Expand Up @@ -1007,6 +1047,8 @@ func (s *RegistryService) BanNode(ctx context.Context, client *ent.Client, publi
}

func (s *RegistryService) AssertNodeBanned(ctx context.Context, client *ent.Client, nodeID string) error {
txn := s.newRelicApp.StartTransaction("RegistryService.AssertNodeBanned")
defer txn.End()
node, err := client.Node.Get(ctx, nodeID)
if ent.IsNotFound(err) {
return nil
Expand Down Expand Up @@ -1035,6 +1077,8 @@ func (s *RegistryService) AssertPublisherBanned(ctx context.Context, client *ent
}

func (s *RegistryService) ReindexAllNodes(ctx context.Context, client *ent.Client) error {
txn := s.newRelicApp.StartTransaction("RegistryService.ReindexAllNodes")
defer txn.End()
log.Ctx(ctx).Info().Msgf("reindexing nodes")
nodes, err := s.decorateNodeQueryWithLatestVersion(client.Node.Query()).All(ctx)
if err != nil {
Expand Down Expand Up @@ -1064,6 +1108,8 @@ func (s *RegistryService) ReindexAllNodes(ctx context.Context, client *ent.Clien
var reindexLock = sync.Mutex{}

func (s *RegistryService) ReindexAllNodesBackground(ctx context.Context, client *ent.Client) (err error) {
txn := s.newRelicApp.StartTransaction("RegistryService.ReindexAllNodesBackground")
defer txn.End()
if !reindexLock.TryLock() {
return fmt.Errorf("another reindex is in progress")
}
Expand Down

0 comments on commit 80f1cc5

Please sign in to comment.