diff --git a/ac_encode.py b/ac_encode.py new file mode 100644 index 0000000..9ed7171 --- /dev/null +++ b/ac_encode.py @@ -0,0 +1,251 @@ +# Arithmetic coding compressor and decompressor for binary strings. +# via: http://www.inference.org.uk/mackay/python/compress/ac/ac_encode.py +# main page: http://www.inference.org.uk/mackay/python/compress/ +# this has been cleaned up (passes pycodestyle) and ported to python 3. + +# default prior distribution +BETA0 = 1 +BETA1 = 1 + +M = 30 +ONE = 1 << M +HALF = 1 << (M - 1) +QUARTER = 1 << (M - 2) +THREEQU = HALF + QUARTER + + +def clear(c, charstack): + # print out character c, and other queued characters + a = repr(c) + repr(1 - c) * charstack[0] + charstack[0] = 0 + return a + + +def encode(string, c0=BETA0, c1=BETA1, adaptive=True): + assert c0 > 0 + assert c1 > 0 + + b = ONE + a = 0 + if not adaptive: + p0 = c0 / (c0 + c1) + ans = "" + charstack = [0] # how many undecided characters remain to print + + for c in string: + w = b - a + if adaptive: + cT = c0 + c1 + p0 = c0 / cT + boundary = a + int(p0 * w) + + # these warnings mean that some of the probabilities + # requested by the probabilistic model are so small + # (compared to our integers) that we had to round them up + # to bigger values. + if boundary == a: + boundary += 1 + print("warningA") + if boundary == b: + boundary -= 1 + print("warningB") + + if c == '1': + a = boundary + if adaptive: + c1 += 1 + elif c == '0': + b = boundary + if adaptive: + c0 += 1 + # ignore other characters + + while a >= HALF or b <= HALF: # output bits + if a >= HALF: + ans += clear(1, charstack) + a -= HALF + b -= HALF + else: + ans += clear(0, charstack) + a *= 2 + b *= 2 + + assert a <= HALF + assert b >= HALF + assert a >= 0 + assert b <= ONE + + # if the gap a-b is getting small, rescale it + while a > QUARTER and b < THREEQU: + charstack[0] += 1 + a *= 2 + b *= 2 + a -= HALF + b -= HALF + + assert a <= HALF + assert b >= HALF + assert a >= 0 + assert b <= ONE + + # terminate + if HALF - a > b - HALF: + w = HALF - a + ans += clear(0, charstack) + while w < HALF: + ans += clear(1, charstack) + w *= 2 + else: + w = b - HALF + ans += clear(1, charstack) + while w < HALF: + ans += clear(0, charstack) + w *= 2 + + return ans + + +def decode(string, N, c0=BETA0, c1=BETA1, adaptive=True): + # must supply N, the number of source characters remaining. + assert c0 > 0 + assert c1 > 0 + + b = ONE + a = 0 + model_needs_updating = True + if not adaptive: + p0 = c0 / (c0 + c1) + ans = "" + + u = 0 + v = ONE + for c in string: + if N <= 0: + break # out of the string-reading loop + assert N > 0 + + # (u,v) is the current "encoded alphabet" binary interval, + # and halfway is its midpoint. + # (a,b) is the current "source alphabet" interval, + # and boundary is the "midpoint" + assert u >= 0 + assert v <= ONE + halfway = u + (v - u) / 2 + if c == '1': + u = halfway + elif c == '0': + v = halfway + + # Read bits until we can decide what the source symbol was. + # Then emulate the encoder's computations, + # and tie (u,v) to tag along for the ride. + while 1: # do-while + if model_needs_updating: + w = b - a + if adaptive: + cT = c0 + c1 + p0 = c0 / cT + boundary = a + int(p0 * w) + if boundary == a: + boundary += 1 + print("warningA") + if boundary == b: + boundary -= 1 + print("warningB") + model_needs_updating = False + + if boundary <= u: + ans += "1" + if adaptive: + c1 += 1 + a = boundary + model_needs_updating = True + N -= 1 + elif boundary >= v: + ans += "0" + if adaptive: + c0 += 1 + b = boundary + model_needs_updating = True + N -= 1 + else: + # not enough bits have yet been read to know the decision. + pass + + # emulate outputting of bits by the encoder, + # and tie (u,v) to tag along for the ride. + while a >= HALF or b <= HALF: + if a >= HALF: + a -= HALF + b -= HALF + u -= HALF + v -= HALF + a *= 2 + b *= 2 + u *= 2 + v *= 2 + model_needs_updating = True + + assert a <= HALF + assert b >= HALF + assert a >= 0 + assert b <= ONE + + # if the gap a-b is getting small, rescale it + while a > QUARTER and b < THREEQU: + a *= 2 + b *= 2 + u *= 2 + v *= 2 + a -= HALF + b -= HALF + u -= HALF + v -= HALF + + # this is the condition for this do-while loop + if not (N > 0 and model_needs_updating): + break + + return ans + + +def test(): + tests = [ + "1010", + "111", + "00001000000000000000", + "1", + "10", + "01", + "0", + "0000000", + """ + 00000000000000010000000000000000 + 00000000000000001000000000000000 + 00011000000 + """, + ] + + for s in tests: + # an ugly way to remove whitespace and newlines from the test strings: + s = "".join(s.split()) + + N = len(s) # required for decoding later. + print("original:", s) + + e = encode(s, c0=10, c1=1) + print("encoded: ", e) + + ds = decode(e, N, c0=10, c1=1) + print("decoded: ", ds) + + if ds != s: + print("FAIL") + else: + print("PASS") + + print() + + +if __name__ == '__main__': + test()