]> git.proxmox.com Git - ceph.git/blame - ceph/src/rgw/jwt-cpp/base.h
Import ceph 15.2.8
[ceph.git] / ceph / src / rgw / jwt-cpp / base.h
CommitLineData
f91f0fd5
TL
1#pragma once
2#include <string>
3#include <array>
4
5namespace jwt {
6 namespace alphabet {
7 struct base64 {
8 static const std::array<char, 64>& data() {
9 static std::array<char, 64> data = {
10 {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
11 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
12 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
13 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'}};
14 return data;
15 };
16 static const std::string& fill() {
17 static std::string fill = "=";
18 return fill;
19 }
20 };
21 struct base64url {
22 static const std::array<char, 64>& data() {
23 static std::array<char, 64> data = {
24 {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
25 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
26 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
27 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-', '_'}};
28 return data;
29 };
30 static const std::string& fill() {
31 static std::string fill = "%3d";
32 return fill;
33 }
34 };
35 }
36
37 class base {
38 public:
39 template<typename T>
40 static std::string encode(const std::string& bin) {
41 return encode(bin, T::data(), T::fill());
42 }
43 template<typename T>
44 static std::string decode(const std::string& base) {
45 return decode(base, T::data(), T::fill());
46 }
47
48 private:
49 static std::string encode(const std::string& bin, const std::array<char, 64>& alphabet, const std::string& fill) {
50 size_t size = bin.size();
51 std::string res;
52
53 // clear incomplete bytes
54 size_t fast_size = size - size % 3;
55 for (size_t i = 0; i < fast_size;) {
56 uint32_t octet_a = (unsigned char)bin[i++];
57 uint32_t octet_b = (unsigned char)bin[i++];
58 uint32_t octet_c = (unsigned char)bin[i++];
59
60 uint32_t triple = (octet_a << 0x10) + (octet_b << 0x08) + octet_c;
61
62 res += alphabet[(triple >> 3 * 6) & 0x3F];
63 res += alphabet[(triple >> 2 * 6) & 0x3F];
64 res += alphabet[(triple >> 1 * 6) & 0x3F];
65 res += alphabet[(triple >> 0 * 6) & 0x3F];
66 }
67
68 if (fast_size == size)
69 return res;
70
71 size_t mod = size % 3;
72
73 uint32_t octet_a = fast_size < size ? (unsigned char)bin[fast_size++] : 0;
74 uint32_t octet_b = fast_size < size ? (unsigned char)bin[fast_size++] : 0;
75 uint32_t octet_c = fast_size < size ? (unsigned char)bin[fast_size++] : 0;
76
77 uint32_t triple = (octet_a << 0x10) + (octet_b << 0x08) + octet_c;
78
79 switch (mod) {
80 case 1:
81 res += alphabet[(triple >> 3 * 6) & 0x3F];
82 res += alphabet[(triple >> 2 * 6) & 0x3F];
83 res += fill;
84 res += fill;
85 break;
86 case 2:
87 res += alphabet[(triple >> 3 * 6) & 0x3F];
88 res += alphabet[(triple >> 2 * 6) & 0x3F];
89 res += alphabet[(triple >> 1 * 6) & 0x3F];
90 res += fill;
91 break;
92 default:
93 break;
94 }
95
96 return res;
97 }
98
99 static std::string decode(const std::string& base, const std::array<char, 64>& alphabet, const std::string& fill) {
100 size_t size = base.size();
101
102 size_t fill_cnt = 0;
103 while (size > fill.size()) {
104 if (base.substr(size - fill.size(), fill.size()) == fill) {
105 fill_cnt++;
106 size -= fill.size();
107 if(fill_cnt > 2)
108 throw std::runtime_error("Invalid input");
109 }
110 else break;
111 }
112
113 if ((size + fill_cnt) % 4 != 0)
114 throw std::runtime_error("Invalid input");
115
116 size_t out_size = size / 4 * 3;
117 std::string res;
118 res.reserve(out_size);
119
120 auto get_sextet = [&](size_t offset) {
121 for (size_t i = 0; i < alphabet.size(); i++) {
122 if (alphabet[i] == base[offset])
123 return i;
124 }
125 throw std::runtime_error("Invalid input");
126 };
127
128
129 size_t fast_size = size - size % 4;
130 for (size_t i = 0; i < fast_size;) {
131 uint32_t sextet_a = get_sextet(i++);
132 uint32_t sextet_b = get_sextet(i++);
133 uint32_t sextet_c = get_sextet(i++);
134 uint32_t sextet_d = get_sextet(i++);
135
136 uint32_t triple = (sextet_a << 3 * 6)
137 + (sextet_b << 2 * 6)
138 + (sextet_c << 1 * 6)
139 + (sextet_d << 0 * 6);
140
141 res += (triple >> 2 * 8) & 0xFF;
142 res += (triple >> 1 * 8) & 0xFF;
143 res += (triple >> 0 * 8) & 0xFF;
144 }
145
146 if (fill_cnt == 0)
147 return res;
148
149 uint32_t triple = (get_sextet(fast_size) << 3 * 6)
150 + (get_sextet(fast_size + 1) << 2 * 6);
151
152 switch (fill_cnt) {
153 case 1:
154 triple |= (get_sextet(fast_size + 2) << 1 * 6);
155 res += (triple >> 2 * 8) & 0xFF;
156 res += (triple >> 1 * 8) & 0xFF;
157 break;
158 case 2:
159 res += (triple >> 2 * 8) & 0xFF;
160 break;
161 default:
162 break;
163 }
164
165 return res;
166 }
167 };
168}