--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+module thrift.protocol.compact;
+
+import std.array : uninitializedArray;
+import std.typetuple : allSatisfy, TypeTuple;
+import thrift.protocol.base;
+import thrift.transport.base;
+import thrift.internal.endian;
+
+/**
+ * D implementation of the Compact protocol.
+ *
+ * See THRIFT-110 for a protocol description. This implementation is based on
+ * the C++ one.
+ */
+final class TCompactProtocol(Transport = TTransport) if (
+ isTTransport!Transport
+) : TProtocol {
+ /**
+ * Constructs a new instance.
+ *
+ * Params:
+ * trans = The transport to use.
+ * containerSizeLimit = If positive, the container size is limited to the
+ * given number of items.
+ * stringSizeLimit = If positive, the string length is limited to the
+ * given number of bytes.
+ */
+ this(Transport trans, int containerSizeLimit = 0, int stringSizeLimit = 0) {
+ trans_ = trans;
+ this.containerSizeLimit = containerSizeLimit;
+ this.stringSizeLimit = stringSizeLimit;
+ }
+
+ Transport transport() @property {
+ return trans_;
+ }
+
+ void reset() {
+ lastFieldId_ = 0;
+ fieldIdStack_ = null;
+ booleanField_ = TField.init;
+ hasBoolValue_ = false;
+ }
+
+ /**
+ * If positive, limits the number of items of deserialized containers to the
+ * given amount.
+ *
+ * This is useful to avoid allocating excessive amounts of memory when broken
+ * data is received. If the limit is exceeded, a SIZE_LIMIT-type
+ * TProtocolException is thrown.
+ *
+ * Defaults to zero (no limit).
+ */
+ int containerSizeLimit;
+
+ /**
+ * If positive, limits the length of deserialized strings/binary data to the
+ * given number of bytes.
+ *
+ * This is useful to avoid allocating excessive amounts of memory when broken
+ * data is received. If the limit is exceeded, a SIZE_LIMIT-type
+ * TProtocolException is thrown.
+ *
+ * Defaults to zero (no limit).
+ */
+ int stringSizeLimit;
+
+ /*
+ * Writing methods.
+ */
+
+ void writeBool(bool b) {
+ if (booleanField_.name !is null) {
+ // we haven't written the field header yet
+ writeFieldBeginInternal(booleanField_,
+ b ? CType.BOOLEAN_TRUE : CType.BOOLEAN_FALSE);
+ booleanField_.name = null;
+ } else {
+ // we're not part of a field, so just write the value
+ writeByte(b ? CType.BOOLEAN_TRUE : CType.BOOLEAN_FALSE);
+ }
+ }
+
+ void writeByte(byte b) {
+ trans_.write((cast(ubyte*)&b)[0..1]);
+ }
+
+ void writeI16(short i16) {
+ writeVarint32(i32ToZigzag(i16));
+ }
+
+ void writeI32(int i32) {
+ writeVarint32(i32ToZigzag(i32));
+ }
+
+ void writeI64(long i64) {
+ writeVarint64(i64ToZigzag(i64));
+ }
+
+ void writeDouble(double dub) {
+ ulong bits = hostToLe(*cast(ulong*)(&dub));
+ trans_.write((cast(ubyte*)&bits)[0 .. 8]);
+ }
+
+ void writeString(string str) {
+ writeBinary(cast(ubyte[])str);
+ }
+
+ void writeBinary(ubyte[] buf) {
+ assert(buf.length <= int.max);
+ writeVarint32(cast(int)buf.length);
+ trans_.write(buf);
+ }
+
+ void writeMessageBegin(TMessage msg) {
+ writeByte(cast(byte)PROTOCOL_ID);
+ writeByte(cast(byte)((VERSION_N & VERSION_MASK) |
+ ((cast(int)msg.type << TYPE_SHIFT_AMOUNT) & TYPE_MASK)));
+ writeVarint32(msg.seqid);
+ writeString(msg.name);
+ }
+ void writeMessageEnd() {}
+
+ void writeStructBegin(TStruct tstruct) {
+ fieldIdStack_ ~= lastFieldId_;
+ lastFieldId_ = 0;
+ }
+
+ void writeStructEnd() {
+ lastFieldId_ = fieldIdStack_[$ - 1];
+ fieldIdStack_ = fieldIdStack_[0 .. $ - 1];
+ fieldIdStack_.assumeSafeAppend();
+ }
+
+ void writeFieldBegin(TField field) {
+ if (field.type == TType.BOOL) {
+ booleanField_.name = field.name;
+ booleanField_.type = field.type;
+ booleanField_.id = field.id;
+ } else {
+ return writeFieldBeginInternal(field);
+ }
+ }
+ void writeFieldEnd() {}
+
+ void writeFieldStop() {
+ writeByte(TType.STOP);
+ }
+
+ void writeListBegin(TList list) {
+ writeCollectionBegin(list.elemType, list.size);
+ }
+ void writeListEnd() {}
+
+ void writeMapBegin(TMap map) {
+ if (map.size == 0) {
+ writeByte(0);
+ } else {
+ assert(map.size <= int.max);
+ writeVarint32(cast(int)map.size);
+ writeByte(cast(byte)(toCType(map.keyType) << 4 | toCType(map.valueType)));
+ }
+ }
+ void writeMapEnd() {}
+
+ void writeSetBegin(TSet set) {
+ writeCollectionBegin(set.elemType, set.size);
+ }
+ void writeSetEnd() {}
+
+
+ /*
+ * Reading methods.
+ */
+
+ bool readBool() {
+ if (hasBoolValue_ == true) {
+ hasBoolValue_ = false;
+ return boolValue_;
+ }
+
+ return readByte() == CType.BOOLEAN_TRUE;
+ }
+
+ byte readByte() {
+ ubyte[1] b = void;
+ trans_.readAll(b);
+ return cast(byte)b[0];
+ }
+
+ short readI16() {
+ return cast(short)zigzagToI32(readVarint32());
+ }
+
+ int readI32() {
+ return zigzagToI32(readVarint32());
+ }
+
+ long readI64() {
+ return zigzagToI64(readVarint64());
+ }
+
+ double readDouble() {
+ IntBuf!long b = void;
+ trans_.readAll(b.bytes);
+ b.value = leToHost(b.value);
+ return *cast(double*)(&b.value);
+ }
+
+ string readString() {
+ return cast(string)readBinary();
+ }
+
+ ubyte[] readBinary() {
+ auto size = readVarint32();
+ checkSize(size, stringSizeLimit);
+
+ if (size == 0) {
+ return null;
+ }
+
+ auto buf = uninitializedArray!(ubyte[])(size);
+ trans_.readAll(buf);
+ return buf;
+ }
+
+ TMessage readMessageBegin() {
+ TMessage msg = void;
+
+ auto protocolId = readByte();
+ if (protocolId != cast(byte)PROTOCOL_ID) {
+ throw new TProtocolException("Bad protocol identifier",
+ TProtocolException.Type.BAD_VERSION);
+ }
+
+ auto versionAndType = readByte();
+ auto ver = versionAndType & VERSION_MASK;
+ if (ver != VERSION_N) {
+ throw new TProtocolException("Bad protocol version",
+ TProtocolException.Type.BAD_VERSION);
+ }
+
+ msg.type = cast(TMessageType)((versionAndType >> TYPE_SHIFT_AMOUNT) & TYPE_BITS);
+ msg.seqid = readVarint32();
+ msg.name = readString();
+
+ return msg;
+ }
+ void readMessageEnd() {}
+
+ TStruct readStructBegin() {
+ fieldIdStack_ ~= lastFieldId_;
+ lastFieldId_ = 0;
+ return TStruct();
+ }
+
+ void readStructEnd() {
+ lastFieldId_ = fieldIdStack_[$ - 1];
+ fieldIdStack_ = fieldIdStack_[0 .. $ - 1];
+ }
+
+ TField readFieldBegin() {
+ TField f = void;
+ f.name = null;
+
+ auto bite = readByte();
+ auto type = cast(CType)(bite & 0x0f);
+
+ if (type == CType.STOP) {
+ // Struct stop byte, nothing more to do.
+ f.id = 0;
+ f.type = TType.STOP;
+ return f;
+ }
+
+ // Mask off the 4 MSB of the type header, which could contain a field id
+ // delta.
+ auto modifier = cast(short)((bite & 0xf0) >> 4);
+ if (modifier > 0) {
+ f.id = cast(short)(lastFieldId_ + modifier);
+ } else {
+ // Delta encoding not used, just read the id as usual.
+ f.id = readI16();
+ }
+ f.type = getTType(type);
+
+ if (type == CType.BOOLEAN_TRUE || type == CType.BOOLEAN_FALSE) {
+ // For boolean fields, the value is encoded in the type – keep it around
+ // for the readBool() call.
+ hasBoolValue_ = true;
+ boolValue_ = (type == CType.BOOLEAN_TRUE ? true : false);
+ }
+
+ lastFieldId_ = f.id;
+ return f;
+ }
+ void readFieldEnd() {}
+
+ TList readListBegin() {
+ auto sizeAndType = readByte();
+
+ auto lsize = (sizeAndType >> 4) & 0xf;
+ if (lsize == 0xf) {
+ lsize = readVarint32();
+ }
+ checkSize(lsize, containerSizeLimit);
+
+ TList l = void;
+ l.elemType = getTType(cast(CType)(sizeAndType & 0x0f));
+ l.size = cast(size_t)lsize;
+
+ return l;
+ }
+ void readListEnd() {}
+
+ TMap readMapBegin() {
+ TMap m = void;
+
+ auto size = readVarint32();
+ ubyte kvType;
+ if (size != 0) {
+ kvType = readByte();
+ }
+ checkSize(size, containerSizeLimit);
+
+ m.size = size;
+ m.keyType = getTType(cast(CType)(kvType >> 4));
+ m.valueType = getTType(cast(CType)(kvType & 0xf));
+
+ return m;
+ }
+ void readMapEnd() {}
+
+ TSet readSetBegin() {
+ auto sizeAndType = readByte();
+
+ auto lsize = (sizeAndType >> 4) & 0xf;
+ if (lsize == 0xf) {
+ lsize = readVarint32();
+ }
+ checkSize(lsize, containerSizeLimit);
+
+ TSet s = void;
+ s.elemType = getTType(cast(CType)(sizeAndType & 0xf));
+ s.size = cast(size_t)lsize;
+
+ return s;
+ }
+ void readSetEnd() {}
+
+private:
+ void writeFieldBeginInternal(TField field, byte typeOverride = -1) {
+ // If there's a type override, use that.
+ auto typeToWrite = (typeOverride == -1 ? toCType(field.type) : typeOverride);
+
+ // check if we can use delta encoding for the field id
+ if (field.id > lastFieldId_ && (field.id - lastFieldId_) <= 15) {
+ // write them together
+ writeByte(cast(byte)((field.id - lastFieldId_) << 4 | typeToWrite));
+ } else {
+ // write them separate
+ writeByte(cast(byte)typeToWrite);
+ writeI16(field.id);
+ }
+
+ lastFieldId_ = field.id;
+ }
+
+
+ void writeCollectionBegin(TType elemType, size_t size) {
+ if (size <= 14) {
+ writeByte(cast(byte)(size << 4 | toCType(elemType)));
+ } else {
+ assert(size <= int.max);
+ writeByte(cast(byte)(0xf0 | toCType(elemType)));
+ writeVarint32(cast(int)size);
+ }
+ }
+
+ void writeVarint32(uint n) {
+ ubyte[5] buf = void;
+ ubyte wsize;
+
+ while (true) {
+ if ((n & ~0x7F) == 0) {
+ buf[wsize++] = cast(ubyte)n;
+ break;
+ } else {
+ buf[wsize++] = cast(ubyte)((n & 0x7F) | 0x80);
+ n >>= 7;
+ }
+ }
+
+ trans_.write(buf[0 .. wsize]);
+ }
+
+ /*
+ * Write an i64 as a varint. Results in 1-10 bytes on the wire.
+ */
+ void writeVarint64(ulong n) {
+ ubyte[10] buf = void;
+ ubyte wsize;
+
+ while (true) {
+ if ((n & ~0x7FL) == 0) {
+ buf[wsize++] = cast(ubyte)n;
+ break;
+ } else {
+ buf[wsize++] = cast(ubyte)((n & 0x7F) | 0x80);
+ n >>= 7;
+ }
+ }
+
+ trans_.write(buf[0 .. wsize]);
+ }
+
+ /*
+ * Convert l into a zigzag long. This allows negative numbers to be
+ * represented compactly as a varint.
+ */
+ ulong i64ToZigzag(long l) {
+ return (l << 1) ^ (l >> 63);
+ }
+
+ /*
+ * Convert n into a zigzag int. This allows negative numbers to be
+ * represented compactly as a varint.
+ */
+ uint i32ToZigzag(int n) {
+ return (n << 1) ^ (n >> 31);
+ }
+
+ CType toCType(TType type) {
+ final switch (type) {
+ case TType.STOP:
+ return CType.STOP;
+ case TType.BOOL:
+ return CType.BOOLEAN_TRUE;
+ case TType.BYTE:
+ return CType.BYTE;
+ case TType.DOUBLE:
+ return CType.DOUBLE;
+ case TType.I16:
+ return CType.I16;
+ case TType.I32:
+ return CType.I32;
+ case TType.I64:
+ return CType.I64;
+ case TType.STRING:
+ return CType.BINARY;
+ case TType.STRUCT:
+ return CType.STRUCT;
+ case TType.MAP:
+ return CType.MAP;
+ case TType.SET:
+ return CType.SET;
+ case TType.LIST:
+ return CType.LIST;
+ case TType.VOID:
+ assert(false, "Invalid type passed.");
+ }
+ }
+
+ int readVarint32() {
+ return cast(int)readVarint64();
+ }
+
+ long readVarint64() {
+ ulong val;
+ ubyte shift;
+ ubyte[10] buf = void; // 64 bits / (7 bits/byte) = 10 bytes.
+ auto bufSize = buf.sizeof;
+ auto borrowed = trans_.borrow(buf.ptr, bufSize);
+
+ ubyte rsize;
+
+ if (borrowed) {
+ // Fast path.
+ while (true) {
+ auto bite = borrowed[rsize];
+ rsize++;
+ val |= cast(ulong)(bite & 0x7f) << shift;
+ shift += 7;
+ if (!(bite & 0x80)) {
+ trans_.consume(rsize);
+ return val;
+ }
+ // Have to check for invalid data so we don't crash.
+ if (rsize == buf.sizeof) {
+ throw new TProtocolException(TProtocolException.Type.INVALID_DATA,
+ "Variable-length int over 10 bytes.");
+ }
+ }
+ } else {
+ // Slow path.
+ while (true) {
+ ubyte[1] bite;
+ trans_.readAll(bite);
+ ++rsize;
+
+ val |= cast(ulong)(bite[0] & 0x7f) << shift;
+ shift += 7;
+ if (!(bite[0] & 0x80)) {
+ return val;
+ }
+
+ // Might as well check for invalid data on the slow path too.
+ if (rsize >= buf.sizeof) {
+ throw new TProtocolException(TProtocolException.Type.INVALID_DATA,
+ "Variable-length int over 10 bytes.");
+ }
+ }
+ }
+ }
+
+ /*
+ * Convert from zigzag int to int.
+ */
+ int zigzagToI32(uint n) {
+ return (n >> 1) ^ -(n & 1);
+ }
+
+ /*
+ * Convert from zigzag long to long.
+ */
+ long zigzagToI64(ulong n) {
+ return (n >> 1) ^ -(n & 1);
+ }
+
+ TType getTType(CType type) {
+ final switch (type) {
+ case CType.STOP:
+ return TType.STOP;
+ case CType.BOOLEAN_FALSE:
+ return TType.BOOL;
+ case CType.BOOLEAN_TRUE:
+ return TType.BOOL;
+ case CType.BYTE:
+ return TType.BYTE;
+ case CType.I16:
+ return TType.I16;
+ case CType.I32:
+ return TType.I32;
+ case CType.I64:
+ return TType.I64;
+ case CType.DOUBLE:
+ return TType.DOUBLE;
+ case CType.BINARY:
+ return TType.STRING;
+ case CType.LIST:
+ return TType.LIST;
+ case CType.SET:
+ return TType.SET;
+ case CType.MAP:
+ return TType.MAP;
+ case CType.STRUCT:
+ return TType.STRUCT;
+ }
+ }
+
+ void checkSize(int size, int limit) {
+ if (size < 0) {
+ throw new TProtocolException(TProtocolException.Type.NEGATIVE_SIZE);
+ } else if (limit > 0 && size > limit) {
+ throw new TProtocolException(TProtocolException.Type.SIZE_LIMIT);
+ }
+ }
+
+ enum PROTOCOL_ID = 0x82;
+ enum VERSION_N = 1;
+ enum VERSION_MASK = 0b0001_1111;
+ enum TYPE_MASK = 0b1110_0000;
+ enum TYPE_BITS = 0b0000_0111;
+ enum TYPE_SHIFT_AMOUNT = 5;
+
+ // Probably need to implement a better stack at some point.
+ short[] fieldIdStack_;
+ short lastFieldId_;
+
+ TField booleanField_;
+
+ bool hasBoolValue_;
+ bool boolValue_;
+
+ Transport trans_;
+}
+
+/**
+ * TCompactProtocol construction helper to avoid having to explicitly specify
+ * the transport type, i.e. to allow the constructor being called using IFTI
+ * (see $(LINK2 http://d.puremagic.com/issues/show_bug.cgi?id=6082, D Bugzilla
+ * enhancement requet 6082)).
+ */
+TCompactProtocol!Transport tCompactProtocol(Transport)(Transport trans,
+ int containerSizeLimit = 0, int stringSizeLimit = 0
+) if (isTTransport!Transport)
+{
+ return new TCompactProtocol!Transport(trans,
+ containerSizeLimit, stringSizeLimit);
+}
+
+private {
+ enum CType : ubyte {
+ STOP = 0x0,
+ BOOLEAN_TRUE = 0x1,
+ BOOLEAN_FALSE = 0x2,
+ BYTE = 0x3,
+ I16 = 0x4,
+ I32 = 0x5,
+ I64 = 0x6,
+ DOUBLE = 0x7,
+ BINARY = 0x8,
+ LIST = 0x9,
+ SET = 0xa,
+ MAP = 0xb,
+ STRUCT = 0xc
+ }
+ static assert(CType.max <= 0xf,
+ "Compact protocol wire type representation must fit into 4 bits.");
+}
+
+unittest {
+ import std.exception;
+ import thrift.transport.memory;
+
+ // Check the message header format.
+ auto buf = new TMemoryBuffer;
+ auto compact = tCompactProtocol(buf);
+ compact.writeMessageBegin(TMessage("foo", TMessageType.CALL, 0));
+
+ auto header = new ubyte[7];
+ buf.readAll(header);
+ enforce(header == [
+ 130, // Protocol id.
+ 33, // Version/type byte.
+ 0, // Sequence id.
+ 3, 102, 111, 111 // Method name.
+ ]);
+}
+
+unittest {
+ import thrift.internal.test.protocol;
+ testContainerSizeLimit!(TCompactProtocol!())();
+ testStringSizeLimit!(TCompactProtocol!())();
+}
+
+/**
+ * TProtocolFactory creating a TCompactProtocol instance for passed in
+ * transports.
+ *
+ * The optional Transports template tuple parameter can be used to specify
+ * one or more TTransport implementations to specifically instantiate
+ * TCompactProtocol for. If the actual transport types encountered at
+ * runtime match one of the transports in the list, a specialized protocol
+ * instance is created. Otherwise, a generic TTransport version is used.
+ */
+class TCompactProtocolFactory(Transports...) if (
+ allSatisfy!(isTTransport, Transports)
+) : TProtocolFactory {
+ ///
+ this(int containerSizeLimit = 0, int stringSizeLimit = 0) {
+ containerSizeLimit_ = 0;
+ stringSizeLimit_ = 0;
+ }
+
+ TProtocol getProtocol(TTransport trans) const {
+ foreach (Transport; TypeTuple!(Transports, TTransport)) {
+ auto concreteTrans = cast(Transport)trans;
+ if (concreteTrans) {
+ return new TCompactProtocol!Transport(concreteTrans);
+ }
+ }
+ throw new TProtocolException(
+ "Passed null transport to TCompactProtocolFactory.");
+ }
+
+ int containerSizeLimit_;
+ int stringSizeLimit_;
+}