--- /dev/null
+#include <stdio.h>
+#include <stdlib.h>
+
+#define __USE_GNU 1 // Pull in RTLD_NEXT
+#include <dlfcn.h>
+
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netinet/in.h>
+#include <arpa/inet.h>
+
+#include <fcntl.h>
+#include <errno.h>
+#include <string.h>
+#include <resolv.h> /* base64 encode/decode */
+#include "md5.h"
+
+//#define DO_DEBUG 1
+
+#ifdef DO_DEBUG
+#define DEBUG(...) \
+ if (DO_DEBUG) { \
+ fprintf(stderr, "wswrapper: "); \
+ fprintf(stderr, __VA_ARGS__); \
+ }
+#else
+#define DEBUG(...)
+#endif
+
+#define MSG(...) \
+ fprintf(stderr, "wswrapper: "); \
+ fprintf(stderr, __VA_ARGS__);
+
+#define RET_ERROR(eno, ...) \
+ fprintf(stderr, "wswrapper error: "); \
+ fprintf(stderr, __VA_ARGS__); \
+ errno = eno; \
+ return -1;
+
+
+
+const char _WS_response[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n\
+Upgrade: WebSocket\r\n\
+Connection: Upgrade\r\n\
+%sWebSocket-Origin: %s\r\n\
+%sWebSocket-Location: %s://%s%s\r\n\
+%sWebSocket-Protocol: sample\r\n\
+\r\n%s";
+
+/* WARNING: threading not supported */
+int _WS_bufsize = 65536;
+char *_WS_rbuf = NULL;
+char *_WS_sbuf = NULL;
+int _WS_rcarry_cnt = 0;
+char _WS_rcarry[3] = "";
+int _WS_newframe = 1;
+int _WS_sockfd = 0;
+
+int _WS_init() {
+ if (! (_WS_rbuf = malloc(_WS_bufsize)) ) {
+ return 0;
+ }
+ if (! (_WS_sbuf = malloc(_WS_bufsize)) ) {
+ return 0;
+ }
+}
+
+int _WS_gen_md5(char *key1, char *key2, char *key3, char *target) {
+ unsigned int i, spaces1 = 0, spaces2 = 0;
+ unsigned long num1 = 0, num2 = 0;
+ unsigned char buf[17];
+ for (i=0; i < strlen(key1); i++) {
+ if (key1[i] == ' ') {
+ spaces1 += 1;
+ }
+ if ((key1[i] >= 48) && (key1[i] <= 57)) {
+ num1 = num1 * 10 + (key1[i] - 48);
+ }
+ }
+ num1 = num1 / spaces1;
+
+ for (i=0; i < strlen(key2); i++) {
+ if (key2[i] == ' ') {
+ spaces2 += 1;
+ }
+ if ((key2[i] >= 48) && (key2[i] <= 57)) {
+ num2 = num2 * 10 + (key2[i] - 48);
+ }
+ }
+ num2 = num2 / spaces2;
+
+ /* Pack it big-endian */
+ buf[0] = (num1 & 0xff000000) >> 24;
+ buf[1] = (num1 & 0xff0000) >> 16;
+ buf[2] = (num1 & 0xff00) >> 8;
+ buf[3] = num1 & 0xff;
+
+ buf[4] = (num2 & 0xff000000) >> 24;
+ buf[5] = (num2 & 0xff0000) >> 16;
+ buf[6] = (num2 & 0xff00) >> 8;
+ buf[7] = num2 & 0xff;
+
+ strncpy(buf+8, key3, 8);
+ buf[16] = '\0';
+
+ md5_buffer(buf, 16, target);
+ target[16] = '\0';
+
+ return 1;
+}
+
+
+int _WS_handshake(int sockfd)
+{
+ int sz = 0, len, idx;
+ int ret = -1, save_errno = EPROTO;
+ char *last, *start, *end;
+ long flags;
+ char handshake[4096], response[4096],
+ path[1024], prefix[5] = "", scheme[10] = "ws", host[1024],
+ origin[1024], key1[100], key2[100], key3[9], chksum[17];
+
+ static void * (*rfunc)(), * (*wfunc)();
+ if (!rfunc) rfunc = (void *(*)()) dlsym(RTLD_NEXT, "recv");
+ if (!wfunc) wfunc = (void *(*)()) dlsym(RTLD_NEXT, "send");
+ DEBUG("_WS_handshake starting\n");
+
+ /* Disable NONBLOCK if set */
+ flags = fcntl(sockfd, F_GETFL, 0);
+ if (flags & O_NONBLOCK) {
+ fcntl(sockfd, F_SETFL, flags^O_NONBLOCK);
+ }
+
+ while (1) {
+ len = (int) rfunc(sockfd, handshake+sz, 4095, 0);
+ if (len < 1) {
+ ret = len;
+ save_errno = errno;
+ break;
+ }
+ sz += len;
+ handshake[sz] = '\x00';
+ if (sz < 4) {
+ // Not enough yet
+ continue;
+ }
+ if (strstr(handshake, "GET ") != handshake) {
+ // We got something but it wasn't a WebSockets client
+ break;
+ }
+ last = strstr(handshake, "\r\n\r\n");
+ if (! last) {
+ continue;
+ }
+ if (! strstr(handshake, "Upgrade: WebSocket\r\n")) {
+ MSG("Invalid WebSockets handshake\n");
+ break;
+ }
+
+ // Now parse out the data
+ start = handshake+4;
+ end = strstr(start, " HTTP/1.1");
+ if (!end) { break; }
+ snprintf(path, end-start+1, "%s", start);
+
+ start = strstr(handshake, "\r\nHost: ");
+ if (!start) { break; }
+ start += 8;
+ end = strstr(start, "\r\n");
+ snprintf(host, end-start+1, "%s", start);
+
+ start = strstr(handshake, "\r\nOrigin: ");
+ if (!start) { break; }
+ start += 10;
+ end = strstr(start, "\r\n");
+ snprintf(origin, end-start+1, "%s", start);
+
+ start = strstr(handshake, "\r\n\r\n") + 4;
+ if (strlen(start) == 8) {
+ sprintf(prefix, "Sec-");
+
+ snprintf(key3, 8+1, "%s", start);
+
+ start = strstr(handshake, "\r\nSec-WebSocket-Key1: ");
+ if (!start) { break; }
+ start += 22;
+ end = strstr(start, "\r\n");
+ snprintf(key1, end-start+1, "%s", start);
+
+ start = strstr(handshake, "\r\nSec-WebSocket-Key2: ");
+ if (!start) { break; }
+ start += 22;
+ end = strstr(start, "\r\n");
+ snprintf(key2, end-start+1, "%s", start);
+
+ _WS_gen_md5(key1, key2, key3, chksum);
+
+ //DEBUG("Got handshake (v76): %s\n", handshake);
+ MSG("Got handshake (v76)\n");
+
+ } else {
+ sprintf(prefix, "");
+ sprintf(key1, "");
+ sprintf(key2, "");
+ sprintf(key3, "");
+ sprintf(chksum, "");
+
+ //DEBUG("Got handshake (v75): %s\n", handshake);
+ MSG("Got handshake (v75)\n");
+ }
+ sprintf(response, _WS_response, prefix, origin, prefix, scheme,
+ host, path, prefix, chksum);
+ //DEBUG("Handshake response: %s\n", response);
+ wfunc(sockfd, response, strlen(response), 0);
+ save_errno = 0;
+ ret = 0;
+ break;
+ }
+
+ /* Re-enable NONBLOCK if it was set */
+ if (flags & O_NONBLOCK) {
+ fcntl(sockfd, F_SETFL, flags);
+ }
+ errno = save_errno;
+ return ret;
+}
+
+ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
+ size_t len, int flags)
+{
+ int rawcount, deccount, left, rawlen, retlen, decodelen;
+ int sockflags;
+ int i;
+ char * fstart, * fend, * cstart;
+
+ static void * (*rfunc)(), * (*rfunc2)();
+ if (!rfunc) rfunc = (void *(*)()) dlsym(RTLD_NEXT, "recv");
+ if (!rfunc2) rfunc2 = (void *(*)()) dlsym(RTLD_NEXT, "read");
+
+ if (len == 0) {
+ return 0;
+ }
+
+ if ((_WS_sockfd == 0) || (_WS_sockfd != sockfd)) {
+ // Not our file descriptor, just pass through
+ if (recvf) {
+ return (ssize_t) rfunc(sockfd, buf, len, flags);
+ } else {
+ return (ssize_t) rfunc2(sockfd, buf, len);
+ }
+ }
+ DEBUG("_WS_recv(%d, _, %d) called\n", sockfd, len);
+
+ sockflags = fcntl(sockfd, F_GETFL, 0);
+ left = len;
+ retlen = 0;
+
+ // first copy in any carry-over bytes
+ if (_WS_rcarry_cnt) {
+ if (_WS_rcarry_cnt == 1) {
+ DEBUG("Using carry byte: %u (", _WS_rcarry[0]);
+ } else if (_WS_rcarry_cnt == 2) {
+ DEBUG("Using carry bytes: %u,%u (", _WS_rcarry[0],
+ _WS_rcarry[1]);
+ } else {
+ RET_ERROR(EIO, "Too many carry-over bytes\n");
+ }
+ if (len <= _WS_rcarry_cnt) {
+ DEBUG("final)\n");
+ memcpy((char *) buf, _WS_rcarry, len);
+ _WS_rcarry_cnt -= len;
+ return len;
+ } else {
+ DEBUG("prepending)\n");
+ memcpy((char *) buf, _WS_rcarry, _WS_rcarry_cnt);
+ retlen += _WS_rcarry_cnt;
+ left -= _WS_rcarry_cnt;
+ _WS_rcarry_cnt = 0;
+ }
+ }
+
+ // Determine the number of base64 encoded bytes needed
+ rawcount = (left * 4) / 3 + 3;
+ rawcount -= rawcount%4;
+
+ if (rawcount > _WS_bufsize - 1) {
+ RET_ERROR(ENOMEM, "recv of %d bytes is larger than buffer\n", rawcount);
+ }
+
+ i = 0;
+ while (1) {
+ // Peek at everything available
+ rawlen = (int) rfunc(sockfd, _WS_rbuf, _WS_bufsize-1,
+ flags | MSG_PEEK);
+ if (rawlen <= 0) {
+ DEBUG("_WS_recv: returning because rawlen %d\n", rawlen);
+ return (ssize_t) rawlen;
+ }
+ fstart = _WS_rbuf;
+
+ /*
+ while (rawlen >= 2 && fstart[0] == '\x00' && fstart[1] == '\xff') {
+ fstart += 2;
+ rawlen -= 2;
+ }
+ */
+ if (rawlen >= 2 && fstart[0] == '\x00' && fstart[1] == '\xff') {
+ rawlen = (int) rfunc(sockfd, _WS_rbuf, 2, flags);
+ if (rawlen != 2) {
+ RET_ERROR(EIO, "Could not strip empty frame headers\n");
+ }
+ continue;
+ }
+
+ fstart[rawlen] = '\x00';
+
+ if (rawlen - _WS_newframe >= 4) {
+ // We have enough to base64 decode at least 1 byte
+ break;
+ }
+ // Not enough to base64 decode
+ if (sockflags & O_NONBLOCK) {
+ // Just tell the caller to call again
+ DEBUG("_WS_recv: returning because O_NONBLOCK, rawlen %d\n", rawlen);
+ errno = EAGAIN;
+ return -1;
+ }
+ // Repeat until at least 1 byte (4 raw bytes) to decode
+ i++;
+ if (i > 1000000) {
+ MSG("Could not send final part of frame\n");
+ }
+ }
+
+ /*
+ DEBUG("_WS_recv, left: %d, len: %d, rawlen: %d, newframe: %d, raw: ",
+ left, len, rawlen, _WS_newframe);
+ for (i = 0; i < rawlen; i++) {
+ DEBUG("%u,", (unsigned char) ((char *) fstart)[i]);
+ }
+ DEBUG("\n");
+ */
+
+ if (_WS_newframe) {
+ if (fstart[0] != '\x00') {
+ RET_ERROR(EPROTO, "Missing frame start\n");
+ }
+ fstart++;
+ rawlen--;
+ _WS_newframe = 0;
+ }
+
+ fend = memchr(fstart, '\xff', rawlen);
+
+ if (fend) {
+ _WS_newframe = 1;
+ if ((fend - fstart) % 4) {
+ RET_ERROR(EPROTO, "Frame length is not multiple of 4\n");
+ }
+ } else {
+ fend = fstart + rawlen - (rawlen % 4);
+ if (fend - fstart < 4) {
+ RET_ERROR(EPROTO, "Frame too short\n");
+ }
+ }
+
+ // How much should we consume
+ if (rawcount < fend - fstart) {
+ _WS_newframe = 0;
+ deccount = rawcount;
+ } else {
+ deccount = fend - fstart;
+ }
+
+ // Now consume what we processed
+ if (flags & MSG_PEEK) {
+ MSG("*** Got MSG_PEEK ***\n");
+ } else {
+ rfunc(sockfd, _WS_rbuf, fstart - _WS_rbuf + deccount + _WS_newframe, flags);
+ }
+
+ fstart[deccount] = '\x00'; // base64 terminator
+
+ // Do direct base64 decode, instead of decode()
+ decodelen = b64_pton(fstart, (char *) buf + retlen, deccount);
+ if (decodelen <= 0) {
+ RET_ERROR(EPROTO, "Base64 decode error\n");
+ }
+
+ if (decodelen <= left) {
+ retlen += decodelen;
+ } else {
+ retlen += left;
+
+ if (! (flags & MSG_PEEK)) {
+ // Add anything left over to the carry-over
+ _WS_rcarry_cnt = decodelen - left;
+ if (_WS_rcarry_cnt > 2) {
+ RET_ERROR(EPROTO, "Got too much base64 data\n");
+ }
+ memcpy(_WS_rcarry, buf + retlen, _WS_rcarry_cnt);
+ if (_WS_rcarry_cnt == 1) {
+ DEBUG("Saving carry byte: %u\n", _WS_rcarry[0]);
+ } else if (_WS_rcarry_cnt == 2) {
+ DEBUG("Saving carry bytes: %u,%u\n", _WS_rcarry[0],
+ _WS_rcarry[1]);
+ } else {
+ MSG("Waah2!\n");
+ }
+ }
+ }
+ ((char *) buf)[retlen] = '\x00';
+
+ /*
+ DEBUG("*** recv %s as ", fstart);
+ for (i = 0; i < retlen; i++) {
+ DEBUG("%u,", (unsigned char) ((char *) buf)[i]);
+ }
+ DEBUG(" (%d -> %d): %d\n", deccount, decodelen, retlen);
+ */
+ return retlen;
+}
+
+ssize_t _WS_send(int sendf, int sockfd, const void *buf,
+ size_t len, int flags)
+{
+ int rawlen, enclen, rlen, over, left, clen, retlen, dbufsize;
+ int sockflags;
+ char * target;
+ int i;
+ static void * (*sfunc)(), * (*sfunc2)();
+ if (!sfunc) sfunc = (void *(*)()) dlsym(RTLD_NEXT, "send");
+ if (!sfunc2) sfunc2 = (void *(*)()) dlsym(RTLD_NEXT, "write");
+
+ if ((_WS_sockfd == 0) || (_WS_sockfd != sockfd)) {
+ // Not our file descriptor, just pass through
+ if (sendf) {
+ return (ssize_t) sfunc(sockfd, buf, len, flags);
+ } else {
+ return (ssize_t) sfunc2(sockfd, buf, len);
+ }
+ }
+ DEBUG("_WS_send(%d, _, %d) called\n", sockfd, len);
+
+ sockflags = fcntl(sockfd, F_GETFL, 0);
+
+ dbufsize = (_WS_bufsize * 3)/4 - 2;
+ if (len > dbufsize) {
+ RET_ERROR(ENOMEM, "send of %d bytes is larger than send buffer\n", len);
+ }
+
+ // base64 encode and add frame markers
+ rawlen = 0;
+ _WS_sbuf[rawlen++] = '\x00';
+ enclen = b64_ntop(buf, len, _WS_sbuf+rawlen, _WS_bufsize-rawlen);
+ if (enclen < 0) {
+ RET_ERROR(EPROTO, "Base64 encoding error\n");
+ }
+ rawlen += enclen;
+ _WS_sbuf[rawlen++] = '\xff';
+
+ rlen = (int) sfunc(sockfd, _WS_sbuf, rawlen, flags);
+
+ if (rlen <= 0) {
+ return rlen;
+ } else if (rlen < rawlen) {
+ // Spin until we can send a whole base64 chunck and frame end
+ over = (rlen - 1) % 4;
+ left = (4 - over) % 4 + 1; // left to send
+ DEBUG("_WS_send: rlen: %d (over: %d, left: %d), rawlen: %d\n", rlen, over, left, rawlen);
+ rlen += left;
+ _WS_sbuf[rlen-1] = '\xff';
+ i = 0;
+ do {
+ i++;
+ clen = (int) sfunc(sockfd, _WS_sbuf + rlen - left, left, flags);
+ if (clen > 0) {
+ left -= clen;
+ } else if (clen == 0) {
+ MSG("_WS_send: got clen %d\n", clen);
+ } else if (!(sockflags & O_NONBLOCK)) {
+ MSG("_WS_send: clen %d\n", clen);
+ return clen;
+ }
+ if (i > 1000000) {
+ MSG("Could not send final part of frame\n");
+ }
+ } while (left > 0);
+ DEBUG("_WS_send: spins until finished %d\n", i);
+ }
+
+
+ /*
+ * Report back the number of original characters sent,
+ * not the raw number sent
+ */
+ // Adjust for framing
+ retlen = rlen - 2;
+ // Adjust for base64 padding
+ if (_WS_sbuf[rlen-1] == '=') { retlen --; }
+ if (_WS_sbuf[rlen-2] == '=') { retlen --; }
+
+ // Adjust for base64 encoding
+ retlen = (retlen*3)/4;
+
+ /*
+ DEBUG("*** send ");
+ for (i = 0; i < retlen; i++) {
+ DEBUG("%u,", (unsigned char) ((char *)buf)[i]);
+ }
+ DEBUG(" as '%s' (%d)\n", _WS_sbuf+1, rlen);
+ */
+ return (ssize_t) retlen;
+}
+
+
+/* Override network routines */
+
+/*
+int socket(int domain, int type, int protocol)
+{
+ static void * (*func)();
+ if (!func) func = (void *(*)()) dlsym(RTLD_NEXT, "socket");
+ DEBUG("socket(_, %d, _) called\n", type);
+
+ return (int) func(domain, type, protocol);
+}
+
+int bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen)
+{
+ static void * (*func)();
+ if (!func) func = (void *(*)()) dlsym(RTLD_NEXT, "bind");
+ DEBUG("bind(%d, _, %d) called\n", sockfd, addrlen);
+
+ return (int) func(sockfd, addr, addrlen);
+}
+*/
+
+int accept(int sockfd, struct sockaddr *addr, socklen_t *addrlen)
+{
+ int fd, ret, envfd;
+ static void * (*func)();
+ if (!func) func = (void *(*)()) dlsym(RTLD_NEXT, "accept");
+ DEBUG("accept(%d, _, _) called\n", sockfd);
+
+ fd = (int) func(sockfd, addr, addrlen);
+
+ if (_WS_sockfd == 0) {
+ _WS_sockfd = fd;
+
+ if (!_WS_rbuf) {
+ if (! _WS_init()) {
+ RET_ERROR(ENOMEM, "Could not allocate interposer buffer\n");
+ }
+ }
+
+ ret = _WS_handshake(_WS_sockfd);
+ if (ret < 0) {
+ errno = EPROTO;
+ return ret;
+ }
+ MSG("interposing on fd %d\n", _WS_sockfd);
+ } else {
+ DEBUG("already interposing on fd %d\n", _WS_sockfd);
+ }
+
+ return fd;
+}
+
+int close(int fd)
+{
+ static void * (*func)();
+ if (!func) func = (void *(*)()) dlsym(RTLD_NEXT, "close");
+
+ if ((_WS_sockfd != 0) && (_WS_sockfd == fd)) {
+ MSG("finished interposing on fd %d\n", _WS_sockfd);
+ _WS_sockfd = 0;
+ }
+ return (int) func(fd);
+}
+
+
+ssize_t read(int fd, void *buf, size_t count)
+{
+ //DEBUG("read(%d, _, %d) called\n", fd, count);
+ return (ssize_t) _WS_recv(0, fd, buf, count, 0);
+}
+
+ssize_t write(int fd, const void *buf, size_t count)
+{
+ //DEBUG("write(%d, _, %d) called\n", fd, count);
+ return (ssize_t) _WS_send(0, fd, buf, count, 0);
+}
+
+ssize_t recv(int sockfd, void *buf, size_t len, int flags)
+{
+ //DEBUG("recv(%d, _, %d, %d) called\n", sockfd, len, flags);
+ return (ssize_t) _WS_recv(1, sockfd, buf, len, flags);
+}
+
+ssize_t send(int sockfd, const void *buf, size_t len, int flags)
+{
+ //DEBUG("send(%d, _, %d, %d) called\n", sockfd, len, flags);
+ return (ssize_t) _WS_send(1, sockfd, buf, len, flags);
+}
+