]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, software | |
12 | // distributed under the License is distributed on an "AS IS" BASIS, | |
13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
14 | // See the License for the specific language governing permissions and | |
15 | // limitations under the License. | |
16 | ||
17 | package ipc // import "github.com/apache/arrow/go/v6/arrow/ipc" | |
18 | ||
19 | import ( | |
20 | "encoding/binary" | |
21 | "fmt" | |
22 | "io" | |
23 | "sync/atomic" | |
24 | ||
25 | "github.com/apache/arrow/go/v6/arrow/internal/debug" | |
26 | "github.com/apache/arrow/go/v6/arrow/internal/flatbuf" | |
27 | "github.com/apache/arrow/go/v6/arrow/memory" | |
28 | "golang.org/x/xerrors" | |
29 | ) | |
30 | ||
31 | // MetadataVersion represents the Arrow metadata version. | |
32 | type MetadataVersion flatbuf.MetadataVersion | |
33 | ||
34 | const ( | |
35 | MetadataV1 = MetadataVersion(flatbuf.MetadataVersionV1) // version for Arrow-0.1.0 | |
36 | MetadataV2 = MetadataVersion(flatbuf.MetadataVersionV2) // version for Arrow-0.2.0 | |
37 | MetadataV3 = MetadataVersion(flatbuf.MetadataVersionV3) // version for Arrow-0.3.0 to 0.7.1 | |
38 | MetadataV4 = MetadataVersion(flatbuf.MetadataVersionV4) // version for >= Arrow-0.8.0 | |
39 | MetadataV5 = MetadataVersion(flatbuf.MetadataVersionV5) // version for >= Arrow-1.0.0, backward compatible with v4 | |
40 | ) | |
41 | ||
42 | func (m MetadataVersion) String() string { | |
43 | if v, ok := flatbuf.EnumNamesMetadataVersion[flatbuf.MetadataVersion(m)]; ok { | |
44 | return v | |
45 | } | |
46 | return fmt.Sprintf("MetadataVersion(%d)", int16(m)) | |
47 | } | |
48 | ||
49 | // MessageType represents the type of Message in an Arrow format. | |
50 | type MessageType flatbuf.MessageHeader | |
51 | ||
52 | const ( | |
53 | MessageNone = MessageType(flatbuf.MessageHeaderNONE) | |
54 | MessageSchema = MessageType(flatbuf.MessageHeaderSchema) | |
55 | MessageDictionaryBatch = MessageType(flatbuf.MessageHeaderDictionaryBatch) | |
56 | MessageRecordBatch = MessageType(flatbuf.MessageHeaderRecordBatch) | |
57 | MessageTensor = MessageType(flatbuf.MessageHeaderTensor) | |
58 | MessageSparseTensor = MessageType(flatbuf.MessageHeaderSparseTensor) | |
59 | ) | |
60 | ||
61 | func (m MessageType) String() string { | |
62 | if v, ok := flatbuf.EnumNamesMessageHeader[flatbuf.MessageHeader(m)]; ok { | |
63 | return v | |
64 | } | |
65 | return fmt.Sprintf("MessageType(%d)", int(m)) | |
66 | } | |
67 | ||
68 | const ( | |
69 | // maxNestingDepth is an arbitrary value to catch user mistakes. | |
70 | // For deeply nested schemas, it is expected the user will indicate | |
71 | // explicitly the maximum allowed recursion depth. | |
72 | maxNestingDepth = 64 | |
73 | ) | |
74 | ||
75 | // Message is an IPC message, including metadata and body. | |
76 | type Message struct { | |
77 | refCount int64 | |
78 | msg *flatbuf.Message | |
79 | meta *memory.Buffer | |
80 | body *memory.Buffer | |
81 | } | |
82 | ||
83 | // NewMessage creates a new message from the metadata and body buffers. | |
84 | // NewMessage panics if any of these buffers is nil. | |
85 | func NewMessage(meta, body *memory.Buffer) *Message { | |
86 | if meta == nil || body == nil { | |
87 | panic("arrow/ipc: nil buffers") | |
88 | } | |
89 | meta.Retain() | |
90 | body.Retain() | |
91 | return &Message{ | |
92 | refCount: 1, | |
93 | msg: flatbuf.GetRootAsMessage(meta.Bytes(), 0), | |
94 | meta: meta, | |
95 | body: body, | |
96 | } | |
97 | } | |
98 | ||
99 | func newMessageFromFB(meta *flatbuf.Message, body *memory.Buffer) *Message { | |
100 | if meta == nil || body == nil { | |
101 | panic("arrow/ipc: nil buffers") | |
102 | } | |
103 | body.Retain() | |
104 | return &Message{ | |
105 | refCount: 1, | |
106 | msg: meta, | |
107 | meta: memory.NewBufferBytes(meta.Table().Bytes), | |
108 | body: body, | |
109 | } | |
110 | } | |
111 | ||
112 | // Retain increases the reference count by 1. | |
113 | // Retain may be called simultaneously from multiple goroutines. | |
114 | func (msg *Message) Retain() { | |
115 | atomic.AddInt64(&msg.refCount, 1) | |
116 | } | |
117 | ||
118 | // Release decreases the reference count by 1. | |
119 | // Release may be called simultaneously from multiple goroutines. | |
120 | // When the reference count goes to zero, the memory is freed. | |
121 | func (msg *Message) Release() { | |
122 | debug.Assert(atomic.LoadInt64(&msg.refCount) > 0, "too many releases") | |
123 | ||
124 | if atomic.AddInt64(&msg.refCount, -1) == 0 { | |
125 | msg.meta.Release() | |
126 | msg.body.Release() | |
127 | msg.msg = nil | |
128 | msg.meta = nil | |
129 | msg.body = nil | |
130 | } | |
131 | } | |
132 | ||
133 | func (msg *Message) Version() MetadataVersion { | |
134 | return MetadataVersion(msg.msg.Version()) | |
135 | } | |
136 | ||
137 | func (msg *Message) Type() MessageType { | |
138 | return MessageType(msg.msg.HeaderType()) | |
139 | } | |
140 | ||
141 | func (msg *Message) BodyLen() int64 { | |
142 | return msg.msg.BodyLength() | |
143 | } | |
144 | ||
145 | type MessageReader interface { | |
146 | Message() (*Message, error) | |
147 | Release() | |
148 | Retain() | |
149 | } | |
150 | ||
151 | // MessageReader reads messages from an io.Reader. | |
152 | type messageReader struct { | |
153 | r io.Reader | |
154 | ||
155 | refCount int64 | |
156 | msg *Message | |
157 | } | |
158 | ||
159 | // NewMessageReader returns a reader that reads messages from an input stream. | |
160 | func NewMessageReader(r io.Reader) MessageReader { | |
161 | return &messageReader{r: r, refCount: 1} | |
162 | } | |
163 | ||
164 | // Retain increases the reference count by 1. | |
165 | // Retain may be called simultaneously from multiple goroutines. | |
166 | func (r *messageReader) Retain() { | |
167 | atomic.AddInt64(&r.refCount, 1) | |
168 | } | |
169 | ||
170 | // Release decreases the reference count by 1. | |
171 | // When the reference count goes to zero, the memory is freed. | |
172 | // Release may be called simultaneously from multiple goroutines. | |
173 | func (r *messageReader) Release() { | |
174 | debug.Assert(atomic.LoadInt64(&r.refCount) > 0, "too many releases") | |
175 | ||
176 | if atomic.AddInt64(&r.refCount, -1) == 0 { | |
177 | if r.msg != nil { | |
178 | r.msg.Release() | |
179 | r.msg = nil | |
180 | } | |
181 | } | |
182 | } | |
183 | ||
184 | // Message returns the current message that has been extracted from the | |
185 | // underlying stream. | |
186 | // It is valid until the next call to Message. | |
187 | func (r *messageReader) Message() (*Message, error) { | |
188 | var buf = make([]byte, 4) | |
189 | _, err := io.ReadFull(r.r, buf) | |
190 | if err != nil { | |
191 | return nil, xerrors.Errorf("arrow/ipc: could not read continuation indicator: %w", err) | |
192 | } | |
193 | var ( | |
194 | cid = binary.LittleEndian.Uint32(buf) | |
195 | msgLen int32 | |
196 | ) | |
197 | switch cid { | |
198 | case 0: | |
199 | // EOS message. | |
200 | return nil, io.EOF // FIXME(sbinet): send nil instead? or a special EOS error? | |
201 | case kIPCContToken: | |
202 | _, err = io.ReadFull(r.r, buf) | |
203 | if err != nil { | |
204 | return nil, xerrors.Errorf("arrow/ipc: could not read message length: %w", err) | |
205 | } | |
206 | msgLen = int32(binary.LittleEndian.Uint32(buf)) | |
207 | if msgLen == 0 { | |
208 | // optional 0 EOS control message | |
209 | return nil, io.EOF // FIXME(sbinet): send nil instead? or a special EOS error? | |
210 | } | |
211 | ||
212 | default: | |
213 | // ARROW-6314: backwards compatibility for reading old IPC | |
214 | // messages produced prior to version 0.15.0 | |
215 | msgLen = int32(cid) | |
216 | } | |
217 | ||
218 | buf = make([]byte, msgLen) | |
219 | _, err = io.ReadFull(r.r, buf) | |
220 | if err != nil { | |
221 | return nil, xerrors.Errorf("arrow/ipc: could not read message metadata: %w", err) | |
222 | } | |
223 | ||
224 | meta := flatbuf.GetRootAsMessage(buf, 0) | |
225 | bodyLen := meta.BodyLength() | |
226 | ||
227 | buf = make([]byte, bodyLen) | |
228 | _, err = io.ReadFull(r.r, buf) | |
229 | if err != nil { | |
230 | return nil, xerrors.Errorf("arrow/ipc: could not read message body: %w", err) | |
231 | } | |
232 | body := memory.NewBufferBytes(buf) | |
233 | ||
234 | if r.msg != nil { | |
235 | r.msg.Release() | |
236 | r.msg = nil | |
237 | } | |
238 | r.msg = newMessageFromFB(meta, body) | |
239 | ||
240 | return r.msg, nil | |
241 | } |