stuff-from-scratch/src/base/compression/huffman/HuffmanCodeLengthTable.cpp
2023-01-17 10:13:25 +00:00

253 lines
7.3 KiB
C++

#include "HuffmanCodeLengthTable.h"
#include "ByteUtils.h"
#include "RunLengthEncoder.h"
#include "BitStream.h"
#include <algorithm>
#include <sstream>
#include <iostream>
void HuffmanCodeLengthTable::buildCompressedLengthSequence()
{
RunLengthEncoder rl_encoder;
auto rle_encoded = rl_encoder.encode(mInputLengthSequence);
for (const auto& entry : rle_encoded)
{
//std::cout << "Got rle " << static_cast<int>(entry.first) << " | " << entry.second << std::endl;
}
mCompressedLengthSequence.clear();
for (const auto& entry : rle_encoded)
{
const auto length = entry.first;
const auto count = entry.second;
if (count < 3)
{
for(std::size_t idx=0; idx<count; idx++)
{
mCompressedLengthSequence.push_back({length, 0});
}
}
else if (length == 0)
{
std::size_t num_big = count / 138;
for(std::size_t idx=0; idx<num_big; idx++)
{
mCompressedLengthSequence.push_back({18, 127});
}
auto remainder_big = count % 138;
if (remainder_big > 10)
{
mCompressedLengthSequence.push_back({18, remainder_big-11});
}
else if(remainder_big > 2)
{
mCompressedLengthSequence.push_back({17, remainder_big-3});
}
else
{
for(std::size_t idx=0; idx<remainder_big; idx++)
{
mCompressedLengthSequence.push_back({0, 0});
}
}
}
else
{
mCompressedLengthSequence.push_back({length, 0});
auto num_blocks_of_six = (count-1)/6;
for(std::size_t idx=0; idx<num_blocks_of_six; idx++)
{
mCompressedLengthSequence.push_back({16, 3});
}
auto remaining_counts = (count-1) % 6;
if (remaining_counts >= 3)
{
mCompressedLengthSequence.push_back({16, remaining_counts - 3});
}
else
{
for(std::size_t idx=0; idx<remaining_counts; idx++)
{
mCompressedLengthSequence.push_back({length, 0});
}
}
}
}
mCompressedLengthCounts = std::vector<std::size_t>(19, 0);
for (const auto& entry : mCompressedLengthSequence)
{
mCompressedLengthCounts[entry.first]++;
}
}
const std::vector<HuffmanCodeLengthTable::CompressedSequenceEntry>& HuffmanCodeLengthTable::getCompressedLengthSequence() const
{
return mCompressedLengthSequence;
}
const std::vector<std::size_t> HuffmanCodeLengthTable::getCompressedLengthCounts() const
{
return mCompressedLengthCounts;
}
std::optional<PrefixCode> HuffmanCodeLengthTable::getCodeForSymbol(unsigned symbol) const
{
return mTree.getCode(symbol);
}
bool HuffmanCodeLengthTable::readNextSymbol(unsigned& result, BitStream* stream)
{
if (getNumCodeLengths() == 0)
{
return false;
}
std::size_t working_index{0};
auto length = getCodeLength(working_index);
auto delta = length;
bool found{false};
unsigned char buffer{0};
uint32_t working_bits{0};
unsigned working_symbol{0};
while(!found)
{
auto valid = stream->readNextNBits(delta, buffer);
//std::cout << "Got buffer " << ByteUtils::toString(buffer) << std::endl;;
unsigned hold = buffer;
working_bits = working_bits | (hold << (length - delta));
//std::cout << "Read " << delta << " bits with length " << length << " and value " << ByteUtils::toString(working_bits) << std::endl;
if (const auto symbol = findMatch(working_index, working_bits))
{
found = true;
working_symbol = *symbol;
}
else
{
working_index++;
if (working_index >= getNumCodeLengths())
{
break;
}
auto new_length = getCodeLength(working_index);
delta = new_length - length;
length = new_length;
}
}
if (found)
{
result = working_symbol;
// std::cout << "Found symbol " << working_symbol << " with bits " << ByteUtils::toString(working_bits) << std::endl;
// std::cout << "At Byte offset " << stream->getCurrentByteOffset() << " and bit offset " << stream->getCurrentBitOffset() << std::endl;
return true;
}
else
{
//std::cout << "SYMBOL NOT FOUND " << " with bits " << ByteUtils::toString(working_bits) << " and index " << working_index << std::endl;
return false;
}
}
void HuffmanCodeLengthTable::buildPrefixCodes()
{
if(mInputLengthSequence.empty())
{
return;
}
unsigned char max_length = *std::max_element(mInputLengthSequence.begin(), mInputLengthSequence.end());
std::vector<unsigned> counts(max_length+1, 0);
for (const auto length : mInputLengthSequence)
{
counts[length]++;
}
counts[0] = 0;
uint32_t code{0};
std::vector<uint32_t> next_code(max_length + 1, 0);
for (unsigned bits = 1; bits <= max_length; bits++)
{
code = (code + counts[bits-1]) << 1;
//std::cout << "Start code for bit " << bits << " is " << ByteUtils::toString(code) << " | dec " << code << " count " << counts[bits-1] << std::endl;
next_code[bits] = code;
}
for(std::size_t idx=0; idx<mInputLengthSequence.size(); idx++)
{
if (const auto length = mInputLengthSequence[idx]; length != 0)
{
const auto code = next_code[length];
next_code[length]++;
auto prefix_code = PrefixCode(code, length);
mTree.addCodeLengthEntry(length, {PrefixCode(code, length), static_cast<unsigned>(idx)});
mCodes.push_back(prefix_code);
}
}
mTree.sortTable();
//std::cout << dumpPrefixCodes();
}
const PrefixCode& HuffmanCodeLengthTable::getCode(std::size_t index) const
{
return mCodes[index];
}
std::string HuffmanCodeLengthTable::dumpPrefixCodes() const
{
return mTree.dump();
}
std::size_t HuffmanCodeLengthTable::mapToDeflateIndex(std::size_t index) const
{
if (index>= DEFLATE_PERMUTATION_SIZE)
{
return 0;
}
else
{
return DEFLATE_PERMUTATION[index];
}
}
std::size_t HuffmanCodeLengthTable::getNumCodeLengths() const
{
return mTree.getNumCodeLengths();
}
std::optional<HuffmanTree::Symbol> HuffmanCodeLengthTable::findMatch(std::size_t treeIndex, uint32_t code) const
{
return mTree.findMatch(treeIndex, code);
}
unsigned HuffmanCodeLengthTable::getCodeLength(std::size_t index) const
{
return mTree.getCodeLength(index);
}
void HuffmanCodeLengthTable::setInputLengthSequence(const std::vector<unsigned char>& sequence, bool targetDeflate)
{
mTargetDeflate = targetDeflate;
if (targetDeflate)
{
mInputLengthSequence = std::vector<unsigned char>(DEFLATE_PERMUTATION_SIZE, 0);
for(std::size_t idx=0; idx<sequence.size(); idx++)
{
mInputLengthSequence[mapToDeflateIndex(idx)] = sequence[idx];
//std::cout << "Got code length for " << mapToDeflateIndex(idx) << " of " << static_cast<unsigned>(sequence[idx]) << std::endl;
}
}
else
{
mInputLengthSequence = sequence;
}
}