1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
23 "github.com/apache/arrow/go/v6/parquet/internal/debug"
24 "github.com/klauspost/compress/zstd"
27 type zstdCodec struct{}
29 type zstdcloser struct {
40 func getencoder() *zstd.Encoder {
41 initEncoder.Do(func() {
42 enc, _ = zstd.NewWriter(nil, zstd.WithZeroFrames(true))
47 func getdecoder() *zstd.Decoder {
48 initDecoder.Do(func() {
49 dec, _ = zstd.NewReader(nil)
54 func (zstdCodec) Decode(dst, src []byte) []byte {
55 dst, err := getdecoder().DecodeAll(src, dst[:0])
62 func (z *zstdcloser) Close() error {
67 func (zstdCodec) NewReader(r io.Reader) io.ReadCloser {
68 ret, _ := zstd.NewReader(r)
69 return &zstdcloser{ret}
72 func (zstdCodec) NewWriter(w io.Writer) io.WriteCloser {
73 ret, _ := zstd.NewWriter(w)
77 func (zstdCodec) NewWriterLevel(w io.Writer, level int) (io.WriteCloser, error) {
78 var compressLevel zstd.EncoderLevel
79 if level == DefaultCompressionLevel {
80 compressLevel = zstd.SpeedDefault
82 compressLevel = zstd.EncoderLevelFromZstd(level)
84 return zstd.NewWriter(w, zstd.WithEncoderLevel(compressLevel))
87 func (z zstdCodec) Encode(dst, src []byte) []byte {
88 return getencoder().EncodeAll(src, dst[:0])
91 func (z zstdCodec) EncodeLevel(dst, src []byte, level int) []byte {
92 compressLevel := zstd.EncoderLevelFromZstd(level)
93 if level == DefaultCompressionLevel {
94 compressLevel = zstd.SpeedDefault
96 enc, _ := zstd.NewWriter(nil, zstd.WithZeroFrames(true), zstd.WithEncoderLevel(compressLevel))
97 return enc.EncodeAll(src, dst[:0])
100 // from zstd.h, ZSTD_COMPRESSBOUND
101 func (zstdCodec) CompressBound(len int64) int64 {
102 debug.Assert(len > 0, "len for zstd CompressBound should be > 0")
103 extra := ((128 << 10) - len) >> 11
104 if len >= (128 << 10) {
107 return len + (len >> 8) + extra
111 codecs[Codecs.Zstd] = zstdCodec{}