From eeb5d2941e591ab488c495eb21cb5706dac423bd Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Wed, 25 May 2016 07:31:48 -0700 Subject: [PATCH] . --- atttt.py | 156 ++++++++++++++++++++++++++++++++----------------------- basic.py | 11 ++-- misc.py | 1 + 3 files changed, 100 insertions(+), 68 deletions(-) diff --git a/atttt.py b/atttt.py index 4e65b20..0763bcc 100755 --- a/atttt.py +++ b/atttt.py @@ -8,8 +8,8 @@ from basic import Brain def uniq_rows(a, return_index=False, return_inverse=False, return_counts=False): + # via http://stackoverflow.com/a/16973510 # black magic wrapper around np.unique - # via np.dtype((np.void, a.dtype.itemsize * a.shape[1])) return_any = return_index or return_inverse or return_counts if not return_any: np.unique(a.view(np.dtype((np.void, a.dtype.itemsize * a.shape[1])))).view(a.dtype).reshape(-1, a.shape[1]) @@ -34,8 +34,9 @@ class ATTTT(): return 1 - def reply(self, item=None, maxn=1000, raw=False, attempts=None): + def reply(self, item=None, maxn=1000, include_scores=False, attempts=None): if attempts == None: + # just guess some value that'll take roughly the same amount of time attempts = int(2**12 / self.brain.order) lament('attempts:', attempts) @@ -46,7 +47,7 @@ class ATTTT(): result = sorted(replies, key=lambda t: t[1], reverse=True)[0] - if raw: + if include_scores: return result else: return result[0] @@ -63,56 +64,31 @@ class PatternBrain(Brain): return (v,) - def learn_all(self, items, merges=1): - min_count = 2 - if merges < 0: - min_count = -merges - merges = 65536 + def resolve_tokens(self, tokens): + # positive values are just unicode characters + return [o < 0 and self.tokens[o] or chr(o) for o in tokens] - # use numpy so this isn't nearly as disgustingly slow - int32_min = -2**(np.dtype(np.int32).itemsize * 8 - 1) - empty = int32_min - neg_lookup = {-1: ''} # default with padding + def new_token(self, value): + new_id = -1 - len(self.tokens) + self.tokens[new_id] = value + return new_id - alignment = 2 - align = lambda x: (x + alignment // 2) // alignment * alignment - new_items = [] - for item in items: - item = item.strip('\n') - # assert at least 1 padding character at the end - next_biggest = align(len(item) + 1) - # fill with padding (-1) - new_item = -np.ones(next_biggest, dtype=np.int32) - for i, c in enumerate(item): - new_item[i] = ord(c) - new_items.append(new_item) - - # add an extra padding item to the head and tail - # for easier conversion from sequence back to all_items later on - pad = -np.ones(1, dtype=np.int32) - new_items.insert(0, pad) - new_items.append(pad) - - all_items = np.concatenate(new_items) - - if merges > 0: - # set up a 2d array to step through at half the row length, - # this means double redundancy, to acquire all the sequences. - # we don't have to .roll it later to get the other half, - # though that would require less memory. - sequences = all_items.repeat(2)[1:-1].reshape(-1, 2).copy() + def merge_all(self, all_items, merges, min_count=2): + # set up a 2d array to step through at half the row length; + # this means double redundancy; to acquire all the sequences. + # we could instead .roll it later to get the other half. + # that would require less memory, but memory isn't really a concern. + sequences = all_items.repeat(2)[1:-1].reshape(-1, 2).copy() for i in range(merges): - # learn - most_common = (None, 1) - # TODO: eventually check for empty here too invalid = np.any(sequences == -1, axis=1) valid_sequences = np.delete(sequences, np.where(invalid), axis=0) unique, counts = uniq_rows(valid_sequences, return_counts=True) count = counts.max() + most_common = (None, 1) if count > most_common[1]: seq = unique[counts == count][0] most_common = (seq, count) @@ -121,45 +97,85 @@ class PatternBrain(Brain): lament('no more valid sequences') break - new_id = -1 - len(neg_lookup) - neg_lookup[new_id] = "".join([o < 0 and neg_lookup[o] or chr(o) for o in most_common[0]]) + token_value = "".join(self.resolve_tokens(most_common[0])) + new_id = self.new_token(token_value) - if len("".join(neg_lookup.values())) > len(all_items): - lament('preventing dict from growing larger than source') + if len("".join(self.tokens.values())) > len(all_items): + # this might not ever occur + lament('preventing token dictionary from growing larger than source') break - # replace our most common sequence in the sequences + # replace the most common two-token sequence + # with one token to represent both found = np.all(sequences == most_common[0], axis=1) before = np.roll(found, -1) after = np.roll(found, 1) # don't wrap around truth values before[-1] = False after[0] = False - # or remove padding - #before[0] = False - #after[-1] = False # remove the "found" sequences + # and update the previous/next, + # not unlike a doubly-linked list. befores = sequences[before].T.copy() befores[1] = new_id sequences[before] = befores.T afters = sequences[after].T.copy() afters[0] = new_id sequences[after] = afters.T - #sequences[found] = [empty, empty] here = np.where(found) sequences = np.delete(sequences, here, axis=0) - print("({:8}) new token: {:5} \"{}\"".format(len(here[0]), new_id, neg_lookup[new_id])) + print("new token id {:5} occurs {:8} times: \"{}\"".format(new_id, len(here[0]), self.tokens[new_id])) + + # TODO: find unused tokens + + # reconstruct all_items out of the sequences + all_items = sequences.reshape(-1)[::2][1:].copy() + return all_items + + + def learn_all(self, items, merges=0): + min_count = 2 # minimum number of occurences to stop creating tokens at + if merges < 0: + min_count = -merges + merges = 65536 # arbitrary sanity value + + # we'll use numpy matrices so this isn't nearly as disgustingly slow + + self.tokens = {-1: ''} # default with an empty padding token + + # we need to assert that the number of sequences is a multiple of this + # otherwise we can't .reshape() it to be two-dimensional later on + alignment = 2 + align = lambda x: (x + alignment // 2) // alignment * alignment + + new_items = [] + for item in items: + item = item.strip('\n') + # assert at least 1 padding character at the end + next_biggest = align(len(item) + 1) + # initialize with padding (-1) + new_item = -np.ones(next_biggest, dtype=np.int32) + for i, c in enumerate(item): + new_item[i] = ord(c) + new_items.append(new_item) + + # add an extra padding item to the head and tail + # to make it easier to convert from sequences back to items later on + pad = -np.ones(1, dtype=np.int32) + new_items.insert(0, pad) + new_items.append(pad) + + all_items = np.concatenate(new_items) if merges > 0: - # reconstruct all_items out of the sequences - all_items = sequences.reshape(-1)[::2][1:].copy() + all_items = self.merge_all(all_items, merges, min_count) + # begin the actual learning self.padding = '~' self.reset() np_item = [] for i in all_items: - #for np_item in np.split(all_items, np.where(all_items == -1)): if i == -1: if len(np_item) == 0: continue @@ -167,32 +183,40 @@ class PatternBrain(Brain): for i in np_item: if i < 0: assert(i != -1) - item += self.helper(neg_lookup[i]) + item += self.helper(self.tokens[i]) else: item += self.helper(chr(i)) #die(np_item, item) self.learn(item) np_item = [] - elif i != empty: + else: np_item.append(i) self.update() def run(pname, args, env): if not 1 <= len(args) <= 2: - lament("usage: {} {{input file}} [state_fn file]".format(sys.argv[0])) - sys.exit(1) + lament("usage: {} {{input file}} [savestate file]".format(pname)) + return 1 - args = dict(enumerate(args)) # for .get() + args = dict(enumerate(args)) # just for the .get() method fn = args[0] state_fn = args.get(1, None) + # the number of lines to output. count = int(env.get('COUNT', '8')) - order = int(env.get('ORDER', '3')) - temperature = float(env.get('TEMPERATURE', '0')) - maxn = int(env.get('MAXN', '1000')) + # learn and sample using this number of sequential tokens. + order = int(env.get('ORDER', '2')) + # how experimental to be with sampling. + # probably doesn't work properly. + temperature = float(env.get('TEMPERATURE', '0.5')) + # the max character length of output. (not guaranteed) + maxn = int(env.get('MAXN', '240')) + # attempts to maximize scoring attempts = int(env.get('ATTEMPTS', '-1')) + # if positive, maximum number of tokens to merge. + # if negative, minimum number of occurences to stop at. merges = int(env.get('MERGES', '0')) if attempts <= 0: @@ -201,11 +225,12 @@ def run(pname, args, env): brain = PatternBrain(order=order, temperature=temperature) tool = ATTTT(brain) - lament('# loading') if state_fn: + lament('# loading') try: brain.load(state_fn, raw=False) except FileNotFoundError: + lament('# no file to load. skipping') pass if brain and brain.new: @@ -214,6 +239,7 @@ def run(pname, args, env): brain.learn_all(lines, merges) if brain and brain.new and state_fn: + lament('# saving') brain.save(state_fn, raw=False) lament('# replying') @@ -222,6 +248,8 @@ def run(pname, args, env): #print('{:6.1f}\t{}'.format(reply[1], reply[0])) print(tool.reply(maxn=maxn, attempts=attempts)) + return 0 + if __name__ == '__main__': import sys diff --git a/basic.py b/basic.py index a19e174..417a5ab 100755 --- a/basic.py +++ b/basic.py @@ -24,12 +24,10 @@ def normalize_sorted(counter): # http://nbviewer.jupyter.org/gist/yoavg/d76121dfde2618422139 class Brain: - # TODO: don't default padding here, but make sure it's set before running - # the reason is it's the only place that's specific to a string anymore - def __init__(self, order=1, temperature=0.5, padding="~"): + def __init__(self, order=1, temperature=0.5): self.order = order - self.padding = padding self.temperature = temperature + self.padding = None self.reset() @@ -77,6 +75,8 @@ class Brain: def learn(self, item): + assert(self.padding) + if self.type is None and item is not None: self.type = type(item) if type(item) is not self.type: @@ -123,11 +123,14 @@ class Brain: return c + # for overriding in subclasses + # in case the input tokens aren't strings (e.g. tuples) def helper(self, v): return v def reply(self, item=None, maxn=1000): + assert(self.padding) self.update() history = self.helper(self.padding) * self.order diff --git a/misc.py b/misc.py index e2cd2d7..5fcff9d 100755 --- a/misc.py +++ b/misc.py @@ -3,6 +3,7 @@ lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs) def die(*args, **kwargs): + # just for ad-hoc debugging really lament(*args, **kwargs) sys.exit(1)