diff --git a/hkdf.go b/hkdf.go index f2ff598..d4f8aa6 100644 --- a/hkdf.go +++ b/hkdf.go @@ -151,6 +151,43 @@ func ExtractHKDF(h func() hash.Hash, secret, salt []byte) ([]byte, error) { } } +// ExpandHKDFOneShot derives a key from the given hash, key, and optional context info. +func ExpandHKDFOneShot(h func() hash.Hash, pseudorandomKey, info []byte, keyLength int) ([]byte, error) { + if !SupportsHKDF() { + return nil, errUnsupportedVersion() + } + + md, err := hashFuncToMD(h) + if err != nil { + return nil, err + } + + out := make([]byte, keyLength) + switch vMajor { + case 1: + ctx, err := newHKDFCtx1(md, C.GO_EVP_KDF_HKDF_MODE_EXPAND_ONLY, nil, nil, pseudorandomKey, info) + if err != nil { + return nil, err + } + defer C.go_openssl_EVP_PKEY_CTX_free(ctx) + if C.go_openssl_EVP_PKEY_derive_wrapper(ctx, base(out), C.size_t(keyLength)).result != 1 { + return nil, newOpenSSLError("EVP_PKEY_derive") + } + case 3: + ctx, err := newHKDFCtx3(md, C.GO_EVP_KDF_HKDF_MODE_EXPAND_ONLY, nil, nil, pseudorandomKey, info) + if err != nil { + return nil, err + } + defer C.go_openssl_EVP_KDF_CTX_free(ctx) + if C.go_openssl_EVP_KDF_derive(ctx, base(out), C.size_t(keyLength), nil) != 1 { + return nil, newOpenSSLError("EVP_KDF_derive") + } + default: + panic(errUnsupportedVersion()) + } + return out, nil +} + func ExpandHKDF(h func() hash.Hash, pseudorandomKey, info []byte) (io.Reader, error) { if !SupportsHKDF() { return nil, errUnsupportedVersion() diff --git a/hkdf_test.go b/hkdf_test.go index 42c2460..cd5cb0c 100644 --- a/hkdf_test.go +++ b/hkdf_test.go @@ -405,3 +405,46 @@ func TestHKDFUnsupportedHash(t *testing.T) { t.Error("expected error for unsupported hash") } } +func TestExpandHKDFOneShot(t *testing.T) { + if !openssl.SupportsHKDF() { + t.Skip("HKDF is not supported") + } + for i, tt := range hkdfTests { + out, err := openssl.ExpandHKDFOneShot(tt.hash, tt.prk, tt.info, len(tt.out)) + if err != nil { + t.Errorf("test %d: error expanding HKDF one-shot: %v.", i, err) + continue + } + if !bytes.Equal(out, tt.out) { + t.Errorf("test %d: incorrect output from ExpandHKDFOneShot: have %v, need %v.", i, out, tt.out) + } + } +} + +func TestExpandHKDFOneShotLimit(t *testing.T) { + if !openssl.SupportsHKDF() { + t.Skip("HKDF is not supported") + } + hash := openssl.NewSHA1 + master := []byte{0x00, 0x01, 0x02, 0x03} + info := []byte{} + + prk, err := openssl.ExtractHKDF(hash, master, nil) + if err != nil { + t.Fatalf("error extracting HKDF: %v.", err) + } + limit := hash().Size() * 255 + out, err := openssl.ExpandHKDFOneShot(hash, prk, info, limit) + if err != nil { + t.Errorf("error expanding HKDF one-shot: %v.", err) + } + if len(out) != limit { + t.Errorf("incorrect output length: have %d, need %d.", len(out), limit) + } + + // Expanding one more byte should fail + _, err = openssl.ExpandHKDFOneShot(hash, prk, info, limit+1) + if err == nil { + t.Errorf("expected error for key expansion overflow") + } +}