stuff-from-scratch/src/compression/huffman/HuffmanStream.cpp
2022-11-30 20:53:17 +00:00

346 lines
9.7 KiB
C++

#include "HuffmanStream.h"
#include "ByteUtils.h"
#include "HuffmanFixedCodes.h"
#include <iostream>
#include <algorithm>
#include <unordered_map>
#include <sstream>
std::vector<unsigned> DISTANCE_OFFSETS
{
5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193,
258, 385, 513, 769, 1025, 1537, 2049, 3073, 4097,
6145, 8193, 12289, 16385, 24577
};
HuffmanStream::HuffmanStream(BitStream* inputStream, BitStream* outputStream)
: mInputStream(inputStream),
mOutputStream(outputStream)
{
}
void HuffmanStream::generateFixedCodeMapping()
{
mUsingFixedCodes = true;
mCodeLengthTable.setInputLengthSequence(HuffmanFixedCodes::getDeflateFixedHuffmanCodes(), false);
mCodeLengthTable.buildPrefixCodes();
}
bool HuffmanStream::readNextCodeLengthSymbol(unsigned& buffer)
{
return mCodeLengthTable.readNextSymbol(buffer, mInputStream);
}
bool HuffmanStream::readNextLiteralSymbol(unsigned& buffer)
{
return mLiteralTable.readNextSymbol(buffer, mInputStream);
}
bool HuffmanStream::readNextDistanceSymbol(unsigned& buffer)
{
unsigned base_symbol{0};
unsigned char extra_bits{0};
const auto valid = mDistanceTable.readNextSymbol(base_symbol, mInputStream);
if (!valid)
{
return false;
}
//std::cout << "Got distance base symbol " << base_symbol << std::endl;
if (base_symbol <= 3)
{
buffer = 1 + base_symbol;
}
else
{
const auto num_extra_bits = (base_symbol - 3 - 1)/2 + 1;
unsigned extra_sum{0};
if (num_extra_bits > 8)
{
auto byte_val = *mInputStream->readNextByte();
mInputStream->readNextNBits(num_extra_bits-8, extra_bits);
extra_sum = extra_bits;
extra_sum = extra_sum << (num_extra_bits - 8);
extra_sum |= byte_val;
}
else
{
mInputStream->readNextNBits(num_extra_bits, extra_bits);
extra_sum = extra_bits;
}
buffer = DISTANCE_OFFSETS[base_symbol - 4] + extra_sum;
}
return true;
}
void HuffmanStream::addValue(unsigned value, unsigned& count, unsigned& lastValue, std::vector<unsigned char>& literals, unsigned numLiterals, std::vector<unsigned char>& distances)
{
if (count < mNumLiterals)
{
literals[count] = value;
}
else
{
distances[count - mNumLiterals] = value;
}
lastValue = value;
count++;
}
void HuffmanStream::readCodeLengths()
{
std::vector<unsigned char> literal_lengths(288, 0);
std::vector<unsigned char> distance_lengths(32, 0);
unsigned symbol{0};
unsigned count{0};
unsigned last_value{0};
while(count < mNumLiterals + mNumDistances)
{
bool valid = readNextCodeLengthSymbol(symbol);
if (!valid)
{
//std::cout << "Hit unknown symbol - bailing out" << std::endl;
break;
}
if (symbol < 16)
{
addValue(symbol, count, last_value, literal_lengths, mNumLiterals, distance_lengths);
}
else if(symbol == 16)
{
unsigned char num_reps{0};
mInputStream->readNextNBits(2, num_reps);
//std::cout << "Got val 16 doing " << 3 + num_reps << std::endl;
for(unsigned char idx=0; idx< 3 + num_reps; idx++)
{
addValue(last_value, count, last_value, literal_lengths, mNumLiterals, distance_lengths);
}
}
else if(symbol == 17)
{
unsigned char num_reps{0};
mInputStream->readNextNBits(3, num_reps);
//std::cout << "Got val 17 doing " << 3 + num_reps << std::endl;
for(unsigned char idx=0; idx< 3 + num_reps; idx++)
{
addValue(0, count, last_value, literal_lengths, mNumLiterals, distance_lengths);
}
}
else if(symbol == 18)
{
unsigned char num_reps{0};
mInputStream->readNextNBits(7, num_reps);
//std::cout << "Got val 18 doing " << 11 + num_reps << std::endl;
for(unsigned idx=0; idx< 11 + unsigned(num_reps); idx++)
{
addValue(0, count, last_value, literal_lengths, mNumLiterals, distance_lengths);
}
}
}
//std::cout << "Got final literal length sequence " << std::endl;
for(unsigned idx=0; idx<literal_lengths.size(); idx++)
{
//std::cout << static_cast<int>(literal_lengths[idx]) << "," ;
}
//std::cout << std::endl;
//std::cout << "Got final distance length sequence " << std::endl;
for(unsigned idx=0; idx<distance_lengths.size(); idx++)
{
//std::cout << static_cast<int>(distance_lengths[idx]) << "," ;
}
//std::cout << std::endl;
mLiteralTable.setInputLengthSequence(literal_lengths, false);
mLiteralTable.buildPrefixCodes();
mDistanceTable.setInputLengthSequence(distance_lengths, false);
mDistanceTable.buildPrefixCodes();
}
void HuffmanStream::copyFromBuffer(unsigned length, unsigned distance)
{
std::size_t offset = mBuffer.size() - 1 - distance;
for(unsigned idx=0; idx<length; idx++)
{
auto symbol = mBuffer[offset + idx];
mOutputStream->writeByte(symbol);
mBuffer.push_back(symbol);
}
}
void HuffmanStream::readSymbols()
{
bool hit_end_stream{false};
unsigned symbol{0};
unsigned distance{0};
while(!hit_end_stream)
{
const auto valid = readNextLiteralSymbol(symbol);
if (!valid)
{
//std::cout << "Hit unknown symbol - bailing out" << std::endl;
break;
}
//std::cout << "Got symbol " << symbol << std::endl;
if(symbol <= 255)
{
mOutputStream->writeByte(symbol);
mBuffer.push_back(symbol);
}
else if(symbol == 256)
{
hit_end_stream = true;
break;
}
else if (symbol <= 264)
{
auto length = 3 + symbol - 257;
const auto valid_dist = readNextDistanceSymbol(distance);
copyFromBuffer(length, distance);
}
else if (symbol <= 268)
{
unsigned char extra{0};
mInputStream->readNextNBits(1, extra);
auto length = 11 + 2*(symbol - 265) + extra;
const auto valid_dist = readNextDistanceSymbol(distance);
copyFromBuffer(length, distance);
}
else if (symbol <= 272)
{
unsigned char extra{0};
mInputStream->readNextNBits(2, extra);
auto length = 19 + 4*(symbol - 269) + extra;
const auto valid_dist = readNextDistanceSymbol(distance);
copyFromBuffer(length, distance);
}
else if (symbol <= 276)
{
unsigned char extra{0};
mInputStream->readNextNBits(3, extra);
auto length = 35 + 8*(symbol - 273) + extra;
const auto valid_dist = readNextDistanceSymbol(distance);
copyFromBuffer(length, distance);
}
else if (symbol <= 280)
{
unsigned char extra{0};
mInputStream->readNextNBits(4, extra);
auto length = 67 + 16*(symbol - 277) + extra;
const auto valid_dist = readNextDistanceSymbol(distance);
copyFromBuffer(length, distance);
}
else if (symbol <= 284)
{
unsigned char extra{0};
mInputStream->readNextNBits(5, extra);
auto length = 131 + 32*(symbol - 281) + extra;
const auto valid_dist = readNextDistanceSymbol(distance);
copyFromBuffer(length, distance);
}
else if (symbol == 285)
{
auto length = 258;
const auto valid_dist = readNextDistanceSymbol(distance);
copyFromBuffer(length, distance);
}
}
if (hit_end_stream)
{
//std::cout << "Found end of stream ok" << std::endl;
}
}
bool HuffmanStream::decode()
{
if (!mUsingFixedCodes)
{
readCodingsTable();
readSymbols();
//std::cout << "Got final buffer size " << mBuffer.size() << std::endl;
for(unsigned idx=0; idx< 100; idx++)
{
//std::cout << idx << " | " << mBuffer[idx] << std::endl;
}
}
else
{
bool found_end_seq{false};
unsigned symbol{0};
while(!found_end_seq)
{
bool valid = readNextCodeLengthSymbol(symbol);
if (!valid)
{
//std::cout << "Hit unknown symbol - bailing out" << std::endl;
break;
}
if (symbol == 256)
{
found_end_seq = true;
break;
}
}
}
return false;
}
void HuffmanStream::readCodingsTable()
{
unsigned char h_lit{0};
mInputStream->readNextNBits(5, h_lit);
mNumLiterals = h_lit + 257;
//std::cout << "Got HLIT " << mNumLiterals << std::endl;
unsigned char h_dist{0};
mInputStream->readNextNBits(5, h_dist);
mNumDistances = h_dist + 1;
//std::cout << "Got HDIST " << mNumDistances << std::endl;
unsigned char h_clen{0};
mInputStream->readNextNBits(4, h_clen);
unsigned num_code_lengths = h_clen + 4;
//std::cout << "Got HCLEN " << num_code_lengths << std::endl;
auto sequence = std::vector<unsigned char>(num_code_lengths, 0);
unsigned char buffer{0};
for(unsigned idx = 0; idx< num_code_lengths; idx++)
{
mInputStream->readNextNBits(3, buffer);
//std::cout << "Got coding table value " << idx << " | " << static_cast<int>(buffer) << " | " << ByteUtils::toString(buffer) << std::endl;
sequence[idx] = buffer;
}
mCodeLengthTable.setInputLengthSequence(sequence, true);
mCodeLengthTable.buildPrefixCodes();
readCodeLengths();
}