Skip to content

Commit dc77acd

Browse files
committed
fix tests
1 parent 6b09b11 commit dc77acd

File tree

3 files changed

+44
-75
lines changed

3 files changed

+44
-75
lines changed

cshake.go

Lines changed: 30 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -35,48 +35,22 @@ func SumSHAKE256(data []byte, length int) []byte {
3535
return out
3636
}
3737

38-
// SupportsSHAKE128 returns true if the SHAKE128 extendable output function is
39-
// supported.
40-
func SupportsSHAKE128() bool {
41-
return supportsSHAKE(128)
42-
}
43-
44-
// SupportsSHAKE256 returns true if the SHAKE256 extendable output function is
45-
// supported.
46-
func SupportsSHAKE256() bool {
47-
return supportsSHAKE(256)
48-
}
49-
50-
// SupportsCSHAKE128 returns true if the CSHAKE128 extendable output function is
51-
// supported.
52-
func SupportsCSHAKE128() bool {
53-
return false
54-
}
55-
56-
// SupportsCSHAKE256 returns true if the CSHAKE256 extendable output function is
57-
// supported.
58-
func SupportsCSHAKE256() bool {
59-
return false
60-
}
61-
62-
// cacheSHAKESupported is a cache of SHAKE size support.
63-
var cacheSHAKESupported sync.Map
64-
65-
// SupportsSHAKE returns true if the SHAKE extendable output function is
66-
// supported.
67-
func supportsSHAKE(size int) bool {
38+
// SupportsSHAKE returns true if the SHAKE extendable output functions
39+
// with the given securityBits are supported.
40+
func SupportsSHAKE(securityBits int) bool {
6841
if vMajor == 1 || (vMajor == 3 && vMinor < 3) {
6942
// SHAKE MD's are supported since OpenSSL 1.1.1,
7043
// but EVP_DigestSqueeze is only supported since 3.3,
7144
// and we need it to implement [sha3.SHAKE].
7245
return false
7346
}
74-
if v, ok := cacheSHAKESupported.Load(size); ok {
75-
return v.(bool)
76-
}
77-
supported := loadShake(size) != nil
78-
cacheSHAKESupported.Store(size, supported)
79-
return supported
47+
return loadShake(securityBits) != nil
48+
}
49+
50+
// SupportsCSHAKE returns true if the CSHAKE extendable output functions
51+
// with the given securityBits are supported.
52+
func SupportsCSHAKE(securityBits int) bool {
53+
return false
8054
}
8155

8256
// SHAKE is an instance of a SHAKE extendable output function.
@@ -203,35 +177,32 @@ type shakeAlgorithm struct {
203177
}
204178

205179
// loadShake converts a crypto.Hash to a EVP_MD.
206-
func loadShake(xofLength int) *shakeAlgorithm {
207-
if v, ok := cacheMD.Load(xofLength); ok {
180+
func loadShake(securityBits int) (alg *shakeAlgorithm) {
181+
if v, ok := cacheMD.Load(securityBits); ok {
208182
return v.(*shakeAlgorithm)
209183
}
184+
defer func() {
185+
cacheMD.Store(securityBits, alg)
186+
}()
210187

211-
var shake shakeAlgorithm
212-
switch xofLength {
188+
var name *C.char
189+
switch securityBits {
213190
case 128:
214-
if versionAtOrAbove(1, 1, 0) {
215-
shake.md = C.go_openssl_EVP_shake128()
216-
}
191+
name = C.CString("SHAKE-128")
217192
case 256:
218-
if versionAtOrAbove(1, 1, 0) {
219-
shake.md = C.go_openssl_EVP_shake256()
220-
}
221-
}
222-
if shake.md == nil {
223-
cacheMD.Store(xofLength, (*hashAlgorithm)(nil))
193+
name = C.CString("SHAKE-256")
194+
default:
224195
return nil
225196
}
226-
shake.blockSize = int(C.go_openssl_EVP_MD_get_block_size(shake.md))
227-
if vMajor == 3 {
228-
md := C.go_openssl_EVP_MD_fetch(nil, C.go_openssl_EVP_MD_get0_name(shake.md), nil)
229-
// Don't overwrite md in case it can't be fetched, as the md may still be used
230-
// outside of EVP_MD_CTX.
231-
if md != nil {
232-
shake.md = md
233-
}
197+
defer C.free(unsafe.Pointer(name))
198+
199+
md := C.go_openssl_EVP_MD_fetch(nil, name, nil)
200+
if md == nil {
201+
return nil
234202
}
235-
cacheMD.Store(xofLength, &shake)
236-
return &shake
203+
204+
alg = new(shakeAlgorithm)
205+
alg.md = md
206+
alg.blockSize = int(C.go_openssl_EVP_MD_get_block_size(md))
207+
return alg
237208
}

cshake_test.go

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@ func skipCSHAKEIfNotSupported(t *testing.T, algo string) {
3030
var supported bool
3131
switch algo {
3232
case "SHAKE128":
33-
supported = openssl.SupportsSHAKE128()
33+
supported = openssl.SupportsSHAKE(128)
3434
case "SHAKE256":
35-
supported = openssl.SupportsSHAKE256()
35+
supported = openssl.SupportsSHAKE(256)
3636
case "CSHAKE128":
37-
supported = openssl.SupportsCSHAKE128()
37+
supported = openssl.SupportsCSHAKE(128)
3838
case "CSHAKE256":
39-
supported = openssl.SupportsCSHAKE256()
39+
supported = openssl.SupportsCSHAKE(256)
4040
}
4141
if !supported {
4242
t.Skip("skipping: not supported")
@@ -94,7 +94,7 @@ func TestCSHAKEReset(t *testing.T) {
9494
skipCSHAKEIfNotSupported(t, algo)
9595

9696
// Calculate hash for the first time
97-
c := v.constructor(nil, []byte{0x99, 0x98})
97+
c := v.constructor(nil, []byte(v.defCustomStr))
9898
c.Write(sequentialBytes(0x100))
9999
c.Read(out1)
100100

@@ -112,14 +112,14 @@ func TestCSHAKEReset(t *testing.T) {
112112

113113
func TestCSHAKEAccumulated(t *testing.T) {
114114
t.Run("CSHAKE128", func(t *testing.T) {
115-
if !openssl.SupportsSHAKE128() {
115+
if !openssl.SupportsCSHAKE(128) {
116116
t.Skip("skipping: not supported")
117117
}
118118
testCSHAKEAccumulated(t, openssl.NewCSHAKE128, (1600-256)/8,
119119
"bb14f8657c6ec5403d0b0e2ef3d3393497e9d3b1a9a9e8e6c81dbaa5fd809252")
120120
})
121121
t.Run("CSHAKE256", func(t *testing.T) {
122-
if !openssl.SupportsSHAKE256() {
122+
if !openssl.SupportsCSHAKE(256) {
123123
t.Skip("skipping: not supported")
124124
}
125125
testCSHAKEAccumulated(t, openssl.NewCSHAKE256, (1600-512)/8,
@@ -158,7 +158,7 @@ func testCSHAKEAccumulated(t *testing.T, newCSHAKE func(N, S []byte) *openssl.SH
158158
}
159159

160160
func TestCSHAKELargeS(t *testing.T) {
161-
if !openssl.SupportsSHAKE128() {
161+
if !openssl.SupportsCSHAKE(128) {
162162
t.Skip("skipping: not supported")
163163
}
164164
const s = (1<<32)/8 + 1000 // s * 8 > 2^32
@@ -178,11 +178,11 @@ func TestCSHAKELargeS(t *testing.T) {
178178

179179
func TestCSHAKESum(t *testing.T) {
180180
const testString = "hello world"
181-
t.Run("CSHAKE128", func(t *testing.T) {
182-
if !openssl.SupportsSHAKE128() {
181+
t.Run("SHAKE128", func(t *testing.T) {
182+
if !openssl.SupportsSHAKE(128) {
183183
t.Skip("skipping: not supported")
184184
}
185-
h := openssl.NewCSHAKE128(nil, nil)
185+
h := openssl.NewSHAKE128()
186186
h.Write([]byte(testString[:5]))
187187
h.Write([]byte(testString[5:]))
188188
want := make([]byte, 32)
@@ -192,11 +192,11 @@ func TestCSHAKESum(t *testing.T) {
192192
t.Errorf("got:%x want:%x", got, want)
193193
}
194194
})
195-
t.Run("CSHAKE256", func(t *testing.T) {
196-
if !openssl.SupportsSHAKE256() {
195+
t.Run("SHAKE256", func(t *testing.T) {
196+
if !openssl.SupportsSHAKE(256) {
197197
t.Skip("skipping: not supported")
198198
}
199-
h := openssl.NewCSHAKE256(nil, nil)
199+
h := openssl.NewSHAKE256()
200200
h.Write([]byte(testString[:5]))
201201
h.Write([]byte(testString[5:]))
202202
want := make([]byte, 32)

shims.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,6 @@ DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_sha3_224, (void), ()) \
252252
DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_sha3_256, (void), ()) \
253253
DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_sha3_384, (void), ()) \
254254
DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_sha3_512, (void), ()) \
255-
DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_shake128, (void), ()) \
256-
DEFINEFUNC_1_1_1(const GO_EVP_MD_PTR, EVP_shake256, (void), ()) \
257255
DEFINEFUNC_LEGACY_1_0(void, HMAC_CTX_init, (GO_HMAC_CTX_PTR arg0), (arg0)) \
258256
DEFINEFUNC_LEGACY_1_0(void, HMAC_CTX_cleanup, (GO_HMAC_CTX_PTR arg0), (arg0)) \
259257
DEFINEFUNC_LEGACY_1(int, HMAC_Init_ex, (GO_HMAC_CTX_PTR arg0, const void *arg1, int arg2, const GO_EVP_MD_PTR arg3, GO_ENGINE_PTR arg4), (arg0, arg1, arg2, arg3, arg4)) \

0 commit comments

Comments
 (0)