Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion kms/capi/capi.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ const (
HashArg = "sha1"
StoreLocationArg = "store-location" // 'machine', 'user', etc
StoreNameArg = "store" // 'MY', 'CA', 'ROOT', etc
FriendlyNameArg = "friendly-name"
DescriptionArg = "description"
KeyIDArg = "key-id"
SubjectCNArg = "cn"
SerialNumberArg = "serial"
Expand Down Expand Up @@ -315,6 +317,8 @@ func (k *CAPIKMS) getCertContext(req *apiv1.LoadCertificateRequest) (*windows.Ce
issuerName := u.Get(IssuerNameArg)
subjectCN := u.Get(SubjectCNArg)
serialNumber := u.Get(SerialNumberArg)
friendlyName := u.Get(FriendlyNameArg)
description := u.Get(DescriptionArg)

// default to the user store
var storeLocation string
Expand Down Expand Up @@ -390,7 +394,7 @@ func (k *CAPIKMS) getCertContext(req *apiv1.LoadCertificateRequest) (*windows.Ce
if handle == nil {
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%s not found", KeyIDArg, keyID)}
}
case issuerName != "" && (serialNumber != "" || subjectCN != ""):
case issuerName != "" && (serialNumber != "" || subjectCN != "" || friendlyName != "" || description != ""):
var prevCert *windows.CertContext
for {
handle, err = findCertificateInStore(st,
Expand Down Expand Up @@ -439,6 +443,24 @@ func (k *CAPIKMS) getCertContext(req *apiv1.LoadCertificateRequest) (*windows.Ce
if x509Cert.Subject.CommonName == subjectCN {
return handle, nil
}
case len(friendlyName) > 0:
val, err := cryptFindCertificateFriendlyName(handle)
if err != nil {
return nil, fmt.Errorf("cryptFindCertificateFriendlyName failed: %w", err)
}

if val == friendlyName {
return handle, nil
}
case len(description) > 0:
val, err := cryptFindCertificateDescription(handle)
if err != nil {
return nil, fmt.Errorf("cryptFindCertificateDescription failed: %w", err)
}

if val == description {
return handle, nil
}
}

prevCert = handle
Expand Down Expand Up @@ -748,6 +770,14 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
cryptFindCertificateKeyProvInfo(certContext)
}

if friendlyName := u.Get(FriendlyNameArg); friendlyName != "" {
cryptSetCertificateFriendlyName(certContext, friendlyName)
}

if description := u.Get(DescriptionArg); description != "" {
cryptSetCertificateDescription(certContext, description)
}

st, err := windows.CertOpenStore(
certStoreProvSystem,
0,
Expand Down
101 changes: 101 additions & 0 deletions kms/capi/ncrypt_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ const (
compareShift = 16 // CERT_COMPARE_SHIFT
compareSHA1Hash = 1 // CERT_COMPARE_SHA1_HASH
compareCertID = 16 // CERT_COMPARE_CERT_ID
compareProp = 5 // CERT_COMPARE_CERT_ID
findIssuerStr = compareNameStrW<<compareShift | infoIssuerFlag // CERT_FIND_ISSUER_STR_W
findIssuerName = compareName<<compareShift | infoIssuerFlag // CERT_FIND_ISSUER_NAME
findHash = compareSHA1Hash << compareShift // CERT_FIND_HASH
findProperty = compareProp << compareShift // CERT_FIND_PROPERTY
findCertID = compareCertID << compareShift // CERT_FIND_CERT_ID

signatureKeyUsage = 0x80 // CERT_DIGITAL_SIGNATURE_KEY_USAGE
Expand All @@ -83,6 +85,8 @@ const (
CERT_ID_SHA1_HASH = uint32(3)

CERT_KEY_PROV_INFO_PROP_ID = uint32(2)
CERT_FRIENDLY_NAME_PROP_ID = uint32(11)
CERT_DESCRIPTION_PROP_ID = uint32(13)

CERT_NAME_STR_COMMA_FLAG = uint32(0x04000000)
CERT_SIMPLE_NAME_STR = uint32(1)
Expand Down Expand Up @@ -152,6 +156,7 @@ var (
procCertFindCertificateInStore = crypt32.MustFindProc("CertFindCertificateInStore")
procCryptFindCertificateKeyProvInfo = crypt32.MustFindProc("CryptFindCertificateKeyProvInfo")
procCertGetCertificateContextProperty = crypt32.MustFindProc("CertGetCertificateContextProperty")
procCertSetCertificateContextProperty = crypt32.MustFindProc("CertSetCertificateContextProperty")
procCertStrToName = crypt32.MustFindProc("CertStrToNameW")
)

Expand Down Expand Up @@ -607,6 +612,102 @@ func cryptFindCertificateKeyContainerName(certContext *windows.CertContext) (str
return "", nil
}

func certSetCertificateContextProperty(certContext *windows.CertContext, propID uint32, pvData uintptr) error {
r0, _, err := procCertSetCertificateContextProperty.Call(
uintptr(unsafe.Pointer(certContext)),
uintptr(propID),
0,
pvData,
)

if r0 == 0 {
return err
}
return nil
}

func cryptSetCertificateFriendlyName(certContext *windows.CertContext, val string) error {
data := CRYPTOAPI_BLOB{
len: uint32(len(val)+1) * 2,
data: uintptr(unsafe.Pointer(wide(val))),
}

return certSetCertificateContextProperty(certContext, CERT_FRIENDLY_NAME_PROP_ID, uintptr(unsafe.Pointer(&data)))
}

func cryptSetCertificateDescription(certContext *windows.CertContext, val string) error {
data := CRYPTOAPI_BLOB{
len: uint32(len(val)+1) * 2,
data: uintptr(unsafe.Pointer(wide(val))),
}

return certSetCertificateContextProperty(certContext, CERT_DESCRIPTION_PROP_ID, uintptr(unsafe.Pointer(&data)))
}

func certGetCertificateContextProperty(certContext *windows.CertContext, propID uint32, pvData *byte, pcbData *uint32) error {
r0, _, err := procCertGetCertificateContextProperty.Call(
uintptr(unsafe.Pointer(certContext)),
uintptr(propID),
uintptr(unsafe.Pointer(pvData)),
uintptr(unsafe.Pointer(pcbData)),
)
if r0 == 0 {
return err
}
return nil
}

func cryptFindCertificateFriendlyName(certContext *windows.CertContext) (string, error) {
var size uint32

err := certGetCertificateContextProperty(certContext, CERT_FRIENDLY_NAME_PROP_ID, nil, &size)
if err != nil {
if errno, ok := err.(windows.Errno); ok && uint32(errno) == CRYPT_E_NOT_FOUND {
return "", nil
}

return "", err
}

if size == 0 {
return "", nil
}

buf := make([]byte, size)
err = certGetCertificateContextProperty(certContext, CERT_FRIENDLY_NAME_PROP_ID, &buf[0], &size)
if err != nil {
return "", err
}

uc := bytes.ReplaceAll(buf, []byte{0x00}, []byte(""))
return string(uc), nil
}

func cryptFindCertificateDescription(certContext *windows.CertContext) (string, error) {
var size uint32

err := certGetCertificateContextProperty(certContext, CERT_DESCRIPTION_PROP_ID, nil, &size)
if err != nil {
if errno, ok := err.(windows.Errno); ok && uint32(errno) == CRYPT_E_NOT_FOUND {
return "", nil
}

return "", err
}
if size == 0 {
return "", nil
}

buf := make([]byte, size)
err = certGetCertificateContextProperty(certContext, CERT_DESCRIPTION_PROP_ID, &buf[0], &size)
if err != nil {
return "", err
}

uc := bytes.ReplaceAll(buf, []byte{0x00}, []byte(""))
return string(uc), nil
}

func certStrToName(x500Str string) ([]byte, error) {
var size uint32

Expand Down
32 changes: 24 additions & 8 deletions kms/tpmkms/tpmkms.go
Original file line number Diff line number Diff line change
Expand Up @@ -707,11 +707,14 @@
)

func (k *TPMKMS) loadCertificateChainFromWindowsCertificateStore(req *apiv1.LoadCertificateRequest) ([]*x509.Certificate, error) {
pub, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{
var subjectKeyID string
if pub, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{
Name: req.Name,
})
if err != nil {
return nil, fmt.Errorf("failed retrieving public key: %w", err)
}); err == nil {
subjectKeyID, err = generateWindowsSubjectKeyID(pub)
if err != nil {
return nil, fmt.Errorf("failed generating subject key id: %w", err)
}
}

o, err := parseNameURI(req.Name)
Expand All @@ -728,13 +731,24 @@
store = o.store
}

subjectKeyID, err := generateWindowsSubjectKeyID(pub)
if err != nil {
return nil, fmt.Errorf("failed generating subject key id: %w", err)
uv := url.Values{
"store-location": []string{location},
"store": []string{store},

Check failure on line 736 in kms/tpmkms/tpmkms.go

View workflow job for this annotation

GitHub Actions / ci / lint / lint

File is not properly formatted (goimports)
}

switch {
case subjectKeyID != "":
uv.Set("key-id", subjectKeyID)
case o.issuer != "":
uv.Set("issuer", o.issuer)
case o.friendlyName != "":
uv.Set("friendly-name", o.friendlyName)
case o.description != "":
uv.Set("description", o.description)
}

cert, err := k.windowsCertificateManager.LoadCertificate(&apiv1.LoadCertificateRequest{
Name: fmt.Sprintf("capi:key-id=%s;store-location=%s;store=%s;", subjectKeyID, location, store),
Name: uri.New("capi", uv).String(),
})
if err != nil {
return nil, fmt.Errorf("failed retrieving certificate using Windows platform cryptography provider: %w", err)
Expand Down Expand Up @@ -875,6 +889,8 @@
uv.Set("sha1", fp)
uv.Set("store-location", location)
uv.Set("store", store)
uv.Set("friendly-name", o.friendlyName)
uv.Set("description", o.description)
uv.Set("skip-find-certificate-key", skipFindCertificateKey)

if err := k.windowsCertificateManager.StoreCertificate(&apiv1.StoreCertificateRequest{
Expand Down
6 changes: 6 additions & 0 deletions kms/tpmkms/uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ type objectProperties struct {
path string
storeLocation string
store string
friendlyName string
description string
intermediateStoreLocation string
intermediateStore string
skipFindCertificateKey bool
Expand Down Expand Up @@ -60,8 +62,12 @@ func parseNameURI(nameURI string) (o objectProperties, err error) {

// store location and store options are used on Windows to override
// which store(s) are used for storing and loading (intermediate) certificates
// friendly-name and description are used on Windows to populate additional certificate
// context properties to aid in retrieval
o.storeLocation = u.Get("store-location")
o.store = u.Get("store")
o.friendlyName = u.Get("friendly-name")
o.description = u.Get("description")
o.intermediateStoreLocation = u.Get("intermediate-store-location")
o.intermediateStore = u.Get("intermediate-store")
o.skipFindCertificateKey = u.GetBool("skip-find-certificate-key")
Expand Down
Loading