diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..867ec08 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,31 @@ +name: Linters, Spellcheck, and Tests + +on: + push: + paths: + - '**.go' + workflow_dispatch: + +jobs: + LintnTest: + runs-on: ubuntu-latest + timeout-minutes: 2 + steps: + - uses: actions/checkout@v4 + - name: Setup go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + - name: Install dependencies + run: make installpkgs + - name: Run linters + run: make lint + - name: Run tests + run: make test + + Spellcheck: + runs-on: ubuntu-latest + timeout-minutes: 1 + steps: + - uses: actions/checkout@v4 + - uses: crate-ci/typos@v1.29.7 diff --git a/.gitignore b/.gitignore index 2ee5bd8..6e3bbcf 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,8 @@ *.dll *.so *.dylib +*.html +*.json # Test binary, built with `go test -c` *.test diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..ddbd2d1 --- /dev/null +++ b/Makefile @@ -0,0 +1,43 @@ +# Run tests and log the test coverage +test: + go test -v -race -coverprofile=".cover.out" $$(go list ./... | grep -v /tmp) + +# Runs source code linters and catches common errors +lint: + test -z $$(gofmt -l .) || (echo "Code isn't gofmt'ed!" && exit 1) + go vet $$(go list ./... | grep -v /tmp) + gosec -quiet -fmt=golint -exclude-dir="tmp" ./... + staticcheck ./... + govulncheck -test ./... + # pointerinterface ./... + +# Runs spellchecker on the code and comments +# This requires this tool to be installed from https://github.com/crate-ci/typos?tab=readme-ov-file +# Example installation (if you have rust installed): cargo install typos-cli +spellcheck: + typos . + +# All in one check +runchecks: test lint spellcheck + +# Generate pretty coverage report +analyse: + go tool cover -html=".cover.out" -o="cover.html" + @echo -e "\nCOVERAGE\n====================" + go tool cover -func=.cover.out + @echo -e "\nCYCLOMATIC COMPLEXITY\n====================" + gocyclo -avg -top 10 . + +# Updates 3rd party packages and tools +installpkgs: + go mod download + go install github.com/fzipp/gocyclo/cmd/gocyclo@latest + go install github.com/securego/gosec/v2/cmd/gosec@latest + go install honnef.co/go/tools/cmd/staticcheck@latest + go install golang.org/x/vuln/cmd/govulncheck@latest + # go install code.larus.se/lmas/pointerinterface@latest + +# Clean up built binary and other temporary files (ignores errors from rm) +clean: + go clean + rm .cover.out cover.html diff --git a/components/device.go b/components/host.go old mode 100755 new mode 100644 similarity index 99% rename from components/device.go rename to components/host.go index dc389c5..6657404 --- a/components/device.go +++ b/components/host.go @@ -67,7 +67,6 @@ func NewDevice() *HostingDevice { func Hostname() (string, error) { name, err := os.Hostname() if err != nil { - log.Println(err.Error()) return "", err } return name, nil diff --git a/components/host_test.go b/components/host_test.go new file mode 100644 index 0000000..5dbe78d --- /dev/null +++ b/components/host_test.go @@ -0,0 +1,41 @@ +package components + +import ( + "testing" +) + +func TestHostname(t *testing.T) { + res, err := Hostname() + + if res == "" || err != nil { + t.Errorf("Expected a host name and no error, got: %s and %v", res, err) + } +} + +func TestIpAddresses(t *testing.T) { + res, err := IpAddresses() + + if len(res) == 0 || err != nil { + t.Errorf("Expected IP addresses and no error, got: %s and %v", res, err) + } +} + +func TestMacAddresses(t *testing.T) { + ip, err := IpAddresses() + if err != nil { + t.Fatalf("An error occurred in getting IP Addresses for the Mac Address test") + } + res, err := MacAddresses(ip) + + if len(res) == 0 || err != nil { + t.Errorf("Expected no error, got: %s and %v", res, err) + } +} + +func TestNewDevice(t *testing.T) { + res := NewDevice() + + if res == nil { + t.Errorf("Expected a new device, got: %v", res) + } +} diff --git a/components/husk.go b/components/husk.go index 431562a..ddc26ff 100644 --- a/components/husk.go +++ b/components/husk.go @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2024 Synecdoque + * Copyright (c) 2025 Synecdoque * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -35,7 +35,7 @@ type Husk struct { Certificate string `json:"-"` CA_cert string `json:"-"` TlsConfig *tls.Config `json:"-"` // client side mutual TLS configuration - DName pkix.Name `json:"distinguishedName"` + DName pkix.Name `json:"-"` Details map[string][]string `json:"details"` ProtoPort map[string]int `json:"protoPort"` InfoLink string `json:"onlineDocumentation"` diff --git a/components/husk_test.go b/components/husk_test.go new file mode 100644 index 0000000..9364b53 --- /dev/null +++ b/components/husk_test.go @@ -0,0 +1,38 @@ +package components + +import ( + "testing" +) + +type sProtocolsTestStruct struct { + input map[string]int + expectedOutput []string +} + +var sProtocolsTestParams = []sProtocolsTestStruct{ + {makeEmptyProtoPortMap(), nil}, + {makeProtoPortMapWithPortZero(), nil}, + {makeFullProtoPortMap(), []string{"Port1", "Port2"}}, +} + +func makeEmptyProtoPortMap() map[string]int { + return make(map[string]int) +} + +func makeProtoPortMapWithPortZero() map[string]int { + return map[string]int{"Port": 0} +} + +func makeFullProtoPortMap() map[string]int { + return map[string]int{"Port1": 123, "Port2": 404, "Port3": 0} +} + +func TestSProtocols(t *testing.T) { + for _, testCase := range sProtocolsTestParams { + res := SProtocols(testCase.input) + + if len(res) != len(testCase.expectedOutput) { + t.Errorf("Expected %v, got: %v", testCase.expectedOutput, res) + } + } +} diff --git a/components/service.go b/components/service.go index ff8a14c..2243f2d 100644 --- a/components/service.go +++ b/components/service.go @@ -26,7 +26,7 @@ package components type Service struct { ID int `json:"-"` // Id assigned by the Service Registrar Definition string `json:"definition"` // Service definition or purpose - SubPath string `json:"-"` // The URL subpath after the resource's + SubPath string `json:"subpath"` // The URL subpath after the resource's Details map[string][]string `json:"details"` // Metadata or details about the service RegPeriod int `json:"registrationPeriod"` // The period until the registrar is expecting a sign of life RegTimestamp string `json:"-"` // the creation date in the Service Registry to ensure that reRegistration is with the same record @@ -37,7 +37,7 @@ type Service struct { CUnit string `json:"costUnit"` // cost unit } -// type Services is a collection of service stucts +// type Services is a collection of service structs type Services map[string]*Service // Merge method is used in the configuration use case to prevent the subpath or description to be changed or "configured" @@ -113,10 +113,11 @@ func MergeDetails(map1, map2 map[string][]string) map[string][]string { // A Cervice is a consumed service type Cervice struct { - Definition string - Details map[string][]string - Nodes map[string][]string - Protos []string + IReferentce string // Internal reference when consuming more than one service of the same type + Definition string // Service definition or purpose + Details map[string][]string + Nodes map[string][]string + Protos []string } // Cervises is a collection of "Cervice" structs diff --git a/components/service_test.go b/components/service_test.go new file mode 100644 index 0000000..3c400b2 --- /dev/null +++ b/components/service_test.go @@ -0,0 +1,210 @@ +package components + +import ( + "fmt" + "testing" +) + +func manualEqualityCheck(map1 map[string][]string, map2 map[string][]string) error { + if len(map1) != len(map2) { + return fmt.Errorf("Expected map length %d, got %d", len(map2), len(map1)) + } + for key, value := range map2 { + mv, ok := map1[key] + if !ok { + return fmt.Errorf("Expected key %q not found in merged map", key) + } + if len(mv) != len(value) { + return fmt.Errorf("For key %q, expected slice length %d, got %d", key, len(value), len(mv)) + } + for i := range value { + if mv[i] != value[i] { + return fmt.Errorf("For key %q, at index %d, expected %q, got %q", key, i, value[i], mv[i]) + } + } + } + for key := range map1 { + if _, ok := map2[key]; !ok { + return fmt.Errorf("Unexpected key %q found in merged map", key) + } + } + return nil +} + +var testServiceWithEmptyDetails = Service{ + ID: 1, + Definition: "original one", + SubPath: "testOriginalSubPath", + Details: make(map[string][]string), + RegPeriod: 45, + RegTimestamp: "", + RegExpiration: "", + Description: "A test original service", + SubscribeAble: false, + ACost: 0, + CUnit: "", +} + +var testService = Service{ + ID: 1, + Definition: "test", + SubPath: "testSubPath", + Details: make(map[string][]string), + RegPeriod: 45, + RegTimestamp: "", + RegExpiration: "", + Description: "A test service", + SubscribeAble: false, + ACost: 0, + CUnit: "", +} + +var testOriginalService = Service{ + ID: 1, + Definition: "original one", + SubPath: "testOriginalSubPath", + Details: map[string][]string{"test": {"test1", "test2"}}, + RegPeriod: 45, + RegTimestamp: "", + RegExpiration: "", + Description: "A test original service", + SubscribeAble: false, + ACost: 0, + CUnit: "", +} + +func TestMerge(t *testing.T) { + testService.Merge(&testOriginalService) + if testService.Definition != testOriginalService.Definition || + testService.SubPath != testOriginalService.SubPath || + testService.Description != testOriginalService.Description { + t.Errorf("Expected the test service to be the same as the original test service %s, got: %s", + testOriginalService.Definition, testService.Definition) + } +} + +func TestDeepCopy(t *testing.T) { + res := testOriginalService.DeepCopy() + res.Details["test"][0] = "changed" + res.Details["newkey"] = []string{"newTest"} + + if testOriginalService.Details["test"][0] == "changed" { + t.Errorf("DeepCopy failed, expected original slice to remain, original slice was mutated") + } + if _, ok := testOriginalService.Details["newkey"]; ok { + t.Errorf("DeepCopy failed, expected no new key in original, got %s", testOriginalService.Details["newkey"]) + } + + res = testServiceWithEmptyDetails.DeepCopy() + if len(res.Details) != 0 { + t.Errorf("DeepCopy failed, expected details map to be empty after copy, got: %v", res.Details) + } +} + +func makeNewTestService(id int, definition string) *Service { + return &Service{ + ID: id, + Definition: definition, + SubPath: "newTestServiceSubPath", + Details: make(map[string][]string), + RegPeriod: 45, + RegTimestamp: "", + RegExpiration: "", + Description: "A new test Service", + SubscribeAble: false, + ACost: 0, + CUnit: "", + } +} + +func TestCloneServices(t *testing.T) { + test1 := makeNewTestService(1, "test") + test2 := makeNewTestService(2, "test") + + cloned := CloneServices([]Service{*test1, *test2}) + if len(cloned) != 1 { + t.Errorf("Expected 1 Service, got %d", len(cloned)) + } + if cloned["test"].ID != 2 { + t.Errorf("Second Service did not overwrite the first as expected") + } + + cloned["test"].ID = 3 + if test1.ID == 3 || test2.ID == 3 { + t.Errorf("DeepCopy failed: mutation of clone affected either one of the originals") + } + + cloned = CloneServices(nil) + if cloned == nil { + t.Errorf("Expected non-nil empty map for nil input") + } + if len(cloned) != 0 { + t.Errorf("Expected 0 Services, got %d", len(cloned)) + } + + test1 = makeNewTestService(1, "") + test2 = makeNewTestService(2, "") + + cloned = CloneServices([]Service{*test1, *test2}) + + if len(cloned) != 1 { + t.Errorf("Expected 1 entry, got %d", len(cloned)) + } +} + +func makeNewMap(key string, value string) map[string][]string { + newMap := map[string][]string{ + key: {value}, + } + return newMap +} + +var expectedRegularMerge = map[string][]string{ + "a": {"1"}, + "b": {"2"}, +} + +var expectedKeyOverlapMerge = map[string][]string{ + "a": {"1", "3"}, +} + +var expectedOneEmptyMapMerge = map[string][]string{ + "a": {"1"}, +} + +var expectedBothEmptyMapMerge = map[string][]string{} + +type mergeDetailsTestStruct struct { + map1 map[string][]string + map2 map[string][]string + expected map[string][]string +} + +var mergeDetailsTestParams = []mergeDetailsTestStruct{ + {makeNewMap("a", "1"), makeNewMap("b", "2"), expectedRegularMerge}, + {makeNewMap("a", "1"), makeNewMap("a", "3"), expectedKeyOverlapMerge}, + {makeNewMap("a", "1"), make(map[string][]string), expectedOneEmptyMapMerge}, + {make(map[string][]string), make(map[string][]string), expectedBothEmptyMapMerge}, +} + +func TestMergeDetails(t *testing.T) { + for _, test := range mergeDetailsTestParams { + merged := MergeDetails(test.map1, test.map2) + + err := manualEqualityCheck(merged, test.expected) + if err != nil { + t.Errorf("Expected %v, got %v", test.expected, merged) + } + + if len(merged) != 0 { + merged["a"][0] = "changed" + + err1 := manualEqualityCheck(merged, test.map1) + err2 := manualEqualityCheck(merged, test.map2) + if err1 != nil || err2 != nil { + continue // The two maps should not be equal so if we get an "error" the test case has passed + } + t.Errorf("A change in the merged map resulted in a change in the input maps") + } + } +} diff --git a/components/system.go b/components/system.go index abccc4b..a2e7452 100644 --- a/components/system.go +++ b/components/system.go @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2024 Synecdoque + * Copyright (c) 2025 Synecdoque * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -22,16 +22,21 @@ package components import ( + "bytes" "context" "fmt" + "io" + "net/http" + "net/url" "os" "os/signal" + "sync" "syscall" ) -// System struct aggragates an Arrowhead compliant system +// System struct aggregates an Arrowhead compliant system type System struct { - Name string `json:"systemname"` + Name string `json:"systemName"` Host *HostingDevice // the system runs on a device Husk *Husk // the system aggregates a "husk" (a wrapper or a shell) UAssets map[string]*UnitAsset // the system aggregates "asset", which is made up of one or more unit-asset @@ -39,13 +44,15 @@ type System struct { Ctx context.Context // create a context that can be cancelled Sigs chan os.Signal // channel to initiate a graceful shutdown when Ctrl+C is pressed RegistrarChan chan *CoreSystem // channel for the lead service registrar + // Tracks which hosts to send log msgs to (and how many errors were encountered, before being removed) + Messengers map[string]int // list of messenger systems + Mutex *sync.Mutex } // CoreSystem struct holds details about the core system included in the configuration file type CoreSystem struct { - Name string `json:"coresystem"` - Url string `json:"url"` - Certificate string `json:"-"` + Name string `json:"coreSystem"` + Url string `json:"url"` } // NewSystem instantiates the new system and gathers the host information @@ -58,10 +65,82 @@ func NewSystem(name string, ctx context.Context) System { newSystem.RegistrarChan = make(chan *CoreSystem, 1) newSystem.Host = NewDevice() newSystem.UAssets = make(map[string]*UnitAsset) // initialize UAsset as an empty map + // Since the return System isn't a pointer (incorrectly), this map needs to + // be a pointer instead (usually not normal) and initialised (usually not needed) + // in order to avoid linter errors. + // The errors is due to this func returning a copy of newSystem and attempts + // to copy the mutex too, but it's not allowed for sync objects. + // Reference: https://stackoverflow.com/questions/37242009/function-returns-lock-by-value + newSystem.Messengers = make(map[string]int) + newSystem.Mutex = &sync.Mutex{} return newSystem } -// The following code is used only for issues support on GitHub @sdoque -------------------------- +func verifyStatus(u *url.URL) ([]byte, error) { + resp, err := http.Get(u.String()) + if err != nil { + return nil, err + } + defer resp.Body.Close() + // Body must be fully drained AND closed upon returning, otherwise it might leak memory + body, err := io.ReadAll(resp.Body) + if resp.StatusCode < 200 || resp.StatusCode > 299 { + return nil, fmt.Errorf("bad response: %d %s", resp.StatusCode, resp.Status) + } + return body, err +} + +const ServiceRegistrarName string = "serviceregistrar" +const ServiceRegistrarLeader string = "lead Service Registrar since" + +// GetRunningCoreSystemURL returns the URL of a running core system based on the provided type. +// When systemType is "serviceregistrar", it verifies the service is the lead registrar by checking +// its /status endpoint response. For other core system types, it simply tests that the URL is accessible. +func GetRunningCoreSystemURL(sys *System, systemType string) (string, error) { + // Store the latest error encountered when iterating thru the system list + // and then return this error if no matching system was found. + var lastErr error + + for _, core := range sys.CoreS { + // Ignore unrelated systems + if core.Name != systemType { + continue + } + + coreURL, err := url.Parse(core.Url) + if err != nil { + lastErr = fmt.Errorf("parsing core URL: %w", err) + continue + } + + coreSystemURL := coreURL.String() // Preserves the original URL + if core.Name != ServiceRegistrarName { + return coreSystemURL, nil + } + + // Perform extra checks on the response from a service registrar + coreURL = coreURL.JoinPath("status") + body, err := verifyStatus(coreURL) + if err != nil { + lastErr = fmt.Errorf("verifying registrar: %w", err) + continue + } + + // Skips non-leading registrars + if !bytes.HasPrefix(body, []byte(ServiceRegistrarLeader)) { + continue + } + return coreSystemURL, nil + } + + err := fmt.Errorf("core system '%s' not found", systemType) + if lastErr != nil { + err = fmt.Errorf("core system '%s' not found: %w", systemType, lastErr) + } + return "", err +} + +// The following code is used only for issues support on GitHub @sdoque var ( AppName string Version string @@ -70,6 +149,8 @@ var ( ) func getBuildInfo() { + // TODO: This info should be updated when setting up version release tools + // Leaving the fmt.Prints as is for now. if AppName != "" { fmt.Printf("System: %s - %s\n", AppName, Version) fmt.Printf("Build date: %s\n", BuildDate) diff --git a/components/system_test.go b/components/system_test.go new file mode 100644 index 0000000..46fc674 --- /dev/null +++ b/components/system_test.go @@ -0,0 +1,189 @@ +package components + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "testing" +) + +func TestNewSystem(t *testing.T) { + name := "TestingSystem" + ctx, cancel := context.WithCancel(context.Background()) + sys := NewSystem(name, ctx) + + if sys.Name != name { + t.Errorf("expected system name %s, got %s", name, sys.Name) + } + + // It's a bit of a silly test but the system context is an important dependency + // for cancelling some background services (system registration and http servers). + select { + case <-sys.Ctx.Done(): + t.Fatal("expected context to NOT be cancelled") + default: + // pass + } + + cancel() + select { + case <-sys.Ctx.Done(): + // pass + default: + t.Error("expected context to be cancelled") + } +} + +//////////////////////////////////////////////////////////////////////////////// + +type errorReadCloser struct { + r io.Reader + errRead error + errClose error +} + +func (ec errorReadCloser) Read(p []byte) (n int, err error) { + if ec.errRead != nil { + return 0, ec.errRead + } + return ec.r.Read(p) +} + +func (ec errorReadCloser) Close() error { + return ec.errClose +} + +var errMockTrans = fmt.Errorf("mock error") + +type mockTrans struct { + status int + body string + err error + errBody error +} + +func newMockTransport() *mockTrans { + t := &mockTrans{ + status: http.StatusOK, + } + // Hijack the default http client so no actual http requests are sent over the network + http.DefaultClient.Transport = t + return t +} + +func (t *mockTrans) setResponse(status int, body string) { + t.status = status + t.body = body +} + +func (t *mockTrans) setError() { + t.err = errMockTrans +} + +func (t *mockTrans) setBodyError() { + t.errBody = errMockTrans +} + +// RoundTrip method is required to fulfil the RoundTripper interface (as required by the DefaultClient). +// It prevents the request from being sent over the network. +func (t *mockTrans) RoundTrip(req *http.Request) (*http.Response, error) { + if t.err != nil { + return nil, t.err + } + resp := &http.Response{ + StatusCode: t.status, + Status: http.StatusText(t.status), + Body: errorReadCloser{ + strings.NewReader(t.body), + t.errBody, + nil, + }, + ContentLength: int64(len(t.body)), + Request: req, + } + return resp, nil +} + +const coreRegURL = "http://registrar" +const coreFakeURL = "http://fake" + +var coreReg = &CoreSystem{ServiceRegistrarName, coreRegURL} +var coreFake = &CoreSystem{"fakesystem", coreFakeURL} + +type sampleGetRunningCoreSystem struct { + name string + url string + wantErr bool + setup func(*mockTrans) +} + +var tableGetRunningCoreSystem = []sampleGetRunningCoreSystem{ + // Tests for non-registrars + // Case: unrelated system + {"bad name", "", true, nil}, + // Case: url.Parse() error + {coreFake.Name, "", true, func(m *mockTrans) { coreFake.Url = string(rune(0)) }}, + // Case: http.Get() no error + {coreFake.Name, coreFake.Url, false, func(m *mockTrans) { m.setError() }}, + // Case: io.ReadAll() no error + {coreFake.Name, coreFake.Url, false, func(m *mockTrans) { m.setBodyError() }}, + // Case: http < 200 no error + {coreFake.Name, coreFake.Url, false, func(m *mockTrans) { m.setResponse(199, "") }}, + // Case: http > 299 no error + {coreFake.Name, coreFake.Url, false, func(m *mockTrans) { m.setResponse(300, "") }}, + // Case: return url + {coreFake.Name, coreFake.Url, false, nil}, + + // Tests for registrars + // Case: url.Parse() error + {coreReg.Name, "", true, func(m *mockTrans) { coreReg.Url = string(rune(0)) }}, + // Case: http.Get() error + {coreReg.Name, "", true, func(m *mockTrans) { m.setError() }}, + // Case: io.ReadAll() error + {coreReg.Name, "", true, func(m *mockTrans) { m.setBodyError() }}, + // Case: http < 200 error + {coreReg.Name, "", true, func(m *mockTrans) { m.setResponse(199, "") }}, + // Case: http > 299 error + {coreReg.Name, "", true, func(m *mockTrans) { m.setResponse(300, "") }}, + // Case: return error when missing prefix string in body for registrar + {coreReg.Name, "", true, nil}, + // Case: return url + {coreReg.Name, coreReg.Url, false, func(m *mockTrans) { + m.setResponse(200, ServiceRegistrarLeader) + }, + }, +} + +func TestGetRunningCoreSystem(t *testing.T) { + name := "testSystem" + sys := NewSystem(name, context.Background()) + + // Case: return error for empty core system list (and should not match itself) + if len(sys.CoreS) != 0 { + t.Fatalf("expected no core systems, had %d in list", len(sys.CoreS)) + } + _, err := GetRunningCoreSystemURL(&sys, name) + if err == nil { + t.Error("expected error, got nil") + } + sys.CoreS = []*CoreSystem{coreReg, coreFake} + + for _, test := range tableGetRunningCoreSystem { + coreReg.Url = coreRegURL // reset URLs after testing url.Parse() errors + coreFake.Url = coreFakeURL + m := newMockTransport() + if test.setup != nil { + test.setup(m) + } + + gotURL, gotErr := GetRunningCoreSystemURL(&sys, test.name) + switch { + case test.wantErr == (gotErr == nil): + t.Errorf("expected error = %v, got: %v", test.wantErr, gotErr) + case gotURL != test.url: + t.Errorf("expected core system URL '%s', got '%s'", test.url, gotURL) + } + } +} diff --git a/components/uasset.go b/components/uasset.go index 300ba79..5b84ccd 100644 --- a/components/uasset.go +++ b/components/uasset.go @@ -31,5 +31,12 @@ type UnitAsset interface { GetServices() Services GetCervices() Cervices GetDetails() map[string][]string + GetTraits() any Serving(w http.ResponseWriter, r *http.Request, servicePath string) } + +// HasTraits is an interface that defines a method to get traits of a UnitAsset. +// used in usecases configuration and service discovery. +type HasTraits interface { + GetTraits() any // or interface{} in older Go +} diff --git a/forms/certificateForms.go b/forms/certificate_forms.go similarity index 95% rename from forms/certificateForms.go rename to forms/certificate_forms.go index db02549..42d6841 100644 --- a/forms/certificateForms.go +++ b/forms/certificate_forms.go @@ -51,5 +51,8 @@ func Certificate(w http.ResponseWriter, req *http.Request, sys components.System // Set the content type to text/plain w.Header().Set("Content-Type", "text/plain") - w.Write([]byte(cert)) + _, err := w.Write([]byte(cert)) + if err != nil { + log.Println("Error writing the certificate: ", err) + } } diff --git a/forms/costForms.go b/forms/cost_forms.go similarity index 100% rename from forms/costForms.go rename to forms/cost_forms.go diff --git a/forms/fileForms.go b/forms/file_forms.go similarity index 95% rename from forms/fileForms.go rename to forms/file_forms.go index 425fcaa..d94fcdd 100644 --- a/forms/fileForms.go +++ b/forms/file_forms.go @@ -57,6 +57,8 @@ func init() { FormTypeMap["FileForm_v1"] = reflect.TypeOf(FileForm_v1{}) } +const fileDir string = "files" + // TransferFile enables the transfer of different types files when the filename is given in the URL func TransferFile(w http.ResponseWriter, r *http.Request) { // Parse the URL to ensure it's valid and to easily extract parts of it @@ -81,6 +83,10 @@ func TransferFile(w http.ResponseWriter, r *http.Request) { contentType = "application/zip" case ".txt": contentType = "text/plain" + case ".owl": + contentType = "application/rdf+xml" + case ".ttl": + contentType = "text/turtle" case ".html", ".htm": contentType = "text/html" case ".csv": @@ -90,7 +96,7 @@ func TransferFile(w http.ResponseWriter, r *http.Request) { } // Open the requested file from the ./files directory - dir := http.Dir("./files") + dir := http.Dir(fileDir) reqFile, err := dir.Open(filename) if err != nil { log.Println("Requested file not found:", err) diff --git a/forms/file_forms_test.go b/forms/file_forms_test.go new file mode 100644 index 0000000..df763b6 --- /dev/null +++ b/forms/file_forms_test.go @@ -0,0 +1,152 @@ +package forms + +import ( + "fmt" + "net/http" + "net/http/httptest" + "os" + "path" + "path/filepath" + "testing" +) + +type transferFileTestStruct struct { + filename string + expectedBody string + expectedCode int + fileType string + testName string +} + +type mockResponseWriter struct { + http.ResponseWriter + statusCode int +} + +func (e *mockResponseWriter) Write(b []byte) (int, error) { + e.WriteHeader(300) + return 0, fmt.Errorf("Forced write error") +} + +func (e *mockResponseWriter) WriteHeader(statusCode int) { + e.statusCode = statusCode +} + +func (e *mockResponseWriter) Header() http.Header { + return make(http.Header) +} + +var transferFileTestParams = []transferFileTestStruct{ + {"test.jpeg", "\xff\xd8", + 200, ".jpeg", "Good case, jpeg works"}, + {"test.zip", "\x50\x4b\x03\x04", + 200, ".zip", "Good case, zip works"}, + {"test.txt", "\n", 200, ".txt", "Good case, txt works"}, + {"test.owl", ``, + 200, ".owl", "Good case, owl works"}, + {"test.ttl", "@prefix : <#> .@prefix rdf: .", + 200, ".ttl", "Good case, ttl works"}, + {"test.html", "", + 200, ".html", "Good case, html works"}, + {"test.csv", "id,name\n", + 200, ".csv", "Good case, csv works"}, + {"test.mp4", "\x00\x00\x00\x18\x66\x74\x79\x70\x69\x73\x6f\x6d\x00\x00\x02\x00\x69\x73\x6f\x6d\x69\x73\x6f\x32", + 200, ".mp4", "Good case, mp4 works"}, + {"test.txt", "Internal Server Error\n", + 500, ".txt", "Bad case, parsing url fails"}, + {"wrong.txt", "Not Found\n", + 404, ".txt", "Bad case, file not found"}, +} + +var fileTypeMap = map[string][]byte{ + ".jpeg": {0xFF, 0xD8}, + ".zip": {0x50, 0x4B, 0x03, 0x04}, + ".txt": []byte("\n"), + ".owl": []byte(``), + ".ttl": []byte("@prefix : <#> .@prefix rdf: ."), + ".html": []byte(""), + ".csv": []byte("id,name\n"), + ".mp4": {0x00, 0x00, 0x00, 0x18, 0x66, 0x74, 0x79, 0x70, 0x69, 0x73, 0x6F, 0x6D, + 0x00, 0x00, 0x02, 0x00, 0x69, 0x73, 0x6F, 0x6D, 0x69, 0x73, 0x6F, 0x32}, +} + +func createTestFolderAndFile(filename string, fileType string) error { + fullPath := filepath.Join(fileDir, filename) + err := os.MkdirAll(fileDir, 0755) + if err != nil { + return err + } + + f, err := os.OpenFile(fullPath, os.O_CREATE|os.O_RDWR, 0600) + if err != nil { + return err + } + defer f.Close() + + return os.WriteFile(fullPath, fileTypeMap[fileType], 0644) +} + +func removeTestFolderAndFile() error { + return os.RemoveAll(fileDir) +} + +func TestTransferFile(t *testing.T) { + for _, testCase := range transferFileTestParams { + fileURL := "/" + path.Join(fileDir, testCase.filename) + inputW := httptest.NewRecorder() + inputR := httptest.NewRequest(http.MethodPost, fileURL, nil) + if testCase.testName == "Bad case, parsing url fails" { + inputR.URL.Path = "/foo%ZZbar" + } + if testCase.testName == "Bad case, file not found" { + inputR.URL.Path = "/files/doesNotExist.error" + } + + err := createTestFolderAndFile(testCase.filename, testCase.fileType) + if err != nil { + t.Error(err) + continue + } + TransferFile(inputW, inputR) + err = removeTestFolderAndFile() + if err != nil { + t.Error(err) + } + + if inputW.Body.String() != testCase.expectedBody || inputW.Code != testCase.expectedCode { + t.Errorf("Expected: %s and %d, got: %s and %d", + testCase.expectedBody, testCase.expectedCode, inputW.Body.String(), inputW.Code) + } + } + + // Special case + fullPath := "/files/test.txt" + specialRecorder := &mockResponseWriter{} + inputR := httptest.NewRequest(http.MethodPost, fullPath, nil) + err := createTestFolderAndFile("test.txt", ".txt") + if err != nil { + t.Error(err) + return + } + TransferFile(specialRecorder, inputR) + err = removeTestFolderAndFile() + if err != nil { + t.Error(err) + } + + if specialRecorder.statusCode != 300 { + t.Errorf("Expected status code 300, got: %d", specialRecorder.statusCode) + } +} + +func TestFileEscape(t *testing.T) { + inputW := httptest.NewRecorder() + inputR := httptest.NewRequest(http.MethodPost, "http://localhost/../signal_forms.go", nil) + TransferFile(inputW, inputR) + + if inputW.Code != 404 { + t.Errorf("Expected error code 404, got: %d", inputW.Code) + } +} diff --git a/forms/formsDefiniton.go b/forms/forms_definition.go similarity index 100% rename from forms/formsDefiniton.go rename to forms/forms_definition.go diff --git a/forms/message_forms.go b/forms/message_forms.go new file mode 100644 index 0000000..56a494d --- /dev/null +++ b/forms/message_forms.go @@ -0,0 +1,98 @@ +package forms + +import ( + "fmt" + "reflect" +) + +// Register the forms +func init() { + FormTypeMap[messengerRegistrationVersion] = reflect.TypeOf(MessengerRegistration_v1{}) + FormTypeMap[systemMessageVersion] = reflect.TypeOf(SystemMessage_v1{}) +} + +//////////////////////////////////////////////////////////////////////////////// + +type MessengerRegistration_v1 struct { + Host string `json:"host"` + Version string `json:"version"` +} + +const messengerRegistrationVersion string = "MessengerRegistration_v1" + +func NewMessengerRegistration_v1(host string) MessengerRegistration_v1 { + return MessengerRegistration_v1{ + Host: host, + Version: messengerRegistrationVersion, + } +} + +func (f *MessengerRegistration_v1) NewForm() Form { + new := NewMessengerRegistration_v1("") + return &new +} + +func (f *MessengerRegistration_v1) FormVersion() string { return f.Version } + +//////////////////////////////////////////////////////////////////////////////// + +// MessageLevel indicates the importance or criticality of a message. +type MessageLevel int + +// Mimics the levels from the "slog" package +const ( + LevelDebug MessageLevel = -4 + LevelInfo MessageLevel = 0 + LevelWarn MessageLevel = 4 + LevelError MessageLevel = 8 +) + +func LevelToString(lvl MessageLevel) string { + switch lvl { + case LevelDebug: + return "DEBUG" + case LevelInfo: + return "INFO" + case LevelWarn: + return "WARN" + case LevelError: + return "ERROR" + default: + return "UNKNOWN" + } +} + +// A SystemMessage is a log message sent from a system to one or many messengers. +// The receiving messengers will note the message's time of arrival. +// The timestamp is noted on the messenger side, so as to maintain a uniform +// chronological order of the messages (if, for example, there exists systems +// on other hosts with misconfigured time or timezone). +type SystemMessage_v1 struct { + Level MessageLevel `json:"level"` // Severity level + Body string `json:"body"` // Plaintext string of the actual message to be logged. + System string `json:"system"` // The system sending the log + Version string `json:"version"` +} + +const systemMessageVersion string = "SystemMessage_v1" + +func NewSystemMessage_v1(lvl MessageLevel, body string, system string) SystemMessage_v1 { + return SystemMessage_v1{ + Level: lvl, + Body: body, + System: system, + Version: systemMessageVersion, + } +} + +func (f SystemMessage_v1) String() string { + return fmt.Sprintf("%s %s", LevelToString(f.Level), f.Body) +} + +// NewForm resets the form and defaults to using LevelInfo. +func (f *SystemMessage_v1) NewForm() Form { + new := NewSystemMessage_v1(LevelInfo, "", "") + return &new +} + +func (f *SystemMessage_v1) FormVersion() string { return f.Version } diff --git a/forms/serviceForms.go b/forms/service_forms.go old mode 100755 new mode 100644 similarity index 98% rename from forms/serviceForms.go rename to forms/service_forms.go index e7542a2..88b81e0 --- a/forms/serviceForms.go +++ b/forms/service_forms.go @@ -64,7 +64,7 @@ func init() { /////////////////////////////////////////////////////////////////////////////// type ServiceRecordList_v1 struct { - List []ServiceRecord_v1 `list:"version"` + List []ServiceRecord_v1 `json:"list"` Version string `json:"version"` } diff --git a/forms/servicequestForms.go b/forms/servicequest_forms.go similarity index 97% rename from forms/servicequestForms.go rename to forms/servicequest_forms.go index 79603a9..f3dee27 100644 --- a/forms/servicequestForms.go +++ b/forms/servicequest_forms.go @@ -30,7 +30,7 @@ import "reflect" type ServiceQuest_v1 struct { SysId int `json:"systemId"` RequesterName string `json:"requesterName"` - ServiceDefinition string `json:"serrviceDefinition"` + ServiceDefinition string `json:"serviceDefinition"` Protocol string `json:"protocol"` Details map[string][]string `json:"details"` Version string `json:"version"` diff --git a/forms/signalForms.go b/forms/signal_forms.go similarity index 100% rename from forms/signalForms.go rename to forms/signal_forms.go diff --git a/forms/systemForms.go b/forms/system_forms.go old mode 100755 new mode 100644 similarity index 100% rename from forms/systemForms.go rename to forms/system_forms.go diff --git a/go.mod b/go.mod index 1b71a59..f0890bb 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/sdoque/mbaigo -go 1.23.4 +go 1.24.4 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e69de29 diff --git a/tests/configuration.go b/tests/configuration.go new file mode 100644 index 0000000..2f52ea3 --- /dev/null +++ b/tests/configuration.go @@ -0,0 +1,55 @@ +package tests + +import ( + "encoding/json" + "os/signal" + + "github.com/sdoque/mbaigo/components" + "github.com/sdoque/mbaigo/usecases" +) + +// PROPOSAL: new additions to usecases/configuration.go + +// NewResourceFunc is the function type used for loading unit assets that were +// defined in "systemconfig.json". +// A new, custom instance of [Components.UnitAsset] should be created and populated +// with fields from the provided [usecases.ConfigurableAsset]. +// Any services or consumed services should be added too. +// The function should then return the UnitAsset and an optional cleanup function. +// +// TODO: this function really needs an error return +// TODO: feels unnecessarily confusing to provide system instance. +type NewResourceFunc func(usecases.ConfigurableAsset, *components.System) (components.UnitAsset, func()) + +// LoadResources loads all unit assets from rawRes (which was loaded from "systemconfig.json" file) +// and calls newResFunc repeatedly for each loaded asset. +// The fully loaded unit asset and an optional cleanup function are collected from +// newResFunc and are then attached to the sys system. +// LoadResources then returns a system cleanup function and an optional error. +// The error always originate from [json.Unmarshal]. +func LoadResources(sys *components.System, rawRes []json.RawMessage, newResFunc NewResourceFunc) (func(), error) { + // Resets this map so it can be filled with loaded unit assets (rather than templates) + sys.UAssets = make(map[string]*components.UnitAsset) + + var cleanups []func() + for _, raw := range rawRes { + var ca usecases.ConfigurableAsset + if err := json.Unmarshal(raw, &ca); err != nil { + return func() {}, err + } + + ua, f := newResFunc(ca, sys) + sys.UAssets[ua.GetName()] = &ua + cleanups = append(cleanups, f) + } + + doCleanups := func() { + for _, f := range cleanups { + f() + } + // Stops hijacking SIGINT and return signal control to user + signal.Stop(sys.Sigs) + close(sys.Sigs) + } + return doCleanups, nil +} diff --git a/tests/examples_test.go b/tests/examples_test.go new file mode 100644 index 0000000..0e8dd1d --- /dev/null +++ b/tests/examples_test.go @@ -0,0 +1,171 @@ +package tests + +import ( + "context" + "errors" + "math/rand" + "net/http" + "os" + "path" + + "github.com/sdoque/mbaigo/components" + "github.com/sdoque/mbaigo/forms" + "github.com/sdoque/mbaigo/usecases" +) + +const ( + unitName string = "randomiser" + unitService string = "random" +) + +// Traits are Asset-specific configurable parameters +type Traits struct { + Address string `json:"address"` // Address of the IO + Value float64 `json:"value"` // Start up value of the IO + MinValue float64 `json:"minValue"` // Minimum value of the IO + MaxValue float64 `json:"maxValue"` // Maximum value of the IO +} + +// The most simplest unit asset +type uaRandomiser struct { + Name string `json:"-"` + Owner *components.System `json:"-"` + Details map[string][]string `json:"-"` + ServicesMap components.Services `json:"-"` + CervicesMap components.Cervices `json:"-"` + Traits +} + +// Force type check (fulfilling the interface) at compile time +var _ components.UnitAsset = &uaRandomiser{} + +// Add required functions to fulfil the UnitAsset interface +func (ua uaRandomiser) GetName() string { return ua.Name } +func (ua uaRandomiser) GetServices() components.Services { return ua.ServicesMap } +func (ua uaRandomiser) GetCervices() components.Cervices { return ua.CervicesMap } +func (ua uaRandomiser) GetDetails() map[string][]string { return ua.Details } +func (ua uaRandomiser) GetTraits() any { return ua.Traits } +func (ua uaRandomiser) Serving(w http.ResponseWriter, r *http.Request, servicePath string) { + if servicePath != unitService { + http.Error(w, "unknown service path: "+servicePath, http.StatusBadRequest) + return + } + + f := forms.SignalA_v1a{ + Value: rand.Float64(), + } + b, err := usecases.Pack(f.NewForm(), "application/json") + if err != nil { + http.Error(w, "error from Pack: "+err.Error(), http.StatusInternalServerError) + return + } + if _, err := w.Write(b); err != nil { + http.Error(w, "error from Write: "+err.Error(), http.StatusInternalServerError) + } +} + +func createUATemplate(sys *components.System) { + s := &components.Service{ + Definition: unitService, // The "name" of the service + SubPath: unitService, // Not "allowed" to be changed afterwards + Details: map[string][]string{"key1": {"value1"}}, + RegPeriod: 60, + // NOTE: must start with lower-case, it gets embedded into another sentence in the web API + Description: "returns a random float64", + } + ua := components.UnitAsset(&uaRandomiser{ + Name: unitName, // WARN: don't use the system name!! this is an asset! + Details: map[string][]string{"key2": {"value2"}}, + ServicesMap: components.Services{ + s.SubPath: s, + }, + }) + sys.UAssets[ua.GetName()] = &ua +} + +func loadUAConfig(ca usecases.ConfigurableAsset, sys *components.System) (components.UnitAsset, func()) { + s := ca.Services[0] + ua := &uaRandomiser{ + Name: ca.Name, + Owner: sys, + Details: ca.Details, + ServicesMap: usecases.MakeServiceMap(ca.Services), + // Let it consume its own service + CervicesMap: components.Cervices{unitService: &components.Cervice{ + Definition: s.Definition, + Details: s.Details, + // Nodes will be filled up by any discovered cervices + Nodes: make(map[string][]string, 0), + }}, + } + return ua, func() {} +} + +//////////////////////////////////////////////////////////////////////////////// + +const ( + systemName string = "test" + systemPort int = 29999 +) + +var serviceURL = "GET /" + path.Join(systemName, unitName, unitService) + +// The most simplest system +func newSystem() (*components.System, func(), error) { + ctx, cancel := context.WithCancel(context.Background()) + + // TODO: want this to return a pointer type instead! + // easier to use and pointer is used all the time anyway down below + sys := components.NewSystem(systemName, ctx) + sys.Husk = &components.Husk{ + Description: " is the most simplest system possible", + Details: map[string][]string{"key3": {"value3"}}, + ProtoPort: map[string]int{"http": systemPort}, + } + + // Setup default config with default unit asset and values + createUATemplate(&sys) + rawResources, err := usecases.Configure(&sys) + + // Extra check to work around "created config" error. Not required normally! + if err != nil { + // Return errors not related to config creation + if errors.Is(err, usecases.ErrNewConfig) == false { + cancel() + return nil, nil, err + } + // Since Configure() created the config file, it must be cleaned up when this test is done! + defer os.Remove("systemconfig.json") + // Default config file was created, redo the func call to load the file + rawResources, err = usecases.Configure(&sys) + if err != nil { + cancel() + return nil, nil, err + } + } + // NOTE: if the config file already existed (thus the above error block didn't + // get to run), then the config file should be left alone and not removed! + + // Load unit assets defined in the config file + cleanups, err := LoadResources(&sys, rawResources, loadUAConfig) + if err != nil { + cancel() + return nil, nil, err + } + + // TODO: this is not ready for production yet? + // usecases.RequestCertificate(&sys) + + usecases.RegisterServices(&sys) + + // TODO: prints logs + usecases.SetoutServers(&sys) + + stop := func() { + cancel() + // TODO: a waitgroup or something should be used to make sure all goroutines have stopped + // Not doing much in the mock cleanups so this works fine for now...? + cleanups() + } + return &sys, stop, nil +} diff --git a/tests/integration_test.go b/tests/integration_test.go new file mode 100644 index 0000000..3ccf208 --- /dev/null +++ b/tests/integration_test.go @@ -0,0 +1,246 @@ +package tests + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "path" + "runtime" + "runtime/pprof" + "strings" + "sync" + "testing" + "time" + + "github.com/sdoque/mbaigo/components" + "github.com/sdoque/mbaigo/forms" + "github.com/sdoque/mbaigo/usecases" +) + +type requestEvent struct { + event string + hits int + body []byte +} + +// Mock simulating traffic between a system and registrars/orchestrators +type mockTrans struct { + t *testing.T + hits map[string]int // Used to track http requests + mutex sync.Mutex // For protecting access to the above map + events chan requestEvent // Tracks service "events" and requests to the cloud services +} + +func newMockTransport(t *testing.T) *mockTrans { + m := &mockTrans{ + t: t, + hits: make(map[string]int), + events: make(chan requestEvent), + } + // Hijack the default http client so no actual http requests are sent over the network + http.DefaultClient.Transport = m + return m +} + +func (m *mockTrans) waitFor(event string) (int, []byte, error) { + select { + case e := <-m.events: + if e.event != event { + return 0, nil, fmt.Errorf("got %s, expected %s", e.event, event) + } + return e.hits, e.body, nil + case <-time.Tick(10 * time.Second): + return 0, nil, fmt.Errorf("event timeout") + } +} + +func newServiceRecord() []byte { + f := forms.ServiceRecord_v1{ + Id: 13, // NOTE: this should match with eventUnregister + Created: time.Now().Format(time.RFC3339), + EndOfValidity: time.Now().Format(time.RFC3339), + Version: "ServiceRecord_v1", + } + b, err := usecases.Pack(&f, "application/json") + if err != nil { + panic(err) // Hard fail if Pack() can't handle the above form + } + return b +} + +func newServicePoint() []byte { + f := forms.ServicePoint_v1{ + // per usecases/registration.go:serviceRegistrationForm() + ServNode: fmt.Sprintf("localhost_%s_%s_%s", systemName, unitName, unitService), + // per orchestrator/thing.go:selectService() + ServLocation: fmt.Sprintf("http://localhost:%d/%s/%s/%s", + systemPort, systemName, unitName, unitService, + ), + Version: "ServicePoint_v1", + } + b, err := usecases.Pack(&f, "application/json") + if err != nil { + panic(err) // Another hard fail if Pack() can't work with the above form + } + return b +} + +const ( + eventRegistryStatus string = "GET /serviceregistrar/registry/status" + eventRegister string = "POST /serviceregistrar/registry/register" + eventUnregister string = "DELETE /serviceregistrar/registry/unregister/13" + eventOrchestration string = "GET /orchestrator/orchestration" + eventOrchestrate string = "POST /orchestrator/orchestration/squest" +) + +var mockRequests = map[string]struct { + sendEvent bool + status int + body []byte +}{ + eventRegistryStatus: {false, 200, []byte(components.ServiceRegistrarLeader)}, + eventRegister: {true, 200, newServiceRecord()}, + eventUnregister: {true, 200, nil}, + eventOrchestration: {false, 200, nil}, + eventOrchestrate: {true, 200, newServicePoint()}, +} + +func (m *mockTrans) RoundTrip(req *http.Request) (*http.Response, error) { + m.mutex.Lock() // This lock is mainly for guarding concurrent access to the hits map + defer m.mutex.Unlock() + event := req.Method + " " + req.URL.Path + m.hits[event] += 1 + if event == serviceURL { + // The example service will, through the system, return a proper response + return http.DefaultTransport.RoundTrip(req) + } + + // Any other requests needs to be mocked, simulating responses from the + // service registrar and orchestrator. + mock, found := mockRequests[event] + if !found { + m.t.Errorf("unknown request: %s", event) + // Let's see how the system responds to this + mock.status = http.StatusNotImplemented + mock.body = []byte(http.StatusText(mock.status)) + } + rec := httptest.NewRecorder() + rec.Header().Set("Content-Type", "application/json") + rec.WriteHeader(mock.status) + rec.Write(mock.body) // Safe to ignore the returned error, it's always nil + + // Allows for syncing up the test, with the request flow performed by the system + if mock.sendEvent { + var b []byte + if req.Body != nil { + var err error + b, err = io.ReadAll(req.Body) + if err != nil { + m.t.Errorf("failed reading request body: %v", err) + } + defer req.Body.Close() + } + // Using a goroutine prevents thread locking + go func(e string, h int, b []byte) { + m.events <- requestEvent{e, h, b} + }(event, m.hits[event], b) + } + return rec.Result(), nil +} + +//////////////////////////////////////////////////////////////////////////////// + +func countGoroutines() (int, string) { + c := runtime.NumGoroutine() + buf := &bytes.Buffer{} + // A write to this buffer will always return nil error, so safe to ignore here. + // This call will spawn some goroutine too, so need to chill for a little while. + _ = pprof.Lookup("goroutine").WriteTo(buf, 2) + trace := buf.String() + // Calling signal.Notify() will leave an extra goroutine that runs forever, + // so it should be subtracted from the count. For more info, see: + // https://github.com/golang/go/issues/52619 + // https://github.com/golang/go/issues/72803 + // https://github.com/golang/go/issues/21576 + if strings.Contains(trace, "os/signal.signal_recv") { + c -= 1 + } + return c, trace +} + +func assertNotEq(t *testing.T, got, want any) { + if got != want { + t.Errorf("got %v, expected %v", got, want) + } +} + +func TestSimpleSystemIntegration(t *testing.T) { + routinesStart, _ := countGoroutines() + m := newMockTransport(t) + sys, stopSystem, err := newSystem() + if err != nil { + t.Fatalf("expected no error, got: %s", err) + } + + // Validate service registration + hits, body, err := m.waitFor(eventRegister) + assertNotEq(t, err, nil) + if hits != 1 { + t.Errorf("system skipped: %s", eventRegister) + } + var sr forms.ServiceRecord_v1 + err = json.Unmarshal(body, &sr) + assertNotEq(t, err, nil) + assertNotEq(t, sr.SystemName, systemName) + assertNotEq(t, sr.SubPath, path.Join(unitName, unitService)) + + // Validate service usage + ua := *sys.UAssets[unitName] + if ua == nil { + t.Fatalf("system missing unit asset: %s", unitName) + } + service := ua.GetCervices()[unitService] + if service == nil { + t.Fatalf("unit asset missing cervice: %s", unitService) + } + f, err := usecases.GetState(service, sys) + assertNotEq(t, err, nil) + fs, ok := f.(*forms.SignalA_v1a) + if ok == false || fs == nil || fs.Value == 0.0 { + t.Errorf("invalid form: %#v", f) + } + + // Late validation for service discovery + hits, body, err = m.waitFor(eventOrchestrate) + assertNotEq(t, err, nil) + if hits != 1 { + t.Errorf("system skipped: %s", eventUnregister) + } + var sq forms.ServiceQuest_v1 + err = json.Unmarshal(body, &sq) + assertNotEq(t, err, nil) + assertNotEq(t, sq.ServiceDefinition, unitService) + + // Validate service unregister + stopSystem() + hits, _, err = m.waitFor(eventUnregister) // NOTE: doesn't receive a body + assertNotEq(t, err, nil) + if hits != 1 { + t.Errorf("system skipped: %s", eventUnregister) + } + + // Detect any leaking goroutines + // Delay a short moment and let the goroutines finish. Not sure if there's + // a better way to wait for an _unknown number_ of goroutines. + // This might give flaky test results in slower environments! + time.Sleep(1 * time.Second) + routinesStop, trace := countGoroutines() + if (routinesStop - routinesStart) != 0 { + t.Errorf("leaking goroutines: count at start=%d, stop=%d\n%s", + routinesStart, routinesStop, trace, + ) + } +} diff --git a/usecases/authentication.go b/usecases/authentication.go index 134ff25..d37d1ce 100644 --- a/usecases/authentication.go +++ b/usecases/authentication.go @@ -29,6 +29,7 @@ import ( "encoding/pem" "fmt" "log" + "net" "net/http" "strings" @@ -40,18 +41,28 @@ func RequestCertificate(sys *components.System) { // Generate ECDSA Private Key privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { - log.Fatalf("Failed to generate private key: %v", err) + log.Fatalf("Failed to generate private key: %v\n", err) } sys.Husk.Pkey = privateKey + dnsNames := []string{"localhost"} + var ipAddrs []net.IP + for _, ipStr := range sys.Host.IPAddresses { + ip := net.ParseIP(ipStr) + if ip != nil { + ipAddrs = append(ipAddrs, ip) + } + } csrTemplate := x509.CertificateRequest{ Subject: sys.Husk.DName, + DNSNames: dnsNames, // this is the SAN DNS + IPAddresses: ipAddrs, // this is the SAN IPs SignatureAlgorithm: x509.ECDSAWithSHA256, } csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &csrTemplate, privateKey) if err != nil { - log.Fatalf("Failed to create CSR: %v", err) + log.Fatalf("Failed to create CSR: %v\n", err) return } @@ -61,7 +72,7 @@ func RequestCertificate(sys *components.System) { // Send the CSR to the CA and receive the certificate in response response, err := sendCSR(sys, csrPEM) if err != nil { - log.Printf("certification failure: %v", err) + log.Printf("certification failure: %v\n", err) return } @@ -71,7 +82,7 @@ func RequestCertificate(sys *components.System) { // Get CA's certificate caCert, err := getCACertificate(sys) if err != nil { - log.Printf("failed to obtain CA's certificate: %v", err) + log.Printf("failed to obtain CA's certificate: %v\n", err) return } sys.Husk.CA_cert = caCert @@ -79,13 +90,13 @@ func RequestCertificate(sys *components.System) { // Load CA certificate caCertPool := x509.NewCertPool() if ok := caCertPool.AppendCertsFromPEM([]byte(caCert)); !ok { - log.Fatalf("Failed to append CA certificate to pool") + log.Fatalf("Failed to append CA certificate to pool\n") } // Prepare the client's certificate and key for TLS configuration clientCert, err := prepareClientCertificate(sys.Husk.Certificate, sys.Husk.Pkey) if err != nil { - log.Fatalf("Failed to prepare client certificate: %v", err) + log.Fatalf("Failed to prepare client certificate: %v\n", err) } // Configure Transport Layer Security (TLS) @@ -93,6 +104,7 @@ func RequestCertificate(sys *components.System) { Certificates: []tls.Certificate{clientCert}, RootCAs: caCertPool, InsecureSkipVerify: false, + MinVersion: tls.VersionTLS12, } sys.Husk.TlsConfig = tlsConfig @@ -100,7 +112,7 @@ func RequestCertificate(sys *components.System) { fmt.Printf("System %s's parsed Certificate:\n", sys.Name) cert, err := x509.ParseCertificate(clientCert.Certificate[0]) if err != nil { - log.Printf("failed to parse certificate: %v", err) + log.Printf("failed to parse certificate: %v\n", err) return } fmt.Printf(" Subject: %s\n", cert.Subject) @@ -108,23 +120,25 @@ func RequestCertificate(sys *components.System) { fmt.Printf(" Serial Number: %d\n", cert.SerialNumber) fmt.Printf(" Not Before: %s\n", cert.NotBefore) fmt.Printf(" Not After: %s\n", cert.NotAfter) + fmt.Printf(" DNS Names: %v\n", cert.DNSNames) + fmt.Printf(" IP Addresses: %v\n", cert.IPAddresses) + } func sendCSR(sys *components.System, csrPEM []byte) (string, error) { - var err error - url := "" - for _, cSys := range sys.CoreS { - core := cSys - if core.Name == "ca" { - url = core.Url - } - } - if url == "" { - return "", fmt.Errorf("failed to locate certificate authority: %w", err) + url, err := components.GetRunningCoreSystemURL(sys, "ca") // Assuming the first core system is the CA + if err != nil { + return "", fmt.Errorf("failed to get CA URL: %w", err) } url += "/certify" - resp, err := http.Post(url, "application/x-pem-file", bytes.NewReader(csrPEM)) + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(csrPEM)) + if err != nil { + log.Printf("Error creating request: %v", err) + return "", err + } + req.Header.Set("Content-Type", "application/x-pem-file") + resp, err := http.DefaultClient.Do(req) if err != nil { return "", fmt.Errorf("failed to send CSR: %w", err) } @@ -137,28 +151,31 @@ func sendCSR(sys *components.System, csrPEM []byte) (string, error) { // Read the response body buf := new(bytes.Buffer) - buf.ReadFrom(resp.Body) + _, err = buf.ReadFrom(resp.Body) + if err != nil { + log.Printf("Error while reading body: %v", err) + return "", err + } return buf.String(), nil } // getCACertificate gets the CA's certificate necessary for the dual server-client authentication in the TLS setup func getCACertificate(sys *components.System) (string, error) { - var err error - coreUAurl := "" - for _, cSys := range sys.CoreS { - core := cSys - if core.Name == "ca" { - coreUAurl = core.Url - } - } - if coreUAurl == "" { - return "", fmt.Errorf("failed to locate certificate authority: %w", err) + coreUAurl, err := components.GetRunningCoreSystemURL(sys, "ca") // Assuming the first core system is the CA + if err != nil { + return "", fmt.Errorf("failed to get CA URL: %w", err) } - url := strings.TrimSuffix(coreUAurl, "ification") // the configuration file address to the CA includes the unit asset + // Remove the "ification" suffix from the URL to get the CA's address + url := strings.TrimSuffix(coreUAurl, "ification") // Make a GET request to the CA's endpoint - resp, err := http.Get(url) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + log.Printf("Error creating request: %v", err) + return "", err + } + resp, err := http.DefaultClient.Do(req) if err != nil { return "", fmt.Errorf("failed to send request to CA: %w", err) } @@ -171,7 +188,11 @@ func getCACertificate(sys *components.System) (string, error) { // Read the response body buf := new(bytes.Buffer) - buf.ReadFrom(resp.Body) + _, err = buf.ReadFrom(resp.Body) + if err != nil { + log.Printf("Error while reading body: %v", err) + return "", err + } return buf.String(), nil } diff --git a/usecases/configuration.go b/usecases/configuration.go index 4862200..221f943 100644 --- a/usecases/configuration.go +++ b/usecases/configuration.go @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2024 Synecdoque + * Copyright (c) 2025 Synecdoque * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -22,160 +22,173 @@ package usecases import ( - "crypto/x509/pkix" "encoding/json" + "errors" "fmt" "os" "github.com/sdoque/mbaigo/components" ) -// templateOut is the stuct used to prepare the systemconfig.json file +// configurableAsset is a struct that contains the name of the asset and its +// configurable details and services +type ConfigurableAsset struct { + Name string `json:"name"` + Details map[string][]string `json:"details"` + Services []components.Service `json:"services"` + Traits []json.RawMessage `json:"traits"` +} + +// templateOut is the struct used to prepare the systemconfig.json file type templateOut struct { CName string `json:"systemname"` - UAsset []components.UnitAsset `json:"unit_assets"` - CServices []components.Service `json:"services"` + LocalCloud string `json:"localcloud,omitempty"` + Assets []ConfigurableAsset `json:"unit_assets"` Protocols map[string]int `json:"protocolsNports"` - PKIdetails pkix.Name `json:"distinguishedName"` CCoreS []components.CoreSystem `json:"coreSystems"` } -// configFileIn is used to extact out the information of the systemconfig.json file +// configFileIn is used to extract out the information of the systemconfig.json file // Since it does not know about the details of the Thing, it does not unmarsahll this // information type configFileIn struct { - CName string `json:"systemname"` - rawResources []json.RawMessage `json:"-"` - CServices []components.Service `json:"services"` - Protocols map[string]int `json:"protocolsNports"` - PKIdetails pkix.Name `json:"distinguishedName"` - CCoreS []components.CoreSystem `json:"coreSystems"` + CName string `json:"systemname"` + LocalCloud string `json:"localcloud,omitempty"` + Protocols map[string]int `json:"protocolsNports"` + CCoreS []components.CoreSystem `json:"coreSystems"` + Resources []json.RawMessage `json:"unit_assets"` } -// Configure read the system configuration JSON file to get the deployment details. -// If the file is missing, it generates a default systemconfig.json file and shuts down the system -func Configure(sys *components.System) ([]json.RawMessage, []components.Service, error) { +var ErrNewConfig = errors.New("new config file was created") - var rawBytes []json.RawMessage // the mbaigo library does not know about the unit asset's structure (defined in the file thing.go and not part of the library) - var servicesList []components.Service // this is the list of services for each unit asset - // prepare content of configuration file - var defaultConfig templateOut +func setupDefaultConfig(sys *components.System) (defaultConfig templateOut, err error) { + var assetTemplate components.UnitAsset + if sys.UAssets == nil { + return templateOut{}, fmt.Errorf("unitAssets missing") + } + + for _, ua := range sys.UAssets { + assetTemplate = *ua // this creates a copy (value, not reference) + break + } + + servicesTemplate := getServicesList(assetTemplate) + confAsset := ConfigurableAsset{ + Name: assetTemplate.GetName(), + Details: assetTemplate.GetDetails(), + Services: servicesTemplate, + } + + // If the asset exposes traits, serialize them and store as raw JSON + if assetWithTraits, ok := assetTemplate.(components.HasTraits); ok { + if traits := assetWithTraits.GetTraits(); traits != nil { + traitJSON, err := json.Marshal(traits) + if err != nil { + return templateOut{}, fmt.Errorf("couldn't marshal traits: %v", err) + } + confAsset.Traits = []json.RawMessage{traitJSON} + } + } + + // prepare content of configuration file defaultConfig.CName = sys.Name + for key, values := range sys.Husk.Details { // if the system has a LocalCloud detail, add it to the config file + if key == "LocalCloud" && len(values) > 0 { + defaultConfig.LocalCloud = values[0] + break + } + } defaultConfig.Protocols = sys.Husk.ProtoPort - defaultConfig.UAsset = getFirstAsset(sys.UAssets) - originalSs := getServicesList(defaultConfig.UAsset[0]) - defaultConfig.CServices = originalSs - - defaultConfig.PKIdetails.CommonName = "arrowhead.eu" - defaultConfig.PKIdetails.Country = []string{"SE"} - defaultConfig.PKIdetails.Province = []string{"Norrbotten"} - defaultConfig.PKIdetails.Locality = []string{"Luleaa"} - defaultConfig.PKIdetails.Organization = []string{"Luleaa University of Technology"} - defaultConfig.PKIdetails.OrganizationalUnit = []string{"CPS"} - - serReg := components.CoreSystem{ - Name: "serviceregistrar", - Url: "http://localhost:20102/serviceregistrar/registry", - Certificate: ".X509pubKey", + defaultConfig.Assets = []ConfigurableAsset{confAsset} // this is a list of unit assets + + servReg := components.CoreSystem{ + Name: "serviceregistrar", + Url: "http://localhost:20102/serviceregistrar/registry", } orches := components.CoreSystem{ - Name: "orchestrator", - Url: "http://localhost:20103/orchestrator/orchestration", - Certificate: ".X509pubKey", + Name: "orchestrator", + Url: "http://localhost:20103/orchestrator/orchestration", } ca := components.CoreSystem{ - Name: "ca", - Url: "http://localhost:20100/ca/certification", - Certificate: ".X509pubKey", + Name: "ca", + Url: "http://localhost:20100/ca/certification", + } + maitreD := components.CoreSystem{ + Name: "maitreD", + Url: "http://localhost:20101/maitreD/maitreD", } - coreSystems := []components.CoreSystem{serReg, orches, ca} + + // add the core systems to the configuration file + // the system is part of a local cloud with mandatory core systems + coreSystems := []components.CoreSystem{servReg, orches, ca, maitreD} defaultConfig.CCoreS = coreSystems + return defaultConfig, nil +} - // open the configuration file or create one with the default content prepared above - systemConfigFile, err := os.Open("systemconfig.json") +// Configure reads the system configuration JSON file to get the deployment details. +// If the file is missing, it generates a default systemconfig.json file and shuts down the system +func Configure(sys *components.System) ([]json.RawMessage, error) { + defaultConfig, err := setupDefaultConfig(sys) + if err != nil { + return nil, fmt.Errorf("couldn't create default config: %v", err) + } - if err != nil { // could not find the systemconfig.json so a default one is being created - defaultConfigFile, err := os.Create("systemconfig.json") - if err != nil { - return rawBytes, servicesList, err - } - defer defaultConfigFile.Close() - systemconfigjson, err := json.MarshalIndent(defaultConfig, "", " ") - if err != nil { - return rawBytes, servicesList, err - } - nBytes, err := defaultConfigFile.Write(systemconfigjson) + // 0600 allows user Read/Write permission (secure config file), but no R/W for groups and others, 0644 to allow R/W on sudo and only R on groups/others, 0666 for R/W permissions for everyone + systemConfigFile, err := os.OpenFile("systemconfig.json", os.O_RDWR|os.O_CREATE, 0600) + if err != nil { + return nil, fmt.Errorf("error while opening/creating systemconfig file: %v", err) + } + defer systemConfigFile.Close() + + fileInfo, err := systemConfigFile.Stat() // *.Stat() returns fileInfo/stats + if err != nil { + return nil, fmt.Errorf("error occurred while getting config file stats: %s", err) + } + if fileInfo.Size() == 0 { // *.Size() returns the filesize (number bytes) as an int, 0 is an empty file + enc := json.NewEncoder(systemConfigFile) + enc.SetIndent("", " ") + err = enc.Encode(defaultConfig) // Write default values into systemconfig since file was empty if err != nil { - return rawBytes, servicesList, err + return nil, fmt.Errorf("error writing default values to system config: %v", err) } - return rawBytes, servicesList, fmt.Errorf("a new configuration file has been written with %d bytes. Please update it and restart the system", nBytes) + return nil, ErrNewConfig } - // the system configuration file could be open, read the configurations and pass them on to the system - defer systemConfigFile.Close() - configBytes, err := os.ReadFile("systemconfig.json") + var configurationIn configFileIn + err = json.NewDecoder(systemConfigFile).Decode(&configurationIn) // Read the contents of systemconfig into configurationIn if err != nil { - return rawBytes, servicesList, err + return nil, fmt.Errorf("error reading systemconfig: %v", err) } - // the challenge is that the definition of the unit asset is unknown to the mbaigo library and only known to the system that invokes the library - var configurationIn configFileIn - // extract the information related to the system separately from the unit_assets (i.e., the resources) - type Alias configFileIn - aux := &struct { - Resources []json.RawMessage `json:"unit_assets"` - *Alias - }{ - Alias: (*Alias)(&configurationIn), - } - if err := json.Unmarshal(configBytes, aux); err != nil { - return rawBytes, servicesList, err - } - if len(aux.Resources) > 0 { - configurationIn.rawResources = aux.Resources + var rawResources []json.RawMessage + if len(configurationIn.Resources) > 0 { // If unit assets was present in systemconfig file, send those + rawResources = configurationIn.Resources } else { - var rawMessages []json.RawMessage - for _, s := range defaultConfig.UAsset { - // convert the struct to JSON-encoded byte array + for _, s := range defaultConfig.Assets { // Otherwise send the system default jsonBytes, err := json.Marshal(s) if err != nil { - fmt.Println("Failed to marshal struct:", err) + return nil, fmt.Errorf("failed to marshal struct: %v", err) } - rawMessages = append(rawMessages, json.RawMessage(jsonBytes)) // append the json.RawMessage to the slice + rawResources = append(rawResources, json.RawMessage(jsonBytes)) } - configurationIn.rawResources = rawMessages } sys.Name = configurationIn.CName - sys.Husk.DName = configurationIn.PKIdetails + // If the systemconfig file has a LocalCloud defined, add it to the system details + if configurationIn.LocalCloud != "" { + if sys.Husk.Details == nil { + sys.Husk.Details = make(map[string][]string) + } + sys.Husk.Details["LocalCloud"] = []string{configurationIn.LocalCloud} + } sys.Husk.ProtoPort = configurationIn.Protocols for _, ccore := range configurationIn.CCoreS { newCore := ccore sys.CoreS = append(sys.CoreS, &newCore) } - // update the services (e.g., re-registration period, costs, or units) - for i := range configurationIn.CServices { - for _, originalService := range originalSs { - if originalService.Definition == configurationIn.CServices[i].Definition { - configurationIn.CServices[i].Merge(&originalService) // keep the original definition and subpath as the original ones - } - } - } - servicesList = configurationIn.CServices - - return configurationIn.rawResources, servicesList, nil -} - -// getFirstAsset returns the first key-value pair in the Assets map -func getFirstAsset(assetMap map[string]*components.UnitAsset) []components.UnitAsset { - var assetList []components.UnitAsset - for key := range assetMap { - assetList = append(assetList, *assetMap[key]) - return assetList - } - return assetList + return rawResources, nil } // getServicesList() returns the original list of services @@ -187,3 +200,14 @@ func getServicesList(uat components.UnitAsset) []components.Service { } return serviceList } + +// MakeServiceMap() creates a map of services from a slice of services +// The map is indexed by the service subpath +func MakeServiceMap(services []components.Service) map[string]*components.Service { + serviceMap := make(map[string]*components.Service) + for i := range services { + svc := services[i] // take the address of the element in the slice + serviceMap[svc.SubPath] = &svc + } + return serviceMap +} diff --git a/usecases/configuration_test.go b/usecases/configuration_test.go new file mode 100644 index 0000000..6a5d55f --- /dev/null +++ b/usecases/configuration_test.go @@ -0,0 +1,411 @@ +package usecases + +import ( + "encoding/json" + "fmt" + "net/http" + "os" + "testing" + + "github.com/sdoque/mbaigo/components" +) + +// A mocked UnitAsset used for testing +type mockUnitAssetWithTraits struct { + Name string `json:"name"` + Owner *components.System `json:"-"` + Details map[string][]string `json:"details"` + ServicesMap components.Services `json:"-"` + CervicesMap components.Cervices `json:"-"` + Traits map[string][]string `json:"-"` +} + +func (mua mockUnitAssetWithTraits) GetTraits() any { + return mua.Traits +} + +func (mua mockUnitAssetWithTraits) GetName() string { + return mua.Name +} + +func (mua mockUnitAssetWithTraits) GetServices() components.Services { + return mua.ServicesMap +} + +func (mua mockUnitAssetWithTraits) GetCervices() components.Cervices { + return mua.CervicesMap +} + +func (mua mockUnitAssetWithTraits) GetDetails() map[string][]string { + return mua.Details +} + +func (mua mockUnitAssetWithTraits) Serving(w http.ResponseWriter, r *http.Request, servicePath string) { +} + +// --------------------------------------------------------- // +// Helpfunctions that creates a default config file +// with/without any asset traits +// --------------------------------------------------------- // + +// This is pretty much a copy of setupDefaultConfig() in configuration.go, +// but this also creates and writes to a systemconfig.json file +func createConfigHasTraits(sys *components.System) (err error) { + var defaultConfig templateOut + + var assetTemplate components.UnitAsset + for _, ua := range sys.UAssets { + assetTemplate = *ua + break + } + servicesTemplate := getServicesList(assetTemplate) + + confAsset := ConfigurableAsset{ + Name: assetTemplate.GetName(), + Details: assetTemplate.GetDetails(), + Services: servicesTemplate, + } + + setTest := &components.Service{ + ID: 1, + Definition: "test", + SubPath: "test", + Details: map[string][]string{"Forms": {"SignalA_v1a"}}, + Description: "A test service", + RegPeriod: 45, + RegTimestamp: "now", + RegExpiration: "45", + } + ServicesMap := &components.Services{ + setTest.SubPath: setTest, + } + mua := &mockUnitAssetWithTraits{ + Name: "testUnitAsset", + Details: map[string][]string{"Test": {"Test"}}, + ServicesMap: *ServicesMap, + CervicesMap: nil, + Traits: map[string][]string{"Trait": {"testTrait"}}, + } + var muaInterface components.UnitAsset = mua + sys.UAssets[mua.GetName()] = &muaInterface + + // If the asset exposes traits, serialize them and store as raw JSON + if assetWithTraits, ok := assetTemplate.(components.HasTraits); ok { + if traits := assetWithTraits.GetTraits(); traits != nil { + traitJSON, err := json.Marshal(traits) + if err == nil { + confAsset.Traits = []json.RawMessage{traitJSON} + } else { + return err + } + } + } + defaultConfig.Assets = []ConfigurableAsset{confAsset} + + leadingRegistrar := components.CoreSystem{ + Name: "serviceregistrar", + Url: "http://localhost:20102/serviceregistrar/registry", + } + orchestrator := components.CoreSystem{ + Name: "orchestrator", + Url: "http://localhost:20103/orchestrator/orchestration", + } + ca := components.CoreSystem{ + Name: "ca", + Url: "http://localhost:20100/ca/certification", + } + maitreD := components.CoreSystem{ + Name: "maitreD", + Url: "http://localhost:20101/maitreD/maitreD", + } + + defaultConfig.CCoreS = []components.CoreSystem{leadingRegistrar, orchestrator, ca, maitreD} + defaultConfig.CName = sys.Name + defaultConfig.Protocols = sys.Husk.ProtoPort + defaultConfigFile, err := os.Create("systemconfig.json") + if err != nil { + return fmt.Errorf("encountered error while creating default config file: %v", err) + } + defer defaultConfigFile.Close() + + enc := json.NewEncoder(defaultConfigFile) + enc.SetIndent("", " ") + err = enc.Encode(defaultConfig) + if err != nil { + return fmt.Errorf("jsonEncode: %v", err) + } + return +} + +// This is pretty much a copy of setupDefaultConfig() in configuration.go, +// but this also creates and writes to a systemconfig.json file +func createConfigNoTraits(sys *components.System, assetAmount int) (err error) { + var defaultConfig templateOut + + for x := range assetAmount { + setTest := components.Service{ + ID: x, + Definition: fmt.Sprintf("test%d", x), + SubPath: fmt.Sprintf("test%d", x), + Details: map[string][]string{"Forms": {"SignalA_v1a"}}, + Description: "A test service", + RegPeriod: 45, + RegTimestamp: "now", + RegExpiration: "45", + } + servList := []components.Service{setTest} + mua := ConfigurableAsset{ + Name: fmt.Sprintf("testUnitAsset%d", x), + Details: map[string][]string{"Test": {"Test"}}, + Services: servList, + } + defaultConfig.Assets = append(defaultConfig.Assets, mua) + } + + leadingRegistrar := components.CoreSystem{ + Name: "serviceregistrar", + Url: "http://localhost:20102/serviceregistrar/registry", + } + orchestrator := components.CoreSystem{ + Name: "orchestrator", + Url: "http://localhost:20103/orchestrator/orchestration", + } + ca := components.CoreSystem{ + Name: "ca", + Url: "http://localhost:20100/ca/certification", + } + maitreD := components.CoreSystem{ + Name: "maitreD", + Url: "http://localhost:20101/maitreD/maitreD", + } + + defaultConfig.CCoreS = []components.CoreSystem{leadingRegistrar, orchestrator, ca, maitreD} + defaultConfig.CName = sys.Name + defaultConfig.Protocols = sys.Husk.ProtoPort + defaultConfigFile, err := os.Create("systemconfig.json") + if err != nil { + return fmt.Errorf("encountered error while creating config file: %v", err) + } + defer defaultConfigFile.Close() + + enc := json.NewEncoder(defaultConfigFile) + enc.SetIndent("", " ") + err = enc.Encode(defaultConfig) + if err != nil { + return fmt.Errorf("jsonEncode: %v", err) + } + return +} + +// --------------------------------------------------------- // +// Helpfunctions and structs for testing SetupDefaultConfig() +// --------------------------------------------------------- // + +func cleanup() error { + return os.Remove("systemconfig.json") +} + +type setupDefConfigParams struct { + expectError bool + setup func(*components.System) (err error) + cleanup func() (err error) + testCase string +} + +func TestSetupDefaultConfig(t *testing.T) { + testParams := []setupDefConfigParams{ + { + false, + func(sys *components.System) (err error) { return createConfigNoTraits(sys, 1) }, + func() (err error) { return cleanup() }, + "Best case", + }, + { + false, + func(sys *components.System) (err error) { return createConfigHasTraits(sys) }, + func() (err error) { return cleanup() }, + "Good case, asset has traits", + }, + { + true, + func(sys *components.System) (err error) { return createConfigHasTraits(sys) }, + func() (err error) { return cleanup() }, + "No assets in sys", + }, + } + + // Start of test + for _, c := range testParams { + testSys := createTestSystem(false) + + // Setup + err := c.setup(&testSys) + if err != nil { + t.Errorf("setup failed: %v", err) + } + + if c.testCase == "No assets in sys" { + testSys.UAssets = nil + } + + // Test + _, err = setupDefaultConfig(&testSys) + if c.expectError == false && err != nil { + t.Errorf("Expected no errors in testcase '%s', got: %v", c.testCase, err) + } + if c.expectError == true && err == nil { + t.Errorf("expected errors in testcase '%s', got none", c.testCase) + } + + // Cleanup + err = c.cleanup() + if err != nil { + t.Errorf("failed to remove 'systemconfig.json' in testcase '%s': %v", c.testCase, err) + } + } +} + +// --------------------------------------------------------- // +// Helpfunctions and structs for testing Configure() +// --------------------------------------------------------- // + +type configureParams struct { + expectError bool + + setup func(*components.System) (err error) + cleanup func() (err error) + testCase string +} + +func TestConfigure(t *testing.T) { + testParams := []configureParams{ + { + false, + func(sys *components.System) (err error) { return createConfigNoTraits(sys, 1) }, + func() (err error) { return cleanup() }, + "Best case, one asset", + }, + { + true, + func(sys *components.System) (err error) { + _, err = os.OpenFile("systemconfig.json", os.O_RDWR|os.O_CREATE, 0000) + return + }, + func() (err error) { return cleanup() }, + "Can't open/create config", + }, + { + true, + func(sys *components.System) (err error) { return nil }, + func() (err error) { return cleanup() }, + "Config missing", + }, + { + false, + func(sys *components.System) (err error) { return createConfigNoTraits(sys, 0) }, + func() (err error) { return cleanup() }, + "No Assets in config", + }, + { + false, + func(sys *components.System) (err error) { return createConfigNoTraits(sys, 3) }, + func() (err error) { return cleanup() }, + "Multiple Assets in config", + }, + { + true, + func(sys *components.System) (err error) { + sys.UAssets = nil + return createConfigNoTraits(sys, 1) + }, + func() (err error) { return cleanup() }, + "No assets in sys", + }, + } + + // Start of test + for _, testCase := range testParams { + testSys := createTestSystem(false) + + // Setup + err := testCase.setup(&testSys) + if err != nil { + t.Errorf("failed during setup: %v", err) + } + + // Test + _, err = Configure(&testSys) + if testCase.expectError == false && err != nil { + t.Errorf("Expected no errors in '%s', got: %v", testCase.testCase, err) + } + if testCase.expectError == true && err == nil { + t.Errorf("Expected errors in testcase '%s'", testCase.testCase) + } + + //Cleanup + err = testCase.cleanup() + if err != nil { + t.Errorf("failed to remove 'systemconfig.json' in testcase '%s'", testCase.testCase) + } + } +} + +// --------------------------------------------------------- // +// Testing GetServiceList() +// --------------------------------------------------------- // + +func TestGetServiceList(t *testing.T) { + setTest := &components.Service{ + ID: 1, + Definition: "test", + SubPath: "test", + Details: map[string][]string{"Forms": {"SignalA_v1a"}}, + Description: "A test service", + RegPeriod: 45, + RegTimestamp: "now", + RegExpiration: "45", + } + ServicesMap := &components.Services{ + setTest.SubPath: setTest, + } + mua := mockUnitAsset{ + Name: "test", + Owner: nil, + Details: nil, + ServicesMap: *ServicesMap, + } + servList := getServicesList(mua) + if len(servList) != 1 && servList[0].Definition != "test" { + t.Errorf("Expected length: 1, got %d\tExpected 'Definition': test, got %s", + len(servList), servList[0].Definition) + } +} + +// --------------------------------------------------------- // +// Testing MakeServiceMap() +// --------------------------------------------------------- // + +func TestMakeServiceMap(t *testing.T) { + var servList []components.Service + for x := range 6 { + serv := components.Service{ + ID: x, + Definition: fmt.Sprintf("testDef%d", x), + SubPath: fmt.Sprintf("test%d", x), + Details: map[string][]string{"Forms": {"SignalA_v1a"}}, + Description: fmt.Sprintf("test service %d", x), + RegPeriod: 45, + RegTimestamp: "now", + RegExpiration: "45", + } + servList = append(servList, serv) + } + servMap := MakeServiceMap(servList) + for c := range 6 { + service := fmt.Sprintf("test%d", c) + if servMap[service].SubPath != service || servMap[service].ID != c { + t.Errorf(`Expected servMap["%s"].SubPath to be "%s", with ID: "%d". Got: "%s", with ID: "%d"`, + service, service, c, servMap[service].SubPath, servMap[service].ID) + } + } +} diff --git a/usecases/consumption.go b/usecases/consumption.go index b983ec8..2921e4e 100644 --- a/usecases/consumption.go +++ b/usecases/consumption.go @@ -20,13 +20,13 @@ package usecases import ( - "bytes" - "context" "fmt" "io" "log" + "testing" + "net/http" - "time" + "net/url" "github.com/sdoque/mbaigo/components" "github.com/sdoque/mbaigo/forms" @@ -34,16 +34,27 @@ import ( // GetState request the current state of a unit asset (via the asset's service) func GetState(cer *components.Cervice, sys *components.System) (f forms.Form, err error) { - // if no known providers, search for one via the Orchestrator + return stateHandler(http.MethodGet, cer, sys, nil) +} + +// GetStates requests the current state of certain services of a unit asset depending on requested definition and/or details +func GetStates(cer *components.Cervice, sys *components.System) (f []forms.Form, err []error) { + return stateHandlers(http.MethodGet, cer, sys, nil) +} + +// SetState puts a request to change the state of a unit asset (via the asset's service) +func SetState(cer *components.Cervice, sys *components.System, bodyBytes []byte) (f forms.Form, err error) { + return stateHandler(http.MethodPut, cer, sys, bodyBytes) +} + +func stateHandler(httpMethod string, cer *components.Cervice, sys *components.System, bodyBytes []byte) (f forms.Form, err error) { if len(cer.Nodes) == 0 { - err := Search4Services(cer, sys) + err = Search4Services(cer, sys) if err != nil { return f, err } } - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) // Create a new context, with a 2-second timeout - defer cancel() - // Create a new HTTP request using the first known provider + var serviceUrl string for _, values := range cer.Nodes { if len(values) > 0 { @@ -51,99 +62,156 @@ func GetState(cer *components.Cervice, sys *components.System) (f forms.Form, er break } } - req, err := http.NewRequest(http.MethodGet, serviceUrl, nil) + + resp, err := sendHTTPReq(httpMethod, serviceUrl, bodyBytes) if err != nil { + cer.Nodes = make(map[string][]string) // Failed to get the resource at that location: reset the providers list, which will trigger a new service search return f, err } - // Associate the cancellable context with the request - req = req.WithContext(ctx) - // Send the request ///////////////////////////////// - client := &http.Client{} - resp, err := client.Do(req) + defer resp.Body.Close() + + // If the response includes a payload, unpack it into a forms.Form + bodyBytes, err = io.ReadAll(resp.Body) if err != nil { - cer.Nodes = make(map[string][]string) // failed to get the resource at that location: reset the providers list, which will trigger a new service search - return f, err + return f, fmt.Errorf("reading state response body: %w", err) } - defer resp.Body.Close() - // Check if the status code indicates an error (anything outside the 200–299 range) - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return f, fmt.Errorf("received non-2xx status code: %d, response: %s", resp.StatusCode, http.StatusText(resp.StatusCode)) + if len(bodyBytes) < 1 { + return f, fmt.Errorf("got empty response body") + } - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - log.Printf("GetRValue-Error reading registration response body: %v", err) - return + headerContentType := resp.Header.Get("Content-Type") + return Unpack(bodyBytes, headerContentType) +} + +const messengerMaxErrors int = 3 + +func LogDebug(sys *components.System, msg string, args ...any) { + Log(sys, forms.LevelDebug, msg, args...) +} + +func LogInfo(sys *components.System, msg string, args ...any) { + Log(sys, forms.LevelInfo, msg, args...) +} + +func LogWarn(sys *components.System, msg string, args ...any) { + Log(sys, forms.LevelWarn, msg, args...) +} + +func LogError(sys *components.System, msg string, args ...any) { + Log(sys, forms.LevelError, msg, args...) +} + +func Log(sys *components.System, lvl forms.MessageLevel, msg string, args ...any) { + sm := forms.NewSystemMessage_v1(lvl, fmt.Sprintf(msg, args...), sys.Name) + if !testing.Testing() { + // Only print the msg locally if not running during `go test` + log.Println(sm.String()) + } + var body []byte + sys.Mutex.Lock() + defer sys.Mutex.Unlock() + + // Iterate over all messengers and try sending a copy of the log msg + for host, errors := range sys.Messengers { + // Lazy-load the packed body, only at the first iteration + if body == nil { + var err error + body, err = Pack(forms.Form(&sm), "application/json") + if err != nil { + log.Printf("failed to pack SystemMessage: %v\n", err) + return + } + } + + errCount := 0 // If there's no error while sending msg, the count is reset + if err := sendLogMessage(host, body); err != nil { + // Don't care what kinds of errors might be returned + errCount = errors + 1 + } + if errCount >= messengerMaxErrors { + // Too many errors indicates a problematic messenger + delete(sys.Messengers, host) + continue + } + sys.Messengers[host] = errCount } +} - headerContentTtype := resp.Header.Get("Content-Type") - f, err = Unpack(bodyBytes, headerContentTtype) +// Hard-coding the path is ugly but it skips an extra service discovery cycle for now +const logMessagePath string = "/log/message" + +func sendLogMessage(host string, body []byte) error { + u, err := url.Parse(host) if err != nil { - fmt.Printf("error unpacking the service response: %s", err) + return err } - return f, nil + u = u.JoinPath(logMessagePath) + resp, err := sendHTTPReq(http.MethodPost, u.String(), body) + if err != nil { + return err + } + _ = resp.Body.Close() // Don't care about the response body or any errors it might cause + return nil } -// SetState puts a request to change the state of a unit asset (via the asset's service) -func SetState(cer *components.Cervice, sys *components.System, bodyBytes []byte) (f forms.Form, err error) { - // Get the address of the informing service of the target asset via the Orchestrator +func stateHandlers(httpMethod string, cer *components.Cervice, sys *components.System, bodyBytes []byte) (f []forms.Form, err []error) { if len(cer.Nodes) == 0 { - err := Search4Services(cer, sys) - if err != nil { + currentErr := Search4MultipleServices(cer, sys) + if currentErr != nil { + f = append(f, nil) + err = append(err, currentErr) return f, err } } - // Create a new context, with a 2-second timeout - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - // Create a new HTTP request - var serviceUrl string + var serviceUrls []string for _, values := range cer.Nodes { if len(values) > 0 { - serviceUrl = values[0] - break + serviceUrls = append(serviceUrls, values...) } } - req, err := http.NewRequest(http.MethodPut, serviceUrl, bytes.NewReader(bodyBytes)) - if err != nil { - return f, err - } - - // Set the Content-Type header - req.Header.Set("Content-Type", "application/json") - // Associate the cancellable context with the request - req = req.WithContext(ctx) - - // Send the request - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - cer.Nodes = make(map[string][]string) // Failed to get the resource at that location: reset the providers list, which will trigger a new service search - return f, err - } - defer resp.Body.Close() - // Check if the status code indicates an error (anything outside the 200–299 range) - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return f, fmt.Errorf("received non-2xx status code: %d, response: %s", resp.StatusCode, http.StatusText(resp.StatusCode)) - } + for _, serviceUrl := range serviceUrls { + if len(serviceUrl) == 0 { + continue + } + resp, currentErr := sendHTTPReq(httpMethod, serviceUrl, bodyBytes) + if currentErr != nil { + cer.Nodes = make(map[string][]string) + f = append(f, nil) + err = append(err, currentErr) + continue + } + defer resp.Body.Close() + + // If the response includes a payload, unpack it into a forms.Form + bodyBytes, currentErr = io.ReadAll(resp.Body) + if currentErr != nil { + currentErr = fmt.Errorf("reading state response body: %w", currentErr) + f = append(f, nil) + err = append(err, currentErr) + continue + } - // If the response includes a payload, unpack it into a forms.Form - bodyBytes, err = io.ReadAll(resp.Body) - if err != nil { - return f, fmt.Errorf("error reading response body: %v", err) - } + if len(bodyBytes) < 1 { + currentErr = fmt.Errorf("got empty response body") + f = append(f, nil) + err = append(err, currentErr) + continue + } - if len(bodyBytes) > 0 { headerContentType := resp.Header.Get("Content-Type") - f, err = Unpack(bodyBytes, headerContentType) - if err != nil { - return f, fmt.Errorf("error unpacking the service response: %v", err) + formValue, currentErr := Unpack(bodyBytes, headerContentType) + if currentErr != nil { + currentErr = fmt.Errorf("unpacking response body: %w", currentErr) + f = append(f, nil) + err = append(err, currentErr) + continue } + f = append(f, formValue) + err = append(err, nil) } - - return f, nil + return f, err } diff --git a/usecases/consumption_test.go b/usecases/consumption_test.go new file mode 100644 index 0000000..bd4adf8 --- /dev/null +++ b/usecases/consumption_test.go @@ -0,0 +1,503 @@ +package usecases + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/sdoque/mbaigo/components" + "github.com/sdoque/mbaigo/forms" +) + +type stateParams struct { + testCer *components.Cervice + testSys components.System + bodyBytes []byte + body func() *http.Response + mockTransportErr int + errHTTP error + expectedfForm forms.Form + expectedErr error + testCase string +} + +func newTestCerviceWithNodes() *components.Cervice { + return &components.Cervice{ + IReferentce: "test", + Definition: "A test Cervice with nodes", + Details: map[string][]string{"Forms": {"SignalA_v1a"}}, + Nodes: map[string][]string{"test": {"https://testSystem/testUnitAsset/test"}}, + Protos: []string{"http"}, + } +} + +func newTestCerviceWithoutNodes() *components.Cervice { + return &components.Cervice{ + IReferentce: "test", + Definition: "A test Cervice without nodes", + Details: map[string][]string{"Forms": {"SignalA_v1a"}}, + Nodes: make(map[string][]string), + Protos: []string{"http"}, + } +} + +func newTestCerviceWithBrokenUrl() *components.Cervice { + return &components.Cervice{ + IReferentce: "test", + Definition: "A test Cervice with nodes", + Details: map[string][]string{"Forms": {"SignalA_v1a"}}, + Nodes: map[string][]string{"test": {brokenUrl}}, + Protos: []string{"http"}, + } +} + +var form forms.SignalA_v1a + +var errEmptyRespBody = errors.New("got empty response body") + +var errUnpack = errors.New("problem unpacking response body") + +func createTestBytes() []byte { + return []byte("{\n \"value\": 0,\n \"unit\": \"\",\n \"timestamp\": " + + "\"0001-01-01T00:00:00Z\",\n \"version\": \"SignalA_v1.0\"\n}") +} + +func createWorkingHttpResp() func() *http.Response { + httpResp := func() *http.Response { + return &http.Response{ + Status: "200 OK", + StatusCode: 200, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(string("{\n \"value\": 0,\n \"unit\": \"\",\n " + + " \"timestamp\": \"0001-01-01T00:00:00Z\",\n \"version\": \"SignalA_v1.0\"\n}"))), + } + } + return httpResp +} + +// This function creates two different http responses with a different body, +// since some tests build on receiving multiple correct http responses +func createDoubleHttpResp() func() *http.Response { + f := createServicePointTestForm() + // Create mock response from orchestrator + fakeBody, err := json.Marshal(f) + if err != nil { + log.Println("Fail Marshal at start of test") + } + count := 0 + return func() *http.Response { + count++ + if count == 1 || count == 3 { + return &http.Response{ + Status: "200 OK", + StatusCode: 200, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(string(fakeBody))), + } + } + return &http.Response{ + Status: "200 OK", + StatusCode: 200, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(string("{\n \"value\": 0,\n \"unit\": \"\",\n " + + " \"timestamp\": \"0001-01-01T00:00:00Z\",\n \"version\": \"SignalA_v1.0\"\n}"))), + } + } +} + +func createEmptyHttpResp() func() *http.Response { + httpResp := func() *http.Response { + return &http.Response{ + Status: "200 OK", + StatusCode: 200, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(string(""))), + } + } + return httpResp +} + +func createStatusErrorHttpResp() func() *http.Response { + httpResp := func() *http.Response { + return &http.Response{ + Status: "300 NAK", + StatusCode: 300, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(string("{\n \"value\": 0,\n \"unit\": \"\",\n " + + " \"timestamp\": \"0001-01-01T00:00:00Z\",\n \"version\": \"SignalA_v1.0\"\n}"))), + } + } + return httpResp +} + +func createErrorReaderHttpResp() func() *http.Response { + httpResp := func() *http.Response { + return &http.Response{ + Status: "200 OK", + StatusCode: 200, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(errorReader{}), + } + } + return httpResp +} + +func createUnpackErrorHttpResp() func() *http.Response { + httpResp := func() *http.Response { + return &http.Response{ + Status: "200 OK", + StatusCode: 200, + Header: http.Header{"Content-Type": []string{"Wrong content type"}}, + Body: io.NopCloser(strings.NewReader(string("{\n \"value\": 0,\n \"unit\": \"\",\n " + + " \"timestamp\": \"0001-01-01T00:00:00Z\",\n \"version\": \"SignalA_v1.0\"\n}"))), + } + } + return httpResp +} + +var testStateParams = []stateParams{ + {newTestCerviceWithNodes(), createTestSystem(false), createTestBytes(), + createWorkingHttpResp(), 0, nil, form.NewForm(), nil, "No errors with nodes"}, + {newTestCerviceWithoutNodes(), createTestSystem(false), createTestBytes(), + createDoubleHttpResp(), 0, nil, form.NewForm(), nil, "No errors without nodes"}, + {newTestCerviceWithNodes(), createTestSystem(false), nil, + createEmptyHttpResp(), 0, nil, nil, errEmptyRespBody, "Empty response body error"}, + {newTestCerviceWithoutNodes(), createTestSystem(false), createTestBytes(), + createWorkingHttpResp(), 1, errHTTP, nil, errHTTP, "Search4Services error"}, + {newTestCerviceWithBrokenUrl(), createTestSystem(false), createTestBytes(), + createWorkingHttpResp(), 2, errHTTP, nil, errHTTP, "NewRequest() error"}, + {newTestCerviceWithNodes(), createTestSystem(false), createTestBytes(), + createStatusErrorHttpResp(), 2, errHTTP, nil, errHTTP, "Status code error"}, + {newTestCerviceWithNodes(), createTestSystem(false), createTestBytes(), + createErrorReaderHttpResp(), 0, nil, nil, errBodyRead, "io.ReadAll() error"}, + {newTestCerviceWithNodes(), createTestSystem(false), createTestBytes(), + createUnpackErrorHttpResp(), 0, nil, nil, errUnpack, "Unpack() error"}, + {newTestCerviceWithNodes(), createTestSystem(false), createTestBytes(), + createWorkingHttpResp(), 1, errHTTP, nil, errHTTP, "DefaultClient.Do() error"}, +} + +func TestGetState(t *testing.T) { + for _, test := range testStateParams { + newMockTransport(test.body, test.mockTransportErr, test.errHTTP) + res, err := GetState(test.testCer, &test.testSys) + + if test.expectedfForm != nil { + expected := test.expectedfForm.(*forms.SignalA_v1a) + actual, ok := res.(*forms.SignalA_v1a) + if !ok { + t.Fatalf("Test case: %s, got %v, expected a forms.Form", + test.testCase, res, + ) + } + if expected.Value != actual.Value || expected.Unit != actual.Unit || + expected.Timestamp != actual.Timestamp || expected.Version != actual.Version || + err != test.expectedErr { + t.Errorf("Test case: %s got error: %v. \nExpected form: \n%+v\n, got: \n%+v", + test.testCase, err, expected, actual) + } + } else if err == nil { + t.Errorf("Test case: %s got error: %v:", test.testCase, err) + } + } +} + +func TestSetState(t *testing.T) { + for _, test := range testStateParams { + newMockTransport(test.body, test.mockTransportErr, test.errHTTP) + + if test.testCase == "DefaultClient.Do() error" { + test.testCer.Nodes = map[string][]string{"test": {"https://testSystem/testUnitAsset/test"}} + } + if test.testCase == "No errors without nodes" { + test.testCer.Nodes = make(map[string][]string) + } + res, err := SetState(test.testCer, &test.testSys, test.bodyBytes) + + if test.expectedfForm != nil { + expected := test.expectedfForm.(*forms.SignalA_v1a) + actual, ok := res.(*forms.SignalA_v1a) + if !ok { + t.Fatalf("Test case: %s, got %v, expected a forms.Form", + test.testCase, res, + ) + } + if expected.Value != actual.Value || expected.Unit != actual.Unit || + expected.Timestamp != actual.Timestamp || expected.Version != actual.Version || + err != test.expectedErr { + t.Errorf("Test case: %s got error: %v. \nExpected form: \n%+v\n, got: \n%+v", + test.testCase, err, expected, actual) + } + } else if err == nil { + t.Errorf("Test case: %s got error: %v:", test.testCase, err) + } + } +} + +func createServRecListTestForm(amount int) (servRecList forms.ServiceRecordList_v1) { + servRecList.NewForm() + servRecList.List = make([]forms.ServiceRecord_v1, amount) + for i := range amount { + servRecList.List[i].IPAddresses = []string{"123.456.789"} + servRecList.List[i].ProtoPort = map[string]int{"http": 123} + } + return servRecList +} + +// Use this one if a mock response from an orchestrator is needed +func createDoubleHttpRespWithServRecList(amount int, empty bool, statusErr bool, + readErr bool, unpackErr bool) func() *http.Response { + f := createServRecListTestForm(amount) + // Create mock response from orchestrator + fakeBody, err := json.Marshal(f) + if err != nil { + log.Println("Fail Marshal at start of test") + } + count := 0 + return func() *http.Response { + resp := &http.Response{ + Status: "200 OK", + StatusCode: 200, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(string("{\n \"value\": 0,\n \"unit\": \"\",\n " + + " \"timestamp\": \"0001-01-01T00:00:00Z\",\n \"version\": \"SignalA_v1.0\"\n}"))), + } + count++ + if count == 1 { + resp.Body = io.NopCloser(strings.NewReader(string(fakeBody))) + return resp + } + if empty == true { + resp.Body = io.NopCloser(strings.NewReader(string(""))) + return resp + } + if statusErr == true { + resp.Status = "300 NAK" + resp.StatusCode = 300 + return resp + } + if readErr == true { + resp.Body = io.NopCloser(errorReader{}) + return resp + } + if unpackErr == true { + resp.Header = http.Header{"Content-Type": []string{"Wrong content type"}} + return resp + } + return resp + } +} + +func formsEqual(a, b []forms.Form) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] == nil && b[i] == nil { + continue + } + aForm, ok := a[i].(*forms.SignalA_v1a) + if !ok { + return false + } + bForm, ok := b[i].(*forms.SignalA_v1a) + if !ok { + return false + } + if aForm.Value != bForm.Value || aForm.Unit != bForm.Unit || + aForm.Timestamp != bForm.Timestamp || aForm.Version != bForm.Version { + return false + } + } + return true +} + +func errEqual(a, b []error) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if (a[i] != nil && b[i] == nil) || (a[i] == nil && b[i] != nil) { + return false + } + } + return true +} + +type getStatesTestStruct struct { + body func() *http.Response + mockTransportErr int + errHTTP error + expectedForm []forms.Form + expectedErr []error + testName string +} + +var ( + threeForms = []forms.Form{form.NewForm(), form.NewForm(), form.NewForm()} + oneNilForm = []forms.Form{form.NewForm(), form.NewForm(), nil} + nilForms = []forms.Form{nil, nil, nil} + singleNilForm = []forms.Form{nil} + threeErr = []error{fmt.Errorf("Error"), fmt.Errorf("Error"), fmt.Errorf("Error")} + oneErr = []error{nil, nil, fmt.Errorf("Error")} + nilErr = []error{nil, nil, nil} + singleErr = []error{fmt.Errorf("Error")} +) + +var getStatesTestParams = []getStatesTestStruct{ + {createDoubleHttpRespWithServRecList(3, false, false, false, false), 0, nil, threeForms, + nilErr, "No errors without nodes"}, + {createDoubleHttpRespWithServRecList(3, false, false, false, false), 4, errHTTP, oneNilForm, + oneErr, "Error in one of the services"}, + {createDoubleHttpRespWithServRecList(3, true, false, false, false), 0, nil, nilForms, + threeErr, "Empty response body error"}, + {createWorkingHttpResp(), 1, errHTTP, singleNilForm, + singleErr, "Search4Services error"}, + {createDoubleHttpRespWithServRecList(3, false, true, false, false), 0, nil, nilForms, + threeErr, "Status code error"}, + {createDoubleHttpRespWithServRecList(3, false, false, true, false), 0, nil, nilForms, + threeErr, "io.ReadAll() error"}, + {createDoubleHttpRespWithServRecList(3, false, false, false, true), 0, nil, nilForms, + threeErr, "Unpack() error"}, +} + +func TestGetStates(t *testing.T) { + for _, testCase := range getStatesTestParams { + testCer := newTestCerviceWithoutNodes() + testSys := createTestSystem(false) + newMockTransport(testCase.body, testCase.mockTransportErr, testCase.errHTTP) + + res, err := GetStates(testCer, &testSys) + + if !formsEqual(res, testCase.expectedForm) || !errEqual(err, testCase.expectedErr) { + t.Errorf("Test case: %s\nExpected forms: %+v\nGot: %+v\nExpected error: %v, Got error: %v", + testCase.testName, testCase.expectedForm, res, testCase.expectedErr, err) + } + } + + // Special case: No errors with existing nodes + cerWithNodes := components.Cervice{ + IReferentce: "test", + Definition: "A test Cervice with nodes", + Details: map[string][]string{"Forms": {"SignalA_v1a"}}, + Nodes: map[string][]string{"test": {"test1", "test2", "test3"}}, + Protos: []string{"http"}, + } + testSys := createTestSystem(false) + newMockTransport(createWorkingHttpResp(), 0, nil) + + res, err := GetStates(&cerWithNodes, &testSys) + expectedForm := []forms.Form{form.NewForm(), form.NewForm(), form.NewForm()} + expectedErr := []error{nil, nil, nil} + + if !formsEqual(res, expectedForm) || !errEqual(err, expectedErr) { + t.Errorf("Test case: No errors with nodes \nExpected forms: %v\nGot: %v\nExpected error: %v, Got error: %v", + expectedForm, res, expectedErr, err) + } + + // Special case: Error with a broken url in nodes + cerWithBrokenUrlNode := components.Cervice{ + IReferentce: "test", + Definition: "A test Cervice with nodes", + Details: map[string][]string{"Forms": {"SignalA_v1a"}}, + Nodes: map[string][]string{"test": {"test1", brokenUrl, "test3"}}, + Protos: []string{"http"}, + } + testSys = createTestSystem(false) + newMockTransport(createWorkingHttpResp(), 0, nil) + + res, err = GetStates(&cerWithBrokenUrlNode, &testSys) + expectedForm = []forms.Form{form.NewForm(), nil, form.NewForm()} + expectedErr = []error{nil, fmt.Errorf("Error"), nil} + + if !formsEqual(res, expectedForm) || !errEqual(err, expectedErr) { + t.Errorf("Test case: Error with broken url \nExpected forms: %v\nGot: %v\nExpected error: %v, Got error: %v", + expectedForm, res, expectedErr, err) + } +} + +type logTransportMock struct { + t *testing.T + errResponse error +} + +func newLogTransportMock(t *testing.T) *logTransportMock { + lt := &logTransportMock{t, nil} + http.DefaultClient.Transport = lt + return lt +} + +func (mock *logTransportMock) setError(err error) { + mock.errResponse = err +} + +// This mock transport also verifies that the system message forms are valid. +func (mock *logTransportMock) RoundTrip(req *http.Request) (res *http.Response, err error) { + body, err := io.ReadAll(req.Body) + if err != nil { + mock.t.Errorf("unexpected error while reading request body: %v", err) + return + } + defer req.Body.Close() + form, err := Unpack(body, req.Header.Get("Content-Type")) + if err != nil { + mock.t.Errorf("unexpected error from unpack: %v", err) + return + } + message, ok := form.(*forms.SystemMessage_v1) + if !ok { + mock.t.Error("unexpected form") + return + } + if message.System != testLogSys || message.Body != testLogMsg { + mock.t.Errorf("unexpected message: %v", message) + } + + if mock.errResponse != nil { + return nil, mock.errResponse + } + rec := httptest.NewRecorder() + rec.WriteHeader(http.StatusOK) + return rec.Result(), nil +} + +const testLogHost = "host" +const testLogSys = "test system" +const testLogMsg = "test msg" + +// NOTE: this test also covers sendLogMessage function + +func TestLog(t *testing.T) { + mock := newLogTransportMock(t) + mock.setError(fmt.Errorf("mock err")) + sys := components.NewSystem(testLogSys, context.Background()) + + // Case: increase error count by one + sys.Messengers[testLogHost] = 0 + Log(&sys, forms.LevelDebug, testLogMsg) + if got, want := sys.Messengers[testLogHost], 1; got != want { + t.Errorf("expected error count %d, got %d", want, got) + } + + // Case: removes messenger after too many errors + sys.Messengers[testLogHost] = messengerMaxErrors + Log(&sys, forms.LevelDebug, testLogMsg) + _, found := sys.Messengers[testLogHost] + if found { + t.Errorf("expected messenger being removed") + } + + // Case: transfer ok + mock.setError(nil) + sys.Messengers[testLogHost] = 0 + Log(&sys, forms.LevelDebug, testLogMsg) + if got, want := sys.Messengers[testLogHost], 0; got != want { + t.Errorf("expected error count %d, got %d", want, got) + } +} diff --git a/usecases/cost.go b/usecases/cost.go index 1a3f1a2..865c5d5 100644 --- a/usecases/cost.go +++ b/usecases/cost.go @@ -21,9 +21,8 @@ package usecases import ( "encoding/json" - "errors" + "fmt" "io" - "log" "net/http" "time" @@ -44,68 +43,49 @@ func GetActivitiesCost(serv *components.Service) (payload []byte, err error) { // SetActivitiesCost updates the service cost func SetActivitiesCost(serv *components.Service, bodyBytes []byte) (err error) { - var jsonData map[string]interface{} - err = json.Unmarshal(bodyBytes, &jsonData) + f, err := Unpack(bodyBytes, "application/json") if err != nil { - log.Printf("Error unmarshaling JSON data: %v", err) - return + return fmt.Errorf("unmarshalling cost form: %w", err) } - formVersion, ok := jsonData["version"].(string) + acForm, ok := f.(*forms.ActivityCostForm_v1) if !ok { - log.Printf("Error: 'version' key not found in JSON data") - return + return fmt.Errorf("couldn't convert to correct form") } - var acForm forms.ActivityCostForm_v1 - switch formVersion { - case "ActivityCostForm_v1": - var f forms.ActivityCostForm_v1 - err = json.Unmarshal(bodyBytes, &f) - if err != nil { - log.Println("Unable to extract new activity costs request ") - return - } - acForm = f - default: - err = errors.New("unsupported version of activity costs form") - return - } - - if serv.Definition == acForm.Activity { - serv.ACost = acForm.Cost // update the service's cost - log.Printf("The new service cost is %f => the service is %+v\n", acForm.Cost, serv) - } else { - err = errors.New("mismatch between service list order") // corrected typo - return + if serv.Definition != acForm.Activity { + return fmt.Errorf("service definition and activity cost forms activity field doesn't match") } + serv.ACost = acForm.Cost // update the service's cost return } // ACServices handles the http request for the cost of a service func ACServices(w http.ResponseWriter, r *http.Request, ua *components.UnitAsset, serviceP string) { + // Has to use (*ua) in order to reach the methods for the interface UnitAsset, since ua is a pointer to an interface servicesList := (*ua).GetServices() serv := servicesList[serviceP] switch r.Method { case "GET": payload, err := GetActivitiesCost(serv) if err != nil { - log.Printf("Error in getting the activity costs\n") http.Error(w, "Error marshaling data.", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(payload) - return + _, err = w.Write(payload) + if err != nil { + http.Error(w, "Error while writing to response body", http.StatusInternalServerError) + } case "PUT": defer r.Body.Close() bodyBytes, err := io.ReadAll(r.Body) // Use io.ReadAll instead of ioutil.ReadAll if err != nil { - log.Printf("Error reading registration response body: %v", err) + http.Error(w, "Error reading registration response body", http.StatusBadRequest) return } err = SetActivitiesCost(serv, bodyBytes) if err != nil { - log.Printf("there was an error updating the activittiy costs, %s\n", err) + http.Error(w, "Error occurred while updating activity costs", http.StatusInternalServerError) } default: http.Error(w, "Method is not supported.", http.StatusNotFound) diff --git a/usecases/cost_test.go b/usecases/cost_test.go new file mode 100644 index 0000000..6af83ea --- /dev/null +++ b/usecases/cost_test.go @@ -0,0 +1,214 @@ +package usecases + +import ( + "io" + "math" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/sdoque/mbaigo/components" +) + +func TestGetActivitiesCost(t *testing.T) { + testServ := &components.Service{ + Definition: "testDef", + ACost: 123, + CUnit: "testCUnit", + } + data, err := GetActivitiesCost(testServ) + if err != nil { + t.Errorf("no error expected, got: %v", err) + } + + // Check that correct data is present + if strings.Contains(string(data), `"activity": "testDef"`) == false { + t.Errorf("Definition/activity doesn't match") + } + if (strings.Contains(string(data), `"cost": 123`)) == false { + t.Errorf("ACost/cost doesn't match") + } +} + +// ------------------------------------------------------ // +// Helper functions and structs for TestSetActivitiesCost() +// ------------------------------------------------------ // + +type setACparams struct { + dataString string + expectError bool + testCase string +} + +func createTestService() (serv *components.Service) { + testServ := &components.Service{ + ID: 0, + Definition: "testDefinition", + SubPath: "testService", + Details: map[string][]string{"Details": {"detail1", "detail2"}}, + RegPeriod: 45, + RegTimestamp: "Now", + RegExpiration: "Later", + Description: "A service for testing purposes", + SubscribeAble: false, + ACost: 123, + CUnit: "testCUnit", + } + return testServ +} + +func TestSetActivitiesCost(t *testing.T) { + testParams := []setACparams{ + // Best case: No errors + { + `{"activity":"testDefinition","cost":321,"unit":"", + "timestamp":"0001-01-01T00:00:00Z","version":"ActivityCostForm_v1"}`, + false, "Best case, no errors", + }, + // Bad case: Fail @ unmarshal + {"", true, "Bad case, break first unmarshal"}, + // Bad case: No version field in byte array + { + `{"activity":"testDefinition","cost":321,"unit":"","timestamp":"0001-01-01T00:00:00Z"}`, + true, "Bad case, version missing", + }, + // Bad case: Unsupported version + { + `{"activity":"testDefinition","cost":321,"unit":"", + "timestamp":"0001-01-01T00:00:00Z","version":"WrongVersion"}`, + true, "Bad case, unsupported version", + }, + // Bad case: Mismatch between 'serv.Definition' and 'acForm.Activity' + { + `{"activity":"WrongDef","cost":321,"unit":"", + "timestamp":"0001-01-01T00:00:00Z","version":"ActivityCostForm_v1"}`, + true, "Bad case, serv.Definition != acForm.Activity", + }, + // Bad case: Fail @ 2nd unmarshal + { + `{"activity":"testDefinition","cost":"321","unit":"", + "timestamp":"0001-01-01T00:00:00Z","version":"ActivityCostForm_v1"}`, + true, "Bad case, break first unmarshal", + }, + // Bad case: Couldn't convert to ActivityCostForm_v1 + { + `{"file_url":"filepath", + "timestamp":"0001-01-01T00:00:00Z","version":"FileForm_v1"}`, + true, "Bad case, couldn't convert to ActivityCostForm_v1", + }, + } + testServ := createTestService() + + for _, c := range testParams { + err := SetActivitiesCost(testServ, []byte(c.dataString)) + + if (c.expectError == true && err == nil) || (c.expectError == false && err != nil) { + t.Errorf("Testcase '%s' failed, expectError was %v error was: %v", c.testCase, c.expectError, err) + } + } +} + +// ------------------------------------------------------ // +// Helper functions and structs for TestACServices() +// ------------------------------------------------------ // + +// Creates a unitasset with values used for testing +func createUnitAsset(cost float64) components.UnitAsset { + setTest := &components.Service{ + ID: 1, + Definition: "test", + SubPath: "test", + Details: map[string][]string{"Forms": {"SignalA_v1a"}}, + Description: "A test service", + RegPeriod: 45, + RegTimestamp: "now", + RegExpiration: "45", + ACost: cost, + } + ServicesMap := &components.Services{ + setTest.SubPath: setTest, + } + var ua components.UnitAsset = &mockUnitAsset{ + Name: "testUnitAsset", + Details: map[string][]string{"Test": {"Test"}}, + ServicesMap: *ServicesMap, + CervicesMap: nil, + } + return ua +} + +type acServicesParams struct { + httpMethod string + responseWriter *httptest.ResponseRecorder + expectError bool + request *http.Request + unitAsset components.UnitAsset + testCase string +} + +func TestACServices(t *testing.T) { + testParams := []acServicesParams{ + // Good case: no errors in GET/PUT + { + "GET", httptest.NewRecorder(), false, + httptest.NewRequest( + http.MethodGet, + "http://localhost", + io.NopCloser(strings.NewReader(``)), + ), + createUnitAsset(0), "GET, Best case: no errors in GET", + }, + { + "PUT", httptest.NewRecorder(), false, + httptest.NewRequest( + http.MethodPut, + "http://localhost", + io.NopCloser(strings.NewReader( + `{"activity":"test", "cost": 321, "version":"ActivityCostForm_v1"}`, + )), + ), + createUnitAsset(0), "PUT, Best case: no errors in PUT", + }, + // GET, Bad case: GetActivitiesCost() returns error + { + "GET", httptest.NewRecorder(), true, + httptest.NewRequest(http.MethodGet, "http://localhost", io.NopCloser(strings.NewReader(``))), + createUnitAsset(math.NaN()), "GET, Bad case: error from GetActivitiesCost()"}, + // PUT, Bad case: Reading response body returns an error + { + "PUT", httptest.NewRecorder(), true, + httptest.NewRequest(http.MethodPut, "http://localhost", io.NopCloser(errReader(0))), + createUnitAsset(0), "PUT, Bad case: reading response body", + }, + // PUT, Bad case: SetActivitiesCost() returns error + { + "PUT", httptest.NewRecorder(), true, + httptest.NewRequest(http.MethodPut, "http://localhost", io.NopCloser(strings.NewReader(``))), + createUnitAsset(0), "PUT, Bad case: error updating activities cost", + }, + // DEFAULT: Method not supported (POST), + { + "POST", httptest.NewRecorder(), true, + httptest.NewRequest(http.MethodPost, "http://localhost", io.NopCloser(strings.NewReader(``))), + createUnitAsset(0), "POST, Bad case: Method not supported", + }, + // TODO: GET, Bad case: Couldn't write to responsewriter + } + + for _, c := range testParams { + // Setup + ua := c.unitAsset + w := c.responseWriter + r := c.request + // Test + ACServices(w, r, &ua, "test") + + if c.expectError == false && w.Result().StatusCode != 200 { + t.Errorf("Expected statuscode 200 in testcase '%s' got: %d", c.testCase, w.Result().StatusCode) + } + if c.expectError == true && w.Result().StatusCode == 200 { + t.Errorf("Expected statuscode not to be 200 in testcase '%s'", c.testCase) + } + } +} diff --git a/usecases/docs.go b/usecases/docs.go index 412b542..13a44c1 100644 --- a/usecases/docs.go +++ b/usecases/docs.go @@ -29,6 +29,7 @@ package usecases import ( "fmt" + "log" "net/http" "strconv" "strings" @@ -39,102 +40,96 @@ import ( // System Documentation (based on HATEOAS) provides an initial documentation on the system's web server of with hyperlinks to the services for browsers // HATEOAS is the acronym for Hypermedia as the Engine of Application State, using hyperlinks to navigate the API func SysHateoas(w http.ResponseWriter, req *http.Request, sys components.System) { - text := "" - w.Write([]byte(text)) - text = "

System Description

" - w.Write([]byte(text)) - text = "

The system " + sys.Name + " " + sys.Husk.Description + "


" - w.Write([]byte(text)) - text = "Online Documentation

" - w.Write([]byte(text)) - - text = "

The resource list is

" + _, err := w.Write([]byte(text)) + if err != nil { + log.Printf("Error while writing to response body for SysHateoas: %v", err) + } } // ResHateoas provides information about the unit asset(s) and each service and is accessed via the system's web server func ResHateoas(w http.ResponseWriter, req *http.Request, ua components.UnitAsset, sys components.System) { - text := "" - w.Write([]byte(text)) - - text = "

Unit Asset Description

" - w.Write([]byte(text)) + text := "\n" + text += "

Unit Asset Description

\n" uaName := ua.GetName() metaservice := "" for key, values := range ua.GetDetails() { metaservice += key + ": " + fmt.Sprintf("%v", values) + " " } - text = "The resource " + uaName + " belongs to system " + sys.Name + " and has the details " + metaservice + " with the following services:" + "" + _, err := w.Write([]byte(text)) + if err != nil { + log.Printf("Error while writing response body for ResHateoas: %v", err) + } } // ServiceHateoas provides information about the service and is accessed via the system's web server -func ServiceHateoas(w http.ResponseWriter, req *http.Request, ser components.Service, sys components.System) { +func ServiceHateoas(w http.ResponseWriter, req *http.Request, serv components.Service, sys components.System) { parts := strings.Split(req.URL.Path, "/") uaName := parts[2] - text := "" - w.Write([]byte(text)) - - text = "

Service Description

" - w.Write([]byte(text)) + text := "\n" + text += "

Service Description

\n" metaservice := "" - for key, values := range ser.Details { + for key, values := range serv.Details { metaservice += key + ": " + fmt.Sprintf("%v", values) + " " } - text = "The service " + ser.Definition + " " + ser.Description + " and has the details " + metaservice - w.Write([]byte(text)) + text += "The service " + serv.Definition + " " + serv.Description + " and has the details " + metaservice + _, err := w.Write([]byte(text)) + if err != nil { + log.Printf("Error while writing response body for ServiceHateoas: %v", err) + } } // // getFirstAsset returns the first key-value pair in the Assets map diff --git a/usecases/extra_utils_test.go b/usecases/extra_utils_test.go new file mode 100644 index 0000000..3f043b5 --- /dev/null +++ b/usecases/extra_utils_test.go @@ -0,0 +1,195 @@ +package usecases + +import ( + "context" + "encoding/xml" + "fmt" + "net/http" + + "github.com/sdoque/mbaigo/components" + "github.com/sdoque/mbaigo/forms" +) + +// mockTransport is used for replacing the default network Transport (used by +// http.DefaultClient) and it will intercept network requests. +type mockTransport struct { + respFunc func() *http.Response + hits int + err error +} + +func newMockTransport(respFunc func() *http.Response, v int, err error) *mockTransport { + t := &mockTransport{ + respFunc: respFunc, + hits: v, + err: err, + } + // Hijack the default http client so no actual http requests are sent over the network + http.DefaultClient.Transport = t + return t +} + +// RoundTrip method is required to fulfil the RoundTripper interface (as required by the DefaultClient). +// It prevents the request from being sent over the network, and count how many times +// a http request was sent +func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) { + t.hits -= 1 + if t.hits == 0 { + return resp, t.err + } + resp = t.respFunc() + resp.Request = req + return resp, nil +} + +// Traits are Asset-specific configurable parameters and variables +type Traits struct { +} + +// A mocked UnitAsset used for testing +type mockUnitAsset struct { + Name string `json:"name"` // Must be a unique name, ie. a sensor ID + Owner *components.System `json:"-"` // The parent system this UA is part of + Details map[string][]string `json:"details"` // Metadata or details about this UA + ServicesMap components.Services `json:"-"` + CervicesMap components.Cervices `json:"-"` + Traits +} + +func (mua mockUnitAsset) GetName() string { + return mua.Name +} + +func (mua mockUnitAsset) GetServices() components.Services { + return mua.ServicesMap +} + +func (mua mockUnitAsset) GetCervices() components.Cervices { + return mua.CervicesMap +} + +func (mua mockUnitAsset) GetDetails() map[string][]string { + return mua.Details +} + +// GetTraits returns the traits of the Resource. +func (ua mockUnitAsset) GetTraits() any { + return ua.Traits +} + +func (mua mockUnitAsset) Serving(w http.ResponseWriter, r *http.Request, servicePath string) {} + +// A mocked form used for testing +type mockForm struct { + XMLName xml.Name `json:"-" xml:"testName"` + Value any `json:"value" xml:"value"` + Unit string `json:"unit" xml:"unit"` + Version string `json:"version" xml:"version"` +} + +// NewForm creates a new form +func (f mockForm) NewForm() forms.Form { + f.Version = "testVersion" + return f +} + +// FormVersion returns the version of the form +func (f mockForm) FormVersion() string { + return f.Version +} + +// Create a error reader to break json.Unmarshal() +type errReader int + +var errBodyRead error = fmt.Errorf("bad body read") + +func (errReader) Read(p []byte) (n int, err error) { + return 0, errBodyRead +} +func (errReader) Close() error { + return nil +} + +// Variables used in testing +var brokenUrl = string(rune(0)) +var errHTTP error = fmt.Errorf("bad http request") + +// Help function to create a test system +func createTestSystem(broken bool) (sys components.System) { + // instantiate the System + ctx := context.Background() + sys = components.NewSystem("testSystem", ctx) + + // Instantiate the Capsule + sys.Husk = &components.Husk{ + Description: "A test system", + Details: map[string][]string{"Developer": {"Test dev"}}, + ProtoPort: map[string]int{"https": 0, "http": 1234, "coap": 0}, + InfoLink: "https://for.testing.purposes", + } + + // create fake services and cervices for a mocked unit asset + testCerv := &components.Cervice{ + Definition: "testCerv", + Details: map[string][]string{"Forms": {"SignalA_v1a"}}, + Nodes: map[string][]string{}, + } + + CervicesMap := &components.Cervices{ + testCerv.Definition: testCerv, + } + setTest := &components.Service{ + ID: 1, + Definition: "test", + SubPath: "test", + Details: map[string][]string{"Forms": {"SignalA_v1a"}}, + Description: "A test service", + RegPeriod: 45, + RegTimestamp: "now", + RegExpiration: "45", + } + ServicesMap := &components.Services{ + setTest.SubPath: setTest, + } + mua := &mockUnitAsset{ + Name: "testUnitAsset", + Details: map[string][]string{"Test": {"Test"}}, + ServicesMap: *ServicesMap, + CervicesMap: *CervicesMap, + } + + sys.UAssets = make(map[string]*components.UnitAsset) + var muaInterface components.UnitAsset = mua + sys.UAssets[mua.GetName()] = &muaInterface + + leadingRegistrar := &components.CoreSystem{ + Name: components.ServiceRegistrarName, + Url: "https://leadingregistrar", + } + test := &components.CoreSystem{ + Name: "test", + Url: "https://test", + } + if broken == false { + orchestrator := &components.CoreSystem{ + Name: "orchestrator", + Url: "https://orchestator", + } + sys.CoreS = []*components.CoreSystem{ + leadingRegistrar, + orchestrator, + test, + } + } else { + orchestrator := &components.CoreSystem{ + Name: "orchestrator", + Url: brokenUrl, + } + sys.CoreS = []*components.CoreSystem{ + leadingRegistrar, + orchestrator, + test, + } + } + return +} diff --git a/usecases/kgraphing.go b/usecases/kgraphing.go index 8bb42f6..349fdd4 100644 --- a/usecases/kgraphing.go +++ b/usecases/kgraphing.go @@ -29,6 +29,7 @@ package usecases import ( "fmt" + "log" "net/http" "strconv" "strings" @@ -45,7 +46,10 @@ func KGraphing(w http.ResponseWriter, req *http.Request, sys *components.System) rdf += modelUAsset(sys) w.Header().Set("Content-Type", "text/turtle") - w.Write([]byte(rdf)) + _, err := w.Write([]byte(rdf)) + if err != nil { + log.Println("Failed to write KGraphing information: ", err) + } } func prefixes() (description string) { @@ -69,6 +73,16 @@ func modelSystem(sys *components.System) (systemModel string) { } details := sys.Husk.Details for key, values := range details { + if key == "LocalCloud" { // it is expected that only the System Registrars have those keys and all have the same name (if not the KGrapher will use the first one it finds) + if len(values) > 0 { + v := values[0] + if !(strings.HasPrefix(v, "<") && strings.HasSuffix(v, ">")) && !strings.HasPrefix(v, "alc:") { + v = "alc:" + v + } + systemModel += fmt.Sprintf(" afo:isContainedIn %s ;\n", v) + } + continue + } for _, value := range values { if !(strings.HasPrefix(value, "<") && strings.HasSuffix(value, ">")) { value = "alc:" + value @@ -123,6 +137,23 @@ func modelUAsset(sys *components.System) string { details := (*asset).GetDetails() for key, values := range details { + fmt.Printf("key: %s, values: %v\n", key, values) + if strings.HasSuffix(key, ":") { + for _, value := range values { + if value == "" { + log.Printf("Warning: empty value for key '%s' in asset '%s'. Skipping.", key, assetName) + continue + } + relationship := value[0] // byte + reference := value[1:] // string (from second character onward) + + switch relationship { + case '=': // single quotes for byte comparison + assetModels += fmt.Sprintf(" owl:sameAs %s ;\n", reference) + } + } + continue + } for _, value := range values { if !(strings.HasPrefix(value, "<") && strings.HasSuffix(value, ">")) { value = "alc:" + value @@ -145,7 +176,7 @@ func modelUAsset(sys *components.System) string { servicesLen := len(services) serviceCount := 0 for _, service := range services { - assetModels += fmt.Sprintf(" afo:providesService alc:%s_%s_%s", sName, assetName, service.Definition) + assetModels += fmt.Sprintf(" afo:providesService alc:%s_%s_%s", sName, assetName, service.SubPath) serviceCount++ if serviceCount < servicesLen { assetModels += " ;\n" @@ -230,7 +261,7 @@ func modelServices(sName string, ua *components.UnitAsset, sys *components.Syste servicesModel += fmt.Sprintf(" afo:hasServiceDefinition \"%s\" ;\n", service.Definition) for protocol, port := range sys.Husk.ProtoPort { if port != 0 { - addr := protocol + "://" + sys.Host.IPAddresses[0] + ":" + strconv.Itoa(port) + "/" + sys.Name + "/" + assetName + "/" + service.Definition + addr := protocol + "://" + sys.Host.IPAddresses[0] + ":" + strconv.Itoa(port) + "/" + sys.Name + "/" + assetName + "/" + service.SubPath servicesModel += fmt.Sprintf(" afo:hasUrl <%s> ;\n", addr) } } diff --git a/usecases/provision.go b/usecases/provision.go index 4f86d19..f731cba 100644 --- a/usecases/provision.go +++ b/usecases/provision.go @@ -20,18 +20,19 @@ package usecases import ( - "encoding/json" - "errors" "fmt" "io" "log" "net/http" "strings" + "github.com/sdoque/mbaigo/components" "github.com/sdoque/mbaigo/forms" ) // HTTPProcessSetRequest processes a Get request +// TODO: this function should really return an error too and behave like everyone +// else. And causing http.Errors is an ugly side effect. func HTTPProcessGetRequest(w http.ResponseWriter, r *http.Request, f forms.Form) { if f == nil { http.Error(w, "No payload found.", http.StatusNotFound) @@ -47,46 +48,39 @@ func HTTPProcessGetRequest(w http.ResponseWriter, r *http.Request, f forms.Form) responseData, err := Pack(f, bestContentType) if err != nil { - http.Error(w, fmt.Sprintf("Error packing response: %v", err), http.StatusInternalServerError) + log.Printf("Error packing response: %v", err) + http.Error(w, "Error packing response.", http.StatusInternalServerError) return } w.Header().Set("Content-Type", bestContentType) w.WriteHeader(http.StatusOK) - w.Write(responseData) + _, err = w.Write(responseData) + if err != nil { + log.Printf("Error while writing response: %v", err) + http.Error(w, "Error writing response.", http.StatusInternalServerError) + } } // HTTPProcessSetRequest processes a SET request -func HTTPProcessSetRequest(w http.ResponseWriter, req *http.Request) (f forms.SignalA_v1a, err error) { - defer req.Body.Close() +func HTTPProcessSetRequest(w http.ResponseWriter, req *http.Request) (sig forms.SignalA_v1a, err error) { bodyBytes, err := io.ReadAll(req.Body) // Use io.ReadAll instead of ioutil.ReadAll if err != nil { - log.Printf("Error reading request body: %v", err) + err = fmt.Errorf("reading request body: %w", err) return } - var jsonData map[string]interface{} - err = json.Unmarshal(bodyBytes, &jsonData) + defer req.Body.Close() + headerContentType := req.Header.Get("Content-Type") + f, err := Unpack(bodyBytes, headerContentType) if err != nil { - log.Printf("Error unmarshaling JSON data: %v", err) return } - formVersion, ok := jsonData["version"].(string) + temp, ok := f.(*forms.SignalA_v1a) if !ok { - log.Printf("Error: 'version' key not found in JSON data") + err = fmt.Errorf("form is not of type SignalA_v1a") return } - switch formVersion { - case "SignalA_v1.0": - var sig forms.SignalA_v1a - err = json.Unmarshal(bodyBytes, &sig) - if err != nil { - log.Println("Unable to extract signal set request ") - return - } - f = sig - default: - err = errors.New("unsupported service set request form version") - } + sig = *temp // Stupid type conversion because return type was picked incorrectly return } @@ -108,7 +102,10 @@ func getBestContentType(acceptHeader string) string { // Check for q-value in the MIME type if len(parts) > 1 && strings.HasPrefix(parts[1], "q=") { - fmt.Sscanf(parts[1], "q=%f", &qValue) + _, err := fmt.Sscanf(parts[1], "q=%f", &qValue) + if err != nil { + continue + } } // Update the best content type if this one has a higher q-value @@ -125,3 +122,43 @@ func getBestContentType(acceptHeader string) string { return bestType } + +func RegisterMessenger(resp http.ResponseWriter, req *http.Request, sys *components.System) { + if req.Method != "POST" { + http.Error(resp, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return + } + + body, err := io.ReadAll(req.Body) + if err != nil { + http.Error(resp, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + defer req.Body.Close() + + // Won't bother logging the following errors as they are caused by bad/poor + // client requests, which we don't really care about on the server side. + form, err := Unpack(body, req.Header.Get("Content-Type")) + if err != nil { + http.Error(resp, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + return + } + registration, ok := form.(*forms.MessengerRegistration_v1) + if !ok { + http.Error(resp, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + return + } + if len(registration.Host) < 1 { + http.Error(resp, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + return + } + + sys.Mutex.Lock() + defer sys.Mutex.Unlock() + if _, found := sys.Messengers[registration.Host]; found { + // The system already knows the messenger, avoid re-storing it so that + // the error count don't get reset + return + } + sys.Messengers[registration.Host] = 0 // Registers the new messenger with zero errors +} diff --git a/usecases/provision_test.go b/usecases/provision_test.go new file mode 100644 index 0000000..2b6817b --- /dev/null +++ b/usecases/provision_test.go @@ -0,0 +1,248 @@ +package usecases + +import ( + "context" + "encoding/xml" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/sdoque/mbaigo/components" + "github.com/sdoque/mbaigo/forms" +) + +type httpProcessGetRequestStruct struct { + inputW http.ResponseWriter + inputBody string + inputF forms.Form + expectedBody string + testName string +} + +type mockResponseWriter struct { + http.ResponseWriter +} + +func (e *mockResponseWriter) Write(b []byte) (int, error) { + return 0, fmt.Errorf("Forced write error") +} + +func (e *mockResponseWriter) WriteHeader(statusCode int) {} + +func (e *mockResponseWriter) Header() http.Header { + return make(http.Header) +} + +func createEmptyFormVersion() mockForm { + form := mockForm{ + XMLName: xml.Name{}, + Value: 0, + Unit: "testUnit", + Version: "", + } + return form +} + +func createBrokenForm() mockForm { + form := mockForm{ + XMLName: xml.Name{}, + Value: complex(1, 2), + Unit: "testUnit", + Version: "SignalA_v1.0", + } + return form +} + +var httpProcessGetRequestParams = []httpProcessGetRequestStruct{ + {httptest.NewRecorder(), "{\n \"value\": 0,\n \"unit\": \"\",\n}", form.NewForm(), + "{\n \"value\": 0,\n \"unit\": \"\",\n \"timestamp\": \"0001-01-01T00:00:00Z\",\n " + + " \"version\": \"SignalA_v1.0\"\n}", "Good case"}, + {httptest.NewRecorder(), "
0
", nil, + "No payload found.\n", "Bad case, form is nil"}, + {httptest.NewRecorder(), "\n", createEmptyFormVersion(), + "No payload information found.\n", "Bad case, form version is empty"}, + {httptest.NewRecorder(), "", createBrokenForm(), + "Error packing response.\n", "Bad case, form value is invalid"}, + {&mockResponseWriter{}, "", form.NewForm(), + "", "Bad case, Write fails"}, +} + +func TestHTTPProcessGetRequest(t *testing.T) { + for _, testCase := range httpProcessGetRequestParams { + inputR := httptest.NewRequest(http.MethodGet, "/test123", io.NopCloser(strings.NewReader(testCase.inputBody))) + HTTPProcessGetRequest(testCase.inputW, inputR, testCase.inputF) + + if testCase.testName == "Bad case, Write fails" { + if _, ok := testCase.inputW.(*mockResponseWriter); !ok { + t.Errorf("Expected inputW to be of type *mockResponseWriter") + } + } + recorder, ok := testCase.inputW.(*httptest.ResponseRecorder) + if ok { + if recorder.Body.String() != testCase.expectedBody { + t.Errorf("Expected %s, got: %s", testCase.expectedBody, recorder.Body.String()) + } + } + } +} + +type httpProcessSetRequestStruct struct { + inputW http.ResponseWriter + inputBody string + expectedErr bool + expectedForm forms.SignalA_v1a + testName string +} + +func createForm() forms.SignalA_v1a { + form.NewForm() + form.Value = 0 + return form +} + +var httpProcessSetRequestParams = []httpProcessSetRequestStruct{ + {httptest.NewRecorder(), "{\n \"value\": 0,\n \"unit\": \"\",\n " + + " \"timestamp\": \"0001-01-01T00:00:00Z\",\n \"version\": \"SignalA_v1.0\"\n}", + false, createForm(), "Good case"}, + {httptest.NewRecorder(), "\n", true, forms.SignalA_v1a{}, "Bad case, Unmarshal returns error"}, + {httptest.NewRecorder(), "{\n \"value\": 0,\n \"unit\": \"\",\n " + + " \"timestamp\": \"0001-01-01T00:00:00Z\",\n \"bersion\": \"SignalA_v1.0\"\n}", + true, forms.SignalA_v1a{}, "Bad case, version key missing"}, + {httptest.NewRecorder(), "{\n \"value\": \"not-a-number\",\n \"unit\": \"\",\n " + + " \"timestamp\": \"0001-01-01T00:00:00Z\",\n \"version\": \"SignalA_v1.0\"\n}", + true, forms.SignalA_v1a{}, "Bad case, Second Unmarshal breaks"}, + {httptest.NewRecorder(), "{\n \"value\": 0,\n \"unit\": \"\",\n " + + " \"timestamp\": \"0001-01-01T00:00:00Z\",\n \"version\": \"SignalB_v1.0\"\n}", + true, forms.SignalA_v1a{}, "Bad case, version is wrong"}, + {httptest.NewRecorder(), "{\n \"value\": false,\n " + + " \"timestamp\": \"0001-01-01T00:00:00Z\",\n \"version\": \"SignalB_v1.0\"\n}", + true, forms.SignalA_v1a{}, "Bad case, form version is SignalB_v1a"}, +} + +func TestHTTPProcessSetRequest(t *testing.T) { + for _, testCase := range httpProcessSetRequestParams { + inputR := httptest.NewRequest(http.MethodPut, "/test123", io.NopCloser(strings.NewReader(testCase.inputBody))) + inputR.Header.Set("Content-Type", "application/json") + f, err := HTTPProcessSetRequest(testCase.inputW, inputR) + + if f != testCase.expectedForm || (err == nil && testCase.expectedErr == true) || + (err != nil && testCase.expectedErr == false) { + t.Errorf("Expected %v and %v, got: %v and %v", testCase.expectedForm, testCase.expectedErr, f, err) + } + } + + // Special case + specialRequest := httptest.NewRequest(http.MethodPut, "/test123", io.NopCloser(errorReader{})) + specialRequest.Header.Set("Content-Type", "application/json") + expectedForm := forms.SignalA_v1a{} + f, err := HTTPProcessSetRequest(httptest.NewRecorder(), specialRequest) + + if f != expectedForm || err == nil { + t.Errorf("Expected %v, got: %v", expectedForm, f) + } +} + +type getBestContentTypeStruct struct { + acceptHeaderInput string + bestContentTypeOutput string + testName string +} + +var getBestContentTypeParams = []getBestContentTypeStruct{ + {"", "application/json", + "Good case, no accept header provided"}, + {"application/xml", "application/xml", + "Good case, accept header provided without q-values"}, + {"application/xml;q=0.7, application/json;q=0.9", "application/json", + "Good case, accept header provided with q-values"}, + {"application/xml;q=wrong, application/json;q=1.1", "application/json", + "Good case, xml gets skipped"}, + {"application/xml;q=0.9, application/json;q=0.9", "application/xml", + "Good case, equal q-values selects the first one"}, + {"application/xml;q=-0.9", "application/json", + "Good case, no MIME type found"}, +} + +func TestGetBestContentType(t *testing.T) { + for _, testCase := range getBestContentTypeParams { + res := getBestContentType(testCase.acceptHeaderInput) + + if res != testCase.bestContentTypeOutput { + t.Errorf("Expected %v, got: %v in test case: %s", testCase.bestContentTypeOutput, res, testCase.testName) + } + } +} + +const testMessenger string = "testmessenger" +const testRegMesForm string = `{"version": "MessengerRegistration_v1", "host": "` + testMessenger + `"}` + +func TestRegisterMessenger(t *testing.T) { + table := []struct { + method string + contentType string + body io.ReadCloser + expectedStatus int + }{ + // Bad method + {http.MethodGet, "application/json", nil, http.StatusMethodNotAllowed}, + // Bad body + {http.MethodPost, "application/json", errReader(0), http.StatusInternalServerError}, + // Bad unpack + {http.MethodPost, "bad type", nil, http.StatusBadRequest}, + // Bad form + {http.MethodPost, "application/json", + io.NopCloser(strings.NewReader(`{"version": "SystemMessage_v1"}`)), + http.StatusBadRequest, + }, + // Missing host + {http.MethodPost, "application/json", + io.NopCloser(strings.NewReader(`{"version": "MessengerRegistration_v1"}`)), + http.StatusBadRequest, + }, + // All good + // WARN: this case is expected to be the last one in this table, as its + // result is being used in the special cases! + {http.MethodPost, "application/json", + io.NopCloser(strings.NewReader(testRegMesForm)), + http.StatusOK, + }, + } + + sys := components.NewSystem("testsys", context.Background()) + testFunc := func(method, content string, body io.ReadCloser) *http.Response { + rec := httptest.NewRecorder() + req := httptest.NewRequest(method, "/msg", body) + req.Header.Set("Content-Type", content) + RegisterMessenger(rec, req, &sys) + + return rec.Result() + } + for _, test := range table { + res := testFunc(test.method, test.contentType, test.body) + if got, want := res.StatusCode, test.expectedStatus; got != want { + t.Errorf("expected status %d, got %d", want, got) + } + } + + // Verify the messenger was registered from the last test case + errors, found := sys.Messengers[testMessenger] + if errors != 0 || found == false { + t.Errorf("expected registered messenger, found none") + } + + // Verify duplicate registration doesn't lose error count + errCount := -1 + sys.Messengers[testMessenger] = errCount + res := testFunc(http.MethodPost, "application/json", + io.NopCloser(strings.NewReader(testRegMesForm)), + ) + if got, want := res.StatusCode, http.StatusOK; got != want { + t.Errorf("expected status %d, got %d", want, got) + } + if got, want := sys.Messengers[testMessenger], errCount; got != want { + t.Errorf("expected error count %d, got %d", want, got) + } +} diff --git a/usecases/registration.go b/usecases/registration.go old mode 100755 new mode 100644 index 78b7a05..63de4b3 --- a/usecases/registration.go +++ b/usecases/registration.go @@ -26,101 +26,76 @@ import ( "fmt" "io" "log" - "net" "net/http" "strconv" - "strings" + "sync" "time" "github.com/sdoque/mbaigo/components" "github.com/sdoque/mbaigo/forms" ) +type registrarTracker struct { + url string + mutex sync.RWMutex +} + +func (rt *registrarTracker) set(url string) { + rt.mutex.Lock() + rt.url = url + rt.mutex.Unlock() +} + +func (rt *registrarTracker) get() string { + rt.mutex.RLock() + defer rt.mutex.RUnlock() + return rt.url +} + // RegisterServices keeps track of the leading Service Registrar and keeps all services registered func RegisterServices(sys *components.System) { - - var leadingRegistrar *components.CoreSystem - // Create a buffered channel for the pointer to the leading service registrar - registrarStream := make(chan *components.CoreSystem, 1) + // Keep track of the registrar URL. The URL is shared between goroutines, + // so it must be protected from data races using a mutex. + registrar := ®istrarTracker{} // Goroutine looking for leading service registrar every 5 seconds go func() { - defer close(registrarStream) - ticker := time.NewTicker(5 * time.Second) - defer ticker.Stop() + ticker := time.Tick(5 * time.Second) for { - if leadingRegistrar != nil { - resp, err := http.Get(leadingRegistrar.Url + "/status") - if err != nil { - log.Println("lost leading registrar status:", err) - leadingRegistrar = nil - continue // Skip to the next iteration of the loop - } - - // Read from resp.Body and then close it directly after - bodyBytes, err := io.ReadAll(resp.Body) - resp.Body.Close() // Close the body directly after reading from it - if err != nil { - log.Println("\rError reading response from leading registrar:", err) - leadingRegistrar = nil - continue // Skip to the next iteration of the loop - } - - if !strings.HasPrefix(string(bodyBytes), "lead Service Registrar since") { - leadingRegistrar = nil - log.Println("lost previous leading registrar") - } - } else { - for _, cSys := range sys.CoreS { - core := cSys - if core.Name == "serviceregistrar" { - resp, err := http.Get(core.Url + "/status") - if err != nil { - fmt.Println("error checking service registrar status:", err) - continue // Skip to the next iteration of the loop - } - - // Read from resp.Body and then close it directly after - bodyBytes, err := io.ReadAll(resp.Body) - resp.Body.Close() // Close the body directly after reading from it - if err != nil { - fmt.Println("Error reading service registrar response body:", err) - continue // Skip to the next iteration of the loop - } - - if strings.HasPrefix(string(bodyBytes), "lead Service Registrar since") { - leadingRegistrar = core - fmt.Printf("\nlead registrar found at: %s\n", leadingRegistrar.Url) - } - } - } + newURL, err := components.GetRunningCoreSystemURL(sys, components.ServiceRegistrarName) + registrar.set(newURL) // should be empty on error anyway + if err != nil { + log.Println("failed to find lead registrar:", err) } select { - case <-ticker.C: + case <-ticker: case <-sys.Ctx.Done(): return } } }() + + // Run registration loops for each services assetList := &sys.UAssets for _, aResource := range *assetList { servs := (*aResource).GetServices() for _, service := range servs { - // service := (*servs)[j] // Correctly dereference the slice pointer and access the element go func(theUnitAsset *components.UnitAsset, theService *components.Service) { delay := 1 * time.Second + var err error for { - timer := time.NewTimer(delay) select { - case <-timer.C: - if leadingRegistrar != nil { - delay = registerService(sys, theUnitAsset, theService, leadingRegistrar) - } else { - delay = 15 * time.Second + case <-time.Tick(delay): + delay, err = registerService(sys, registrar.get(), theUnitAsset, theService) + if err != nil { + log.Println("registering service:", err) } case <-sys.Ctx.Done(): - deregisterService(leadingRegistrar, theService) + err = unregisterService(registrar.get(), theService) + if err != nil { + log.Println("unregistering service:", err) + } return } } @@ -130,122 +105,126 @@ func RegisterServices(sys *components.System) { } // registerService makes a POST or PUT request to register or register individual services -func registerService(sys *components.System, ua *components.UnitAsset, ser *components.Service, registrar *components.CoreSystem) (delay time.Duration) { - +func registerService(sys *components.System, registrar string, ua *components.UnitAsset, serv *components.Service) (delay time.Duration, err error) { delay = 15 * time.Second + if registrar == "" { + if serv.ID != 0 { + serv.ID = 0 // reset the service ID, so that a new registration (POST) will be made when the registrar is back + } + return + } + // Prepare request - reqPayload, err := serviceRegistrationForm(sys, ua, ser, "ServiceRecord_v1") + reqPayload, err := serviceRegistrationForm(sys, ua, serv, "ServiceRecord_v1") if err != nil { - log.Println("Registration marshall error, ", err) + err = fmt.Errorf("registration marshall: %w", err) return } - registrationurl := registrar.Url + "/register" + registrationURL := registrar + "/register" var req *http.Request // Declare req outside the blocks - if ser.ID == 0 { - req, err = http.NewRequest("POST", registrationurl, bytes.NewBuffer(reqPayload)) + if serv.ID == 0 { + req, err = http.NewRequest("POST", registrationURL, bytes.NewBuffer(reqPayload)) if err != nil { - log.Printf("unable to register service %s with lead registrar\n", ser.Definition) + err = fmt.Errorf("unable to register service %s with lead registrar", serv.Definition) return } } else { - req, err = http.NewRequest("PUT", registrationurl, bytes.NewBuffer(reqPayload)) + req, err = http.NewRequest("PUT", registrationURL, bytes.NewBuffer(reqPayload)) if err != nil { - log.Printf("unable to confirm the %s service with lead registar", ser.Definition) + err = fmt.Errorf("unable to confirm the %s service with lead registrar", serv.Definition) return } } req.Header.Set("Content-Type", "application/json; charset=UTF-8") - client := &http.Client{Timeout: time.Second * 5} - resp, err := client.Do(req) // execute the request and get the reply + + resp, err := http.DefaultClient.Do(req) // execute the request and get the reply if err != nil { - switch err := err.(type) { - case net.Error: - if err.Timeout() { - log.Printf("registry timeout with lead registrar %s\n", registrationurl) - } else { - log.Printf("unable to (re-)register service %s with lead registrar\n", ser.Definition) - } - default: - log.Printf("registration request error with %s, and error %s\n", registrationurl, err) - } - registrar = nil - ser.ID = 0 // if re-registration failed, a complete new one should be made (POST) + err = fmt.Errorf("registration request: %w", err) + serv.ID = 0 // if re-registration failed, a complete new one should be made (POST) return } - // Handle response ------------------------------------------------ + if resp.StatusCode < 200 || resp.StatusCode > 299 { + err = fmt.Errorf("bad registration response: %s", resp.Status) + serv.ID = 0 + return + } - if resp != nil { - defer resp.Body.Close() - bodyBytes, err := io.ReadAll(resp.Body) // Use io.ReadAll instead of ioutil.ReadAll - if err != nil { - log.Printf("Error reading registration response body: %v", err) - return - } + // Handle response ------------------------------------------------ - headerContentTtype := resp.Header.Get("Content-Type") - rRecord, err := Unpack(bodyBytes, headerContentTtype) - if err != nil { - log.Printf("error extracting the registration record relpy %v\n", err) - } + var b []byte + b, err = io.ReadAll(resp.Body) // Use io.ReadAll instead of ioutil.ReadAll + if err != nil { + err = fmt.Errorf("reading registration response body: %w", err) + return + } + defer resp.Body.Close() - // Perform a type assertion to convert the returned Form to ServiceRecord_v1 - rr, ok := rRecord.(*forms.ServiceRecord_v1) - if !ok { - fmt.Println("Problem unpacking the service registration reply") - return - } + headerContentType := resp.Header.Get("Content-Type") + rRecord, err := Unpack(b, headerContentType) + if err != nil { + err = fmt.Errorf("extracting the registration record reply: %w", err) + return + } - ser.ID = rr.Id - ser.RegTimestamp = rr.Created - ser.RegExpiration = rr.EndOfValidity - parsedTime, err := time.Parse(time.RFC3339, rr.EndOfValidity) - if err != nil { - log.Printf("Error parsing input: %s", err) - return - } - delay = time.Until(parsedTime.Add(-5 * time.Second)) // should not wait until the deadline to start to confirrm live status + // Perform a type assertion to convert the returned Form to ServiceRecord_v1 + rr, ok := rRecord.(*forms.ServiceRecord_v1) + if !ok { + err = fmt.Errorf("invalid form from the service registration reply") + return } + serv.ID = rr.Id + serv.RegTimestamp = rr.Created + serv.RegExpiration = rr.EndOfValidity + parsedTime, err := time.Parse(time.RFC3339, rr.EndOfValidity) + if err != nil { + err = fmt.Errorf("parsing time: %w", err) + return + } + // should not wait until the deadline to start to confirm live status + delay = time.Until(parsedTime.Add(-5 * time.Second)) + if delay < 1*time.Second { + // Avoid using zero/negative delays + delay = 1 * time.Second + } return } -// deregisterService deletes a service from the database based on its service id -func deregisterService(registrar *components.CoreSystem, ser *components.Service) { - if registrar == nil { - return // there is no need to deregister if there is no leading registrar +// unregisterService deletes a service from the database based on its service id +func unregisterService(registrar string, serv *components.Service) error { + if registrar == "" { + return nil // there is no need to deregister if there is no leading registrar } - client := &http.Client{} - deRegServURL := registrar.Url + "/unregister/" + strconv.Itoa(ser.ID) - fmt.Printf("Trying to unregiseter %s\n", deRegServURL) - req, err := http.NewRequest("DELETE", deRegServURL, nil) // create a new request using http + u := registrar + "/unregister/" + strconv.Itoa(serv.ID) + req, err := http.NewRequest("DELETE", u, nil) if err != nil { - log.Println(err) - return + return err } - resp, err := client.Do(req) // make the request + resp, err := http.DefaultClient.Do(req) if err != nil { - log.Println(err) - return + // Can't do anything about network errors. Don't care much either, + // since this system is shutting down. Ignorering this error for now. + return nil } defer resp.Body.Close() - fmt.Printf("service %s deleted from the service registrar with HTTP Response Status: %d, %s\n", ser.Definition, resp.StatusCode, http.StatusText(resp.StatusCode)) + return nil } // serviceRegistrationForm returns a json data byte array with the data of the service to be registered // in the form of choice [Sending @ Application system] -func serviceRegistrationForm(sys *components.System, ua *components.UnitAsset, ser *components.Service, version string) (payload []byte, err error) { +func serviceRegistrationForm(sys *components.System, ua *components.UnitAsset, serv *components.Service, version string) (payload []byte, err error) { var f forms.Form switch version { case "ServiceRecord_v1": resName := (*ua).GetName() var sr forms.ServiceRecord_v1 // declare a new service form sr.NewForm() - sr.Id = ser.ID - sr.ServiceDefinition = ser.Definition + sr.Id = serv.ID + sr.ServiceDefinition = serv.Definition sr.SystemName = sys.Name - sr.ServiceNode = sys.Host.Name + "_" + sys.Name + "_" + resName + "_" + ser.Definition + sr.ServiceNode = sys.Host.Name + "_" + sys.Name + "_" + resName + "_" + serv.Definition sr.IPAddresses = sys.Host.IPAddresses sr.ProtoPort = make(map[string]int) // initialize the map for key, port := range sys.Husk.ProtoPort { @@ -254,17 +233,17 @@ func serviceRegistrationForm(sys *components.System, ua *components.UnitAsset, s } } sr.Details = deepCopyMap((*ua).GetDetails()) - for key, valueSlice := range ser.Details { + for key, valueSlice := range serv.Details { sr.Details[key] = append(sr.Details[key], valueSlice...) } - sr.SubPath = resName + "/" + ser.SubPath + sr.SubPath = resName + "/" + serv.SubPath - if ser.RegPeriod != 0 { - sr.RegLife = ser.RegPeriod + if serv.RegPeriod != 0 { + sr.RegLife = serv.RegPeriod } else { sr.RegLife = 30 } - sr.Created = ser.RegTimestamp + sr.Created = serv.RegTimestamp f = &sr default: err = errors.New("unsupported service registration form version") diff --git a/usecases/registration_test.go b/usecases/registration_test.go new file mode 100644 index 0000000..2b7b314 --- /dev/null +++ b/usecases/registration_test.go @@ -0,0 +1,308 @@ +package usecases + +import ( + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "strings" + "testing" + "time" + + "github.com/sdoque/mbaigo/components" + "github.com/sdoque/mbaigo/forms" +) + +type timeoutError struct{} + +func (timeoutError) Error() string { return "timeout" } +func (timeoutError) Timeout() bool { return true } +func (timeoutError) Temporary() bool { return true } + +type errorReader struct{} + +func (errorReader) Read(p []byte) (int, error) { + return 0, fmt.Errorf("forced read error") +} + +func manualEqualityCheck(map1 map[string][]string, map2 map[string][]string) error { + if len(map1) != len(map2) { + return fmt.Errorf("Expected map length %d, got %d", len(map2), len(map1)) + } + for key, value := range map2 { + mv, ok := map1[key] + if !ok { + return fmt.Errorf("Expected key %q not found in merged map", key) + } + if len(mv) != len(value) { + return fmt.Errorf("For key %q, expected slice length %d, got %d", key, len(value), len(mv)) + } + for i := range value { + if mv[i] != value[i] { + return fmt.Errorf("For key %q, at index %d, expected %q, got %q", key, i, value[i], mv[i]) + } + } + } + for key := range map1 { + if _, ok := map2[key]; !ok { + return fmt.Errorf("Unexpected key %q found in merged map", key) + } + } + return nil +} + +func TestDeepCopyMap(t *testing.T) { + var original = map[string][]string{"a": {"1", "2"}, "b": {"3"}} + + test := deepCopyMap(original) + + // If they are not equal from the beginning then the copy was not successful + err := manualEqualityCheck(original, test) + if err != nil { + t.Errorf("Expected deep copied map to be equal to original, Expected: %v, got: %v", original, test) + } + + // When we change something in the original, the deep copied map should not change + original["a"][0] = "changed original" + err = manualEqualityCheck(original, test) + if err == nil { + t.Errorf("Deep copy failed, changes in original affected the deep copied map."+ + " Expected: %v, got %v", original, test) + } + original["a"][0] = "1" + + // When we change something in the deep copied map, the original should not change + test["a"][0] = "changed deep copy" + err = manualEqualityCheck(original, test) + if err == nil { + t.Errorf("Deep copy failed, changes in deep copied map affected the original."+ + " Expected: %v, got %v", original, test) + } +} + +type serviceRegistrationFormTestStruct struct { + version string + expectedErr bool + testName string +} + +var serviceRegistrationFormTestParams = []serviceRegistrationFormTestStruct{ + {"ServiceRecord_v1", false, "Good case, everything works"}, + {"Wrong version", true, "Bad case, the wrong version string is sent in"}, +} + +func TestServiceRegistrationForm(t *testing.T) { + for _, testCase := range serviceRegistrationFormTestParams { + testSys := createTestSystem(false) + mua := testSys.UAssets["testUnitAsset"] + serv := (*testSys.UAssets["testUnitAsset"]).GetServices()["test"] + + payload, err := serviceRegistrationForm(&testSys, mua, serv, testCase.version) + if (testCase.expectedErr == true && err == nil) || (testCase.expectedErr == false && err != nil) { + t.Errorf("In test case: %s: Expected %t error, got: %v", testCase.testName, testCase.expectedErr, err) + } + + if testCase.expectedErr == false { + var sr forms.ServiceRecord_v1 + if err = json.Unmarshal(payload, &sr); err != nil { + t.Fatalf("Invalid JSON: %v", err) + } + + // Check that the ServiceNode is created correctly + expectedNode := testSys.Host.Name + "_" + testSys.Name + "_" + + (*testSys.UAssets["testUnitAsset"]).GetName() + "_" + + (*testSys.UAssets["testUnitAsset"]).GetServices()["test"].Definition + if sr.ServiceNode != expectedNode { + t.Errorf("Expected ServiceNode %q, got: %q", expectedNode, sr.ServiceNode) + } + + // Check that the ProtoPorts that are equal to 0 gets removed + if len(sr.ProtoPort) != 1 { + t.Errorf("Expected: one proto port (excluding 0s), got: %v", sr.ProtoPort) + } + + // Check that the unit asset details exists and are ok + if v, ok := sr.Details["Test"]; !ok || len(v) != 1 { + t.Errorf("Missing or incorrect unit asset details. Expected: %v, got: %v", (*mua).GetDetails(), v) + } + + // Check that the service forms exists and are ok + if v, ok := sr.Details["Forms"]; !ok || len(v) != 1 { + t.Errorf("Missing or incorrect service forms. Expected: %v, got: %v", (*serv).Details, v) + } + } + } + + // Special case + // Check that when the Service RegPeriod equals 0, + // ServiceRegistrationForm defaults to its RegLife default value of 30 + testSys := createTestSystem(false) + mua := testSys.UAssets["testUnitAsset"] + serv := (*testSys.UAssets["testUnitAsset"]).GetServices()["test"] + (*testSys.UAssets["testUnitAsset"]).GetServices()["test"].RegPeriod = 0 + version := "ServiceRecord_v1" + payload, err := serviceRegistrationForm(&testSys, mua, serv, version) + if err != nil { + t.Fatalf("The Service Record version was wrong.") + } + var sr forms.ServiceRecord_v1 + if err := json.Unmarshal(payload, &sr); err != nil { + t.Fatalf("Invalid JSON: %v", err) + } + if sr.RegLife != 30 { + t.Errorf("Expected RegLife: 30, got: %d", sr.RegLife) + } +} + +type unregisterServiceTestStruct struct { + registrarUrl string + expectedErr bool + mockTransportErr int + errHTTP error + testName string +} + +var unregisterServiceTestParams = []unregisterServiceTestStruct{ + {"https://leadingregistrar", false, 0, nil, "Good case, an unregistered service tries to unregister"}, + {"https://leadingregistrar", false, 0, nil, "Good case, an registered service tries to unregister"}, + {"", false, 0, nil, "Good case, no leading registrar URL was sent in"}, + {"https://leadingregistrar", false, 1, errHTTP, "Bad case, empty error from response body"}, + {brokenUrl, true, 0, nil, "Bad case, broken URL"}, +} + +func TestUnregisterService(t *testing.T) { + for _, testCase := range unregisterServiceTestParams { + testSys := createTestSystem(false) + serv := (*testSys.UAssets["testUnitAsset"]).GetServices()["test"] + + newMockTransport(createWorkingHttpResp(), testCase.mockTransportErr, testCase.errHTTP) + err := unregisterService(testCase.registrarUrl, serv) + if (testCase.expectedErr == true && err == nil) || (testCase.expectedErr == false && err != nil) { + t.Errorf("In test case: %s: We expected %t error, got: %v", testCase.testName, testCase.expectedErr, err) + } + } +} + +type registerServiceTestStruct struct { + registrarUrl string + contentType string + mockServID int + correctTime bool + brokenBody bool + expectedErr bool + mockTransportErr int + errHTTP error + testName string +} + +func createWorkingRegisterServiceBody(mockSys components.System, mua *components.UnitAsset, serv *components.Service, + correctTime bool, contentType string, brokenBody bool) func() *http.Response { + + payload, err := serviceRegistrationForm(&mockSys, mua, serv, "ServiceRecord_v1") + if err != nil { + log.Fatalf("The service Record version was wrong") + } + var sr forms.ServiceRecord_v1 + if err = json.Unmarshal(payload, &sr); err != nil { + log.Fatalf("Invalid JSON: %v", err) + } + if correctTime == true { + sr.EndOfValidity = time.Now().Format(time.RFC3339) + } else { + sr.EndOfValidity = "" + } + + fakebody, err := json.Marshal(sr) + if err != nil { + log.Fatalf("Fail marshal at start of test: %v", err) + } + + if brokenBody == false { + respFunc := func() *http.Response { + return &http.Response{ + Status: "200 OK", + StatusCode: 200, + Header: http.Header{"Content-Type": []string{contentType}}, + Body: io.NopCloser(strings.NewReader(string(fakebody))), + } + } + return respFunc + } else { + respFunc := func() *http.Response { + return &http.Response{ + Status: "200 OK", + StatusCode: 200, + Header: http.Header{"Content-Type": []string{contentType}}, + Body: io.NopCloser(errorReader{}), + } + } + return respFunc + } +} + +func createMockSysMockUnitAssetandMockService(id int) (mockSys components.System, mua *components.UnitAsset, + mockServ *components.Service) { + mockSys = createTestSystem(false) + mua = mockSys.UAssets["testUnitAsset"] + mockServ = (*mockSys.UAssets["testUnitAsset"]).GetServices()["test"] + mockServ.ID = id + return +} + +var registerServiceTestParams = []registerServiceTestStruct{ + {"https://leadingregistrar", "application/json", 1, true, false, false, 0, nil, + "Good case, with PUT method"}, + {"https://leadingregistrar", "application/json", 0, true, false, false, 0, nil, + "Good case, with POST method"}, + {"https://leadingregistrar", "application/json", 1, true, false, true, 1, timeoutError{}, + "Bad case, timeout error"}, + {"https://leadingregistrar", "application/json", 1, true, false, true, 1, errHTTP, + "Bad case, error in defaultClint"}, + {"https://leadingregistrar", "application/json", 1, true, true, true, 0, nil, + "Bad case, error in ReadAll"}, + {"https://leadingregistrar", "", 1, true, false, true, 0, nil, + "Bad case, error in Unpack"}, + {"https://leadingregistrar", "application/json", 1, false, false, true, 0, nil, + "Bad case, error parsing time"}, + {"", "application/json", 1, true, false, false, 0, nil, + "Good case, no leading registrar URL sent in"}, + {brokenUrl, "application/json", 1, true, false, true, 0, nil, + "Bad case, broken URL with PUT method"}, + {brokenUrl, "application/json", 0, true, false, true, 0, nil, + "Bad case, broken URL with POST method"}, +} + +var delay = time.Duration(15) * time.Second + +func TestRegisterService(t *testing.T) { + for _, testCase := range registerServiceTestParams { + mockSys, mua, mockServ := createMockSysMockUnitAssetandMockService(testCase.mockServID) + respFunc := createWorkingRegisterServiceBody(mockSys, mua, mockServ, testCase.correctTime, + testCase.contentType, testCase.brokenBody) + newMockTransport(respFunc, testCase.mockTransportErr, testCase.errHTTP) + + test, err := registerService(&mockSys, testCase.registrarUrl, mua, mockServ) + + // Special case + if testCase.registrarUrl == "" { + if err != nil || test != delay { + t.Errorf("In test case: %s: Did we expect error? %t, got: %v and %d delay.", + testCase.testName, testCase.expectedErr, err, test) + } + continue + } + + if testCase.expectedErr == false { + if err != nil || test == delay { + t.Errorf("In test case: %s: Did we expect error? %t, got: %v and %d delay.", + testCase.testName, testCase.expectedErr, err, test) + } + } else { + if err == nil || test != delay { + t.Errorf("In test case: %s: Did we expect error? %t, got: %v and %d delay.", + testCase.testName, testCase.expectedErr, err, test) + } + } + } +} diff --git a/usecases/serversNhandlers.go b/usecases/servers_handlers.go similarity index 81% rename from usecases/serversNhandlers.go rename to usecases/servers_handlers.go index cba86df..13e557b 100644 --- a/usecases/serversNhandlers.go +++ b/usecases/servers_handlers.go @@ -22,6 +22,7 @@ package usecases import ( + "context" "crypto/ecdsa" "crypto/tls" "crypto/x509" @@ -37,15 +38,14 @@ import ( "github.com/sdoque/mbaigo/forms" ) -// SetoutServers setup the http and https servers and starts them -func SetoutServers(sys *components.System) (err error) { +// SetoutServers setups the http and https servers and starts them +func SetoutServers(sys *components.System) error { // get the servers port number (from configuration file) httpPort := sys.Husk.ProtoPort["http"] httpsPort := sys.Husk.ProtoPort["https"] if httpPort == 0 && httpsPort == 0 { - fmt.Printf("The system %s has no web server configured\n", sys.Name) - return + return fmt.Errorf("missing http(s) port in configuration") } // how to handle requests to the servers @@ -56,13 +56,13 @@ func SetoutServers(sys *components.System) (err error) { // Encode the ECDSA private key to PEM format privateKeyPEM, err := encodeECDSAPrivateKeyToPEM(sys.Husk.Pkey) if err != nil { - log.Fatalf("Failed to encode private key: %v", err) + return fmt.Errorf("encoding private key: %w", err) } // Load the certificate and key cert, err := tls.X509KeyPair([]byte(sys.Husk.Certificate), privateKeyPEM) if err != nil { - log.Fatalf("Failed to parse certificate or private key: %v", err) + return fmt.Errorf("parsing certificate/private key: %w", err) } caCertPool := x509.NewCertPool() @@ -73,32 +73,35 @@ func SetoutServers(sys *components.System) (err error) { Certificates: []tls.Certificate{cert}, ClientAuth: tls.RequireAndVerifyClientCert, ClientCAs: caCertPool, + MinVersion: tls.VersionTLS12, } // Create a HTTPS server with the TLS config httpsServer := &http.Server{ - Addr: ":" + strconv.Itoa(httpsPort), - TLSConfig: tlsConfig, - Handler: nil, + Addr: ":" + strconv.Itoa(httpsPort), + ReadTimeout: 30 * time.Second, + WriteTimeout: 60 * time.Second, + TLSConfig: tlsConfig, + Handler: nil, } // Initiate graceful shutdown on signal reception go func() { <-sys.Ctx.Done() - time.Sleep(1 * time.Second) // this line is for the leading service registrar to deregister its own services - fmt.Printf("Initiating graceful shutdown of the HTTPS server.\n") - httpsServer.Shutdown(sys.Ctx) + if err := httpsServer.Shutdown(context.Background()); err != nil { + log.Printf("Error during shutdown: %v", err) + } }() // Inform the user how to access the system's web server (black box documentation) httpsURL := "https://" + sys.Host.IPAddresses[0] + ":" + strconv.Itoa(httpsPort) + "/" + sys.Name - fmt.Printf("The system %s is up with its web server available at %s\n", sys.Name, httpsURL) + log.Printf("The system %s is up with its web server available at %s\n", sys.Name, httpsURL) // Start and monitor the server go func() { - err = httpsServer.ListenAndServeTLS("", "") + err := httpsServer.ListenAndServeTLS("", "") if err != nil && err != http.ErrServerClosed { - log.Fatalf("Listen: %s\n", err) + log.Fatalf("Error from web server: %v\n", err) } }() } @@ -107,27 +110,31 @@ func SetoutServers(sys *components.System) (err error) { if httpPort != 0 { // Create a HTTP server httpServer := &http.Server{ - Addr: ":" + strconv.Itoa(httpPort), - Handler: nil, + Addr: ":" + strconv.Itoa(httpPort), + ReadTimeout: 30 * time.Second, + WriteTimeout: 60 * time.Second, + Handler: nil, } // Initiate graceful shutdown on signal reception go func() { <-sys.Ctx.Done() - time.Sleep(1 * time.Second) // this line is for the leading service registrar to deregister its own services - fmt.Printf("Initiating graceful shutdown of the HTTP server.\n") - httpServer.Shutdown(sys.Ctx) + if err := httpServer.Shutdown(context.Background()); err != nil { + log.Printf("Error during shutdown: %v", err) + } }() // Inform the user how to access the system's web server (black box documentation) httpURL := "http://" + sys.Host.IPAddresses[0] + ":" + strconv.Itoa(httpPort) + "/" + sys.Name - fmt.Printf("The system %s is up with its web server available at %s\n", sys.Name, httpURL) + log.Printf("The system %s is up with its web server available at %s\n", sys.Name, httpURL) // Start and monitor the server - err = httpServer.ListenAndServe() - if err != nil && err != http.ErrServerClosed { - log.Fatalf("Listen: %s\n", err) - } + go func() { + err := httpServer.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + log.Fatalf("Error from web server: %v\n", err) + } + }() } return nil @@ -166,7 +173,7 @@ func ResourceHandler(sys *components.System, w http.ResponseWriter, r *http.Requ return } - resourceName := parts[2] + assetName := parts[2] servicePath := "" if len(parts) > 3 { servicePath = parts[3] @@ -180,9 +187,9 @@ func ResourceHandler(sys *components.System, w http.ResponseWriter, r *http.Requ case 3: handleThreeParts(w, r, parts[2], sys) case 4: - handleFourParts(w, r, resourceName, servicePath, sys) + handleFourParts(w, r, assetName, servicePath, sys) case 5: - handleFiveParts(w, r, resourceName, servicePath, record, sys) + handleFiveParts(w, r, assetName, servicePath, record, sys) default: http.Error(w, "Invalid request", http.StatusBadRequest) } @@ -199,6 +206,8 @@ func handleThreeParts(w http.ResponseWriter, r *http.Request, part string, sys * KGraphing(w, r, sys) case "cert": forms.Certificate(w, r, *sys) + case "msg": + RegisterMessenger(w, r, sys) default: http.Error(w, "Invalid request", http.StatusBadRequest) } @@ -234,6 +243,7 @@ func handleFiveParts(w http.ResponseWriter, r *http.Request, resourceName, servi uAsset := *Resource if servicePath == "files" { forms.TransferFile(w, r) + // return } switch record { @@ -270,6 +280,10 @@ func findServiceByPath(services map[string]*components.Service, path string) *co // findServiceByDefinition returns a service's pointer based on its definition func findServiceByDefinition(services map[string]*components.Service, definition string) *components.Service { - service := services[definition] - return service + for _, service := range services { + if service.Definition == definition { + return service + } + } + return nil } diff --git a/usecases/serviceDiscovery.go b/usecases/service_discovery.go similarity index 62% rename from usecases/serviceDiscovery.go rename to usecases/service_discovery.go index 6ab3921..42737b5 100644 --- a/usecases/serviceDiscovery.go +++ b/usecases/service_discovery.go @@ -20,15 +20,11 @@ package usecases import ( - "bytes" - "context" "encoding/json" - "errors" "fmt" "io" - "log" "net/http" - "time" + "strconv" "github.com/sdoque/mbaigo/components" "github.com/sdoque/mbaigo/forms" @@ -45,8 +41,6 @@ func FillQuestForm(sys *components.System, res components.UnitAsset, sDef, proto f.NewForm() f.RequesterName = sys.Name f.ServiceDefinition = sDef - // TODO: known bug on commit - // f.Protocol = append() f.Protocol = protocol f.Details = res.GetDetails() return f @@ -57,12 +51,12 @@ func ExtractQuestForm(bodyBytes []byte) (rec forms.ServiceQuest_v1, err error) { var jsonData map[string]interface{} err = json.Unmarshal(bodyBytes, &jsonData) if err != nil { - log.Printf("Error unmarshaling JSON data: %v", err) + err = fmt.Errorf("unmarshalling JSON data: %v", err) return } formVersion, ok := jsonData["version"].(string) if !ok { - log.Printf("Error: 'version' key not found in JSON data") + err = fmt.Errorf("'version' key not found in JSON data") return } @@ -71,62 +65,40 @@ func ExtractQuestForm(bodyBytes []byte) (rec forms.ServiceQuest_v1, err error) { var f forms.ServiceQuest_v1 err = json.Unmarshal(bodyBytes, &f) if err != nil { - log.Println("Unable to extract the discovery form request ") + err = fmt.Errorf("unable to extract the discovery form request ") return } rec = f default: - err = errors.New("unsupported service registration form version") + err = fmt.Errorf("unsupported service registration form version") } return } // Search4Service requests from the core systems the address of resources's services that meet the need func Search4Service(qf forms.ServiceQuest_v1, sys *components.System) (servLocation forms.ServicePoint_v1, err error) { - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) // Create a new context, with a 2-second timeout - defer cancel() // Create a new HTTP request to the Orchestrator system (for now the Service Registrar) - var orchestratorPointer *components.CoreSystem - for _, cSys := range sys.CoreS { - if cSys.Name == "orchestrator" { - orchestratorPointer = cSys - } + orURL, err := components.GetRunningCoreSystemURL(sys, "orchestrator") + if err != nil { + return } - // prepare the payload to perform a service quest - oURL := orchestratorPointer.Url + "/squest" + orURL = orURL + "/squest" jsonQF, err := json.MarshalIndent(qf, "", " ") if err != nil { - log.Printf("problem encountered when marshalling the service quest to the Orchestrator at %s\n", oURL) - return servLocation, err - } - // prepare the request - req, err := http.NewRequest(http.MethodPost, oURL, bytes.NewBuffer(jsonQF)) - if err != nil { - return servLocation, err + return } - req.Header.Set("Content-Type", "application/json") // set the Content-Type header - req = req.WithContext(ctx) // associate the cancellable context with the request - - // Send the request ///////////////////////////////// - client := &http.Client{} - resp, err := client.Do(req) + resp, err := sendHTTPReq(http.MethodPost, orURL, jsonQF) if err != nil { - return servLocation, err + return } - defer resp.Body.Close() // Read the response ///////////////////////////////// body, err := io.ReadAll(resp.Body) if err != nil { - return servLocation, err - } - servLocation, err = ExtractDiscoveryForm(body) - if err != nil { - return servLocation, err + return } - return servLocation, err + return ExtractDiscoveryForm(body) } // Search4Services requests from the core systems the address of resources' services that meet the need @@ -140,71 +112,104 @@ func Search4Services(cer *components.Cervice, sys *components.System) (err error Details: cer.Details, Version: "ServiceQuest_v1", } - //pack the service quest form qf, err := Pack(&questForm, "application/json") if err != nil { return err } - // Search for an Orchestrator system within the local cloud - var orchestratorPointer *components.CoreSystem - for _, cSys := range sys.CoreS { - if cSys.Name == "orchestrator" { - orchestratorPointer = cSys - } + orURL, err := components.GetRunningCoreSystemURL(sys, "orchestrator") + if err != nil { + return err } - if orchestratorPointer == nil { - err = errors.New("failed to locate an Orchestrator") + if orURL == "" { + return fmt.Errorf("failed to locate an orchestrator") + } + orURL = orURL + "/squest" + // Prepare the request to the orchestrator + resp, err := sendHTTPReq(http.MethodPost, orURL, qf) + if err != nil { return err } - oURL := orchestratorPointer.Url + "/squest" - - // Prepare the request to the Orchestrator - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) // Create a new context, with a 2-second timeout - defer cancel() - req, err := http.NewRequest(http.MethodPost, oURL, bytes.NewBuffer(qf)) + defer resp.Body.Close() + // Read the response ///////////////////////////////// + bodyBytes, err := io.ReadAll(resp.Body) if err != nil { return err } - req.Header.Set("Content-Type", "application/json") // set the Content-Type header - req = req.WithContext(ctx) // associate the cancellable context with the request - - // Send the request to the Orchestrator ///////////////////////////////// - client := &http.Client{} - resp, err := client.Do(req) + headerContentType := resp.Header.Get("Content-Type") + discoveryForm, err := Unpack(bodyBytes, headerContentType) if err != nil { return err } - - defer resp.Body.Close() - - // Check if the status code indicates an error (anything outside the 200–299 range) - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return fmt.Errorf("received non-2xx status code: %d, response: %s from the Orchestrator", resp.StatusCode, http.StatusText(resp.StatusCode)) + // Perform a type assertion to convert the returned Form to ServicePoint_v1 + df, ok := discoveryForm.(*forms.ServicePoint_v1) + if !ok { + return fmt.Errorf("unable to unpack discovery request form") } + cer.Nodes[df.ServNode] = append(cer.Nodes[df.ServNode], df.ServLocation) + return nil +} +func Search4MultipleServices(cer *components.Cervice, sys *components.System) (err error) { + questForm := forms.ServiceQuest_v1{ + SysId: 0, + RequesterName: sys.Name, + ServiceDefinition: cer.Definition, + Protocol: "http", + Details: cer.Details, + Version: "ServiceQuest_v1", + } + // Pack the service quest form + qf, err := Pack(&questForm, "application/json") + if err != nil { + return err + } + // Search for an Orchestrator system within the local cloud + orURL, err := components.GetRunningCoreSystemURL(sys, "orchestrator") + if err != nil { + return err + } + if orURL == "" { + return fmt.Errorf("failed to locate an orchestrator") + } + orURL = orURL + "/squests" + // Prepare the request to the orchestrator + resp, err := sendHTTPReq(http.MethodPost, orURL, qf) + if err != nil { + return err + } + defer resp.Body.Close() // Read the response ///////////////////////////////// bodyBytes, err := io.ReadAll(resp.Body) if err != nil { return err } - - headerContentTtype := resp.Header.Get("Content-Type") - discoveryForm, err := Unpack(bodyBytes, headerContentTtype) + headerContentType := resp.Header.Get("Content-Type") + discoveryForm, err := Unpack(bodyBytes, headerContentType) if err != nil { - log.Printf("error extracting the discovery request %v\n", err) + return err } - - // Perform a type assertion to convert the returned Form to ServicePoint_v1 - df, ok := discoveryForm.(*forms.ServicePoint_v1) + srList, ok := discoveryForm.(*forms.ServiceRecordList_v1) if !ok { - fmt.Println("Problem unpacking the service discovery request form") - return + return fmt.Errorf("unable to unpack discovery request form") } + for _, values := range srList.List { + sp := convertToServicePoint(values) + cer.Nodes[sp.ServNode] = append(cer.Nodes[sp.ServNode], sp.ServLocation) + } + return nil +} - cer.Nodes[df.ServNode] = append(cer.Nodes[df.ServNode], df.ServLocation) - return err +func convertToServicePoint(sr forms.ServiceRecord_v1) (sp forms.ServicePoint_v1) { + rec := sr + sp.NewForm() + sp.ProviderName = rec.SystemName + sp.ServiceDefinition = rec.ServiceDefinition + sp.Details = rec.Details + sp.ServLocation = "http://" + rec.IPAddresses[0] + ":" + strconv.Itoa(rec.ProtoPort["http"]) + "/" + rec.SystemName + "/" + rec.SubPath + sp.ServNode = rec.ServiceNode + return } // FillDiscoveredServices returns a json data byte array with a slice of matching services (e.g., Service Registrar) @@ -218,7 +223,7 @@ func FillDiscoveredServices(dsList []forms.ServiceRecord_v1, version string) (f dslForm.List = append(dslForm.List, *sf) } default: - err = errors.New("unsupported service registration form version") + err = fmt.Errorf("unsupported service registration form version") return } return @@ -229,12 +234,12 @@ func ExtractDiscoveryForm(bodyBytes []byte) (sLoc forms.ServicePoint_v1, err err var jsonData map[string]interface{} err = json.Unmarshal(bodyBytes, &jsonData) if err != nil { - log.Printf("Error unmarshaling JSON data: %v", err) + err = fmt.Errorf("unmarshalling JSON data: %v", err) return } formVersion, ok := jsonData["version"].(string) if !ok { - log.Printf("Error: 'version' key not found in JSON data") + err = fmt.Errorf("'version' key not found in JSON data") return } switch formVersion { @@ -243,12 +248,12 @@ func ExtractDiscoveryForm(bodyBytes []byte) (sLoc forms.ServicePoint_v1, err err f.NewForm() err = json.Unmarshal(bodyBytes, &f) if err != nil { - log.Println("Unable to extract registration request ") + err = fmt.Errorf("unmarshalling JSON data: %v", err) return } sLoc = f default: - err = errors.New("unsupported service discovery form version") + err = fmt.Errorf("unsupported service discovery form version") } return } diff --git a/usecases/service_discovery_test.go b/usecases/service_discovery_test.go new file mode 100644 index 0000000..7dc0822 --- /dev/null +++ b/usecases/service_discovery_test.go @@ -0,0 +1,606 @@ +package usecases + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "strings" + "testing" + + "github.com/sdoque/mbaigo/components" + "github.com/sdoque/mbaigo/forms" +) + +type testBodyHasProtocol struct { + Version string `json:"version"` + Protocol int `json:"protocol"` +} + +type testBodyHasVersion struct { + Version string `json:"version"` +} + +type testBodyNoVersion struct{} + +func createTestBodyHasProtocol(proto int, version string, errRead bool) ([]byte, error) { + if errRead == true { + return json.Marshal(errReader(0)) + } + body := testBodyHasProtocol{ + Protocol: proto, + Version: version, + } + return json.Marshal(body) +} + +func createTestBodyHasVersion(proto int, version string, errRead bool) ([]byte, error) { + if errRead == true { + return json.Marshal(errReader(0)) + } + body := testBodyHasVersion{ + Version: version, + } + return json.Marshal(body) +} + +func createTestBodyHasNoVersion(proto int, version string, errRead bool) ([]byte, error) { + if errRead == true { + return json.Marshal(errReader(0)) + } + body := testBodyNoVersion{} + return json.Marshal(body) +} + +type extractQuestFormParams struct { + expectedError bool + errRead bool + proto int + version string + f func(int, string, bool) ([]byte, error) + testCase string +} + +func TestExtractQuestForm(t *testing.T) { + testParams := []extractQuestFormParams{ + {false, false, -1, "ServiceQuest_v1", createTestBodyHasVersion, "No errors"}, + {true, true, -1, "ServiceQuest_v1", createTestBodyHasVersion, "Error during Unmarshal"}, + {true, false, -1, "", createTestBodyHasNoVersion, "Missing version"}, + {true, false, 123, "ServiceQuest_v1", createTestBodyHasProtocol, "Error while writing to correct form"}, + {true, false, -1, "", createTestBodyHasVersion, "Error Unsupported version"}, + } + for _, x := range testParams { + data, err := x.f(x.proto, x.version, x.errRead) + if err != nil { + t.Errorf("---\tError occurred while creating test data") + } + // Do the test + _, err = ExtractQuestForm(data) + if x.expectedError == false && err != nil { + t.Errorf("Expected no errors in '%s', got: %v ", x.testCase, err) + } + if x.expectedError == true && err == nil { + t.Errorf("Expected errors in '%s'", x.testCase) + + } + } +} + +// Creates a ServicePoint_v1 form with test values +func createServicePointTestForm() forms.ServicePoint_v1 { + var f forms.ServicePoint_v1 + f.NewForm() + f.Version = "ServicePoint_v1" + f.ServLocation = "TestService" + f.ServiceDefinition = "TestService" + f.Details = map[string][]string{ + "Details": {"detail_1", "detail_2"}, + } + return f +} + +type sendHttpReqParams struct { + testCase string + method string + url string + data []byte + ctx context.Context + respError bool + expectError bool +} + +func testSystemSetup() (resp func() *http.Response, data []byte, ctx context.Context, err error) { + ctx = context.Background() + var form forms.ServiceQuest_v1 + form.NewForm() + resp = func() *http.Response { + return &http.Response{ + Status: "200 OK", + StatusCode: 200, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(string("test body"))), + } + } + data, err = json.MarshalIndent(form, "", " ") + if err != nil { + return nil, nil, ctx, errors.New("---\tError occurred while marshalling in test system setup") + } + return +} + +func TestSendHttpReq(t *testing.T) { + resp, data, ctx, err := testSystemSetup() + newMockTransport(resp, 0, nil) + if err != nil { + t.Errorf("Error occurred while starting test system: %e", err) + } + params := []sendHttpReqParams{ + // {testCase, method, url, data, ctx, respError, expectError} + {"No errors", http.MethodPost, "http://test", data, ctx, false, false}, + {"Error creating new request", http.MethodPost, brokenUrl, data, ctx, false, true}, + {"DefaultClient returns error", http.MethodPost, "http://test", data, ctx, true, true}, + } + var lastLoopErr bool + for _, c := range params { + // Make sure the the mockTransport doesn't return an error unless needed by the test + if (lastLoopErr == true) && (c.respError == false) { + newMockTransport(resp, 0, nil) + lastLoopErr = false + } + // Make a new mockTransport with an error response if the test needs it + if (lastLoopErr == false) && (c.respError == true) { + newMockTransport(resp, 1, errHTTP) + lastLoopErr = true + } + // Run the test + _, err = sendHTTPReq(c.method, c.url, c.data) + if c.expectError == false && err != nil { + t.Errorf("Unexpected error in '%s' test case: %e", c.testCase, err) + } + if c.expectError == true && err == nil { + t.Errorf("Expected error in '%s' test case, got none", c.testCase) + } + } +} + +// --------------------------------------------------------- // +// Helper functions and structs for testing Search4Service() +// --------------------------------------------------------- // + +type search4ServiceParams struct { + expectError bool + response func() *http.Response + transport func(func() *http.Response) *mockTransport + testCase string +} + +// This function returns different http responses depending on the number of times it's read +// allowedReads takes a positive number, and will count back from that number until it reaches 0, then return a +// http.Response with errReader() in body, given 0 or negative number it'll always return a functioning http.Response +func createMultiHttpResp(statusCode int, broken bool, allowedReads int) func() *http.Response { + f := createServicePointTestForm() + // Create mock response from orchestrator + fakeBody, err := json.Marshal(f) + if err != nil { + log.Println("Fail Marshal at start of test") + } + count := allowedReads + return func() *http.Response { + count-- + if broken == true && count == 0 { + return &http.Response{ + StatusCode: 200, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: errReader(0), + } + } + return &http.Response{ + StatusCode: statusCode, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(string(fakeBody))), + } + } +} + +func TestSearch4Service(t *testing.T) { + // Test parameters + params := []search4ServiceParams{ + { + false, + createMultiHttpResp(200, false, 0), + func(resp func() *http.Response) *mockTransport { return newMockTransport(resp, 0, nil) }, + "Best case", + }, + { + true, + createMultiHttpResp(200, false, 0), + func(resp func() *http.Response) *mockTransport { return newMockTransport(resp, 1, errHTTP) }, + "Bad case, error getting core system url", + }, + { + true, + createMultiHttpResp(200, false, 0), + func(resp func() *http.Response) *mockTransport { return newMockTransport(resp, 1, errHTTP) }, + "Bad case, error sending http request", + }, + { + true, + createMultiHttpResp(200, true, 1), + func(resp func() *http.Response) *mockTransport { return newMockTransport(resp, 0, nil) }, + "Bad case, error reading response body", + }, + } + testSys := createTestSystem(false) + for _, c := range params { + // Setup + c.transport(c.response) + var qForm forms.ServiceQuest_v1 + qForm.NewForm() + + // Test + _, err := Search4Service(qForm, &testSys) + if c.expectError == false && err != nil { + t.Errorf("Expected no errors in testcase '%s', got: %v", c.testCase, err) + } + if c.expectError == true && err == nil { + t.Errorf("Expected errors in testcase '%s'", c.testCase) + } + } +} + +// --------------------------------------------------------- // +// Helper functions and structs for testing Search4Services() +// --------------------------------------------------------- // + +type search4ServicesParams struct { + expectError bool + setup func() (*components.Cervice, components.System) + response func() *http.Response + transport func(func() *http.Response) *mockTransport + testCase string +} + +func TestSearch4Services(t *testing.T) { + params := []search4ServicesParams{ + { + false, + func() (cer *components.Cervice, sys components.System) { + sys = createTestSystem(false) + cer = (*sys.UAssets["testUnitAsset"]).GetCervices()["testCerv"] + return + }, + createMultiHttpResp(200, false, 0), + func(resp func() *http.Response) *mockTransport { return newMockTransport(resp, 0, nil) }, + "Best case, no errors", + }, + { + true, + func() (cer *components.Cervice, sys components.System) { + sys = createTestSystem(false) + cer = (*sys.UAssets["testUnitAsset"]).GetCervices()["testCerv"] + return + }, + createMultiHttpResp(200, false, 0), + func(resp func() *http.Response) *mockTransport { return newMockTransport(resp, 1, errHTTP) }, + "Bad case, GetRunningCoreSystemURL() returns error", + }, + { + true, + func() (cer *components.Cervice, sys components.System) { + sys = createTestSystem(false) + for i, cs := range sys.CoreS { + if cs.Name == "orchestrator" { + (*sys.CoreS[i]).Url = "" + } + } + cer = (*sys.UAssets["testUnitAsset"]).GetCervices()["testCerv"] + return + }, + createMultiHttpResp(200, false, 0), + func(resp func() *http.Response) *mockTransport { return newMockTransport(resp, 0, nil) }, + "Bad case, Orchestrator url is empty", + }, + { + true, + func() (cer *components.Cervice, sys components.System) { + sys = createTestSystem(false) + cer = (*sys.UAssets["testUnitAsset"]).GetCervices()["testCerv"] + return + }, + createMultiHttpResp(200, false, 0), + func(resp func() *http.Response) *mockTransport { return newMockTransport(resp, 1, errHTTP) }, + "Bad case, sendHttpReq() returns an error", + }, + { + true, + func() (cer *components.Cervice, sys components.System) { + sys = createTestSystem(false) + cer = (*sys.UAssets["testUnitAsset"]).GetCervices()["testCerv"] + return + }, + createMultiHttpResp(200, true, 1), + func(resp func() *http.Response) *mockTransport { return newMockTransport(resp, 0, nil) }, + "Bad case, error while reading body", + }, + { + true, + func() (cer *components.Cervice, sys components.System) { + sys = createTestSystem(false) + cer = (*sys.UAssets["testUnitAsset"]).GetCervices()["testCerv"] + return + }, + func() *http.Response { + return &http.Response{ + Status: "200 OK", + StatusCode: 200, + Header: http.Header{"Content-Type": []string{"Error"}}, + Body: io.NopCloser(strings.NewReader(string(""))), + } + }, + func(resp func() *http.Response) *mockTransport { return newMockTransport(resp, 0, nil) }, + "Bad case, error during Unpack", + }, + { + true, + func() (cer *components.Cervice, sys components.System) { + sys = createTestSystem(false) + cer = (*sys.UAssets["testUnitAsset"]).GetCervices()["testCerv"] + return + }, + func() *http.Response { + return &http.Response{ + Status: "200 OK", + StatusCode: 200, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(string(`{"version":"SignalA_v1.0"}`))), + } + }, + func(resp func() *http.Response) *mockTransport { return newMockTransport(resp, 0, nil) }, + "Bad case, error during type conversion", + }, + } + + for _, c := range params { + // Setup + c.transport(c.response) + cer, sys := c.setup() + + // Test + err := Search4Services(cer, &sys) + if (c.expectError == false) && (err != nil) { + t.Errorf("Expected no errors in '%s', got: %v", c.testCase, err) + } + if (c.expectError == true) && (err == nil) { + t.Errorf("Expected errors in '%s'", c.testCase) + } + } +} + +func createTestServiceRecord(number int) (f forms.ServiceRecord_v1) { + f.Id = number + f.ServiceDefinition = fmt.Sprintf("testDefinition%d", number) + f.SystemName = fmt.Sprintf("testSystem%d", number) + f.ServiceNode = fmt.Sprintf("test%d", number) + f.IPAddresses = []string{fmt.Sprintf("test%d", number), fmt.Sprintf("test%d", number+1)} + f.ProtoPort = map[string]int{"test": 1} + f.Details = map[string][]string{"Details": {fmt.Sprintf("Detail%d", number), fmt.Sprintf("Detail%d", number+1)}} + f.Certificate = fmt.Sprintf("Certificate%d", number) + f.SubPath = fmt.Sprintf("Subpath%d", number) + f.RegLife = number + f.Version = "ServiceRecord_v1" + f.Created = fmt.Sprintf("Created%d", number) + f.Updated = fmt.Sprintf("Updated%d", number) + f.EndOfValidity = fmt.Sprintf("EoV%d", number) + f.SubscribeAble = true + f.ACost = float64(number) + f.CUnit = fmt.Sprintf("CUnit%d", number) + return +} + +// FillDiscoveredServices(dsList []forms.ServiceRecord_v1, version string) (f forms.Form, err error) +func TestFillDiscoveredServices(t *testing.T) { + // Create a bunch of service records contained in a list + dsList := []forms.ServiceRecord_v1{} + for i := range 10 { + record := createTestServiceRecord(i) + dsList = append(dsList, record) + } + versionList := []string{"ServiceRecordList_v1", "default"} + for _, version := range versionList { + _, err := FillDiscoveredServices(dsList, version) + if version != "ServiceRecordList_v1" && err == nil { + t.Errorf("Expected error in default case") + } + if version == "ServiceRecordList_v1" && err != nil { + t.Errorf("Unexpected error during testing: %v", err) + } + } +} + +// --------------------------------------------------------------- // +// Helper functions and structs for testing ExtractDiscoveryForm() +// --------------------------------------------------------------- // + +type extractDiscoveryFormParams struct { + expectError bool + data func() any + testCase string +} + +// ExtractDiscoveryForm(bodyBytes []byte) (sLoc forms.ServicePoint_v1, err error) +func TestExtractDiscoveryForm(t *testing.T) { + params := []extractDiscoveryFormParams{ + { + false, + func() any { return createServicePointTestForm() }, + "Best case", + }, + { + true, + func() any { + return "" + }, + "Bad case, Unmarshal breaks", + }, + { + true, + func() any { + form := createServicePointTestForm() + form.Version = "" + return form + }, + "Bad case, wrong form version", + }, + { + true, + func() any { return nil }, + "Bad case, version key missing", + }, + { + true, + func() any { + wrongForm := make(map[string]any) + wrongForm["version"] = "ServicePoint_v1" + wrongForm["serviceId"] = false // Target field is an int + return wrongForm + }, + "Bad case, can't unmarshal to ServicePoint_v1 (field type mismatch)", + }, + } + + for _, c := range params { + // Setup + data, err := json.Marshal(c.data()) + if err != nil { + t.Errorf("couldn't marshal data in '%s'", c.testCase) + } + // Test + _, err = ExtractDiscoveryForm(data) + if (c.expectError == false) && (err != nil) { + t.Errorf("Expected no errors in '%s', got: %v", c.testCase, err) + } + if (c.expectError == true) && (err == nil) { + t.Errorf("Expected errors in '%s'", c.testCase) + } + } +} + +func createServiceRecordListTestForm() forms.ServiceRecordList_v1 { + var f forms.ServiceRecordList_v1 + f.NewForm() + f.List = make([]forms.ServiceRecord_v1, 1) + f.List[0].IPAddresses = []string{"123.456.789"} + f.List[0].ProtoPort = map[string]int{"http": 123} + return f +} + +func createMultiHttpRespWithServRecList(statusCode int, broken bool, allowedReads int) func() *http.Response { + f := createServiceRecordListTestForm() + // Create mock response from orchestrator + fakeBody, err := json.Marshal(f) + if err != nil { + log.Println("Fail Marshal at start of test") + } + count := allowedReads + return func() *http.Response { + count-- + if broken == true && count == 0 { + return &http.Response{ + StatusCode: 200, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: errReader(0), + } + } + return &http.Response{ + StatusCode: statusCode, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(string(fakeBody))), + } + } +} + +func createUnpackErrorBody() func() *http.Response { + return func() *http.Response { + return &http.Response{ + Status: "200 OK", + StatusCode: 200, + Header: http.Header{"Content-Type": []string{"Error"}}, + Body: io.NopCloser(strings.NewReader(string(""))), + } + } +} + +func createTypeConversionErrorBody() func() *http.Response { + return func() *http.Response { + return &http.Response{ + Status: "200 OK", + StatusCode: 200, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(string(`{"version":"SignalA_v1.0"}`))), + } + } +} + +type search4MultipleServicesStruct struct { + expectError bool + emptyUrl bool + response func() *http.Response + mockTransportErr int + errHTTP error + testName string +} + +var search4MultipleServicesParams = []search4MultipleServicesStruct{ + {false, false, createMultiHttpRespWithServRecList(200, false, 0), 0, nil, + "Best case, no errors", + }, + {true, false, createMultiHttpRespWithServRecList(200, false, 0), 1, errHTTP, + "Bad case, GetRunningCoreSystemURL() returns error", + }, + {true, true, createMultiHttpRespWithServRecList(200, false, 0), 0, nil, + "Bad case, Orchestrator url is empty", + }, + {true, false, createMultiHttpRespWithServRecList(200, false, 0), 1, errHTTP, + "Bad case, sendHttpReq() returns an error", + }, + {true, false, createMultiHttpRespWithServRecList(200, true, 1), 0, nil, + "Bad case, error while reading body", + }, + {true, false, createUnpackErrorBody(), 0, nil, + "Bad case, error during Unpack", + }, + {true, false, createTypeConversionErrorBody(), 0, nil, + "Bad case, error during type conversion", + }, +} + +func TestSearch4MultipleServices(t *testing.T) { + + for _, testCase := range search4MultipleServicesParams { + // Setup + testSys := createTestSystem(false) + testCer := (*testSys.UAssets["testUnitAsset"]).GetCervices()["testCerv"] + + if testCase.emptyUrl == true { + for i, cs := range testSys.CoreS { + if cs.Name == "orchestrator" { + (*testSys.CoreS[i]).Url = "" + } + } + } + + newMockTransport(testCase.response, testCase.mockTransportErr, testCase.errHTTP) + + // Test + err := Search4MultipleServices(testCer, &testSys) + if (testCase.expectError == false) && (err != nil) { + t.Errorf("Expected no errors in '%s', got: %v", testCase.testName, err) + } + if (testCase.expectError == true) && (err == nil) { + t.Errorf("Expected errors in '%s'", testCase.testName) + } + } +} diff --git a/usecases/packing.go b/usecases/utilities.go similarity index 53% rename from usecases/packing.go rename to usecases/utilities.go index 6990274..91ac5a3 100644 --- a/usecases/packing.go +++ b/usecases/utilities.go @@ -23,16 +23,17 @@ import ( "bytes" "encoding/json" "encoding/xml" - "errors" "fmt" - "log" + "net/http" "reflect" "strings" + "time" + "unicode" "github.com/sdoque/mbaigo/forms" ) -// Pack serializes a form to a byte array for payolad shipment with serializaton format (sf) request +// Pack serializes a form to a byte array for payload shipment with serialization format (sf) request func Pack(f forms.Form, contentType string) (data []byte, err error) { switch contentType { case "application/xml": @@ -61,16 +62,14 @@ func Unpack(data []byte, contentType string) (forms.Form, error) { if len(trimmed) > 0 { switch trimmed[0] { case '{', '[': - log.Println("Detected JSON in text/plain payload.") contentType = "application/json" case '<': - log.Println("Detected XML in text/plain payload.") contentType = "application/xml" default: - return nil, errors.New("plain text content is neither valid JSON nor XML") + return nil, fmt.Errorf("plain text content is neither valid JSON nor XML") } } else { - return nil, errors.New("empty payload with content type text/plain") + return nil, fmt.Errorf("empty payload with content type text/plain") } } @@ -78,28 +77,26 @@ func Unpack(data []byte, contentType string) (forms.Form, error) { switch { case strings.Contains(contentType, "application/json"): if err := json.Unmarshal(data, &rawData); err != nil { - log.Printf("Error unmarshaling JSON: %v", err) - return nil, err + return nil, fmt.Errorf("error unmarshalling JSON: %w", err) } case strings.Contains(contentType, "application/xml"): if err := xml.Unmarshal(data, &rawData); err != nil { - log.Printf("Error unmarshaling XML: %v", err) - return nil, err + return nil, fmt.Errorf("error unmarshalling XML: %w", err) } default: - return nil, errors.New("unsupported content type") + return nil, fmt.Errorf("unsupported content type") } // Retrieve form version formVersion, ok := rawData["version"].(string) if !ok { - return nil, errors.New("'version' key not found in data") + return nil, fmt.Errorf("'version' key not found in data") } // Look up the form type in the map formType, exists := forms.FormTypeMap[formVersion] if !exists { - return nil, errors.New("unsupported form version: " + formVersion) + return nil, fmt.Errorf("unsupported form version: %s", formVersion) } // Create a new instance of the form @@ -109,15 +106,91 @@ func Unpack(data []byte, contentType string) (forms.Form, error) { switch { case strings.Contains(contentType, "application/json"): if err := json.Unmarshal(data, formInstance); err != nil { - log.Printf("Error unmarshaling JSON into form: %v", err) - return nil, err + return nil, fmt.Errorf("error unmarshalling JSON into form: %w", err) } case strings.Contains(contentType, "application/xml"): if err := xml.Unmarshal(data, formInstance); err != nil { - log.Printf("Error unmarshaling XML into form: %v", err) - return nil, err + return nil, fmt.Errorf("error unmarshalling XML into form: %w", err) } } return formInstance, nil } + +// ------- Naming Conventions Tools ------- + +// ToCamel converts PascalCase to camelCase. +func ToCamel(s string) string { + if s == "" { + return s + } + runes := []rune(s) + runes[0] = unicode.ToLower(runes[0]) + return string(runes) +} + +// ToPascal converts camelCase to PascalCase. +func ToPascal(s string) string { + if s == "" { + return s + } + runes := []rune(s) + runes[0] = unicode.ToUpper(runes[0]) + return string(runes) +} + +// IsFirstLetterUpper returns true if the first rune is uppercase. +func IsFirstLetterUpper(s string) bool { + if s == "" { + return false + } + return unicode.IsUpper([]rune(s)[0]) +} + +// IsFirstLetterLower returns true if the first rune is lowercase. +func IsFirstLetterLower(s string) bool { + if s == "" { + return false + } + return unicode.IsLower([]rune(s)[0]) +} + +// IsPascalCase returns true if the string starts with an uppercase letter. +func IsPascalCase(s string) bool { + return IsFirstLetterUpper(s) +} + +// IsCamelCase returns true if the string starts with a lowercase letter. +func IsCamelCase(s string) bool { + return IsFirstLetterLower(s) +} + +// ------- HTTP Client Tools ------- + +func init() { + // Sets up a new global client with better defaults + // (the tests depends on this client too, and sometimes + // replaces it with a mock). + http.DefaultClient = &http.Client{ + Timeout: time.Second * 30, + } +} + +const userAgent string = "mbaigo" + +func sendHTTPReq(method string, url string, data []byte) (*http.Response, error) { + req, err := http.NewRequest(method, url, bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", userAgent) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("bad response: %d %s", resp.StatusCode, resp.Status) + } + return resp, nil +} diff --git a/usecases/utilities_test.go b/usecases/utilities_test.go new file mode 100644 index 0000000..4ceb703 --- /dev/null +++ b/usecases/utilities_test.go @@ -0,0 +1,265 @@ +package usecases + +import ( + "encoding/json" + "math" + "strings" + "testing" + + "github.com/sdoque/mbaigo/forms" +) + +type packParams struct { + contentType string + expectedError bool + form mockForm + expectedValue string + expectedVersion string + testCase string +} + +func TestPack(t *testing.T) { + params := []packParams{ + {"application/xml", false, mockForm{Value: 123, Version: "testVersion"}, "123", "testVersion", "Best case, xml"}, + {"application/json", false, mockForm{Value: 123, Version: "testVersion"}, `"value": 123`, `"version": "testVersion"`, "Best case, json"}, + {"application/xml", true, mockForm{Value: complex(1, 2), Version: "testVersion"}, "", "", "Bad case, xml"}, + {"application/json", true, mockForm{Value: complex(1, 2), Version: "testVersion"}, "", "", "Bad case, json"}, + } + for _, c := range params { + data, err := Pack(c.form, c.contentType) + if c.expectedError == false { + if err != nil { + t.Errorf("failed in testcase '%s' with error: %v", c.testCase, err) + } + if strings.Contains(string(data), c.expectedValue) != true { + t.Errorf("value missing or wrong in testcase '%s'", c.testCase) + } + if strings.Contains(string(data), c.expectedVersion) != true { + t.Errorf("version missing or wrong in testcase '%s'", c.testCase) + } + + } else { + if err == nil { + t.Errorf("expected error in testcase '%s', got none", c.testCase) + } + } + } +} + +// This covers the case of it having a version but is not present in the form type map +type testFormHasVersion struct { + Name string `json:"name"` + Version string `json:"version"` +} + +// This covers the case of it not having a version +type testFormNoVersion struct { + Name string `json:"name"` +} + +type unpackParams struct { + expectError bool + testCase string + contentType string + setup func() (data []byte, err error) +} + +func TestUnpack(t *testing.T) { + testParams := []unpackParams{ + //{expectError, testCase, contentType, setup()} + {false, "Best case, json", "text/plain", func() (data []byte, err error) { + var f forms.SignalA_v1a + f.NewForm() + data, err = json.Marshal(f) + return + }}, + // TODO: The following test can't be done because xml.Unmarshal() can't unmarshal to map[] + // fails with "error unmarshalling XML: unknown type map[string]interface {}" + /*{false, "Best case, xml", "text/plain", func() (data []byte, err error) { + var f forms.SignalA_v1a + f.NewForm() + data, err = xml.Marshal(f) + return + }},*/ + {true, "Bad case, not json/xml", "text/plain", func() (data []byte, err error) { return []byte("TEST123"), nil }}, + {true, "Bad case, empty []byte", "text/plain", func() (data []byte, err error) { return []byte(""), nil }}, + {true, "Bad case, unsupported content type", "unknown", func() (data []byte, err error) { return []byte("test"), nil }}, + {true, "Bad case, missing version", "application/json", func() (data []byte, err error) { + f := &testFormNoVersion{ + Name: "testName", + } + data, err = json.Marshal(f) + return data, err + }}, + {true, "Bad case, unsupported form version", "application/json", func() (data []byte, err error) { + f := &testFormHasVersion{ + Name: "testName", + Version: "testVersion", + } + data, err = json.Marshal(f) + return data, err + }}, + {true, "Bad case, broken unmarshal in json", "application/json", func() (data []byte, err error) { + data = append(data, byte(math.NaN())) + return data, err + }}, + {true, "Bad case, broken unmarshal in xml", "application/xml", func() (data []byte, err error) { + data = append(data, byte(math.NaN())) + return data, err + }}, + // TODO: Refactor code so we can do another test: currently can't reach second unmarshal for json to break it this way, moving on. + // TODO: Refactor code so we can do another test: currently can't reach second unmarshal for xml to break it this way, moving on. + } + + for _, c := range testParams { + // Setup + data, err := c.setup() + if err != nil { + t.Errorf("unexpected error in setup of testcase '%s': %v", c.testCase, err) + } + + // Test + _, err = Unpack(data, c.contentType) + if c.expectError != true { + if err != nil { + t.Errorf("error occurred in testcase '%s', got:\n %v", c.testCase, err) + } + } else { + if err == nil { + t.Errorf("expected errors in testcase '%s', got none", c.testCase) + } + } + } +} + +type toCamelParams struct { + expectedString string + testString string + testCase string +} + +func TestToCamel(t *testing.T) { + testParams := []toCamelParams{ + {"testString", "TestString", "Best case"}, + {"", "", "Empty string"}, + } + for _, c := range testParams { + generatedStr := ToCamel(c.testString) + if generatedStr != c.expectedString { + t.Errorf("expected both strings to be %s, generated string was: %s", c.expectedString, generatedStr) + } + } +} + +type toPascalParams struct { + expectedString string + testString string + testCase string +} + +func TestToPascal(t *testing.T) { + testParams := []toPascalParams{ + {"TestString", "testString", "Best case"}, + {"", "", "Empty string"}, + } + for _, c := range testParams { + generatedStr := ToPascal(c.testString) + if generatedStr != c.expectedString { + t.Errorf("expected both strings to be %s in testcase '%s', generated string was: %s", c.expectedString, c.testCase, generatedStr) + } + } +} + +type isFirstUpperParams struct { + expectedUpper bool + testString string + testCase string +} + +func TestIsFirstLetterUpper(t *testing.T) { + testParams := []isFirstUpperParams{ + {true, "FirstUpper", "First letter is uppercase"}, + {false, "firstUpper", "First letter is not uppercase"}, + {false, "", "Empty string"}, + } + for _, c := range testParams { + isUpper := IsFirstLetterUpper(c.testString) + if isUpper != c.expectedUpper { + if c.expectedUpper == true { + t.Errorf("expected first letter to be uppercase in testcase '%s'", c.testCase) + } else { + t.Errorf("expected first letter to be lowercase in testcase '%s'", c.testCase) + } + } + } +} + +type isFirstLowerParams struct { + expectedLower bool + testString string + testCase string +} + +func TestIsFirstLetterLower(t *testing.T) { + testParams := []isFirstLowerParams{ + {true, "firstLower", "First letter is lowercase"}, + {false, "FirstLower", "First letter is not lowercase"}, + {false, "", "Empty string"}, + } + for _, c := range testParams { + isLower := IsFirstLetterLower(c.testString) + if isLower != c.expectedLower { + if c.expectedLower == true { + t.Errorf("expected first letter to be lowercase in testcase '%s'", c.testCase) + } else { + t.Errorf("expected first letter to be uppercase in testcase '%s'", c.testCase) + } + } + } +} + +type isPascalCaseParams struct { + expectedPascal bool + testString string + testCase string +} + +func TestIsPascalCase(t *testing.T) { + testParams := []isPascalCaseParams{ + {true, "IsPascal", "Is Pascal"}, + {false, "isPascal", "Not Pascal"}, + } + for _, c := range testParams { + isPascal := IsPascalCase(c.testString) + if isPascal != c.expectedPascal { + if c.expectedPascal == true { + t.Errorf("expected first letter to be uppercase in testcase '%s'", c.testCase) + } else { + t.Errorf("expected first letter to be lowercase in testcase '%s'", c.testCase) + } + } + } +} + +type isCamelCaseParams struct { + expectedCamel bool + testString string + testCase string +} + +func TestICamelCase(t *testing.T) { + testParams := []isCamelCaseParams{ + {true, "isCamel", "Is Camel"}, + {false, "IsCamel", "Not Camel"}, + } + for _, c := range testParams { + isCamel := IsCamelCase(c.testString) + if isCamel != c.expectedCamel { + if c.expectedCamel == true { + t.Errorf("expected first letter to be lowercase in testcase '%s'", c.testCase) + } else { + t.Errorf("expected first letter to be uppercase in testcase '%s'", c.testCase) + } + } + } +}