1 module jwtd.jwt_openssl;
2 
3 version(UseOpenSSL) {
4 
5 	import deimos.openssl.ssl;
6 	import deimos.openssl.pem;
7 	import deimos.openssl.rsa;
8 	import deimos.openssl.hmac;
9 	import deimos.openssl.err;
10 
11 	import jwtd.jwt : JWTAlgorithm, SignException, VerifyException;
12 
13 	EC_KEY* getESKeypair(uint curve_type, string key) {
14 		EC_GROUP* curve;
15 		EVP_PKEY* pktmp;
16 		BIO* bpo;
17 		EC_POINT* pub;
18 
19 		if(null == (curve = EC_GROUP_new_by_curve_name(curve_type)))
20 			throw new Exception("Unsupported curve.");
21 		scope(exit) EC_GROUP_free(curve);
22 
23 		bpo = BIO_new_mem_buf(cast(char*)key.ptr, -1);
24 		if(bpo is null) {
25 			throw new Exception("Can't load the key.");
26 		}
27 		scope(exit) BIO_free(bpo);
28 
29 		pktmp = PEM_read_bio_PrivateKey(bpo, null, null, null);
30 		if(pktmp is null) {
31 			throw new Exception("Can't load the evp_pkey.");
32 		}
33 		scope(exit) EVP_PKEY_free(pktmp);
34 
35 		EC_KEY* eckey;
36 		eckey = EVP_PKEY_get1_EC_KEY(pktmp);
37 		if(eckey is null) {
38 			throw new Exception("Can't convert evp_pkey to EC_KEY.");
39 		}
40 		scope(failure) EC_KEY_free(eckey);
41 
42 		if(1 != EC_KEY_set_group(eckey, curve)) {
43 			throw new Exception("Can't associate group with the key.");
44 		}
45 
46 		const BIGNUM *prv = EC_KEY_get0_private_key(eckey);
47 		if(null == prv) {
48 			throw new Exception("Can't get private key.");
49 		}
50 
51 		pub = EC_POINT_new(curve);
52 		if(null == pub) {
53 			throw new Exception("Can't allocate EC point.");
54 		}
55 		scope(exit) EC_POINT_free(pub);
56 
57 		if (1 != EC_POINT_mul(curve, pub, prv, null, null, null)) {
58 			throw new Exception("Can't calculate public key.");
59 		}
60 
61 		if(1 != EC_KEY_set_public_key(eckey, pub)) {
62 			throw new Exception("Can't set public key.");
63 		}
64 
65 		return eckey;
66 	}
67 
68     unittest {
69         import jwtd.test;
70         import std.exception : assertThrown;
71 
72         assertThrown(getESKeypair(0, "key"));
73         assertThrown(getESKeypair(NID_secp256k1, "bogus_key"));
74         assertThrown(getESKeypair(NID_secp256k1, null));
75         assertThrown(getESKeypair(NID_secp256k1, private256));
76     }
77 
78 	EC_KEY* getESPrivateKey(uint curve_type, string key) {
79 		EC_GROUP* curve;
80 		EVP_PKEY* pktmp;
81 		BIO* bpo;
82 
83 		if(null == (curve = EC_GROUP_new_by_curve_name(curve_type)))
84 			throw new Exception("Unsupported curve.");
85 		scope(exit) EC_GROUP_free(curve);
86 
87 		bpo = BIO_new_mem_buf(cast(char*)key.ptr, -1);
88 		if(bpo is null) {
89 			throw new Exception("Can't load the key.");
90 		}
91 		scope(exit) BIO_free(bpo);
92 
93 		pktmp = PEM_read_bio_PrivateKey(bpo, null, null, null);
94 		if(pktmp is null) {
95 			throw new Exception("Can't load the evp_pkey.");
96 		}
97 		scope(exit) EVP_PKEY_free(pktmp);
98 
99 		EC_KEY * eckey;
100 	 	eckey = EVP_PKEY_get1_EC_KEY(pktmp);
101 		if(eckey is null) {
102 			throw new Exception("Can't convert evp_pkey to EC_KEY.");
103 		}
104 
105 		scope(failure) EC_KEY_free(eckey);
106 		if(1 != EC_KEY_set_group(eckey, curve)) {
107 			throw new Exception("Can't associate group with the key.");
108 		}
109 
110 		return eckey;
111 	}
112 
113     unittest {
114         import std.exception : assertThrown;
115         assertThrown(getESPrivateKey(0, "key"));
116         assertThrown(getESPrivateKey(NID_secp256k1, "bogus_key"));
117         assertThrown(getESPrivateKey(NID_secp256k1, null));
118     }
119 
120 	EC_KEY* getESPublicKey(uint curve_type, string key) {
121 		EC_GROUP* curve;
122 
123 		if(null == (curve = EC_GROUP_new_by_curve_name(curve_type)))
124 			throw new Exception("Unsupported curve.");
125 		scope(exit) EC_GROUP_free(curve);
126 
127 		EC_KEY* eckey;
128 
129 		BIO* bpo = BIO_new_mem_buf(cast(char*)key.ptr, -1);
130 		if(bpo is null) {
131 			throw new Exception("Can't load the key.");
132 		}
133 		scope(exit) BIO_free(bpo);
134 
135 		eckey = PEM_read_bio_EC_PUBKEY(bpo, null, null, null);
136 		scope(failure) EC_KEY_free(eckey);
137 
138 		if(1 != EC_KEY_set_group(eckey, curve)) {
139 			throw new Exception("Can't associate group with the key.");
140 		}
141 
142 		if(0 == EC_KEY_check_key(eckey))
143 			throw new Exception("Public key is not valid.");
144 
145 		return eckey;
146 	}
147 
148     unittest {
149         import jwtd.test;
150         import std.exception : assertThrown;
151 
152         assertThrown(getESPublicKey(0, "key"));
153 
154         auto eckey = getESPublicKey(NID_X9_62_prime256v1, es256_public);
155         EC_KEY_free(eckey);
156         assertThrown(getESPublicKey(NID_X9_62_prime256v1, null));
157     }
158 
159 	string sign(string msg, string key, JWTAlgorithm algo = JWTAlgorithm.HS256) {
160 		ubyte[] sign;
161 
162 		void sign_hs(const(EVP_MD)* evp, uint signLen) {
163 			sign = new ubyte[signLen];
164 
165 			HMAC_CTX * ctx = HMAC_CTX_new();
166 			scope(exit) HMAC_CTX_free(ctx);
167 
168 			HMAC_CTX_reset(ctx);
169 			if(0 == HMAC_Init_ex(ctx, key.ptr, cast(int)key.length, evp, null)) {
170 				throw new Exception("Can't initialize HMAC context.");
171 			}
172 			if(0 == HMAC_Update(ctx, cast(const(ubyte)*)msg.ptr, cast(ulong)msg.length)) {
173 				throw new Exception("Can't update HMAC.");
174 			}
175 			if(0 == HMAC_Final(ctx, cast(ubyte*)sign.ptr, &signLen)) {
176 				throw new Exception("Can't finalize HMAC.");
177 			}
178 		}
179 
180 		void sign_rs(ubyte* hash, int type, uint hashLen) 
181 		{
182 			
183 			RSA* rsa_private = RSA_new();
184 			scope(exit) RSA_free(rsa_private);
185 
186 			BIO* bpo = BIO_new_mem_buf(cast(char*)key.ptr, -1);
187 			if(bpo is null)
188 				throw new Exception("Can't load the key.");
189 			scope(exit) BIO_free(bpo);
190 
191 			RSA* rsa = PEM_read_bio_RSAPrivateKey(bpo, &rsa_private, null, null);
192 			if(rsa is null) {
193 				throw new Exception("Can't create RSA key.");
194 			}
195 
196 			uint len	= cast(uint) RSA_size(rsa_private);
197 			sign		= new ubyte[len];
198 
199 			if(0 == RSA_sign(type, hash, hashLen, sign.ptr, &len, rsa_private)) {
200 				throw new Exception("Can't sign RSA message digest.");
201 			}
202 		}
203 
204 		void sign_es(uint curve_type, ubyte* hash, int hashLen) 
205 		{
206 			EC_KEY* eckey = getESPrivateKey(curve_type, key);
207 			scope(exit) EC_KEY_free(eckey);
208 
209 			ECDSA_SIG* sig = ECDSA_do_sign(hash, hashLen, eckey);
210 
211 			if(sig is null) {
212 				throw new Exception("Digest sign failed.");
213 			}
214 			scope(exit) ECDSA_SIG_free(sig);
215 
216 			int keySize		= ECDSA_size(eckey);
217 			int sigPartSize	= (keySize - 8) / 2;
218 					
219 			uint rSize	= BN_num_bytes(sig.r);
220 			uint sSize	= BN_num_bytes(sig.s);
221 
222 			sign = new ubyte[sigPartSize * 2];
223 
224 			ubyte* c = sign.ptr;
225 
226 			int rPadding	= sigPartSize - rSize;
227 			int sPadding	= sigPartSize - sSize;
228 			
229 			BN_bn2bin(sig.r, c + rPadding);
230 			BN_bn2bin(sig.s, c + (sigPartSize + sPadding));
231 
232 			/* Signature not in DER format. DER is not defined by JWT
233 			if(!i2d_ECDSA_SIG(sig, &c)) {
234 				throw new Exception("Convert sign to DER format failed.");
235 			}*/
236 		}
237 
238 		switch(algo) {
239 			case JWTAlgorithm.NONE: {
240 				break;
241 			}
242 			case JWTAlgorithm.HS256: {
243 				sign_hs(EVP_sha256(), SHA256_DIGEST_LENGTH);
244 				break;
245 			}
246 			case JWTAlgorithm.HS384: {
247 				sign_hs(EVP_sha384(), SHA384_DIGEST_LENGTH);
248 				break;
249 			}
250 			case JWTAlgorithm.HS512: {
251 				sign_hs(EVP_sha512(), SHA512_DIGEST_LENGTH);
252 				break;
253 			}
254 			case JWTAlgorithm.RS256: {
255 				ubyte[] hash = new ubyte[SHA256_DIGEST_LENGTH];
256 				SHA256(cast(const(ubyte)*)msg.ptr, msg.length, hash.ptr);
257 				sign_rs(hash.ptr, NID_sha256, SHA256_DIGEST_LENGTH);
258 				break;
259 			}
260 			case JWTAlgorithm.RS384: {
261 				ubyte[] hash = new ubyte[SHA384_DIGEST_LENGTH];
262 				SHA384(cast(const(ubyte)*)msg.ptr, msg.length, hash.ptr);
263 				sign_rs(hash.ptr, NID_sha384, SHA384_DIGEST_LENGTH);
264 				break;
265 			}
266 			case JWTAlgorithm.RS512: {
267 				ubyte[] hash = new ubyte[SHA512_DIGEST_LENGTH];
268 				SHA512(cast(const(ubyte)*)msg.ptr, msg.length, hash.ptr);
269 				sign_rs(hash.ptr, NID_sha512, SHA512_DIGEST_LENGTH);
270 				break;
271 			}
272 			case JWTAlgorithm.ES256: {
273 				ubyte[] hash = new ubyte[SHA256_DIGEST_LENGTH];
274 				SHA256(cast(const(ubyte)*)msg.ptr, msg.length, hash.ptr);
275 				sign_es(NID_X9_62_prime256v1, hash.ptr, SHA256_DIGEST_LENGTH);
276 				break;
277 			}
278 			case JWTAlgorithm.ES256K: {
279 				ubyte[] hash = new ubyte[SHA256_DIGEST_LENGTH];
280 				SHA256(cast(const(ubyte)*)msg.ptr, msg.length, hash.ptr);
281 				sign_es(NID_secp256k1, hash.ptr, SHA256_DIGEST_LENGTH);
282 				break;
283 			}
284 
285 			case JWTAlgorithm.ES384: {
286 				ubyte[] hash = new ubyte[SHA384_DIGEST_LENGTH];
287 				SHA384(cast(const(ubyte)*)msg.ptr, msg.length, hash.ptr);
288 				sign_es(NID_secp384r1, hash.ptr, SHA384_DIGEST_LENGTH);
289 				break;
290 			}
291 			case JWTAlgorithm.ES512: {
292 				ubyte[] hash = new ubyte[SHA512_DIGEST_LENGTH];
293 				SHA512(cast(const(ubyte)*)msg.ptr, msg.length, hash.ptr);
294 				sign_es(NID_secp521r1, hash.ptr, SHA512_DIGEST_LENGTH);
295 				break;
296 			}
297 
298 			default:
299 				throw new SignException("Wrong algorithm.");
300 		}
301 
302 		return cast(string)sign;
303 	}
304 
305 
306 	bool verifySignature(string signature, string signing_input, string key, JWTAlgorithm algo = JWTAlgorithm.HS256) {
307 
308 		bool verify_rs(ubyte* hash, int type, uint hashLen) {
309 			RSA* rsa_public = RSA_new();
310 			scope(exit) RSA_free(rsa_public);
311 
312 			BIO* bpo = BIO_new_mem_buf(cast(char*)key.ptr, -1);
313 			if(bpo is null)
314 				throw new Exception("Can't load key to the BIO.");
315 			scope(exit) BIO_free(bpo);
316 
317 			RSA* rsa = PEM_read_bio_RSA_PUBKEY(bpo, &rsa_public, null, null);
318 			if(rsa is null) {
319 				throw new Exception("Can't create RSA key.");
320 			}
321 
322 			ubyte[] sign	= cast(ubyte[])signature;
323 			uint len		= cast(uint) RSA_size(rsa_public);
324 
325 			int ret = RSA_verify(type, hash, hashLen, sign.ptr, len, rsa_public);
326 			return ret == 1;
327 		}
328 
329 		bool verify_es(uint curve_type, ubyte* hash, int hashLen ) {
330 			EC_KEY* eckey = getESPublicKey(curve_type, key);
331 			scope(exit) EC_KEY_free(eckey);
332 
333 			int keySize		= ECDSA_size(eckey);
334 			int sigPartSize	= (keySize - 8) / 2;
335 
336 			ubyte * sigPointer = cast(ubyte *) signature.ptr;
337 
338 			ECDSA_SIG* sig = ECDSA_SIG_new();
339 			scope(exit) ECDSA_SIG_free(sig);
340 						
341 			if (null == (sig.r = BN_bin2bn(sigPointer, sigPartSize, sig.r)))
342 				throw new Exception("Can't decode ECDSA signature.");
343 
344 			int remaining	= cast(int) signature.length - sigPartSize;
345 
346 			if (null == (sig.s = BN_bin2bn(sigPointer + sigPartSize, remaining, sig.s)))
347 				throw new Exception("Can't decode ECDSA signature.");
348 
349 			int ret =  ECDSA_do_verify(hash, hashLen, sig, eckey);
350 			return ret == 1;
351 		}
352 
353 		switch(algo) {
354 			case JWTAlgorithm.NONE: {
355 				return key.length == 0;
356 			}
357 			case JWTAlgorithm.HS256:
358 			case JWTAlgorithm.HS384:
359 			case JWTAlgorithm.HS512: {
360 				return signature == sign(signing_input, key, algo);
361 			}
362 			case JWTAlgorithm.RS256: {
363 				ubyte[] hash = new ubyte[SHA256_DIGEST_LENGTH];
364 				SHA256(cast(const(ubyte)*)signing_input.ptr, signing_input.length, hash.ptr);
365 				return verify_rs(hash.ptr, NID_sha256, SHA256_DIGEST_LENGTH);
366 			}
367 			case JWTAlgorithm.RS384: {
368 				ubyte[] hash = new ubyte[SHA384_DIGEST_LENGTH];
369 				SHA384(cast(const(ubyte)*)signing_input.ptr, signing_input.length, hash.ptr);
370 				return verify_rs(hash.ptr, NID_sha384, SHA384_DIGEST_LENGTH);
371 			}
372 			case JWTAlgorithm.RS512: {
373 				ubyte[] hash = new ubyte[SHA512_DIGEST_LENGTH];
374 				SHA512(cast(const(ubyte)*)signing_input.ptr, signing_input.length, hash.ptr);
375 				return verify_rs(hash.ptr, NID_sha512, SHA512_DIGEST_LENGTH);
376 			}
377 
378 			case JWTAlgorithm.ES256:{
379 				ubyte[] hash = new ubyte[SHA256_DIGEST_LENGTH];
380 				SHA256(cast(const(ubyte)*)signing_input.ptr, signing_input.length, hash.ptr);
381 				//return verify_es(NID_secp256k1, hash.ptr, SHA256_DIGEST_LENGTH );
382 				return verify_es(NID_X9_62_prime256v1, hash.ptr, SHA256_DIGEST_LENGTH ); // The ES256 Spec uses the prime256v1 Curve
383 			}
384 
385 			case JWTAlgorithm.ES256K: // Added ES256K to use the secp256k1 curve
386 			{
387 				ubyte[] hash = new ubyte[SHA256_DIGEST_LENGTH];
388 				SHA256(cast(const(ubyte)*)signing_input.ptr, signing_input.length, hash.ptr);
389 				return verify_es(NID_secp256k1, hash.ptr, SHA256_DIGEST_LENGTH );
390 			}
391 
392 			case JWTAlgorithm.ES384:{
393 				ubyte[] hash = new ubyte[SHA384_DIGEST_LENGTH];
394 				SHA384(cast(const(ubyte)*)signing_input.ptr, signing_input.length, hash.ptr);
395 				return verify_es(NID_secp384r1, hash.ptr, SHA384_DIGEST_LENGTH );
396 			}
397 			case JWTAlgorithm.ES512: {
398 				ubyte[] hash = new ubyte[SHA512_DIGEST_LENGTH];
399 				SHA512(cast(const(ubyte)*)signing_input.ptr, signing_input.length, hash.ptr);
400 				return verify_es(NID_secp521r1, hash.ptr, SHA512_DIGEST_LENGTH );
401 			}
402 
403 			default:
404 				throw new VerifyException("Wrong algorithm.");
405 		}
406 	}
407 }
408 
409 unittest {
410     version (UseOpenSSL) {
411         import std.exception : assertThrown;
412         assertThrown!SignException(sign("message", "key", cast(JWTAlgorithm)"bogus_algo"));
413         assertThrown!VerifyException(verifySignature("signature", "signing_input", "key", cast(JWTAlgorithm)"bogus_algo"));
414     }
415 }
416