@@ -31,68 +31,81 @@ void hex256(hash_bytes &in, hash_str &out) {
31
31
}
32
32
}
33
33
34
- const EVP_CIPHER *GetCipher (const string &key) {
35
- // For now, we only support GCM ciphers
36
- switch (key.size ()) {
37
- case 16 :
38
- return EVP_aes_128_gcm ();
39
- case 24 :
40
- return EVP_aes_192_gcm ();
41
- case 32 :
42
- return EVP_aes_256_gcm ();
43
- default :
44
- throw InternalException (" Invalid AES key length" );
45
- }
46
- }
47
-
48
- AESGCMStateSSL::AESGCMStateSSL () : gcm_context(EVP_CIPHER_CTX_new()) {
49
- if (!(gcm_context)) {
34
+ AESStateSSL::AESStateSSL (const std::string *key) : context(EVP_CIPHER_CTX_new()) {
35
+ if (!(context)) {
50
36
throw InternalException (" AES GCM failed with initializing context" );
51
37
}
52
38
}
53
39
54
- AESGCMStateSSL ::~AESGCMStateSSL () {
40
+ AESStateSSL ::~AESStateSSL () {
55
41
// Clean up
56
- EVP_CIPHER_CTX_free (gcm_context );
42
+ EVP_CIPHER_CTX_free (context );
57
43
}
58
44
59
- bool AESGCMStateSSL::IsOpenSSL () {
60
- return ssl;
45
+ const EVP_CIPHER *AESStateSSL::GetCipher (const string &key) {
46
+
47
+ switch (cipher) {
48
+ case GCM:
49
+ switch (key.size ()) {
50
+ case 16 :
51
+ return EVP_aes_128_gcm ();
52
+ case 24 :
53
+ return EVP_aes_192_gcm ();
54
+ case 32 :
55
+ return EVP_aes_256_gcm ();
56
+ default :
57
+ throw InternalException (" Invalid AES key length" );
58
+ }
59
+ case CTR:
60
+ switch (key.size ()) {
61
+ case 16 :
62
+ return EVP_aes_128_ctr ();
63
+ case 24 :
64
+ return EVP_aes_192_ctr ();
65
+ case 32 :
66
+ return EVP_aes_256_ctr ();
67
+ default :
68
+ throw InternalException (" Invalid AES key length" );
69
+ }
70
+
71
+ default :
72
+ throw duckdb::InternalException (" Invalid Encryption/Decryption Cipher: %d" , static_cast <int >(cipher));
73
+ }
61
74
}
62
75
63
- void AESGCMStateSSL ::GenerateRandomData (data_ptr_t data, idx_t len) {
76
+ void AESStateSSL ::GenerateRandomData (data_ptr_t data, idx_t len) {
64
77
// generate random bytes for nonce
65
78
RAND_bytes (data, len);
66
79
}
67
80
68
- void AESGCMStateSSL ::InitializeEncryption (const_data_ptr_t iv, idx_t iv_len, const string *key) {
81
+ void AESStateSSL ::InitializeEncryption (const_data_ptr_t iv, idx_t iv_len, const string *key) {
69
82
mode = ENCRYPT;
70
83
71
- if (1 != EVP_EncryptInit_ex (gcm_context , GetCipher (*key), NULL , const_data_ptr_cast (key->data ()), iv)) {
84
+ if (1 != EVP_EncryptInit_ex (context , GetCipher (*key), NULL , const_data_ptr_cast (key->data ()), iv)) {
72
85
throw InternalException (" EncryptInit failed" );
73
86
}
74
87
}
75
88
76
- void AESGCMStateSSL ::InitializeDecryption (const_data_ptr_t iv, idx_t iv_len, const string *key) {
89
+ void AESStateSSL ::InitializeDecryption (const_data_ptr_t iv, idx_t iv_len, const string *key) {
77
90
mode = DECRYPT;
78
91
79
- if (1 != EVP_DecryptInit_ex (gcm_context , GetCipher (*key), NULL , const_data_ptr_cast (key->data ()), iv)) {
92
+ if (1 != EVP_DecryptInit_ex (context , GetCipher (*key), NULL , const_data_ptr_cast (key->data ()), iv)) {
80
93
throw InternalException (" DecryptInit failed" );
81
94
}
82
95
}
83
96
84
- size_t AESGCMStateSSL ::Process (const_data_ptr_t in, idx_t in_len, data_ptr_t out, idx_t out_len) {
97
+ size_t AESStateSSL ::Process (const_data_ptr_t in, idx_t in_len, data_ptr_t out, idx_t out_len) {
85
98
86
99
switch (mode) {
87
100
case ENCRYPT:
88
- if (1 != EVP_EncryptUpdate (gcm_context , data_ptr_cast (out), reinterpret_cast <int *>(&out_len),
101
+ if (1 != EVP_EncryptUpdate (context , data_ptr_cast (out), reinterpret_cast <int *>(&out_len),
89
102
const_data_ptr_cast (in), (int )in_len)) {
90
103
throw InternalException (" EncryptUpdate failed" );
91
104
}
92
105
break ;
93
106
94
107
case DECRYPT:
95
- if (1 != EVP_DecryptUpdate (gcm_context , data_ptr_cast (out), reinterpret_cast <int *>(&out_len),
108
+ if (1 != EVP_DecryptUpdate (context , data_ptr_cast (out), reinterpret_cast <int *>(&out_len),
96
109
const_data_ptr_cast (in), (int )in_len)) {
97
110
98
111
throw InternalException (" DecryptUpdate failed" );
@@ -107,30 +120,30 @@ size_t AESGCMStateSSL::Process(const_data_ptr_t in, idx_t in_len, data_ptr_t out
107
120
return out_len;
108
121
}
109
122
110
- size_t AESGCMStateSSL::Finalize (data_ptr_t out, idx_t out_len, data_ptr_t tag, idx_t tag_len) {
123
+ size_t AESStateSSL::FinalizeGCM (data_ptr_t out, idx_t out_len, data_ptr_t tag, idx_t tag_len) {
111
124
auto text_len = out_len;
112
125
113
126
switch (mode) {
114
- case ENCRYPT:
115
- {
116
- if (1 != EVP_EncryptFinal_ex (gcm_context, data_ptr_cast (out) + out_len, reinterpret_cast <int *>(&out_len))) {
127
+ case ENCRYPT: {
128
+ if (1 != EVP_EncryptFinal_ex (context, data_ptr_cast (out) + out_len, reinterpret_cast <int *>(&out_len))) {
117
129
throw InternalException (" EncryptFinal failed" );
118
130
}
119
131
text_len += out_len;
132
+
120
133
// The computed tag is written at the end of a chunk
121
- if (1 != EVP_CIPHER_CTX_ctrl (gcm_context , EVP_CTRL_GCM_GET_TAG, tag_len, tag)) {
134
+ if (1 != EVP_CIPHER_CTX_ctrl (context , EVP_CTRL_GCM_GET_TAG, tag_len, tag)) {
122
135
throw InternalException (" Calculating the tag failed" );
123
136
}
124
137
return text_len;
125
138
}
126
- case DECRYPT:
127
- {
139
+ case DECRYPT: {
128
140
// Set expected tag value
129
- if (!EVP_CIPHER_CTX_ctrl (gcm_context , EVP_CTRL_GCM_SET_TAG, tag_len, tag)) {
141
+ if (!EVP_CIPHER_CTX_ctrl (context , EVP_CTRL_GCM_SET_TAG, tag_len, tag)) {
130
142
throw InternalException (" Finalizing tag failed" );
131
143
}
144
+
132
145
// EVP_DecryptFinal() will return an error code if final block is not correctly formatted.
133
- int ret = EVP_DecryptFinal_ex (gcm_context , data_ptr_cast (out) + out_len, reinterpret_cast <int *>(&out_len));
146
+ int ret = EVP_DecryptFinal_ex (context , data_ptr_cast (out) + out_len, reinterpret_cast <int *>(&out_len));
134
147
text_len += out_len;
135
148
136
149
if (ret > 0 ) {
@@ -144,12 +157,46 @@ size_t AESGCMStateSSL::Finalize(data_ptr_t out, idx_t out_len, data_ptr_t tag, i
144
157
}
145
158
}
146
159
160
+ size_t AESStateSSL::Finalize (data_ptr_t out, idx_t out_len, data_ptr_t tag, idx_t tag_len) {
161
+
162
+ if (cipher == GCM) {
163
+ return FinalizeGCM (out, out_len, tag, tag_len);
164
+ }
165
+
166
+ auto text_len = out_len;
167
+ switch (mode) {
168
+
169
+ case ENCRYPT: {
170
+ if (1 != EVP_EncryptFinal_ex (context, data_ptr_cast (out) + out_len, reinterpret_cast <int *>(&out_len))) {
171
+ throw InternalException (" EncryptFinal failed" );
172
+ }
173
+
174
+ return text_len += out_len;
175
+ }
176
+
177
+ case DECRYPT: {
178
+ // EVP_DecryptFinal() will return an error code if final block is not correctly formatted.
179
+ int ret = EVP_DecryptFinal_ex (context, data_ptr_cast (out) + out_len, reinterpret_cast <int *>(&out_len));
180
+ text_len += out_len;
181
+
182
+ if (ret > 0 ) {
183
+ // success
184
+ return text_len;
185
+ }
186
+
187
+ throw InvalidInputException (" Computed AES tag differs from read AES tag, are you using the right key?" );
188
+ }
189
+ default :
190
+ throw InternalException (" Unhandled encryption mode %d" , static_cast <int >(mode));
191
+ }
192
+ }
193
+
147
194
} // namespace duckdb
148
195
149
196
extern " C" {
150
197
151
198
// Call the member function through the factory object
152
- DUCKDB_EXTENSION_API AESGCMStateSSLFactory *CreateSSLFactory () {
153
- return new AESGCMStateSSLFactory ();
199
+ DUCKDB_EXTENSION_API AESStateSSLFactory *CreateSSLFactory () {
200
+ return new AESStateSSLFactory ();
154
201
};
155
202
}
0 commit comments