]> git.proxmox.com Git - ceph.git/blobdiff - ceph/src/arrow/go/arrow/ipc/message.go
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / go / arrow / ipc / message.go
diff --git a/ceph/src/arrow/go/arrow/ipc/message.go b/ceph/src/arrow/go/arrow/ipc/message.go
new file mode 100644 (file)
index 0000000..4f23fd6
--- /dev/null
@@ -0,0 +1,241 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipc // import "github.com/apache/arrow/go/v6/arrow/ipc"
+
+import (
+       "encoding/binary"
+       "fmt"
+       "io"
+       "sync/atomic"
+
+       "github.com/apache/arrow/go/v6/arrow/internal/debug"
+       "github.com/apache/arrow/go/v6/arrow/internal/flatbuf"
+       "github.com/apache/arrow/go/v6/arrow/memory"
+       "golang.org/x/xerrors"
+)
+
+// MetadataVersion represents the Arrow metadata version.
+type MetadataVersion flatbuf.MetadataVersion
+
+const (
+       MetadataV1 = MetadataVersion(flatbuf.MetadataVersionV1) // version for Arrow-0.1.0
+       MetadataV2 = MetadataVersion(flatbuf.MetadataVersionV2) // version for Arrow-0.2.0
+       MetadataV3 = MetadataVersion(flatbuf.MetadataVersionV3) // version for Arrow-0.3.0 to 0.7.1
+       MetadataV4 = MetadataVersion(flatbuf.MetadataVersionV4) // version for >= Arrow-0.8.0
+       MetadataV5 = MetadataVersion(flatbuf.MetadataVersionV5) // version for >= Arrow-1.0.0, backward compatible with v4
+)
+
+func (m MetadataVersion) String() string {
+       if v, ok := flatbuf.EnumNamesMetadataVersion[flatbuf.MetadataVersion(m)]; ok {
+               return v
+       }
+       return fmt.Sprintf("MetadataVersion(%d)", int16(m))
+}
+
+// MessageType represents the type of Message in an Arrow format.
+type MessageType flatbuf.MessageHeader
+
+const (
+       MessageNone            = MessageType(flatbuf.MessageHeaderNONE)
+       MessageSchema          = MessageType(flatbuf.MessageHeaderSchema)
+       MessageDictionaryBatch = MessageType(flatbuf.MessageHeaderDictionaryBatch)
+       MessageRecordBatch     = MessageType(flatbuf.MessageHeaderRecordBatch)
+       MessageTensor          = MessageType(flatbuf.MessageHeaderTensor)
+       MessageSparseTensor    = MessageType(flatbuf.MessageHeaderSparseTensor)
+)
+
+func (m MessageType) String() string {
+       if v, ok := flatbuf.EnumNamesMessageHeader[flatbuf.MessageHeader(m)]; ok {
+               return v
+       }
+       return fmt.Sprintf("MessageType(%d)", int(m))
+}
+
+const (
+       // maxNestingDepth is an arbitrary value to catch user mistakes.
+       // For deeply nested schemas, it is expected the user will indicate
+       // explicitly the maximum allowed recursion depth.
+       maxNestingDepth = 64
+)
+
+// Message is an IPC message, including metadata and body.
+type Message struct {
+       refCount int64
+       msg      *flatbuf.Message
+       meta     *memory.Buffer
+       body     *memory.Buffer
+}
+
+// NewMessage creates a new message from the metadata and body buffers.
+// NewMessage panics if any of these buffers is nil.
+func NewMessage(meta, body *memory.Buffer) *Message {
+       if meta == nil || body == nil {
+               panic("arrow/ipc: nil buffers")
+       }
+       meta.Retain()
+       body.Retain()
+       return &Message{
+               refCount: 1,
+               msg:      flatbuf.GetRootAsMessage(meta.Bytes(), 0),
+               meta:     meta,
+               body:     body,
+       }
+}
+
+func newMessageFromFB(meta *flatbuf.Message, body *memory.Buffer) *Message {
+       if meta == nil || body == nil {
+               panic("arrow/ipc: nil buffers")
+       }
+       body.Retain()
+       return &Message{
+               refCount: 1,
+               msg:      meta,
+               meta:     memory.NewBufferBytes(meta.Table().Bytes),
+               body:     body,
+       }
+}
+
+// Retain increases the reference count by 1.
+// Retain may be called simultaneously from multiple goroutines.
+func (msg *Message) Retain() {
+       atomic.AddInt64(&msg.refCount, 1)
+}
+
+// Release decreases the reference count by 1.
+// Release may be called simultaneously from multiple goroutines.
+// When the reference count goes to zero, the memory is freed.
+func (msg *Message) Release() {
+       debug.Assert(atomic.LoadInt64(&msg.refCount) > 0, "too many releases")
+
+       if atomic.AddInt64(&msg.refCount, -1) == 0 {
+               msg.meta.Release()
+               msg.body.Release()
+               msg.msg = nil
+               msg.meta = nil
+               msg.body = nil
+       }
+}
+
+func (msg *Message) Version() MetadataVersion {
+       return MetadataVersion(msg.msg.Version())
+}
+
+func (msg *Message) Type() MessageType {
+       return MessageType(msg.msg.HeaderType())
+}
+
+func (msg *Message) BodyLen() int64 {
+       return msg.msg.BodyLength()
+}
+
+type MessageReader interface {
+       Message() (*Message, error)
+       Release()
+       Retain()
+}
+
+// MessageReader reads messages from an io.Reader.
+type messageReader struct {
+       r io.Reader
+
+       refCount int64
+       msg      *Message
+}
+
+// NewMessageReader returns a reader that reads messages from an input stream.
+func NewMessageReader(r io.Reader) MessageReader {
+       return &messageReader{r: r, refCount: 1}
+}
+
+// Retain increases the reference count by 1.
+// Retain may be called simultaneously from multiple goroutines.
+func (r *messageReader) Retain() {
+       atomic.AddInt64(&r.refCount, 1)
+}
+
+// Release decreases the reference count by 1.
+// When the reference count goes to zero, the memory is freed.
+// Release may be called simultaneously from multiple goroutines.
+func (r *messageReader) Release() {
+       debug.Assert(atomic.LoadInt64(&r.refCount) > 0, "too many releases")
+
+       if atomic.AddInt64(&r.refCount, -1) == 0 {
+               if r.msg != nil {
+                       r.msg.Release()
+                       r.msg = nil
+               }
+       }
+}
+
+// Message returns the current message that has been extracted from the
+// underlying stream.
+// It is valid until the next call to Message.
+func (r *messageReader) Message() (*Message, error) {
+       var buf = make([]byte, 4)
+       _, err := io.ReadFull(r.r, buf)
+       if err != nil {
+               return nil, xerrors.Errorf("arrow/ipc: could not read continuation indicator: %w", err)
+       }
+       var (
+               cid    = binary.LittleEndian.Uint32(buf)
+               msgLen int32
+       )
+       switch cid {
+       case 0:
+               // EOS message.
+               return nil, io.EOF // FIXME(sbinet): send nil instead? or a special EOS error?
+       case kIPCContToken:
+               _, err = io.ReadFull(r.r, buf)
+               if err != nil {
+                       return nil, xerrors.Errorf("arrow/ipc: could not read message length: %w", err)
+               }
+               msgLen = int32(binary.LittleEndian.Uint32(buf))
+               if msgLen == 0 {
+                       // optional 0 EOS control message
+                       return nil, io.EOF // FIXME(sbinet): send nil instead? or a special EOS error?
+               }
+
+       default:
+               // ARROW-6314: backwards compatibility for reading old IPC
+               // messages produced prior to version 0.15.0
+               msgLen = int32(cid)
+       }
+
+       buf = make([]byte, msgLen)
+       _, err = io.ReadFull(r.r, buf)
+       if err != nil {
+               return nil, xerrors.Errorf("arrow/ipc: could not read message metadata: %w", err)
+       }
+
+       meta := flatbuf.GetRootAsMessage(buf, 0)
+       bodyLen := meta.BodyLength()
+
+       buf = make([]byte, bodyLen)
+       _, err = io.ReadFull(r.r, buf)
+       if err != nil {
+               return nil, xerrors.Errorf("arrow/ipc: could not read message body: %w", err)
+       }
+       body := memory.NewBufferBytes(buf)
+
+       if r.msg != nil {
+               r.msg.Release()
+               r.msg = nil
+       }
+       r.msg = newMessageFromFB(meta, body)
+
+       return r.msg, nil
+}