# Arithmetic coding compressor and decompressor for binary strings. # via: http://www.inference.org.uk/mackay/python/compress/ac/ac_encode.py # main page: http://www.inference.org.uk/mackay/python/compress/ # this has been cleaned up (passes pycodestyle) and ported to python 3. # default prior distribution BETA0 = 1 BETA1 = 1 M = 30 ONE = 1 << M HALF = 1 << (M - 1) QUARTER = 1 << (M - 2) THREEQU = HALF + QUARTER def clear(c, charstack): # print out character c, and other queued characters a = repr(c) + repr(1 - c) * charstack[0] charstack[0] = 0 return a def encode(string, c0=BETA0, c1=BETA1, adaptive=True): assert c0 > 0 assert c1 > 0 b = ONE a = 0 tot0 = 0 tot1 = 0 if not adaptive: p0 = c0 / (c0 + c1) ans = "" charstack = [0] # how many undecided characters remain to print for c in string: w = b - a if adaptive: cT = c0 + c1 p0 = c0 / cT boundary = a + int(p0 * w) # these warnings mean that some of the probabilities # requested by the probabilistic model are so small # (compared to our integers) that we had to round them up # to bigger values. if boundary == a: boundary += 1 print("warningA") if boundary == b: boundary -= 1 print("warningB") if c == '1': a = boundary tot1 += 1 if adaptive: c1 += 1 elif c == '0': b = boundary tot0 += 1 if adaptive: c0 += 1 # ignore other characters while a >= HALF or b <= HALF: # output bits if a >= HALF: ans += clear(1, charstack) a -= HALF b -= HALF else: ans += clear(0, charstack) a *= 2 b *= 2 assert a <= HALF assert b >= HALF assert a >= 0 assert b <= ONE # if the gap a-b is getting small, rescale it while a > QUARTER and b < THREEQU: charstack[0] += 1 a *= 2 b *= 2 a -= HALF b -= HALF assert a <= HALF assert b >= HALF assert a >= 0 assert b <= ONE # terminate if HALF - a > b - HALF: w = HALF - a ans += clear(0, charstack) while w < HALF: ans += clear(1, charstack) w *= 2 else: w = b - HALF ans += clear(1, charstack) while w < HALF: ans += clear(0, charstack) w *= 2 return ans def decode(string, N, c0=BETA0, c1=BETA1, adaptive=True): # must supply N, the number of source characters remaining. assert c0 > 0 assert c1 > 0 b = ONE a = 0 tot0 = 0 tot1 = 0 model_needs_updating = True if not adaptive: p0 = c0 / (c0 + c1) ans = "" u = 0 v = ONE for c in string: if N <= 0: break # out of the string-reading loop assert N > 0 # (u,v) is the current "encoded alphabet" binary interval, # and halfway is its midpoint. # (a,b) is the current "source alphabet" interval, # and boundary is the "midpoint" assert u >= 0 assert v <= ONE halfway = u + (v - u) / 2 if c == '1': u = halfway elif c == '0': v = halfway # Read bits until we can decide what the source symbol was. # Then emulate the encoder's computations, # and tie (u,v) to tag along for the ride. while 1: # do-while if model_needs_updating: w = b - a if adaptive: cT = c0 + c1 p0 = c0 / cT boundary = a + int(p0 * w) if boundary == a: boundary += 1 print("warningA") if boundary == b: boundary -= 1 print("warningB") model_needs_updating = False if boundary <= u: ans += "1" tot1 += 1 if adaptive: c1 += 1 a = boundary model_needs_updating = True N -= 1 elif boundary >= v: ans += "0" tot0 += 1 if adaptive: c0 += 1 b = boundary model_needs_updating = True N -= 1 else: # not enough bits have yet been read to know the decision. pass # emulate outputting of bits by the encoder, # and tie (u,v) to tag along for the ride. while a >= HALF or b <= HALF: if a >= HALF: a -= HALF b -= HALF u -= HALF v -= HALF a *= 2 b *= 2 u *= 2 v *= 2 model_needs_updating = True assert a <= HALF assert b >= HALF assert a >= 0 assert b <= ONE # if the gap a-b is getting small, rescale it while a > QUARTER and b < THREEQU: a *= 2 b *= 2 u *= 2 v *= 2 a -= HALF b -= HALF u -= HALF v -= HALF # this is the condition for this do-while loop if not (N > 0 and model_needs_updating): break return ans def test(): tests = [ "1010", "111", "00001000000000000000", "1", "10", "01", "0", "0000000", """ 00000000000000010000000000000000 00000000000000001000000000000000 00011000000 """, ] for s in tests: # an ugly way to remove whitespace and newlines from the test strings: s = "".join(s.split()) N = len(s) # required for decoding later. print("original:", s) e = encode(s, 10, 1) print("encoded: ", e) ds = decode(e, N, 10, 1) print("decoded: ", ds) if ds != s: print("FAIL") else: print("PASS") print() if __name__ == '__main__': test()