1 module jwtd.jwt;
2 
3 import std.json;
4 import std.base64;
5 import std.algorithm;
6 import std.array : split;
7 
8 private alias Base64URLNoPadding = Base64Impl!('-', '_', Base64.NoPadding);
9 
10 version(UseOpenSSL) {
11 	public import jwtd.jwt_openssl;
12 }
13 version(UseBotan) {
14 	public import jwtd.jwt_botan;
15 }
16 version(UsePhobos) {
17 	public import jwtd.jwt_phobos;
18 }
19 
20 enum JWTAlgorithm : string {
21 	NONE  = "none",
22 	HS256 = "HS256",
23 	HS384 = "HS384",
24 	HS512 = "HS512",
25 	RS256 = "RS256",
26 	RS384 = "RS384",
27 	RS512 = "RS512",
28 	ES256 = "ES256",
29 	ES256K = "ES256K",
30 	ES384 = "ES384",
31 	ES512 = "ES512"
32 }
33 
34 class SignException : Exception {
35 	this(string s) { super(s); }
36 }
37 
38 class VerifyException : Exception {
39 	this(string s) { super(s); }
40 }
41 
42 /**
43   simple version that accepts only strings as values for payload and header fields
44 */
45 string encode(string[string] payload, string key, JWTAlgorithm algo = JWTAlgorithm.HS256, string[string] header_fields = null) {
46 	JSONValue jsonHeader = header_fields;
47 	JSONValue jsonPayload = payload;
48 
49 	return encode(jsonPayload, key, algo, jsonHeader);
50 }
51 
52 /**
53   full version that accepts JSONValue tree as payload and header fields
54 */
55 string encode(ref JSONValue payload, string key, JWTAlgorithm algo = JWTAlgorithm.HS256, JSONValue header_fields = null) {
56 	return encode(cast(ubyte[])payload.toString(), key, algo, header_fields);
57 }
58 
59 /**
60   full version that accepts ubyte[] as payload and JSONValue tree as header fields
61 */
62 string encode(in ubyte[] payload, string key, JWTAlgorithm algo = JWTAlgorithm.HS256, JSONValue header_fields = null) {
63 	import std.functional : memoize;
64 
65 	auto getEncodedHeader(JWTAlgorithm algo, JSONValue fields) {
66 		if(fields.type == JSONType.null_)
67 			fields = (JSONValue[string]).init;
68 		fields.object["alg"] = cast(string)algo;
69 		fields.object["typ"] = "JWT";
70 
71 		return Base64URLNoPadding.encode(cast(ubyte[])fields.toString()).idup;
72 	}
73 
74 	string encodedHeader = memoize!(getEncodedHeader, 64)(algo, header_fields);
75 	string encodedPayload = Base64URLNoPadding.encode(payload);
76 
77 	string signingInput = encodedHeader ~ "." ~ encodedPayload;
78 	string signature = Base64URLNoPadding.encode(cast(ubyte[])sign(signingInput, key, algo));
79 
80 	return signingInput ~ "." ~ signature;
81 }
82 
83 unittest {
84     import jwtd.test;
85 
86 	// Code coverage for when header_fields is NULL type
87 	auto header_fields = JSONValue();
88 	assert(header_fields.type == JSONType.null_);
89     auto payload = JSONValue([ "a" : "b" ]);
90 	encode(payload, public256, JWTAlgorithm.HS256, header_fields);
91 }
92 
93 
94 JSONValue decode(string token, string key) 
95 {
96 	return decode(token, key, [JWTAlgorithm.NONE, JWTAlgorithm.HS256, JWTAlgorithm.HS256,
97 								JWTAlgorithm.HS384, JWTAlgorithm.HS512, JWTAlgorithm.RS256, 
98 								JWTAlgorithm.RS384, JWTAlgorithm.RS512, JWTAlgorithm.ES256,
99 								JWTAlgorithm.ES256K, JWTAlgorithm.ES384, JWTAlgorithm.ES512]);
100 }
101 /**
102   simple version that knows which key was used to encode the token
103 */
104 JSONValue decode(string token, string key, JWTAlgorithm[] supportedAlgorithms) {
105 	return decode(token, (ref _) => key, supportedAlgorithms);
106 }
107 
108 JSONValue decodeHeader(string token)
109 {
110 	if (count(token, ".") < 1)
111 		throw new VerifyException("Token is incorrect.");
112 
113 	string[] tokenParts = split(token, ".");
114 
115 	JSONValue header;
116 
117 	try 
118 	{
119 		header = parseJSON(urlsafeB64Decode(tokenParts[0]));
120 	} 
121 	catch(Exception e) 
122 	{
123 		throw new VerifyException("Header is incorrect.");
124 	}
125 
126 	return header;
127 }
128 
129 /**
130   full version where the key is provided after decoding the JOSE header
131 */
132 JSONValue decode(string token, string delegate(ref JSONValue jose) lazyKey, JWTAlgorithm[] supportedAlgorithms) {
133 	import std.algorithm : count;
134 	import std.conv : to;
135 	import std.uni : toUpper;
136 
137 	if(count(token, ".") != 2)
138 		throw new VerifyException("Token is incorrect.");
139 
140 	string[] tokenParts = split(token, ".");
141 
142 	JSONValue header;
143 	try {
144 		header = parseJSON(urlsafeB64Decode(tokenParts[0]));
145 	} catch(Exception e) {
146 		throw new VerifyException("Header is incorrect.");
147 	}
148 
149 	JWTAlgorithm alg;
150 	try {
151 		// toUpper for none
152 		alg = to!(JWTAlgorithm)(toUpper(header["alg"].str()));
153 	} catch(Exception e) {
154 		throw new VerifyException("Algorithm is incorrect.");
155 	}
156 
157 	import std.algorithm.searching;
158 
159 	if (!canFind(supportedAlgorithms, alg))
160 		throw new VerifyException("Invalid Token: Security is None and Security is Required.");
161 
162 	if (auto typ = ("typ" in header)) {
163 		string typ_str = typ.str();
164 		if(typ_str && typ_str != "JWT")
165 			throw new VerifyException("Type is incorrect.");
166 	}
167 
168 	const key = lazyKey(header);
169 	if(!verifySignature(urlsafeB64Decode(tokenParts[2]), tokenParts[0]~"."~tokenParts[1], key, alg))
170 		throw new VerifyException("Signature is incorrect.");
171 
172 	JSONValue payload;
173 
174 	try {
175 		payload = parseJSON(urlsafeB64Decode(tokenParts[1]));
176 	} catch(JSONException e) {
177 		// Code coverage has to miss this line because the signature test above throws before this does
178 		throw new VerifyException("Payload JSON is incorrect.");
179 	}
180 
181 	return payload;
182 }
183 
184 unittest {
185     import jwtd.test;
186     import std.traits : EnumMembers;
187 
188     struct Keys {
189         string priv;
190         string pub;
191 
192         this (string priv, string pub = null) {
193             this.priv = priv;
194             this.pub = (pub ? pub : priv);
195         }
196     }
197 
198     auto commonAlgos = [
199         JWTAlgorithm.NONE  : Keys(),
200         JWTAlgorithm.HS256 : Keys("my key"),
201         JWTAlgorithm.HS384 : Keys("his key"),
202         JWTAlgorithm.HS512 : Keys("her key"),
203     ];
204 
205     version (UseOpenSSL) {
206         Keys[JWTAlgorithm] specialAlgos = [
207             JWTAlgorithm.RS256 : Keys(private256, public256),
208             // TODO: Find key pairs for RS384 and RS512
209             // JWTAlgorithm.RS384 : Keys(private384, public384),
210             // JWTAlgorithm.RS512 : Keys(private512, public512),
211             JWTAlgorithm.ES256 : Keys(es256_private, es256_public),
212             JWTAlgorithm.ES384 : Keys(es384_private, es384_public),
213             JWTAlgorithm.ES512 : Keys(es512_private, es512_public),
214         ];
215     }
216 
217     version (UseBotan) {
218         Keys[JWTAlgorithm] specialAlgos = [
219             JWTAlgorithm.RS256 : Keys(private256, public256),
220             // TODO: Find key pairs for the following
221             // JWTAlgorithm.RS384 : Keys(private384, public384),
222             // JWTAlgorithm.RS512 : Keys(private512, public512),
223             // JWTAlgorithm.ES256 : Keys(es256_private, es256_public),
224             // JWTAlgorithm.ES384 : Keys(es384_private, es384_public),
225             // JWTAlgorithm.ES512 : Keys(es512_private, es512_public),
226         ];
227     }
228 
229     else {
230     }
231 
232     version (UsePhobos) {
233         Keys[JWTAlgorithm] specialAlgos;
234     }
235 
236     void testWith(Keys[JWTAlgorithm] keys) {
237         foreach (algo, k; keys) {
238             auto payload = JSONValue([ "claim" : "value" ]);
239             const encoded = encode(payload, k.priv, algo);
240             const decoded = decode(encoded, k.pub);
241             assert(decoded == payload);
242         }
243     }
244 
245     testWith(commonAlgos);
246     testWith(specialAlgos);
247 }
248 
249 version (unittest) {
250 	string corruptEncodedString(size_t part, string field, string badValue) {
251 		import std.conv : text;
252 
253 		string encoded = encode([ "my" : "payload" ], "key");
254 		string[] tokenParts = split(encoded, ".");
255 		auto jsonValue = parseJSON(urlsafeB64Decode(tokenParts[part]));
256 		jsonValue[field] = badValue;
257 		tokenParts[part] = urlsafeB64Encode(jsonValue.toString());
258 		return text(tokenParts.joiner("."));
259 	}
260 }
261 
262 unittest {
263 	import std.exception : assertThrown;
264 
265     // decode() must not accept invalid tokens
266 
267     // Must have 2 dots
268 	assertThrown!VerifyException(decode("nodot", "key"));
269 	assertThrown!VerifyException(decode("one.dot", "key"));
270 	assertThrown!VerifyException(decode("thr.e.e.dots", "key"));
271 
272     // Must have valid header
273  	assertThrown!VerifyException(decode("corrupt.encoding.blah", "key"));
274 
275     // Must be a known algorithm
276 	assertThrown!VerifyException(decode(corruptEncodedString(0, "alg", "bogus_alg"), "key"));
277 
278     // Must be JWT type
279 	assertThrown!VerifyException(decode(corruptEncodedString(0, "typ", "JWX"), "key"));
280 
281     // Must have valid signature
282 	string encoded = encode([ "my" : "payload" ], "key");
283 	assertThrown!VerifyException(decode(encoded[0..$-1], "key"));
284 }
285 
286 bool verify(string token, string key) {
287 	import std.algorithm : count;
288 	import std.conv : to;
289 	import std.uni : toUpper;
290 
291 	if(count(token, ".") != 2)
292 		throw new VerifyException("Token is incorrect.");
293 
294 	string[] tokenParts = split(token, ".");
295 
296 	string decHeader = urlsafeB64Decode(tokenParts[0]);
297 	JSONValue header = parseJSON(decHeader);
298 
299 	JWTAlgorithm alg;
300 	try {
301 		// toUpper for none
302 		alg = to!(JWTAlgorithm)(toUpper(header["alg"].str()));
303 	} catch(Exception e) {
304 		throw new VerifyException("Algorithm is incorrect.");
305 	}
306 
307 	if (auto typ = ("typ" in header)) {
308 		string typ_str = typ.str();
309 		if(typ_str && typ_str != "JWT")
310 			throw new VerifyException("Type is incorrect.");
311 	}
312 
313 	return verifySignature(urlsafeB64Decode(tokenParts[2]), tokenParts[0]~"."~tokenParts[1], key, alg);
314 }
315 
316 unittest {
317     // verify() must not accept invalid tokens
318 
319 	import std.exception : assertThrown;
320 
321     // Must have 2 dots
322 	assertThrown!VerifyException(verify("nodot", "key"));
323 	assertThrown!VerifyException(verify("one.dot", "key"));
324 	assertThrown!VerifyException(verify("thr.e.e.dots", "key"));
325 
326     // Must have valid algorithm and type
327 	assertThrown!VerifyException(verify(corruptEncodedString(0, "alg", "bogus_alg"), "key"));
328 	assertThrown!VerifyException(verify(corruptEncodedString(0, "typ", "JWX"), "key"));
329 }
330 
331 /**
332  * Encode a string with URL-safe Base64.
333  */
334 string urlsafeB64Encode(string inp) pure nothrow {
335 	return Base64URLNoPadding.encode(cast(ubyte[])inp);
336 }
337 
338 /**
339  * Decode a string with URL-safe Base64.
340  */
341 string urlsafeB64Decode(string inp) pure {
342 	return cast(string)Base64URLNoPadding.decode(inp);
343 }
344 
345 unittest {
346     import jwtd.test;
347 
348 	string hs_secret = "secret";
349 
350 	// none
351 
352 	string noneToken = encode(["language": "D"], "", JWTAlgorithm.NONE);
353 	assert(noneToken == "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJsYW5ndWFnZSI6IkQifQ.");
354 	assert(verify(noneToken, ""));
355 	assert(!verify(noneToken, "somesecret"));
356 
357 	// hs256
358 
359 	string hs256Token = encode(["language": "D"], hs_secret, JWTAlgorithm.HS256);
360 	assert(hs256Token == "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJsYW5ndWFnZSI6IkQifQ.utQLevAUK97y-e6B3-EnSofvTNAfSXNuSbu4moAh-hY");
361 	assert(verify(hs256Token, hs_secret));
362 
363 	// hs512
364 
365 	string hs512Token = encode(["language": "D"], hs_secret, JWTAlgorithm.HS512);
366 	assert(hs512Token == "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9.eyJsYW5ndWFnZSI6IkQifQ.tDRXngYs15t6Q-9AortMxXNfvTgVjaQGD9VTlwL3JD6Xxab8ass2ekCoom8uOiRdpZ772ajLQD42RXMuALct1Q");
367 	assert(verify(hs512Token, hs_secret));
368 
369 	version(UsePhobos) {
370 		//Not supported
371 	} else {
372         // rs256
373 
374         string rs256Token = encode(["language": "D"], private256, JWTAlgorithm.RS256);
375         assert(rs256Token == "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJsYW5ndWFnZSI6IkQifQ.BYpRNUNsho1Yquq7Uolp31K2Ng90h0hRlMV6J6d9WSSIYf7s2MBX2xgDlBuHtB-Yb9dkbkfdxqjYCQdWejiMc_II6dn72ZSBwBCyWdPPRNbTRA2DNlsoKFBS5WMp7iYordfD9KE0LowK61n_Z7AHNAiOop5Ka1xTKH8cqEo8s3ItgoxZt8mzAfhIYNogGown6sYytqg1I72UHsEX9KAuP7sCxCbxZ9cSVg2f4afEuwwo08AdG3hW_LXhT7VD-EweDmvF2JLAyf1_rW66PMgiZZCLQ6kf2hQRsa56xRDmo5qC98wDseBHx9f3PsTsracTKojwQUdezDmbHv90vCt-Iw");
376         assert(verify(rs256Token, public256));
377 
378         // es256
379         string es256Token = encode(["language": "D"], es256_private, JWTAlgorithm.ES256);
380         assert(verify(es256Token, es256_public));
381 	}
382 }