Merge remote-tracking branch 'atttt/master'
This commit is contained in:
commit
61188f00da
3 changed files with 480 additions and 0 deletions
280
atttt.py
Executable file
280
atttt.py
Executable file
|
@ -0,0 +1,280 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
from misc import *
|
||||
from basic import Brain
|
||||
|
||||
|
||||
def align(x, alignment):
|
||||
return (x + alignment // 2) // alignment * alignment
|
||||
|
||||
|
||||
def uniq_rows(a, return_index=False, return_inverse=False, return_counts=False):
|
||||
# via http://stackoverflow.com/a/16973510
|
||||
# black magic wrapper around np.unique
|
||||
return_any = return_index or return_inverse or return_counts
|
||||
if not return_any:
|
||||
np.unique(a.view(np.dtype((np.void, a.dtype.itemsize * a.shape[1])))).view(a.dtype).reshape(-1, a.shape[1])
|
||||
else:
|
||||
void_dtype = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
|
||||
ret = np.unique(a.view(void_dtype), return_index, return_inverse, return_counts)
|
||||
return (ret[0].view(a.dtype).reshape(-1, a.shape[1]),) + ret[1:]
|
||||
|
||||
|
||||
class ATTTT():
|
||||
|
||||
def __init__(self, brain):
|
||||
self.brain = brain
|
||||
self.score = self._score
|
||||
|
||||
|
||||
def _score(self, reply, maxn):
|
||||
if len(reply) > maxn:
|
||||
return -999999999
|
||||
|
||||
#return len(reply)
|
||||
return 1
|
||||
|
||||
|
||||
def reply(self, item=None, maxn=1000, include_scores=False, attempts=None):
|
||||
if attempts == None:
|
||||
# just guess some value that'll take roughly the same amount of time
|
||||
attempts = int(2**12 / self.brain.order)
|
||||
lament('attempts:', attempts)
|
||||
|
||||
replies = []
|
||||
for i in range(attempts):
|
||||
reply = "".join(self.brain.reply(item=item, maxn=maxn+1))
|
||||
replies += [(reply, self.score(reply, maxn))]
|
||||
|
||||
result = sorted(replies, key=lambda t: t[1], reverse=True)[0]
|
||||
|
||||
if include_scores:
|
||||
return result
|
||||
else:
|
||||
return result[0]
|
||||
|
||||
|
||||
class PatternBrain(Brain):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, padding='~', **kwargs)
|
||||
self.tokens = []
|
||||
|
||||
|
||||
def helper(self, v):
|
||||
return (v,)
|
||||
|
||||
|
||||
def resolve_tokens(self, tokens):
|
||||
# positive values are just unicode characters
|
||||
if isinstance(tokens, int) or isinstance(tokens, np.int32):
|
||||
return tokens < 0 and self.tokens[tokens] or chr(tokens)
|
||||
else:
|
||||
return [o < 0 and self.tokens[o] or chr(o) for o in tokens]
|
||||
|
||||
|
||||
def new_token(self, value):
|
||||
new_id = -1 - len(self.tokens)
|
||||
self.tokens[new_id] = value
|
||||
return new_id
|
||||
|
||||
|
||||
@staticmethod
|
||||
def prepare_items(items, pad=True):
|
||||
new_items = []
|
||||
for item in items:
|
||||
item = item.strip('\n')
|
||||
# assert that the number of sequences is a multiple of 2
|
||||
# otherwise we can't .reshape() it to be two-dimensional later on
|
||||
next_biggest = align(len(item) + 1, 2)
|
||||
# initialize with padding (-1)
|
||||
new_item = -np.ones(next_biggest, dtype=np.int32)
|
||||
for i, c in enumerate(item):
|
||||
new_item[i] = ord(c)
|
||||
new_items.append(new_item)
|
||||
|
||||
# add an extra padding item to the head and tail
|
||||
# to make it easier to convert from sequences back to items later on
|
||||
if pad:
|
||||
pad = -np.ones(1, dtype=np.int32)
|
||||
new_items.insert(0, pad)
|
||||
new_items.append(pad)
|
||||
|
||||
return np.concatenate(new_items)
|
||||
|
||||
|
||||
def stat_tokens(self, all_items, skip_normal=False):
|
||||
unique, counts = np.unique(all_items, return_counts=True)
|
||||
count_order = np.argsort(counts)[::-1]
|
||||
counts_descending = counts[count_order]
|
||||
unique_descending = unique[count_order]
|
||||
for i, token_id in enumerate(unique_descending):
|
||||
if token_id == -1:
|
||||
continue
|
||||
if skip_normal and token_id >= 0:
|
||||
continue
|
||||
token = self.resolve_tokens(token_id)
|
||||
lament("token id {:5} occurs {:8} times: \"{}\"".format(
|
||||
token_id, counts_descending[i], token))
|
||||
lament("total tokens: {:5}".format(i + 1))
|
||||
|
||||
|
||||
def merge_all(self, all_items, merges, min_count=2):
|
||||
# set up a 2d array to step through at half the row length;
|
||||
# this means double redundancy; to acquire all the sequences.
|
||||
# we could instead .roll it later to get the other half.
|
||||
# that would require less memory, but memory isn't really a concern.
|
||||
sequences = all_items.repeat(2)[1:-1].reshape(-1, 2).copy()
|
||||
|
||||
for i in range(merges):
|
||||
invalid = np.any(sequences == -1, axis=1)
|
||||
valid_sequences = np.delete(sequences, np.where(invalid), axis=0)
|
||||
unique, counts = uniq_rows(valid_sequences, return_counts=True)
|
||||
count = counts.max()
|
||||
|
||||
most_common = (None, 1)
|
||||
if count > most_common[1]:
|
||||
seq = unique[counts == count][0]
|
||||
most_common = (seq, count)
|
||||
|
||||
if most_common[0] is None or most_common[1] <= 1 or most_common[1] < min_count:
|
||||
lament('no more valid sequences')
|
||||
break
|
||||
|
||||
token_value = "".join(self.resolve_tokens(most_common[0]))
|
||||
new_id = self.new_token(token_value)
|
||||
|
||||
# replace the most common two-token sequence
|
||||
# with one token to represent both
|
||||
found = np.all(sequences == most_common[0], axis=1)
|
||||
before = np.roll(found, -1)
|
||||
after = np.roll(found, 1)
|
||||
# don't wrap around truth values
|
||||
before[-1] = False
|
||||
after[0] = False
|
||||
# remove the "found" sequences
|
||||
# and update the previous/next,
|
||||
# not unlike a doubly-linked list.
|
||||
befores = sequences[before].T.copy()
|
||||
befores[1] = new_id
|
||||
sequences[before] = befores.T
|
||||
afters = sequences[after].T.copy()
|
||||
afters[0] = new_id
|
||||
sequences[after] = afters.T
|
||||
here = np.where(found)
|
||||
sequences = np.delete(sequences, here, axis=0)
|
||||
|
||||
lament("new token id {:5} occurs {:8} times: \"{}\"".format(
|
||||
new_id, len(here[0]), self.tokens[new_id]))
|
||||
|
||||
# reconstruct all_items out of the sequences
|
||||
all_items = sequences.reshape(-1)[::2][1:].copy()
|
||||
return all_items
|
||||
|
||||
|
||||
def learn_all(self, items, merges=0, stat=True):
|
||||
min_count = 2 # minimum number of occurences to stop creating tokens at
|
||||
if merges < 0:
|
||||
min_count = -merges
|
||||
merges = 65536 # arbitrary sanity value
|
||||
|
||||
# we'll use numpy matrices so this isn't nearly as disgustingly slow
|
||||
|
||||
self.tokens = {-1: ''} # default with an empty padding token
|
||||
|
||||
all_items = self.prepare_items(items)
|
||||
|
||||
if merges > 0:
|
||||
all_items = self.merge_all(all_items, merges, min_count)
|
||||
|
||||
# begin the actual learning
|
||||
self.reset()
|
||||
np_item = []
|
||||
for i in all_items:
|
||||
if i == -1:
|
||||
if len(np_item) == 0:
|
||||
continue
|
||||
item = tuple()
|
||||
for i in np_item:
|
||||
if i < 0:
|
||||
assert(i != -1)
|
||||
item += self.helper(self.tokens[i])
|
||||
else:
|
||||
item += self.helper(chr(i))
|
||||
#die(np_item, item)
|
||||
self.learn(item)
|
||||
np_item = []
|
||||
else:
|
||||
np_item.append(i)
|
||||
self.update()
|
||||
|
||||
if merges != 0 and stat:
|
||||
self.stat_tokens(all_items)
|
||||
|
||||
|
||||
def run(pname, args, env):
|
||||
if not 1 <= len(args) <= 2:
|
||||
lament("usage: {} {{input file}} [savestate file]".format(pname))
|
||||
return 1
|
||||
|
||||
args = dict(enumerate(args)) # just for the .get() method
|
||||
|
||||
fn = args[0]
|
||||
state_fn = args.get(1, None)
|
||||
|
||||
# the number of lines to output.
|
||||
count = int(env.get('COUNT', '8'))
|
||||
# learn and sample using this number of sequential tokens.
|
||||
order = int(env.get('ORDER', '2'))
|
||||
# how experimental to be with sampling.
|
||||
# probably doesn't work properly.
|
||||
temperature = float(env.get('TEMPERATURE', '0.5'))
|
||||
# the max character length of output. (not guaranteed)
|
||||
maxn = int(env.get('MAXN', '240'))
|
||||
# attempts to maximize scoring
|
||||
attempts = int(env.get('ATTEMPTS', '-1'))
|
||||
# if positive, maximum number of tokens to merge.
|
||||
# if negative, minimum number of occurences to stop at.
|
||||
merges = int(env.get('MERGES', '0'))
|
||||
|
||||
if attempts <= 0:
|
||||
attempts = None
|
||||
|
||||
brain = PatternBrain(order=order, temperature=temperature)
|
||||
tool = ATTTT(brain)
|
||||
|
||||
if state_fn:
|
||||
lament('# loading')
|
||||
try:
|
||||
brain.load(state_fn, raw=False)
|
||||
except FileNotFoundError:
|
||||
lament('# no file to load. skipping')
|
||||
pass
|
||||
|
||||
if brain and brain.new:
|
||||
lament('# learning')
|
||||
lines = open(fn).readlines()
|
||||
brain.learn_all(lines, merges)
|
||||
|
||||
if brain and brain.new and state_fn:
|
||||
lament('# saving')
|
||||
brain.save(state_fn, raw=False)
|
||||
|
||||
lament('# replying')
|
||||
for i in range(count):
|
||||
#reply = tool.reply(maxn=maxn, raw=True, attempts=attempts)
|
||||
#print('{:6.1f}\t{}'.format(reply[1], reply[0]))
|
||||
print(tool.reply(maxn=maxn, attempts=attempts))
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
import os
|
||||
pname = len(sys.argv) > 0 and sys.argv[0] or ''
|
||||
args = len(sys.argv) > 1 and sys.argv[1:] or []
|
||||
sys.exit(run(pname, args, os.environ))
|
189
basic.py
Executable file
189
basic.py
Executable file
|
@ -0,0 +1,189 @@
|
|||
import math
|
||||
import numpy as np
|
||||
|
||||
from misc import *
|
||||
|
||||
|
||||
def normalize(counter):
|
||||
v = counter.values()
|
||||
s = float(sum(v))
|
||||
m = float(max(v))
|
||||
del v
|
||||
return [(c, cnt/s, cnt/m) for c, cnt in counter.items()]
|
||||
|
||||
|
||||
def normalize_sorted(counter):
|
||||
# if the elements were unsorted,
|
||||
# we couldn't use our lazy method (subtraction) of selecting tokens
|
||||
# and temperature would correspond to arbitrary tokens
|
||||
# instead of more/less common tokens.
|
||||
return sorted(normalize(counter), key=lambda t: t[1], reverse=True)
|
||||
|
||||
|
||||
# http://nbviewer.jupyter.org/gist/yoavg/d76121dfde2618422139
|
||||
class Brain:
|
||||
|
||||
def __init__(self, padding, order=1, temperature=0.5):
|
||||
self.order = order
|
||||
self.temperature = temperature
|
||||
self.padding = padding
|
||||
|
||||
self.reset()
|
||||
|
||||
|
||||
def reset(self):
|
||||
import collections as cool
|
||||
# unnormalized
|
||||
self._machine = cool.defaultdict(cool.Counter)
|
||||
# normalized
|
||||
self.machine = None
|
||||
|
||||
self.type = None
|
||||
self.dirty = False
|
||||
self.new = True
|
||||
|
||||
|
||||
@property
|
||||
def temperature(self):
|
||||
return self._temperature
|
||||
|
||||
|
||||
@temperature.setter
|
||||
def temperature(self, value):
|
||||
assert(0 < value < 1)
|
||||
self._temperature = value
|
||||
|
||||
a = 1 - value * 2
|
||||
# http://www.mathopenref.com/graphfunctions.html?fx=(a*x-x)/(2*a*x-a-1)&sg=f&sh=f&xh=1&xl=0&yh=1&yl=0&ah=1&al=-1&a=0.5
|
||||
tweak = lambda x: (a * x - x) / (2 * a * x - a - 1)
|
||||
self.random = lambda n: 1 - tweak(np.random.random(n))
|
||||
|
||||
|
||||
def learn_all(self, items):
|
||||
for item in items:
|
||||
self.learn(item)
|
||||
self.update()
|
||||
|
||||
|
||||
def learn(self, item):
|
||||
assert(self.padding)
|
||||
|
||||
if self.type is None and item is not None:
|
||||
self.type = type(item)
|
||||
if type(item) is not self.type:
|
||||
raise Exception("that's no good")
|
||||
|
||||
if self.type == type("string"):
|
||||
item = item.strip()
|
||||
|
||||
if len(item) == 0:
|
||||
return
|
||||
|
||||
pad = self.helper(self.padding) * self.order
|
||||
item = pad + item + pad
|
||||
|
||||
stop = len(item) - self.order
|
||||
if stop > 0:
|
||||
for i in range(stop):
|
||||
history, newitem = item[i:i+self.order], item[i+self.order]
|
||||
self._machine[history][newitem] += 1
|
||||
|
||||
self.dirty = True
|
||||
|
||||
|
||||
def update(self):
|
||||
if self.dirty and self._machine:
|
||||
self.machine = {hist: normalize_sorted(items)
|
||||
for hist, items in self._machine.items()}
|
||||
self.dirty = False
|
||||
|
||||
|
||||
def next(self, history):
|
||||
history = history[-self.order:]
|
||||
|
||||
dist = self.machine.get(history, None)
|
||||
if dist == None:
|
||||
lament('warning: no value: {}'.format(history))
|
||||
return None
|
||||
|
||||
x = self.random(1)
|
||||
for c, cs, cm in dist:
|
||||
x = x - cs
|
||||
if x <= 0:
|
||||
return c
|
||||
|
||||
|
||||
# for overriding in subclasses
|
||||
# in case the input tokens aren't strings (e.g. tuples)
|
||||
def helper(self, v):
|
||||
return v
|
||||
|
||||
|
||||
def reply(self, item=None, maxn=1000):
|
||||
assert(self.padding)
|
||||
self.update()
|
||||
|
||||
history = self.helper(self.padding) * self.order
|
||||
|
||||
out = []
|
||||
for i in range(maxn):
|
||||
c = self.next(history)
|
||||
if c.find(self.padding) != -1:
|
||||
out.append(c.replace(self.padding, ''))
|
||||
break
|
||||
history = history[-self.order:] + self.helper(c)
|
||||
out.append(c)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def load(self, fn, raw=True):
|
||||
import pickle
|
||||
if type(fn) == type(''):
|
||||
f = open(fn, 'rb')
|
||||
else:
|
||||
f = fn
|
||||
|
||||
d = pickle.load(f)
|
||||
|
||||
if d['order'] != self.order:
|
||||
lament('warning: order mismatch. cancelling load.')
|
||||
return
|
||||
self.order = d['order']
|
||||
|
||||
if raw:
|
||||
if not d.get('_machine'):
|
||||
lament('warning: no _machine. cancelling load.')
|
||||
return
|
||||
self._machine = d['_machine']
|
||||
|
||||
self.dirty = True
|
||||
self.update()
|
||||
else:
|
||||
if not d.get('machine'):
|
||||
lament('warning: no machine. cancelling load.')
|
||||
return
|
||||
self.machine = d['machine']
|
||||
|
||||
self.new = False
|
||||
if f != fn:
|
||||
f.close()
|
||||
|
||||
|
||||
def save(self, fn, raw=True):
|
||||
import pickle
|
||||
if type(fn) == type(''):
|
||||
f = open(fn, 'wb')
|
||||
else:
|
||||
f = fn
|
||||
|
||||
d = {}
|
||||
d['order'] = self.order
|
||||
if raw:
|
||||
d['_machine'] = self._machine
|
||||
else:
|
||||
d['machine'] = self.machine
|
||||
pickle.dump(d, f)
|
||||
|
||||
if f != fn:
|
||||
f.close()
|
11
misc.py
Executable file
11
misc.py
Executable file
|
@ -0,0 +1,11 @@
|
|||
import sys
|
||||
lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs)
|
||||
|
||||
|
||||
def die(*args, **kwargs):
|
||||
# just for ad-hoc debugging really
|
||||
lament(*args, **kwargs)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
__all__ = [o for o in locals() if type(o) != 'module' and not o.startswith('_')]
|
Loading…
Reference in a new issue