Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion common/scanning/credscanning/fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,13 @@ func getFields(info providers.ProviderInfo,
withRequiredWorkspace = info.Oauth2Opts.ExplicitWorkspaceRequired
}

workspaceMode := optionalType
if info.RequiresWorkspace() || withRequiredWorkspace {
lists.Add(requiredType, Fields.Workspace)
workspaceMode = requiredType
}

lists.Add(workspaceMode, Fields.Workspace)

return lists, nil
}

Expand Down
7 changes: 6 additions & 1 deletion providers/microsoft/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/amp-labs/connectors/internal/components/schema"
"github.com/amp-labs/connectors/internal/components/writer"
"github.com/amp-labs/connectors/providers"
"github.com/amp-labs/connectors/providers/microsoft/internal/batch"
"github.com/amp-labs/connectors/providers/microsoft/internal/metadata"
)

Expand All @@ -28,6 +29,9 @@ type Connector struct {
components.Reader
components.Writer
components.Deleter

// Dependent services.
batchStrategy *batch.Strategy
}

// NewConnector creates a new Microsoft connector. It defaults to the Microsoft
Expand All @@ -49,7 +53,8 @@ func NewConnectorForProvider(provider providers.Provider, params common.Connecto
// nolint:funlen
func constructor(base *components.Connector) (*Connector, error) {
connector := &Connector{
Connector: base,
Connector: base,
batchStrategy: batch.NewStrategy(base.JSONHTTPClient(), base.ProviderInfo()),
}

connector.SchemaProvider = schema.NewOpenAPISchemaProvider(connector.ProviderContext.Module(), metadata.Schemas)
Expand Down
64 changes: 64 additions & 0 deletions providers/microsoft/read-by-ids.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package microsoft

import (
"context"
"fmt"
"net/http"

"github.com/amp-labs/connectors"
"github.com/amp-labs/connectors/common"
"github.com/amp-labs/connectors/common/readhelper"
"github.com/amp-labs/connectors/internal/datautils"
"github.com/amp-labs/connectors/providers/microsoft/internal/batch"
)

var _ connectors.BatchRecordReaderConnector = (*Connector)(nil)

// GetRecordsByIds scoped reading of records given their ids.
// nolint:revive
func (c *Connector) GetRecordsByIds(ctx context.Context,
objectName string, recordIds []string,
fields []string, associations []string,
) ([]common.ReadResultRow, error) {
if len(recordIds) == 0 {
return nil, common.ErrMissingObjects
}

batchParams, requestIdentifiers, err := c.paramsForBatchRead(objectName, recordIds)
if err != nil {
return nil, err
}

batchResponse := batch.Execute[map[string]any](ctx, c.batchStrategy, batchParams)
if err = batchResponse.JoinedErr(); err != nil {
return nil, err
}

marshaler := readhelper.MakeGetMarshaledDataWithId(readhelper.NewIdField("id"))
uniqueFields := datautils.NewSetFromList(fields).List()

return marshaler(batchResponse.GetInOrder(requestIdentifiers), uniqueFields)
}

func (c *Connector) paramsForBatchRead(
objectName string, identifiers []string,
) (*batch.Params, []batch.RequestID, error) {
batchParams := &batch.Params{}

requestIdentifiers := make([]batch.RequestID, len(identifiers))
for index, identifier := range identifiers {
url, err := c.getURL(objectName)
if err != nil {
return nil, nil, err
}

url.AddPath(identifier)
requestIdentifier := batch.RequestID(fmt.Sprintf("%v_%v", objectName, identifier))
requestIdentifiers[index] = requestIdentifier
batchParams.WithRequest(requestIdentifier, http.MethodGet, url, nil, map[string]any{
"Content-Type": "application/json",
})
}

return batchParams, requestIdentifiers, nil
}
115 changes: 115 additions & 0 deletions providers/microsoft/read-by-ids_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package microsoft

import (
"net/http"
"testing"

"github.com/amp-labs/connectors"
"github.com/amp-labs/connectors/common"
"github.com/amp-labs/connectors/test/utils/mockutils/mockcond"
"github.com/amp-labs/connectors/test/utils/mockutils/mockserver"
"github.com/amp-labs/connectors/test/utils/testroutines"
"github.com/amp-labs/connectors/test/utils/testutils"
)

func TestGetRecordsByIds(t *testing.T) { // nolint:funlen,cyclop
t.Parallel()

responseMessages := testutils.DataFromFile(t, "read/messages/batch-by-ids.json")

tests := []testroutines.ReadByIds{
{
Name: "Empty record identifiers",
Server: mockserver.Dummy(),
ExpectedErrs: []error{common.ErrMissingObjects},
},
{
Name: "Read messages by identifiers",
Input: testroutines.ReadByIdsParams{
ObjectName: "me/messages",
RecordIds: []string{"msg1", "msg2", "msg3"},
Fields: []string{"subject"},
},
Server: mockserver.Conditional{
Setup: mockserver.ContentJSON(),
If: mockcond.And{
mockcond.MethodPOST(),
mockcond.Path("/v1.0/$batch"),
mockcond.Body(`{
"requests": [
{
"id": "me/messages_msg1",
"method": "GET",
"url": "/me/messages/msg1",
"headers": {
"Content-Type": "application/json"
}
},
{
"id": "me/messages_msg2",
"method": "GET",
"url": "/me/messages/msg2",
"headers": {
"Content-Type": "application/json"
}
},
{
"id": "me/messages_msg3",
"method": "GET",
"url": "/me/messages/msg3",
"headers": {
"Content-Type": "application/json"
}
}
]
}`),
},
Then: mockserver.Response(http.StatusOK, responseMessages),
}.Server(),
Comparator: testroutines.ComparatorSubsetReadByIds,
Expected: []common.ReadResultRow{{
Id: "msg1",
Fields: map[string]any{
"subject": "Hello",
},
Raw: map[string]any{
"id": "msg1",
"subject": "Hello",
"bodyPreview": "Hi there",
},
}, {
Id: "msg2",
Fields: map[string]any{
"subject": "Meeting",
},
Raw: map[string]any{
"id": "msg2",
"subject": "Meeting",
"bodyPreview": "See you soon",
},
}, {
Id: "msg3",
Fields: map[string]any{
"subject": "Lunch",
},
Raw: map[string]any{
"id": "msg3",
"subject": "Lunch",
"bodyPreview": "Hungry?",
},
}},
ExpectedErrs: nil,
},
}

for _, tt := range tests { // nolint:dupl
// nolint:varnamelen
t.Run(tt.Name, func(t *testing.T) {
t.Parallel()

tt.Run(t, func() (connectors.BatchRecordReaderConnector, error) {
return constructTestConnector(tt.Server.URL)
})
})
}
}
Loading
Loading