1 module fastjwt.jwt;
2 
3 import vibe.data.json;
4 
5 //import fastjwt.stringbuf;
6 import stringbuffer;
7 
8 version(unittest) {
9 	import std.stdio;
10 }
11 
12 enum JWTAlgorithm {
13     NONE,
14     HS256,
15     HS384,
16     HS512
17 }
18 
19 void hash(ref StringBuffer buf, string data, string secret, JWTAlgorithm alg) {
20 	import std.digest.hmac;
21 	import std.digest.sha;
22 	import std.string : representation;
23 	import std.base64 : Base64Impl;
24 
25 	alias URLSafeBase64 = Base64Impl!('-', '_', Base64.NoPadding);
26 
27 	final switch(alg) {
28 		case JWTAlgorithm.HS256:
29 			auto signature = HMAC!SHA256(secret.representation);
30 			signature.put(data.representation);
31 			buf.insertBack(URLSafeBase64.encode(signature.finish()));
32 			break;
33 		case JWTAlgorithm.HS384:
34 			auto signature = HMAC!SHA384(secret.representation);
35 			signature.put(data.representation);
36 			buf.insertBack(URLSafeBase64.encode(signature.finish()));
37 			break;
38 		case JWTAlgorithm.HS512:
39 			auto signature = HMAC!SHA512(secret.representation);
40 			signature.put(data.representation);
41 			buf.insertBack(URLSafeBase64.encode(signature.finish()));
42 			break;
43 		case JWTAlgorithm.NONE:
44 			break;
45 	}
46 }
47 
48 import std.base64;
49 
50 const base64HeaderStrings = [
51 	Base64.encode(cast(ubyte[])"{\"alg\":\"none\",\"typ\":\"JWT\"}"),
52 	Base64.encode(cast(ubyte[])"{\"alg\":\"HS256\",\"typ\":\"JWT\"}"),
53 	Base64.encode(cast(ubyte[])"{\"alg\":\"HS384\",\"typ\":\"JWT\"}"),
54 	Base64.encode(cast(ubyte[])"{\"alg\":\"HS512\",\"typ\":\"JWT\"}")
55 ];
56 
57 void headerBase64(Out)(const JWTAlgorithm alg, ref Out output) {
58 	output.insertBack(base64HeaderStrings[alg]);
59 }
60 
61 unittest {
62 	import std.array : appender;
63 
64 	StringBuffer buf;
65 	headerBase64(JWTAlgorithm.HS256, buf);
66 }
67 
68 void payloadToBase64(Out)(ref Out output, const(Json) payload) {
69 	StringBuffer jsonString;
70 	auto w = jsonString.writer();
71 	writeJsonString(w, payload);
72 	Base64.encode(jsonString.getData!(ubyte[])(), output.writer());
73 }
74 
75 void payloadToBase64(Out,Args...)(ref Out output, Args args) 
76 		if(args.length > 0 && args.length % 2 == 0 && !is(args[0] == Json))
77 {
78 	import std.format : formattedWrite;
79 	void impl(Out,T,S,Args...)(ref Out loutput, bool first, T t, S s, Args args)
80    	{
81 		import std.traits : isIntegral, isFloatingPoint, isSomeString;
82 		if(!first) {
83 			loutput.put(',');
84 		}
85 		static if(isIntegral!S) {
86 			formattedWrite(loutput, "\"%s\":%d", t, s);
87 		} else static if(isFloatingPoint!S) {
88 			formattedWrite(loutput, "\"%s\":%f", t, s);
89 		} else static if(isSomeString!S) {
90 			formattedWrite(loutput, "\"%s\":\"%s\"", t, s);
91 		} else static if(is(S == bool)) {
92 			formattedWrite(loutput, "\"%s\":%s", t, s);
93 		}
94 		
95 		static if(args.length > 0) {
96 			impl(loutput, false, args);
97 		}
98 	}
99 
100 	StringBuffer jsonString;
101 	auto w = jsonString.writer();
102 	w.put("{");
103 	impl(w, true, args);
104 	w.put("}");
105 
106 	Base64.encode(jsonString.getData!(ubyte[])(), output.writer());
107 }
108 
109 unittest {
110 	Json j1 = Json(["field1": Json("foo"), "field2": Json(42), 
111 			"field3": Json(true)]
112 		);
113 
114 	StringBuffer buf;
115 	payloadToBase64(buf, j1);
116 
117 	StringBuffer buf2;
118 	payloadToBase64(buf2, "field1", "foo", "field2", 42, "field3", true);
119 
120 	auto a = Json(buf.getData());
121 	auto b = Json(buf.getData());
122 
123 	assert(a == b);
124 }
125 
126 /** Encode values into a JWTToken string that gets stored into the output
127 parameter.
128 
129 Params:
130 	output = The Output Range to store the JWTToken string
131 	algo = The algorithm to encode the JWTToken with
132 	secret = The secret to use to encode the JWTToken with
133 	args = The values to encode into the JWTToken. Args must be come in pairs
134 		of two. A string and a value.
135 */
136 void encodeJWTToken(Out, Args...)(ref Out output, JWTAlgorithm algo,
137 		string secret, Args args)
138 {
139 	StringBuffer tmp;
140 	headerBase64(algo, tmp);
141 	tmp.insertBack('.');
142 	payloadToBase64(tmp, args);
143 
144 	StringBuffer h;
145 	hash(h, tmp.getData(), secret, algo);
146 
147 	output.insertBack(tmp.getData());
148 	output.insertBack('.');
149 	output.insertBack(h.getData());
150 }
151 
152 ///
153 void encodeJWTToken(Out)(ref Out output, JWTAlgorithm algo,
154 		string secret, const(Json) args)
155 {
156 	StringBuffer tmp;
157 	headerBase64(algo, tmp);
158 	tmp.insertBack('.');
159 	payloadToBase64(tmp, args);
160 
161 	StringBuffer h;
162 	hash(h, tmp.getData(), secret, algo);
163 
164 	output.insertBack(tmp.getData());
165 	output.insertBack('.');
166 	output.insertBack(h.getData());
167 }
168 
169 ///
170 unittest {
171 	foreach(alg; [JWTAlgorithm.HS384, JWTAlgorithm.HS256, JWTAlgorithm.HS512]) {
172     	string secret = "supersecret";
173 		StringBuffer buf;
174 		encodeJWTToken(buf, alg, secret, "id", 1337);
175 
176 		StringBuffer buf2;
177 		Json j = Json(["id" : Json(1337)]);
178 		encodeJWTToken(buf2, alg, secret, j);
179 	}
180 }
181 
182 /** This function decodes a JWTToken.
183 Params:
184 	encodedToken = The Token
185 	secret = The secret used to encode the JWTToken
186 	algo = The algoirthm used to encode the JWTToken
187 	header = The buffer to store the decoded Header of the JWTToken
188 	payload = The buffer to store the decoded Payload of the JWTToken
189 
190 Returns: 0 if everything is ok, everything means the token is not ok
191 */
192 int decodeJWTToken(string encodedToken, string secret, 
193 		JWTAlgorithm algo, ref StringBuffer header, ref StringBuffer payload) 
194 {
195 	import std.algorithm.iteration : splitter;
196 	import std.string : indexOf;
197 
198 	ptrdiff_t[2] dots;
199 	dots[0] = encodedToken.indexOf('.');
200 
201 	if(dots[0] == -1) {
202 		return 1;
203 	}
204 
205 	dots[1] = encodedToken.indexOf('.', dots[0] + 1);
206 
207 	if(dots[1] == -1) {
208 		return 2;
209 	}
210 
211 	StringBuffer h;
212 	hash(h, encodedToken[0 .. dots[1]], secret, algo);
213 
214 	if(h.getData() != encodedToken[dots[1] + 1 .. $]) {
215 		return 3;
216 	}
217 
218 	Base64.decode(encodedToken[0 .. dots[0]], header.writer());
219 	Base64.decode(encodedToken[dots[0] + 1 .. dots[1]], payload.writer());
220 
221 	return 0;
222 }
223 
224 ///
225 unittest {
226 	import std.format : format;
227 
228     string secret = "supersecret";
229 	auto alg = JWTAlgorithm.HS256;
230 
231 	StringBuffer buf;
232 	encodeJWTToken(buf, alg, secret, "sub", 1337);
233 
234 	StringBuffer header;
235 	StringBuffer payload;
236 
237 	int rslt = decodeJWTToken(buf.getData(), secret, alg, header, payload);
238 	assert(rslt == 0, format("%d", rslt));
239 }
240 
241 unittest {
242 	auto s = ["asldjasldj","aslkdjas.asdlj","asdlj..alsdj"];
243 	auto secret = "secret";
244 	auto alg = JWTAlgorithm.HS256;
245 
246 	for(int i = 0; i < s.length; ++i) {
247 		StringBuffer header;
248 		StringBuffer payload;
249 
250 		auto rslt = decodeJWTToken(s[i], secret, alg, header, payload);
251 		assert(rslt == i + 1);
252 	}
253 }
254 
255 unittest {
256 	import std.format : format;
257 
258     string secret = "supersecret";
259 	auto alg = JWTAlgorithm.HS256;
260 
261 	for(int i = 0; i < 1024*2; ++i) {
262 		StringBuffer buf;
263 		encodeJWTToken(buf, alg, secret, "id", 1337);
264 
265 		StringBuffer header;
266 		StringBuffer payload;
267 
268 		int rslt = decodeJWTToken(buf.getData(), secret, alg, header, payload);
269 		assert(rslt == 0, format("%d", rslt));
270 	}
271 }