Merge remote-tracking branch 'arithmetic_coding/master'

This commit is contained in:
Connor Olding 2018-10-11 16:45:29 +02:00
commit e9c0c0b245

251
ac_encode.py Normal file
View File

@ -0,0 +1,251 @@
# Arithmetic coding compressor and decompressor for binary strings.
# via: http://www.inference.org.uk/mackay/python/compress/ac/ac_encode.py
# main page: http://www.inference.org.uk/mackay/python/compress/
# this has been cleaned up (passes pycodestyle) and ported to python 3.
# default prior distribution
BETA0 = 1
BETA1 = 1
M = 30
ONE = 1 << M
HALF = 1 << (M - 1)
QUARTER = 1 << (M - 2)
THREEQU = HALF + QUARTER
def clear(c, charstack):
# print out character c, and other queued characters
a = repr(c) + repr(1 - c) * charstack[0]
charstack[0] = 0
return a
def encode(string, c0=BETA0, c1=BETA1, adaptive=True):
assert c0 > 0
assert c1 > 0
b = ONE
a = 0
if not adaptive:
p0 = c0 / (c0 + c1)
ans = ""
charstack = [0] # how many undecided characters remain to print
for c in string:
w = b - a
if adaptive:
cT = c0 + c1
p0 = c0 / cT
boundary = a + int(p0 * w)
# these warnings mean that some of the probabilities
# requested by the probabilistic model are so small
# (compared to our integers) that we had to round them up
# to bigger values.
if boundary == a:
boundary += 1
print("warningA")
if boundary == b:
boundary -= 1
print("warningB")
if c == '1':
a = boundary
if adaptive:
c1 += 1
elif c == '0':
b = boundary
if adaptive:
c0 += 1
# ignore other characters
while a >= HALF or b <= HALF: # output bits
if a >= HALF:
ans += clear(1, charstack)
a -= HALF
b -= HALF
else:
ans += clear(0, charstack)
a *= 2
b *= 2
assert a <= HALF
assert b >= HALF
assert a >= 0
assert b <= ONE
# if the gap a-b is getting small, rescale it
while a > QUARTER and b < THREEQU:
charstack[0] += 1
a *= 2
b *= 2
a -= HALF
b -= HALF
assert a <= HALF
assert b >= HALF
assert a >= 0
assert b <= ONE
# terminate
if HALF - a > b - HALF:
w = HALF - a
ans += clear(0, charstack)
while w < HALF:
ans += clear(1, charstack)
w *= 2
else:
w = b - HALF
ans += clear(1, charstack)
while w < HALF:
ans += clear(0, charstack)
w *= 2
return ans
def decode(string, N, c0=BETA0, c1=BETA1, adaptive=True):
# must supply N, the number of source characters remaining.
assert c0 > 0
assert c1 > 0
b = ONE
a = 0
model_needs_updating = True
if not adaptive:
p0 = c0 / (c0 + c1)
ans = ""
u = 0
v = ONE
for c in string:
if N <= 0:
break # out of the string-reading loop
assert N > 0
# (u,v) is the current "encoded alphabet" binary interval,
# and halfway is its midpoint.
# (a,b) is the current "source alphabet" interval,
# and boundary is the "midpoint"
assert u >= 0
assert v <= ONE
halfway = u + (v - u) / 2
if c == '1':
u = halfway
elif c == '0':
v = halfway
# Read bits until we can decide what the source symbol was.
# Then emulate the encoder's computations,
# and tie (u,v) to tag along for the ride.
while 1: # do-while
if model_needs_updating:
w = b - a
if adaptive:
cT = c0 + c1
p0 = c0 / cT
boundary = a + int(p0 * w)
if boundary == a:
boundary += 1
print("warningA")
if boundary == b:
boundary -= 1
print("warningB")
model_needs_updating = False
if boundary <= u:
ans += "1"
if adaptive:
c1 += 1
a = boundary
model_needs_updating = True
N -= 1
elif boundary >= v:
ans += "0"
if adaptive:
c0 += 1
b = boundary
model_needs_updating = True
N -= 1
else:
# not enough bits have yet been read to know the decision.
pass
# emulate outputting of bits by the encoder,
# and tie (u,v) to tag along for the ride.
while a >= HALF or b <= HALF:
if a >= HALF:
a -= HALF
b -= HALF
u -= HALF
v -= HALF
a *= 2
b *= 2
u *= 2
v *= 2
model_needs_updating = True
assert a <= HALF
assert b >= HALF
assert a >= 0
assert b <= ONE
# if the gap a-b is getting small, rescale it
while a > QUARTER and b < THREEQU:
a *= 2
b *= 2
u *= 2
v *= 2
a -= HALF
b -= HALF
u -= HALF
v -= HALF
# this is the condition for this do-while loop
if not (N > 0 and model_needs_updating):
break
return ans
def test():
tests = [
"1010",
"111",
"00001000000000000000",
"1",
"10",
"01",
"0",
"0000000",
"""
00000000000000010000000000000000
00000000000000001000000000000000
00011000000
""",
]
for s in tests:
# an ugly way to remove whitespace and newlines from the test strings:
s = "".join(s.split())
N = len(s) # required for decoding later.
print("original:", s)
e = encode(s, c0=10, c1=1)
print("encoded: ", e)
ds = decode(e, N, c0=10, c1=1)
print("decoded: ", ds)
if ds != s:
print("FAIL")
else:
print("PASS")
print()
if __name__ == '__main__':
test()