#include "HuffmanStream.h" #include "ByteUtils.h" #include "HuffmanFixedCodes.h" #include #include #include #include std::vector DISTANCE_OFFSETS { 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 258, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577 }; HuffmanStream::HuffmanStream(BitStream* inputStream, BitStream* outputStream) : mInputStream(inputStream), mOutputStream(outputStream) { } void HuffmanStream::generateFixedCodeMapping() { mUsingFixedCodes = true; mCodeLengthTable.setInputLengthSequence(HuffmanFixedCodes::getDeflateFixedHuffmanCodes(), false); mCodeLengthTable.buildPrefixCodes(); } bool HuffmanStream::readNextCodeLengthSymbol(unsigned& buffer) { return mCodeLengthTable.readNextSymbol(buffer, mInputStream); } bool HuffmanStream::readNextLiteralSymbol(unsigned& buffer) { return mLiteralTable.readNextSymbol(buffer, mInputStream); } bool HuffmanStream::readNextDistanceSymbol(unsigned& buffer) { unsigned base_symbol{0}; unsigned char extra_bits{0}; const auto valid = mDistanceTable.readNextSymbol(base_symbol, mInputStream); if (!valid) { return false; } //std::cout << "Got distance base symbol " << base_symbol << std::endl; if (base_symbol <= 3) { buffer = 1 + base_symbol; } else { const auto num_extra_bits = (base_symbol - 3 - 1)/2 + 1; unsigned extra_sum{0}; if (num_extra_bits > 8) { auto byte_val = *mInputStream->readNextByte(); mInputStream->readNextNBits(num_extra_bits-8, extra_bits); extra_sum = extra_bits; extra_sum = extra_sum << (num_extra_bits - 8); extra_sum |= byte_val; } else { mInputStream->readNextNBits(num_extra_bits, extra_bits); extra_sum = extra_bits; } buffer = DISTANCE_OFFSETS[base_symbol - 4] + extra_sum; } return true; } void HuffmanStream::addValue(unsigned value, unsigned& count, unsigned& lastValue, std::vector& literals, unsigned numLiterals, std::vector& distances) { if (count < mNumLiterals) { literals[count] = value; } else { distances[count - mNumLiterals] = value; } lastValue = value; count++; } void HuffmanStream::readCodeLengths() { std::vector literal_lengths(288, 0); std::vector distance_lengths(32, 0); unsigned symbol{0}; unsigned count{0}; unsigned last_value{0}; while(count < mNumLiterals + mNumDistances) { bool valid = readNextCodeLengthSymbol(symbol); if (!valid) { //std::cout << "Hit unknown symbol - bailing out" << std::endl; break; } if (symbol < 16) { addValue(symbol, count, last_value, literal_lengths, mNumLiterals, distance_lengths); } else if(symbol == 16) { unsigned char num_reps{0}; mInputStream->readNextNBits(2, num_reps); //std::cout << "Got val 16 doing " << 3 + num_reps << std::endl; for(unsigned char idx=0; idx< 3 + num_reps; idx++) { addValue(last_value, count, last_value, literal_lengths, mNumLiterals, distance_lengths); } } else if(symbol == 17) { unsigned char num_reps{0}; mInputStream->readNextNBits(3, num_reps); //std::cout << "Got val 17 doing " << 3 + num_reps << std::endl; for(unsigned char idx=0; idx< 3 + num_reps; idx++) { addValue(0, count, last_value, literal_lengths, mNumLiterals, distance_lengths); } } else if(symbol == 18) { unsigned char num_reps{0}; mInputStream->readNextNBits(7, num_reps); //std::cout << "Got val 18 doing " << 11 + num_reps << std::endl; for(unsigned idx=0; idx< 11 + unsigned(num_reps); idx++) { addValue(0, count, last_value, literal_lengths, mNumLiterals, distance_lengths); } } } //std::cout << "Got final literal length sequence " << std::endl; for(unsigned idx=0; idx(literal_lengths[idx]) << "," ; } //std::cout << std::endl; //std::cout << "Got final distance length sequence " << std::endl; for(unsigned idx=0; idx(distance_lengths[idx]) << "," ; } //std::cout << std::endl; mLiteralTable.setInputLengthSequence(literal_lengths, false); mLiteralTable.buildPrefixCodes(); mDistanceTable.setInputLengthSequence(distance_lengths, false); mDistanceTable.buildPrefixCodes(); } void HuffmanStream::copyFromBuffer(unsigned length, unsigned distance) { std::size_t offset = mBuffer.size() - 1 - distance; for(unsigned idx=0; idxwriteByte(symbol); mBuffer.push_back(symbol); } } void HuffmanStream::readSymbols() { bool hit_end_stream{false}; unsigned symbol{0}; unsigned distance{0}; while(!hit_end_stream) { const auto valid = readNextLiteralSymbol(symbol); if (!valid) { //std::cout << "Hit unknown symbol - bailing out" << std::endl; break; } //std::cout << "Got symbol " << symbol << std::endl; if(symbol <= 255) { mOutputStream->writeByte(symbol); mBuffer.push_back(symbol); } else if(symbol == 256) { hit_end_stream = true; break; } else if (symbol <= 264) { auto length = 3 + symbol - 257; const auto valid_dist = readNextDistanceSymbol(distance); copyFromBuffer(length, distance); } else if (symbol <= 268) { unsigned char extra{0}; mInputStream->readNextNBits(1, extra); auto length = 11 + 2*(symbol - 265) + extra; const auto valid_dist = readNextDistanceSymbol(distance); copyFromBuffer(length, distance); } else if (symbol <= 272) { unsigned char extra{0}; mInputStream->readNextNBits(2, extra); auto length = 19 + 4*(symbol - 269) + extra; const auto valid_dist = readNextDistanceSymbol(distance); copyFromBuffer(length, distance); } else if (symbol <= 276) { unsigned char extra{0}; mInputStream->readNextNBits(3, extra); auto length = 35 + 8*(symbol - 273) + extra; const auto valid_dist = readNextDistanceSymbol(distance); copyFromBuffer(length, distance); } else if (symbol <= 280) { unsigned char extra{0}; mInputStream->readNextNBits(4, extra); auto length = 67 + 16*(symbol - 277) + extra; const auto valid_dist = readNextDistanceSymbol(distance); copyFromBuffer(length, distance); } else if (symbol <= 284) { unsigned char extra{0}; mInputStream->readNextNBits(5, extra); auto length = 131 + 32*(symbol - 281) + extra; const auto valid_dist = readNextDistanceSymbol(distance); copyFromBuffer(length, distance); } else if (symbol == 285) { auto length = 258; const auto valid_dist = readNextDistanceSymbol(distance); copyFromBuffer(length, distance); } } if (hit_end_stream) { //std::cout << "Found end of stream ok" << std::endl; } } bool HuffmanStream::decode() { if (!mUsingFixedCodes) { readCodingsTable(); readSymbols(); //std::cout << "Got final buffer size " << mBuffer.size() << std::endl; for(unsigned idx=0; idx< 100; idx++) { //std::cout << idx << " | " << mBuffer[idx] << std::endl; } } else { bool found_end_seq{false}; unsigned symbol{0}; while(!found_end_seq) { bool valid = readNextCodeLengthSymbol(symbol); if (!valid) { //std::cout << "Hit unknown symbol - bailing out" << std::endl; break; } if (symbol == 256) { found_end_seq = true; break; } } } return false; } void HuffmanStream::readCodingsTable() { unsigned char h_lit{0}; mInputStream->readNextNBits(5, h_lit); mNumLiterals = h_lit + 257; //std::cout << "Got HLIT " << mNumLiterals << std::endl; unsigned char h_dist{0}; mInputStream->readNextNBits(5, h_dist); mNumDistances = h_dist + 1; //std::cout << "Got HDIST " << mNumDistances << std::endl; unsigned char h_clen{0}; mInputStream->readNextNBits(4, h_clen); unsigned num_code_lengths = h_clen + 4; //std::cout << "Got HCLEN " << num_code_lengths << std::endl; auto sequence = std::vector(num_code_lengths, 0); unsigned char buffer{0}; for(unsigned idx = 0; idx< num_code_lengths; idx++) { mInputStream->readNextNBits(3, buffer); //std::cout << "Got coding table value " << idx << " | " << static_cast(buffer) << " | " << ByteUtils::toString(buffer) << std::endl; sequence[idx] = buffer; } mCodeLengthTable.setInputLengthSequence(sequence, true); mCodeLengthTable.buildPrefixCodes(); readCodeLengths(); }