diff --git a/examples/canonicalization_test.go b/examples/canonicalization_test.go index 2d20161..e7f7055 100644 --- a/examples/canonicalization_test.go +++ b/examples/canonicalization_test.go @@ -3,7 +3,7 @@ package examples import ( "testing" - "github.com/moov-io/signedxml" + "github.com/leifj/signedxml" . "github.com/smartystreets/goconvey/convey" ) diff --git a/examples/examples_validate.go b/examples/examples_validate.go index 22ec0bf..d7357d6 100644 --- a/examples/examples_validate.go +++ b/examples/examples_validate.go @@ -5,7 +5,7 @@ import ( "io" "os" - "github.com/moov-io/signedxml" + "github.com/leifj/signedxml" ) func ExampleValidate() { diff --git a/go.mod b/go.mod index 49a6396..d5d14a8 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/moov-io/signedxml +module github.com/leifj/signedxml go 1.21.0 diff --git a/go.sum b/go.sum index 00a9097..e5e223d 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/beevik/etree v1.5.0 h1:iaQZFSDS+3kYZiGoc9uKeOkUY3nYMXOKLl6KIJxiJWs= -github.com/beevik/etree v1.5.0/go.mod h1:gPNJNaBGVZ9AwsidazFZyygnd+0pAU38N4D+WemwKNs= github.com/beevik/etree v1.5.1 h1:TC3zyxYp+81wAmbsi8SWUpZCurbxa6S8RITYRSkNRwo= github.com/beevik/etree v1.5.1/go.mod h1:gPNJNaBGVZ9AwsidazFZyygnd+0pAU38N4D+WemwKNs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= diff --git a/signedxml.go b/signedxml.go index de193f2..5ad6932 100644 --- a/signedxml.go +++ b/signedxml.go @@ -8,6 +8,7 @@ import ( "encoding/pem" "errors" "fmt" + "log" "strings" "github.com/beevik/etree" @@ -220,6 +221,38 @@ func (s *signatureData) parseCanonAlgorithm() error { "CanonicalizationMethod") } +func findNs(in *etree.Element, ns map[string]string) { + ns[in.Space] = in.NamespaceURI() + for _, c := range in.ChildElements() { + findNs(c, ns) + } +} + +func findNamespaces(in *etree.Document) map[string]string { + var ns = make(map[string]string) + findNs(in.Root(), ns) + return ns +} + +func fixNs(e *etree.Element, ns map[string]string) { + if e.NamespaceURI() == "" && e.Space != "" { + if uri, ok := ns[e.Space]; ok { + e.CreateAttr(fmt.Sprintf("xmlns:%s", e.Space), uri) + } else { + log.Printf("signedxml: Missing namespace tag %s\n", e.Space) + } + } + + for _, c := range e.ChildElements() { + fixNs(c, ns) + } +} + +func fixNamespaces(in *etree.Document, out *etree.Document) { + ns := findNamespaces(in) + fixNs(out.Root(), ns) +} + func (s *signatureData) getReferencedXML(reference *etree.Element, inputDoc *etree.Document) (outputDoc *etree.Document, err error) { uri := reference.SelectAttrValue("URI", "") uri = strings.Replace(uri, "#", "", 1) @@ -251,6 +284,8 @@ func (s *signatureData) getReferencedXML(reference *etree.Element, inputDoc *etr return nil, errors.New("signedxml: unable to find refereced xml") } + fixNamespaces(inputDoc, outputDoc) + return outputDoc, nil } @@ -270,39 +305,46 @@ func getCertFromPEMString(pemString string) (*x509.Certificate, error) { return cert, err } +const ALL_TRANSFORMS string = "" + func processTransform(transform *etree.Element, - docIn *etree.Document) (docOut *etree.Document, err error) { + docIn *etree.Document, onlyIfContains string) (docOut *etree.Document, err error) { transformAlgoURI := transform.SelectAttrValue("Algorithm", "") if transformAlgoURI == "" { return nil, errors.New("signedxml: unable to find Algorithm in Transform") } - transformAlgo, ok := CanonicalizationAlgorithms[transformAlgoURI] - if !ok { - return nil, fmt.Errorf("signedxml: unable to find matching transform"+ - "algorithm for %s in CanonicalizationAlgorithms", transformAlgoURI) - } + if onlyIfContains == "" || strings.Contains(transformAlgoURI, onlyIfContains) { + + transformAlgo, ok := CanonicalizationAlgorithms[transformAlgoURI] + if !ok { + return nil, fmt.Errorf("signedxml: unable to find matching transform"+ + "algorithm for %s in CanonicalizationAlgorithms", transformAlgoURI) + } + + var transformContent string - var transformContent string + if transform.ChildElements() != nil { + tDoc := etree.NewDocument() + tDoc.SetRoot(transform.Copy()) + transformContent, err = tDoc.WriteToString() + if err != nil { + return nil, err + } + } - if transform.ChildElements() != nil { - tDoc := etree.NewDocument() - tDoc.SetRoot(transform.Copy()) - transformContent, err = tDoc.WriteToString() + docString, err := transformAlgo.ProcessDocument(docIn, transformContent) if err != nil { return nil, err } - } - docString, err := transformAlgo.ProcessDocument(docIn, transformContent) - if err != nil { - return nil, err + docOut = etree.NewDocument() + docOut.ReadFromString(docString) + } else { + docOut = docIn } - docOut = etree.NewDocument() - docOut.ReadFromString(docString) - return docOut, nil } diff --git a/signedxml_test.go b/signedxml_test.go index 46f86e9..6085f4a 100644 --- a/signedxml_test.go +++ b/signedxml_test.go @@ -128,7 +128,9 @@ func TestSign(t *testing.T) { So(len(refs), ShouldEqual, 1) }) Convey("And the signature should be valid, but validation fail if referenceIDAttribute NOT SET", func() { - validator, _ := NewValidator(xmlStr) + validator, err := NewValidator(xmlStr) + So(err, ShouldBeNil) + So(validator, ShouldNotBeNil) validator.Certificates = append(validator.Certificates, *cert) refs, err := validator.ValidateReferences() So(err, ShouldNotBeNil) @@ -157,9 +159,12 @@ func TestSign(t *testing.T) { signer, _ := NewSigner(string(xml)) signer.SetReferenceIDAttribute("Id") xmlStr, err := signer.Sign(key) + t.Logf("%#v", xmlStr) So(err, ShouldBeNil) + So(xmlStr, ShouldNotBeNil) - validator, _ := NewValidator(xmlStr) + validator, err := NewValidator(xmlStr) + So(err, ShouldBeNil) validator.SetReferenceIDAttribute("Id") validator.Certificates = append(validator.Certificates, *cert) refs, err := validator.ValidateReferences() @@ -249,7 +254,7 @@ func TestValidate(t *testing.T) { refs, err := validator.ValidateReferences() Convey("Then an error occurs", func() { So(err, ShouldNotBeNil) - So(err.Error(), ShouldContainSubstring, "signedxml:") + So(err.Error(), ShouldContainSubstring, "signedxml") t.Logf("%v - %d", description, len(refs)) So(len(refs), ShouldEqual, 0) }) diff --git a/signer.go b/signer.go index a8913db..2057b93 100644 --- a/signer.go +++ b/signer.go @@ -105,7 +105,7 @@ func (s *Signer) setDigest() (err error) { transforms := ref.SelectElement("Transforms") if transforms != nil { for _, transform := range transforms.SelectElements("Transform") { - doc, err = processTransform(transform, doc) + doc, err = processTransform(transform, doc, ALL_TRANSFORMS) if err != nil { return err } diff --git a/tests/issue55_test.go b/tests/issue55_test.go index eb4f15b..def2e81 100644 --- a/tests/issue55_test.go +++ b/tests/issue55_test.go @@ -5,7 +5,7 @@ import ( "path/filepath" "testing" - "github.com/moov-io/signedxml" + "github.com/leifj/signedxml" "github.com/stretchr/testify/require" ) @@ -30,6 +30,6 @@ func TestIssue55(t *testing.T) { validator.Certificates = append(validator.Certificates, *cert) refs, err := validator.ValidateReferences() - require.Contains(t, err.Error(), "signedxml: Calculated digest does not match the expected digestvalue of") + require.Contains(t, err.Error(), "does not match the expected digestvalue of") require.Len(t, refs, 0) } diff --git a/validator.go b/validator.go index 2ca9c22..be32551 100644 --- a/validator.go +++ b/validator.go @@ -111,34 +111,43 @@ func (v *Validator) validateReferences() (referenced []*etree.Document, err erro transforms := ref.SelectElement("Transforms") if transforms != nil { for _, transform := range transforms.SelectElements("Transform") { - doc, err = processTransform(transform, doc) + doc, err = processTransform(transform, doc, ALL_TRANSFORMS) if err != nil { return nil, err } } } + refUri := ref.SelectAttrValue("URI", "") doc, err = v.getReferencedXML(ref, doc) if err != nil { return nil, err } + if transforms != nil { + for _, transform := range transforms.SelectElements("Transform") { + doc, err = processTransform(transform, doc, "c14n") + if err != nil { + return nil, err + } + } + } + referenced = append(referenced, doc) digestValueElement := ref.SelectElement("DigestValue") if digestValueElement == nil { - return nil, errors.New("signedxml: unable to find DigestValue") + return nil, fmt.Errorf("signedxml [%s]: unable to find DigestValue", refUri) } digestValue := digestValueElement.Text() - calculatedValue, err := calculateHash(ref, doc) if err != nil { return nil, err } if calculatedValue != digestValue { - return nil, fmt.Errorf("signedxml: Calculated digest does not match the"+ - " expected digestvalue of %s", digestValue) + return nil, fmt.Errorf("signedxml [%s]: Calculated digest (%s) does not match the"+ + " expected digestvalue of %s", refUri, calculatedValue, digestValue) } } return referenced, nil