#
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import math

class Bitstream:

    def __init__(self, data):

        self.bytes = data

        self.bp_bw = len(data) - 1
        self.mask_bw = 1

        self.bp = 0
        self.low = 0
        self.range = 0xffffff

    def dump(self):

        b = self.bytes

        for i in range(0, len(b), 20):
            print(''.join('{:02x} '.format(x)
                for x in b[i:min(i+20, len(b))] ))

class BitstreamReader(Bitstream):

    def __init__(self, data):

        super().__init__(data)

        self.low = ( (self.bytes[0] << 16) |
                     (self.bytes[1] <<  8) |
                     (self.bytes[2]      ) )
        self.bp = 3

    def read_bit(self):

        bit = bool(self.bytes[self.bp_bw] & self.mask_bw)

        self.mask_bw <<= 1
        if self.mask_bw == 0x100:
            self.mask_bw = 1
            self.bp_bw -= 1

        return bit

    def read_uint(self, nbits):

        val = 0
        for k in range(nbits):
            val |= self.read_bit() << k

        return val

    def ac_decode(self, cum_freqs, sym_freqs):

        r = self.range >> 10
        if self.low >= r << 10:
            raise ValueError('Invalid ac bitstream')

        val = len(cum_freqs) - 1
        while self.low < r * cum_freqs[val]:
            val -= 1

        self.low -= r * cum_freqs[val]
        self.range = r * sym_freqs[val]
        while self.range < 0x10000:
            self.range <<= 8

            self.low <<= 8
            self.low &= 0xffffff
            self.low += self.bytes[self.bp]
            self.bp += 1

        return val

    def get_bits_left(self):

        nbits = 8 * len(self.bytes)

        nbits_bw = nbits - \
            (8*self.bp_bw + 8 - int(math.log2(self.mask_bw)))

        nbits_ac = 8 * (self.bp - 3) + \
            (25 - int(math.floor(math.log2(self.range))))

        return nbits - (nbits_bw + nbits_ac)

class BitstreamWriter(Bitstream):

    def __init__(self, nbytes):

        super().__init__(bytearray(nbytes))

        self.cache = -1
        self.carry = 0
        self.carry_count = 0

    def write_bit(self, bit):

        mask = self.mask_bw
        bp = self.bp_bw

        if bit == 0:
            self.bytes[bp] &= ~mask
        else:
            self.bytes[bp] |= mask

        self.mask_bw <<= 1
        if self.mask_bw == 0x100:
            self.mask_bw = 1
            self.bp_bw -= 1

    def write_uint(self, val, nbits):

        for k in range(nbits):
            self.write_bit(val & 1)
            val >>= 1

    def ac_shift(self):

        if self.low < 0xff0000 or self.carry == 1:

            if self.cache >= 0:
                self.bytes[self.bp] = self.cache + self.carry
                self.bp += 1

            while self.carry_count > 0:
                self.bytes[self.bp] = (self.carry + 0xff) & 0xff
                self.bp += 1
                self.carry_count -= 1

            self.cache = self.low >> 16
            self.carry = 0

        else:
            self.carry_count += 1

        self.low <<= 8
        self.low &= 0xffffff

    def ac_encode(self, cum_freq, sym_freq):

        r = self.range >> 10
        self.low += r * cum_freq
        if (self.low >> 24) != 0:
            self.carry = 1

        self.low &= 0xffffff
        self.range = r * sym_freq
        while self.range < 0x10000:
            self.range <<= 8;
            self.ac_shift()

    def get_bits_left(self):

        nbits = 8 * len(self.bytes)

        nbits_bw = nbits - \
            (8*self.bp_bw + 8 - int(math.log2(self.mask_bw)))

        nbits_ac = 8 * self.bp + (25 - int(math.floor(math.log2(self.range))))
        if self.cache >= 0:
            nbits_ac += 8
        if self.carry_count > 0:
            nbits_ac += 8 * self.carry_count

        return nbits - (nbits_bw + nbits_ac)

    def terminate(self):

        bits = 1
        while self.range >> (24 - bits) == 0:
            bits += 1

        mask = 0xffffff >> bits;
        val = self.low + mask;

        over1 = val >> 24
        val &= 0x00ffffff
        high = self.low + self.range
        over2 = high >> 24
        high &= 0x00ffffff
        val = val & ~mask

        if over1 == over2:

            if val + mask >= high:
                bits += 1
                mask >>= 1
                val = ((self.low + mask) & 0x00ffffff) & ~mask

            if val < self.low:
                self.carry = 1

        self.low = val
        while bits > 0:
            self.ac_shift()
            bits -= 8
        bits += 8;

        val = self.cache

        if self.carry_count > 0:
            self.bytes[self.bp] = self.cache
            self.bp += 1

            while self.carry_count > 1:
                self.bytes[self.bp] = 0xff
                self.bp += 1
                self.carry_count -= 1

            val = 0xff >> (8 - bits)

        mask = 0x80;
        for k in range(bits):

            if val & mask == 0:
                self.bytes[self.bp] &= ~mask
            else:
                self.bytes[self.bp] |= mask

            mask >>= 1

        return self.bytes
