zig/lib/std / compress/zstandard/decode/block.zig

Prepare the decoder to decode a compressed block. Loads the literals stream and Huffman tree from literals and reads the FSE tables from source. Errors returned: - error.BitStreamHasNoStartBit if the (reversed) literal bitstream's first byte does not have any bits set - error.TreelessLiteralsFirst literals is a treeless literals section and the decode state does not have a Huffman tree from a previous block - error.RepeatModeFirst on the first call if one of the sequence FSE tables is set to repeat mode - error.MalformedAccuracyLog if an FSE table has an invalid accuracy - error.MalformedFseTable if there are errors decoding an FSE table - error.EndOfStream if source ends before all FSE tables are read

const std = @import("std");
const assert = std.debug.assert;
const RingBuffer = std.RingBuffer;

Error

Read initial FSE states for sequence decoding. Errors returned: - error.EndOfStream if bit_reader does not contain enough bits.


const types = @import("../types.zig");
const frame = types.frame;
const Table = types.compressed_block.Table;
const LiteralsSection = types.compressed_block.LiteralsSection;
const SequencesSection = types.compressed_block.SequencesSection;

DecodeState

Decode one sequence from bit_reader into dest, written starting at write_pos and update FSE states if last_sequence is false. prepare() must be called for the block before attempting to decode sequences. Errors returned: - error.MalformedSequence if the decompressed sequence would be longer than sequence_size_limit or the sequence's offset is too large - error.UnexpectedEndOfLiteralStream if the decoder state's literal streams do not contain enough literals for the sequence (this may mean the literal stream or the sequence is malformed). - error.InvalidBitStream if the FSE sequence bitstream is malformed - error.EndOfStream if bit_reader does not contain enough bits - error.DestTooSmall if dest is not large enough to holde the decompressed sequence


const huffman = @import("huffman.zig");
const readers = @import("../readers.zig");

init()

Decode one sequence from bit_reader into dest; see decodeSequenceSlice.


const decodeFseTable = @import("fse.zig").decodeFseTable;

prepare()

Decode len bytes of literals into dest. Errors returned: - error.MalformedLiteralsLength if the number of literal bytes decoded by self plus len is greater than the regenerated size of literals - error.UnexpectedEndOfLiteralStream and error.NotFound if there are problems decoding Huffman compressed literals


const readInt = std.mem.readIntLittle;

readInitialFseState()

Decode literals into dest; see decodeLiteralsSlice().


pub const Error = error{
    BlockSizeOverMaximum,
    MalformedBlockSize,
    ReservedBlock,
    MalformedRleBlock,
    MalformedCompressedBlock,
};

decodeSequenceSlice()

Decode a single block from src into dest. The beginning of src must be the start of the block content (i.e. directly after the block header). Increments consumed_count by the number of bytes read from src to decode the block and returns the decompressed size of the block. Errors returned: - error.BlockSizeOverMaximum if block's size is larger than 1 << 17 or dest[written_count..].len - error.MalformedBlockSize if src.len is smaller than the block size and the block is a raw or compressed block - error.ReservedBlock if the block is a reserved block - error.MalformedRleBlock if the block is an RLE block and src.len < 1 - error.MalformedCompressedBlock if there are errors decoding a compressed block - error.DestTooSmall is dest is not large enough to hold the decompressed block


pub const DecodeState = struct {
    repeat_offsets: [3]u32,

decodeSequenceRingBuffer()

Decode a single block from src into dest; see decodeBlock(). Returns the size of the decompressed block, which can be used with dest.sliceLast() to get the decompressed bytes. error.BlockSizeOverMaximum is returned if the block's compressed or decompressed size is larger than block_size_max.


    offset: StateData(8),
    match: StateData(9),
    literal: StateData(9),

decodeLiteralsSlice()

Decode a single block from source into dest. Literal and sequence data from the block is copied into literals_buffer and sequence_buffer, which must be large enough or error.LiteralsBufferTooSmall and error.SequenceBufferTooSmall are returned (the maximum block size is an upper bound for the size of both buffers). See decodeBlock and decodeBlockRingBuffer for function that can decode a block without these extra copies. error.EndOfStream is returned if source does not contain enough bytes.


    offset_fse_buffer: []Table.Fse,
    match_fse_buffer: []Table.Fse,
    literal_fse_buffer: []Table.Fse,

decodeLiteralsRingBuffer()

Decode the header of a block.


    fse_tables_undefined: bool,

decodeBlock()

Decode the header of a block. Errors returned: - error.EndOfStream if src.len < 3


    literal_stream_reader: readers.ReverseBitReader,
    literal_stream_index: usize,
    literal_streams: LiteralsSection.Streams,
    literal_header: LiteralsSection.Header,
    huffman_tree: ?LiteralsSection.HuffmanTree,

decodeBlockRingBuffer()

Decode a LiteralsSection from src, incrementing consumed_count by the number of bytes the section uses. Errors returned: - error.MalformedLiteralsHeader if the header is invalid - error.MalformedLiteralsSection if there are decoding errors - error.MalformedAccuracyLog if compressed literals have invalid accuracy - error.MalformedFseTable if compressed literals have invalid FSE table - error.MalformedHuffmanTree if there are errors decoding a Huffamn tree - error.EndOfStream if there are not enough bytes in src


    literal_written_count: usize,
    written_count: usize = 0,

decodeBlockReader()

Decode a LiteralsSection from src, incrementing consumed_count by the number of bytes the section uses. See decodeLiterasSectionSlice().


    fn StateData(comptime max_accuracy_log: comptime_int) type {
        return struct {
            state: State,
            table: Table,
            accuracy_log: u8,

decodeBlockHeader()

Decode a literals section header. Errors returned: - error.EndOfStream if there are not enough bytes in source


            const State = std.meta.Int(.unsigned, max_accuracy_log);
        };
    }

decodeBlockHeaderSlice()

Decode a sequences section header. Errors returned: - error.ReservedBitSet if the reserved bit is set - error.EndOfStream if there are not enough bytes in source


    pub fn init(
        literal_fse_buffer: []Table.Fse,
        match_fse_buffer: []Table.Fse,
        offset_fse_buffer: []Table.Fse,
    ) DecodeState {
        return DecodeState{
            .repeat_offsets = .{
                types.compressed_block.start_repeated_offset_1,
                types.compressed_block.start_repeated_offset_2,
                types.compressed_block.start_repeated_offset_3,
            },

decodeLiteralsSectionSlice()


            .offset = undefined,
            .match = undefined,
            .literal = undefined,

decodeLiteralsSection()


            .literal_fse_buffer = literal_fse_buffer,
            .match_fse_buffer = match_fse_buffer,
            .offset_fse_buffer = offset_fse_buffer,

decodeLiteralsHeader()


            .fse_tables_undefined = true,

decodeSequencesHeader()


            .literal_written_count = 0,
            .literal_header = undefined,
            .literal_streams = undefined,
            .literal_stream_reader = undefined,
            .literal_stream_index = undefined,
            .huffman_tree = null,

            .written_count = 0,
        };
    }

    /// Prepare the decoder to decode a compressed block. Loads the literals
    /// stream and Huffman tree from `literals` and reads the FSE tables from
    /// `source`.
    ///
    /// Errors returned:
    ///   - `error.BitStreamHasNoStartBit` if the (reversed) literal bitstream's
    ///     first byte does not have any bits set
    ///   - `error.TreelessLiteralsFirst` `literals` is a treeless literals
    ///     section and the decode state does not have a Huffman tree from a
    ///     previous block
    ///   - `error.RepeatModeFirst` on the first call if one of the sequence FSE
    ///     tables is set to repeat mode
    ///   - `error.MalformedAccuracyLog` if an FSE table has an invalid accuracy
    ///   - `error.MalformedFseTable` if there are errors decoding an FSE table
    ///   - `error.EndOfStream` if `source` ends before all FSE tables are read
    pub fn prepare(
        self: *DecodeState,
        source: anytype,
        literals: LiteralsSection,
        sequences_header: SequencesSection.Header,
    ) !void {
        self.literal_written_count = 0;
        self.literal_header = literals.header;
        self.literal_streams = literals.streams;

        if (literals.huffman_tree) |tree| {
            self.huffman_tree = tree;
        } else if (literals.header.block_type == .treeless and self.huffman_tree == null) {
            return error.TreelessLiteralsFirst;
        }

        switch (literals.header.block_type) {
            .raw, .rle => {},
            .compressed, .treeless => {
                self.literal_stream_index = 0;
                switch (literals.streams) {
                    .one => |slice| try self.initLiteralStream(slice),
                    .four => |streams| try self.initLiteralStream(streams[0]),
                }
            },
        }

        if (sequences_header.sequence_count > 0) {
            try self.updateFseTable(source, .literal, sequences_header.literal_lengths);
            try self.updateFseTable(source, .offset, sequences_header.offsets);
            try self.updateFseTable(source, .match, sequences_header.match_lengths);
            self.fse_tables_undefined = false;
        }
    }

    /// Read initial FSE states for sequence decoding.
    ///
    /// Errors returned:
    ///   - `error.EndOfStream` if `bit_reader` does not contain enough bits.
    pub fn readInitialFseState(self: *DecodeState, bit_reader: *readers.ReverseBitReader) error{EndOfStream}!void {
        self.literal.state = try bit_reader.readBitsNoEof(u9, self.literal.accuracy_log);
        self.offset.state = try bit_reader.readBitsNoEof(u8, self.offset.accuracy_log);
        self.match.state = try bit_reader.readBitsNoEof(u9, self.match.accuracy_log);
    }

    fn updateRepeatOffset(self: *DecodeState, offset: u32) void {
        self.repeat_offsets[2] = self.repeat_offsets[1];
        self.repeat_offsets[1] = self.repeat_offsets[0];
        self.repeat_offsets[0] = offset;
    }

    fn useRepeatOffset(self: *DecodeState, index: usize) u32 {
        if (index == 1)
            std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[1])
        else if (index == 2) {
            std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[2]);
            std.mem.swap(u32, &self.repeat_offsets[1], &self.repeat_offsets[2]);
        }
        return self.repeat_offsets[0];
    }

    const DataType = enum { offset, match, literal };

    fn updateState(
        self: *DecodeState,
        comptime choice: DataType,
        bit_reader: *readers.ReverseBitReader,
    ) error{ MalformedFseBits, EndOfStream }!void {
        switch (@field(self, @tagName(choice)).table) {
            .rle => {},
            .fse => |table| {
                const data = table[@field(self, @tagName(choice)).state];
                const T = @TypeOf(@field(self, @tagName(choice))).State;
                const bits_summand = try bit_reader.readBitsNoEof(T, data.bits);
                const next_state = std.math.cast(
                    @TypeOf(@field(self, @tagName(choice))).State,
                    data.baseline + bits_summand,
                ) orelse return error.MalformedFseBits;
                @field(self, @tagName(choice)).state = next_state;
            },
        }
    }

    const FseTableError = error{
        MalformedFseTable,
        MalformedAccuracyLog,
        RepeatModeFirst,
        EndOfStream,
    };

    fn updateFseTable(
        self: *DecodeState,
        source: anytype,
        comptime choice: DataType,
        mode: SequencesSection.Header.Mode,
    ) !void {
        const field_name = @tagName(choice);
        switch (mode) {
            .predefined => {
                @field(self, field_name).accuracy_log =
                    @field(types.compressed_block.default_accuracy_log, field_name);

                @field(self, field_name).table =
                    @field(types.compressed_block, "predefined_" ++ field_name ++ "_fse_table");
            },
            .rle => {
                @field(self, field_name).accuracy_log = 0;
                @field(self, field_name).table = .{ .rle = try source.readByte() };
            },
            .fse => {
                var bit_reader = readers.bitReader(source);

                const table_size = try decodeFseTable(
                    &bit_reader,
                    @field(types.compressed_block.table_symbol_count_max, field_name),
                    @field(types.compressed_block.table_accuracy_log_max, field_name),
                    @field(self, field_name ++ "_fse_buffer"),
                );
                @field(self, field_name).table = .{
                    .fse = @field(self, field_name ++ "_fse_buffer")[0..table_size],
                };
                @field(self, field_name).accuracy_log = std.math.log2_int_ceil(usize, table_size);
            },
            .repeat => if (self.fse_tables_undefined) return error.RepeatModeFirst,
        }
    }

    const Sequence = struct {
        literal_length: u32,
        match_length: u32,
        offset: u32,
    };

    fn nextSequence(
        self: *DecodeState,
        bit_reader: *readers.ReverseBitReader,
    ) error{ InvalidBitStream, EndOfStream }!Sequence {
        const raw_code = self.getCode(.offset);
        const offset_code = std.math.cast(u5, raw_code) orelse {
            return error.InvalidBitStream;
        };
        const offset_value = (@as(u32, 1) << offset_code) + try bit_reader.readBitsNoEof(u32, offset_code);

        const match_code = self.getCode(.match);
        if (match_code >= types.compressed_block.match_length_code_table.len)
            return error.InvalidBitStream;
        const match = types.compressed_block.match_length_code_table[match_code];
        const match_length = match[0] + try bit_reader.readBitsNoEof(u32, match[1]);

        const literal_code = self.getCode(.literal);
        if (literal_code >= types.compressed_block.literals_length_code_table.len)
            return error.InvalidBitStream;
        const literal = types.compressed_block.literals_length_code_table[literal_code];
        const literal_length = literal[0] + try bit_reader.readBitsNoEof(u32, literal[1]);

        const offset = if (offset_value > 3) offset: {
            const offset = offset_value - 3;
            self.updateRepeatOffset(offset);
            break :offset offset;
        } else offset: {
            if (literal_length == 0) {
                if (offset_value == 3) {
                    const offset = self.repeat_offsets[0] - 1;
                    self.updateRepeatOffset(offset);
                    break :offset offset;
                }
                break :offset self.useRepeatOffset(offset_value);
            }
            break :offset self.useRepeatOffset(offset_value - 1);
        };

        if (offset == 0) return error.InvalidBitStream;

        return .{
            .literal_length = literal_length,
            .match_length = match_length,
            .offset = offset,
        };
    }

    fn executeSequenceSlice(
        self: *DecodeState,
        dest: []u8,
        write_pos: usize,
        sequence: Sequence,
    ) (error{MalformedSequence} || DecodeLiteralsError)!void {
        if (sequence.offset > write_pos + sequence.literal_length) return error.MalformedSequence;

        try self.decodeLiteralsSlice(dest[write_pos..], sequence.literal_length);
        const copy_start = write_pos + sequence.literal_length - sequence.offset;
        for (
            dest[write_pos + sequence.literal_length ..][0..sequence.match_length],
            dest[copy_start..][0..sequence.match_length],
        ) |*d, s| d.* = s;
        self.written_count += sequence.match_length;
    }

    fn executeSequenceRingBuffer(
        self: *DecodeState,
        dest: *RingBuffer,
        sequence: Sequence,
    ) (error{MalformedSequence} || DecodeLiteralsError)!void {
        if (sequence.offset > @min(dest.data.len, self.written_count + sequence.literal_length))
            return error.MalformedSequence;

        try self.decodeLiteralsRingBuffer(dest, sequence.literal_length);
        const copy_start = dest.write_index + dest.data.len - sequence.offset;
        const copy_slice = dest.sliceAt(copy_start, sequence.match_length);
        for (copy_slice.first) |b| dest.writeAssumeCapacity(b);
        for (copy_slice.second) |b| dest.writeAssumeCapacity(b);
        self.written_count += sequence.match_length;
    }

    const DecodeSequenceError = error{
        InvalidBitStream,
        EndOfStream,
        MalformedSequence,
        MalformedFseBits,
    } || DecodeLiteralsError;

    /// Decode one sequence from `bit_reader` into `dest`, written starting at
    /// `write_pos` and update FSE states if `last_sequence` is `false`.
    /// `prepare()` must be called for the block before attempting to decode
    /// sequences.
    ///
    /// Errors returned:
    ///   - `error.MalformedSequence` if the decompressed sequence would be
    ///     longer than `sequence_size_limit` or the sequence's offset is too
    ///     large
    ///   - `error.UnexpectedEndOfLiteralStream` if the decoder state's literal
    ///     streams do not contain enough literals for the sequence (this may
    ///     mean the literal stream or the sequence is malformed).
    ///   - `error.InvalidBitStream` if the FSE sequence bitstream is malformed
    ///   - `error.EndOfStream` if `bit_reader` does not contain enough bits
    ///   - `error.DestTooSmall` if `dest` is not large enough to holde the
    ///     decompressed sequence
    pub fn decodeSequenceSlice(
        self: *DecodeState,
        dest: []u8,
        write_pos: usize,
        bit_reader: *readers.ReverseBitReader,
        sequence_size_limit: usize,
        last_sequence: bool,
    ) (error{DestTooSmall} || DecodeSequenceError)!usize {
        const sequence = try self.nextSequence(bit_reader);
        const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length;
        if (sequence_length > sequence_size_limit) return error.MalformedSequence;
        if (sequence_length > dest[write_pos..].len) return error.DestTooSmall;

        try self.executeSequenceSlice(dest, write_pos, sequence);
        if (!last_sequence) {
            try self.updateState(.literal, bit_reader);
            try self.updateState(.match, bit_reader);
            try self.updateState(.offset, bit_reader);
        }
        return sequence_length;
    }

    /// Decode one sequence from `bit_reader` into `dest`; see
    /// `decodeSequenceSlice`.
    pub fn decodeSequenceRingBuffer(
        self: *DecodeState,
        dest: *RingBuffer,
        bit_reader: anytype,
        sequence_size_limit: usize,
        last_sequence: bool,
    ) DecodeSequenceError!usize {
        const sequence = try self.nextSequence(bit_reader);
        const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length;
        if (sequence_length > sequence_size_limit) return error.MalformedSequence;

        try self.executeSequenceRingBuffer(dest, sequence);
        if (!last_sequence) {
            try self.updateState(.literal, bit_reader);
            try self.updateState(.match, bit_reader);
            try self.updateState(.offset, bit_reader);
        }
        return sequence_length;
    }

    fn nextLiteralMultiStream(
        self: *DecodeState,
    ) error{BitStreamHasNoStartBit}!void {
        self.literal_stream_index += 1;
        try self.initLiteralStream(self.literal_streams.four[self.literal_stream_index]);
    }

    fn initLiteralStream(self: *DecodeState, bytes: []const u8) error{BitStreamHasNoStartBit}!void {
        try self.literal_stream_reader.init(bytes);
    }

    fn isLiteralStreamEmpty(self: *DecodeState) bool {
        switch (self.literal_streams) {
            .one => return self.literal_stream_reader.isEmpty(),
            .four => return self.literal_stream_index == 3 and self.literal_stream_reader.isEmpty(),
        }
    }

    const LiteralBitsError = error{
        BitStreamHasNoStartBit,
        UnexpectedEndOfLiteralStream,
    };
    fn readLiteralsBits(
        self: *DecodeState,
        bit_count_to_read: usize,
    ) LiteralBitsError!u16 {
        return self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch bits: {
            if (self.literal_streams == .four and self.literal_stream_index < 3) {
                try self.nextLiteralMultiStream();
                break :bits self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch
                    return error.UnexpectedEndOfLiteralStream;
            } else {
                return error.UnexpectedEndOfLiteralStream;
            }
        };
    }

    const DecodeLiteralsError = error{
        MalformedLiteralsLength,
        NotFound,
    } || LiteralBitsError;

    /// Decode `len` bytes of literals into `dest`.
    ///
    /// Errors returned:
    ///   - `error.MalformedLiteralsLength` if the number of literal bytes
    ///     decoded by `self` plus `len` is greater than the regenerated size of
    ///     `literals`
    ///   - `error.UnexpectedEndOfLiteralStream` and `error.NotFound` if there
    ///     are problems decoding Huffman compressed literals
    pub fn decodeLiteralsSlice(
        self: *DecodeState,
        dest: []u8,
        len: usize,
    ) DecodeLiteralsError!void {
        if (self.literal_written_count + len > self.literal_header.regenerated_size)
            return error.MalformedLiteralsLength;

        switch (self.literal_header.block_type) {
            .raw => {
                const literal_data = self.literal_streams.one[self.literal_written_count..][0..len];
                @memcpy(dest[0..len], literal_data);
                self.literal_written_count += len;
                self.written_count += len;
            },
            .rle => {
                for (0..len) |i| {
                    dest[i] = self.literal_streams.one[0];
                }
                self.literal_written_count += len;
                self.written_count += len;
            },
            .compressed, .treeless => {
                // const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4;
                const huffman_tree = self.huffman_tree orelse unreachable;
                const max_bit_count = huffman_tree.max_bit_count;
                const starting_bit_count = LiteralsSection.HuffmanTree.weightToBitCount(
                    huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight,
                    max_bit_count,
                );
                var bits_read: u4 = 0;
                var huffman_tree_index: usize = huffman_tree.symbol_count_minus_one;
                var bit_count_to_read: u4 = starting_bit_count;
                for (0..len) |i| {
                    var prefix: u16 = 0;
                    while (true) {
                        const new_bits = self.readLiteralsBits(bit_count_to_read) catch |err| {
                            return err;
                        };
                        prefix <<= bit_count_to_read;
                        prefix |= new_bits;
                        bits_read += bit_count_to_read;
                        const result = huffman_tree.query(huffman_tree_index, prefix) catch |err| {
                            return err;
                        };

                        switch (result) {
                            .symbol => |sym| {
                                dest[i] = sym;
                                bit_count_to_read = starting_bit_count;
                                bits_read = 0;
                                huffman_tree_index = huffman_tree.symbol_count_minus_one;
                                break;
                            },
                            .index => |index| {
                                huffman_tree_index = index;
                                const bit_count = LiteralsSection.HuffmanTree.weightToBitCount(
                                    huffman_tree.nodes[index].weight,
                                    max_bit_count,
                                );
                                bit_count_to_read = bit_count - bits_read;
                            },
                        }
                    }
                }
                self.literal_written_count += len;
                self.written_count += len;
            },
        }
    }

    /// Decode literals into `dest`; see `decodeLiteralsSlice()`.
    pub fn decodeLiteralsRingBuffer(
        self: *DecodeState,
        dest: *RingBuffer,
        len: usize,
    ) DecodeLiteralsError!void {
        if (self.literal_written_count + len > self.literal_header.regenerated_size)
            return error.MalformedLiteralsLength;

        switch (self.literal_header.block_type) {
            .raw => {
                const literals_end = self.literal_written_count + len;
                const literal_data = self.literal_streams.one[self.literal_written_count..literals_end];
                dest.writeSliceAssumeCapacity(literal_data);
                self.literal_written_count += len;
                self.written_count += len;
            },
            .rle => {
                for (0..len) |_| {
                    dest.writeAssumeCapacity(self.literal_streams.one[0]);
                }
                self.literal_written_count += len;
                self.written_count += len;
            },
            .compressed, .treeless => {
                // const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4;
                const huffman_tree = self.huffman_tree orelse unreachable;
                const max_bit_count = huffman_tree.max_bit_count;
                const starting_bit_count = LiteralsSection.HuffmanTree.weightToBitCount(
                    huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight,
                    max_bit_count,
                );
                var bits_read: u4 = 0;
                var huffman_tree_index: usize = huffman_tree.symbol_count_minus_one;
                var bit_count_to_read: u4 = starting_bit_count;
                for (0..len) |_| {
                    var prefix: u16 = 0;
                    while (true) {
                        const new_bits = try self.readLiteralsBits(bit_count_to_read);
                        prefix <<= bit_count_to_read;
                        prefix |= new_bits;
                        bits_read += bit_count_to_read;
                        const result = try huffman_tree.query(huffman_tree_index, prefix);

                        switch (result) {
                            .symbol => |sym| {
                                dest.writeAssumeCapacity(sym);
                                bit_count_to_read = starting_bit_count;
                                bits_read = 0;
                                huffman_tree_index = huffman_tree.symbol_count_minus_one;
                                break;
                            },
                            .index => |index| {
                                huffman_tree_index = index;
                                const bit_count = LiteralsSection.HuffmanTree.weightToBitCount(
                                    huffman_tree.nodes[index].weight,
                                    max_bit_count,
                                );
                                bit_count_to_read = bit_count - bits_read;
                            },
                        }
                    }
                }
                self.literal_written_count += len;
                self.written_count += len;
            },
        }
    }

    fn getCode(self: *DecodeState, comptime choice: DataType) u32 {
        return switch (@field(self, @tagName(choice)).table) {
            .rle => |value| value,
            .fse => |table| table[@field(self, @tagName(choice)).state].symbol,
        };
    }
};

/// Decode a single block from `src` into `dest`. The beginning of `src` must be
/// the start of the block content (i.e. directly after the block header).
/// Increments `consumed_count` by the number of bytes read from `src` to decode
/// the block and returns the decompressed size of the block.
///
/// Errors returned:
///
///   - `error.BlockSizeOverMaximum` if block's size is larger than 1 << 17 or
///     `dest[written_count..].len`
///   - `error.MalformedBlockSize` if `src.len` is smaller than the block size
///     and the block is a raw or compressed block
///   - `error.ReservedBlock` if the block is a reserved block
///   - `error.MalformedRleBlock` if the block is an RLE block and `src.len < 1`
///   - `error.MalformedCompressedBlock` if there are errors decoding a
///     compressed block
///   - `error.DestTooSmall` is `dest` is not large enough to hold the
///     decompressed block
pub fn decodeBlock(
    dest: []u8,
    src: []const u8,
    block_header: frame.Zstandard.Block.Header,
    decode_state: *DecodeState,
    consumed_count: *usize,
    block_size_max: usize,
    written_count: usize,
) (error{DestTooSmall} || Error)!usize {
    const block_size = block_header.block_size;
    if (block_size_max < block_size) return error.BlockSizeOverMaximum;
    switch (block_header.block_type) {
        .raw => {
            if (src.len < block_size) return error.MalformedBlockSize;
            if (dest[written_count..].len < block_size) return error.DestTooSmall;
            @memcpy(dest[written_count..][0..block_size], src[0..block_size]);
            consumed_count.* += block_size;
            decode_state.written_count += block_size;
            return block_size;
        },
        .rle => {
            if (src.len < 1) return error.MalformedRleBlock;
            if (dest[written_count..].len < block_size) return error.DestTooSmall;
            for (written_count..block_size + written_count) |write_pos| {
                dest[write_pos] = src[0];
            }
            consumed_count.* += 1;
            decode_state.written_count += block_size;
            return block_size;
        },
        .compressed => {
            if (src.len < block_size) return error.MalformedBlockSize;
            var bytes_read: usize = 0;
            const literals = decodeLiteralsSectionSlice(src[0..block_size], &bytes_read) catch
                return error.MalformedCompressedBlock;
            var fbs = std.io.fixedBufferStream(src[bytes_read..block_size]);
            const fbs_reader = fbs.reader();
            const sequences_header = decodeSequencesHeader(fbs_reader) catch
                return error.MalformedCompressedBlock;

            decode_state.prepare(fbs_reader, literals, sequences_header) catch
                return error.MalformedCompressedBlock;

            bytes_read += fbs.pos;

            var bytes_written: usize = 0;
            {
                const bit_stream_bytes = src[bytes_read..block_size];
                var bit_stream: readers.ReverseBitReader = undefined;
                bit_stream.init(bit_stream_bytes) catch return error.MalformedCompressedBlock;

                if (sequences_header.sequence_count > 0) {
                    decode_state.readInitialFseState(&bit_stream) catch
                        return error.MalformedCompressedBlock;

                    var sequence_size_limit = block_size_max;
                    for (0..sequences_header.sequence_count) |i| {
                        const write_pos = written_count + bytes_written;
                        const decompressed_size = decode_state.decodeSequenceSlice(
                            dest,
                            write_pos,
                            &bit_stream,
                            sequence_size_limit,
                            i == sequences_header.sequence_count - 1,
                        ) catch |err| switch (err) {
                            error.DestTooSmall => return error.DestTooSmall,
                            else => return error.MalformedCompressedBlock,
                        };
                        bytes_written += decompressed_size;
                        sequence_size_limit -= decompressed_size;
                    }
                }

                if (!bit_stream.isEmpty()) {
                    return error.MalformedCompressedBlock;
                }
            }

            if (decode_state.literal_written_count < literals.header.regenerated_size) {
                const len = literals.header.regenerated_size - decode_state.literal_written_count;
                if (len > dest[written_count + bytes_written ..].len) return error.DestTooSmall;
                decode_state.decodeLiteralsSlice(dest[written_count + bytes_written ..], len) catch
                    return error.MalformedCompressedBlock;
                bytes_written += len;
            }

            switch (decode_state.literal_header.block_type) {
                .treeless, .compressed => {
                    if (!decode_state.isLiteralStreamEmpty()) return error.MalformedCompressedBlock;
                },
                .raw, .rle => {},
            }

            consumed_count.* += block_size;
            return bytes_written;
        },
        .reserved => return error.ReservedBlock,
    }
}

/// Decode a single block from `src` into `dest`; see `decodeBlock()`. Returns
/// the size of the decompressed block, which can be used with `dest.sliceLast()`
/// to get the decompressed bytes. `error.BlockSizeOverMaximum` is returned if
/// the block's compressed or decompressed size is larger than `block_size_max`.
pub fn decodeBlockRingBuffer(
    dest: *RingBuffer,
    src: []const u8,
    block_header: frame.Zstandard.Block.Header,
    decode_state: *DecodeState,
    consumed_count: *usize,
    block_size_max: usize,
) Error!usize {
    const block_size = block_header.block_size;
    if (block_size_max < block_size) return error.BlockSizeOverMaximum;
    switch (block_header.block_type) {
        .raw => {
            if (src.len < block_size) return error.MalformedBlockSize;
            const data = src[0..block_size];
            dest.writeSliceAssumeCapacity(data);
            consumed_count.* += block_size;
            decode_state.written_count += block_size;
            return block_size;
        },
        .rle => {
            if (src.len < 1) return error.MalformedRleBlock;
            for (0..block_size) |_| {
                dest.writeAssumeCapacity(src[0]);
            }
            consumed_count.* += 1;
            decode_state.written_count += block_size;
            return block_size;
        },
        .compressed => {
            if (src.len < block_size) return error.MalformedBlockSize;
            var bytes_read: usize = 0;
            const literals = decodeLiteralsSectionSlice(src[0..block_size], &bytes_read) catch
                return error.MalformedCompressedBlock;
            var fbs = std.io.fixedBufferStream(src[bytes_read..block_size]);
            const fbs_reader = fbs.reader();
            const sequences_header = decodeSequencesHeader(fbs_reader) catch
                return error.MalformedCompressedBlock;

            decode_state.prepare(fbs_reader, literals, sequences_header) catch
                return error.MalformedCompressedBlock;

            bytes_read += fbs.pos;

            var bytes_written: usize = 0;
            {
                const bit_stream_bytes = src[bytes_read..block_size];
                var bit_stream: readers.ReverseBitReader = undefined;
                bit_stream.init(bit_stream_bytes) catch return error.MalformedCompressedBlock;

                if (sequences_header.sequence_count > 0) {
                    decode_state.readInitialFseState(&bit_stream) catch
                        return error.MalformedCompressedBlock;

                    var sequence_size_limit = block_size_max;
                    for (0..sequences_header.sequence_count) |i| {
                        const decompressed_size = decode_state.decodeSequenceRingBuffer(
                            dest,
                            &bit_stream,
                            sequence_size_limit,
                            i == sequences_header.sequence_count - 1,
                        ) catch return error.MalformedCompressedBlock;
                        bytes_written += decompressed_size;
                        sequence_size_limit -= decompressed_size;
                    }
                }

                if (!bit_stream.isEmpty()) {
                    return error.MalformedCompressedBlock;
                }
            }

            if (decode_state.literal_written_count < literals.header.regenerated_size) {
                const len = literals.header.regenerated_size - decode_state.literal_written_count;
                decode_state.decodeLiteralsRingBuffer(dest, len) catch
                    return error.MalformedCompressedBlock;
                bytes_written += len;
            }

            switch (decode_state.literal_header.block_type) {
                .treeless, .compressed => {
                    if (!decode_state.isLiteralStreamEmpty()) return error.MalformedCompressedBlock;
                },
                .raw, .rle => {},
            }

            consumed_count.* += block_size;
            if (bytes_written > block_size_max) return error.BlockSizeOverMaximum;
            return bytes_written;
        },
        .reserved => return error.ReservedBlock,
    }
}

/// Decode a single block from `source` into `dest`. Literal and sequence data
/// from the block is copied into `literals_buffer` and `sequence_buffer`, which
/// must be large enough or `error.LiteralsBufferTooSmall` and
/// `error.SequenceBufferTooSmall` are returned (the maximum block size is an
/// upper bound for the size of both buffers). See `decodeBlock`
/// and `decodeBlockRingBuffer` for function that can decode a block without
/// these extra copies. `error.EndOfStream` is returned if `source` does not
/// contain enough bytes.
pub fn decodeBlockReader(
    dest: *RingBuffer,
    source: anytype,
    block_header: frame.Zstandard.Block.Header,
    decode_state: *DecodeState,
    block_size_max: usize,
    literals_buffer: []u8,
    sequence_buffer: []u8,
) !void {
    const block_size = block_header.block_size;
    var block_reader_limited = std.io.limitedReader(source, block_size);
    const block_reader = block_reader_limited.reader();
    if (block_size_max < block_size) return error.BlockSizeOverMaximum;
    switch (block_header.block_type) {
        .raw => {
            if (block_size == 0) return;
            const slice = dest.sliceAt(dest.write_index, block_size);
            try source.readNoEof(slice.first);
            try source.readNoEof(slice.second);
            dest.write_index = dest.mask2(dest.write_index + block_size);
            decode_state.written_count += block_size;
        },
        .rle => {
            const byte = try source.readByte();
            for (0..block_size) |_| {
                dest.writeAssumeCapacity(byte);
            }
            decode_state.written_count += block_size;
        },
        .compressed => {
            const literals = try decodeLiteralsSection(block_reader, literals_buffer);
            const sequences_header = try decodeSequencesHeader(block_reader);

            try decode_state.prepare(block_reader, literals, sequences_header);

            var bytes_written: usize = 0;
            {
                const size = try block_reader.readAll(sequence_buffer);
                var bit_stream: readers.ReverseBitReader = undefined;
                try bit_stream.init(sequence_buffer[0..size]);

                if (sequences_header.sequence_count > 0) {
                    if (sequence_buffer.len < block_reader_limited.bytes_left)
                        return error.SequenceBufferTooSmall;

                    decode_state.readInitialFseState(&bit_stream) catch
                        return error.MalformedCompressedBlock;

                    var sequence_size_limit = block_size_max;
                    for (0..sequences_header.sequence_count) |i| {
                        const decompressed_size = decode_state.decodeSequenceRingBuffer(
                            dest,
                            &bit_stream,
                            sequence_size_limit,
                            i == sequences_header.sequence_count - 1,
                        ) catch return error.MalformedCompressedBlock;
                        sequence_size_limit -= decompressed_size;
                        bytes_written += decompressed_size;
                    }
                }

                if (!bit_stream.isEmpty()) {
                    return error.MalformedCompressedBlock;
                }
            }

            if (decode_state.literal_written_count < literals.header.regenerated_size) {
                const len = literals.header.regenerated_size - decode_state.literal_written_count;
                decode_state.decodeLiteralsRingBuffer(dest, len) catch
                    return error.MalformedCompressedBlock;
                bytes_written += len;
            }

            switch (decode_state.literal_header.block_type) {
                .treeless, .compressed => {
                    if (!decode_state.isLiteralStreamEmpty()) return error.MalformedCompressedBlock;
                },
                .raw, .rle => {},
            }

            if (bytes_written > block_size_max) return error.BlockSizeOverMaximum;
            if (block_reader_limited.bytes_left != 0) return error.MalformedCompressedBlock;
            decode_state.literal_written_count = 0;
        },
        .reserved => return error.ReservedBlock,
    }
}

/// Decode the header of a block.
pub fn decodeBlockHeader(src: *const [3]u8) frame.Zstandard.Block.Header {
    const last_block = src[0] & 1 == 1;
    const block_type = @as(frame.Zstandard.Block.Type, @enumFromInt((src[0] & 0b110) >> 1));
    const block_size = ((src[0] & 0b11111000) >> 3) + (@as(u21, src[1]) << 5) + (@as(u21, src[2]) << 13);
    return .{
        .last_block = last_block,
        .block_type = block_type,
        .block_size = block_size,
    };
}

/// Decode the header of a block.
///
/// Errors returned:
///   - `error.EndOfStream` if `src.len < 3`
pub fn decodeBlockHeaderSlice(src: []const u8) error{EndOfStream}!frame.Zstandard.Block.Header {
    if (src.len < 3) return error.EndOfStream;
    return decodeBlockHeader(src[0..3]);
}

/// Decode a `LiteralsSection` from `src`, incrementing `consumed_count` by the
/// number of bytes the section uses.
///
/// Errors returned:
///   - `error.MalformedLiteralsHeader` if the header is invalid
///   - `error.MalformedLiteralsSection` if there are decoding errors
///   - `error.MalformedAccuracyLog` if compressed literals have invalid
///     accuracy
///   - `error.MalformedFseTable` if compressed literals have invalid FSE table
///   - `error.MalformedHuffmanTree` if there are errors decoding a Huffamn tree
///   - `error.EndOfStream` if there are not enough bytes in `src`
pub fn decodeLiteralsSectionSlice(
    src: []const u8,
    consumed_count: *usize,
) (error{ MalformedLiteralsHeader, MalformedLiteralsSection, EndOfStream } || huffman.Error)!LiteralsSection {
    var bytes_read: usize = 0;
    const header = header: {
        var fbs = std.io.fixedBufferStream(src);
        defer bytes_read = fbs.pos;
        break :header decodeLiteralsHeader(fbs.reader()) catch return error.MalformedLiteralsHeader;
    };
    switch (header.block_type) {
        .raw => {
            if (src.len < bytes_read + header.regenerated_size) return error.MalformedLiteralsSection;
            const stream = src[bytes_read .. bytes_read + header.regenerated_size];
            consumed_count.* += header.regenerated_size + bytes_read;
            return LiteralsSection{
                .header = header,
                .huffman_tree = null,
                .streams = .{ .one = stream },
            };
        },
        .rle => {
            if (src.len < bytes_read + 1) return error.MalformedLiteralsSection;
            const stream = src[bytes_read .. bytes_read + 1];
            consumed_count.* += 1 + bytes_read;
            return LiteralsSection{
                .header = header,
                .huffman_tree = null,
                .streams = .{ .one = stream },
            };
        },
        .compressed, .treeless => {
            const huffman_tree_start = bytes_read;
            const huffman_tree = if (header.block_type == .compressed)
                try huffman.decodeHuffmanTreeSlice(src[bytes_read..], &bytes_read)
            else
                null;
            const huffman_tree_size = bytes_read - huffman_tree_start;
            const total_streams_size = std.math.sub(usize, header.compressed_size.?, huffman_tree_size) catch
                return error.MalformedLiteralsSection;

            if (src.len < bytes_read + total_streams_size) return error.MalformedLiteralsSection;
            const stream_data = src[bytes_read .. bytes_read + total_streams_size];

            const streams = try decodeStreams(header.size_format, stream_data);
            consumed_count.* += bytes_read + total_streams_size;
            return LiteralsSection{
                .header = header,
                .huffman_tree = huffman_tree,
                .streams = streams,
            };
        },
    }
}

/// Decode a `LiteralsSection` from `src`, incrementing `consumed_count` by the
/// number of bytes the section uses. See `decodeLiterasSectionSlice()`.
pub fn decodeLiteralsSection(
    source: anytype,
    buffer: []u8,
) !LiteralsSection {
    const header = try decodeLiteralsHeader(source);
    switch (header.block_type) {
        .raw => {
            try source.readNoEof(buffer[0..header.regenerated_size]);
            return LiteralsSection{
                .header = header,
                .huffman_tree = null,
                .streams = .{ .one = buffer },
            };
        },
        .rle => {
            buffer[0] = try source.readByte();
            return LiteralsSection{
                .header = header,
                .huffman_tree = null,
                .streams = .{ .one = buffer[0..1] },
            };
        },
        .compressed, .treeless => {
            var counting_reader = std.io.countingReader(source);
            const huffman_tree = if (header.block_type == .compressed)
                try huffman.decodeHuffmanTree(counting_reader.reader(), buffer)
            else
                null;
            const huffman_tree_size = @as(usize, @intCast(counting_reader.bytes_read));
            const total_streams_size = std.math.sub(usize, header.compressed_size.?, huffman_tree_size) catch
                return error.MalformedLiteralsSection;

            if (total_streams_size > buffer.len) return error.LiteralsBufferTooSmall;
            try source.readNoEof(buffer[0..total_streams_size]);
            const stream_data = buffer[0..total_streams_size];

            const streams = try decodeStreams(header.size_format, stream_data);
            return LiteralsSection{
                .header = header,
                .huffman_tree = huffman_tree,
                .streams = streams,
            };
        },
    }
}

fn decodeStreams(size_format: u2, stream_data: []const u8) !LiteralsSection.Streams {
    if (size_format == 0) {
        return .{ .one = stream_data };
    }

    if (stream_data.len < 6) return error.MalformedLiteralsSection;

    const stream_1_length = @as(usize, readInt(u16, stream_data[0..2]));
    const stream_2_length = @as(usize, readInt(u16, stream_data[2..4]));
    const stream_3_length = @as(usize, readInt(u16, stream_data[4..6]));

    const stream_1_start = 6;
    const stream_2_start = stream_1_start + stream_1_length;
    const stream_3_start = stream_2_start + stream_2_length;
    const stream_4_start = stream_3_start + stream_3_length;

    if (stream_data.len < stream_4_start) return error.MalformedLiteralsSection;

    return .{ .four = .{
        stream_data[stream_1_start .. stream_1_start + stream_1_length],
        stream_data[stream_2_start .. stream_2_start + stream_2_length],
        stream_data[stream_3_start .. stream_3_start + stream_3_length],
        stream_data[stream_4_start..],
    } };
}

/// Decode a literals section header.
///
/// Errors returned:
///   - `error.EndOfStream` if there are not enough bytes in `source`
pub fn decodeLiteralsHeader(source: anytype) !LiteralsSection.Header {
    const byte0 = try source.readByte();
    const block_type = @as(LiteralsSection.BlockType, @enumFromInt(byte0 & 0b11));
    const size_format = @as(u2, @intCast((byte0 & 0b1100) >> 2));
    var regenerated_size: u20 = undefined;
    var compressed_size: ?u18 = null;
    switch (block_type) {
        .raw, .rle => {
            switch (size_format) {
                0, 2 => {
                    regenerated_size = byte0 >> 3;
                },
                1 => regenerated_size = (byte0 >> 4) + (@as(u20, try source.readByte()) << 4),
                3 => regenerated_size = (byte0 >> 4) +
                    (@as(u20, try source.readByte()) << 4) +
                    (@as(u20, try source.readByte()) << 12),
            }
        },
        .compressed, .treeless => {
            const byte1 = try source.readByte();
            const byte2 = try source.readByte();
            switch (size_format) {
                0, 1 => {
                    regenerated_size = (byte0 >> 4) + ((@as(u20, byte1) & 0b00111111) << 4);
                    compressed_size = ((byte1 & 0b11000000) >> 6) + (@as(u18, byte2) << 2);
                },
                2 => {
                    const byte3 = try source.readByte();
                    regenerated_size = (byte0 >> 4) + (@as(u20, byte1) << 4) + ((@as(u20, byte2) & 0b00000011) << 12);
                    compressed_size = ((byte2 & 0b11111100) >> 2) + (@as(u18, byte3) << 6);
                },
                3 => {
                    const byte3 = try source.readByte();
                    const byte4 = try source.readByte();
                    regenerated_size = (byte0 >> 4) + (@as(u20, byte1) << 4) + ((@as(u20, byte2) & 0b00111111) << 12);
                    compressed_size = ((byte2 & 0b11000000) >> 6) + (@as(u18, byte3) << 2) + (@as(u18, byte4) << 10);
                },
            }
        },
    }
    return LiteralsSection.Header{
        .block_type = block_type,
        .size_format = size_format,
        .regenerated_size = regenerated_size,
        .compressed_size = compressed_size,
    };
}

/// Decode a sequences section header.
///
/// Errors returned:
///   - `error.ReservedBitSet` if the reserved bit is set
///   - `error.EndOfStream` if there are not enough bytes in `source`
pub fn decodeSequencesHeader(
    source: anytype,
) !SequencesSection.Header {
    var sequence_count: u24 = undefined;

    const byte0 = try source.readByte();
    if (byte0 == 0) {
        return SequencesSection.Header{
            .sequence_count = 0,
            .offsets = undefined,
            .match_lengths = undefined,
            .literal_lengths = undefined,
        };
    } else if (byte0 < 128) {
        sequence_count = byte0;
    } else if (byte0 < 255) {
        sequence_count = (@as(u24, (byte0 - 128)) << 8) + try source.readByte();
    } else {
        sequence_count = (try source.readByte()) + (@as(u24, try source.readByte()) << 8) + 0x7F00;
    }

    const compression_modes = try source.readByte();

    const matches_mode = @as(SequencesSection.Header.Mode, @enumFromInt((compression_modes & 0b00001100) >> 2));
    const offsets_mode = @as(SequencesSection.Header.Mode, @enumFromInt((compression_modes & 0b00110000) >> 4));
    const literal_mode = @as(SequencesSection.Header.Mode, @enumFromInt((compression_modes & 0b11000000) >> 6));
    if (compression_modes & 0b11 != 0) return error.ReservedBitSet;

    return SequencesSection.Header{
        .sequence_count = sequence_count,
        .offsets = offsets_mode,
        .match_lengths = matches_mode,
        .literal_lengths = literal_mode,
    };
}