diff --git a/atttt.py b/atttt.py new file mode 100755 index 0000000..02f662c --- /dev/null +++ b/atttt.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python3 + +import sys +import numpy as np + +from misc import * +from basic import Brain + + +def align(x, alignment): + return (x + alignment // 2) // alignment * alignment + + +def uniq_rows(a, return_index=False, return_inverse=False, return_counts=False): + # via http://stackoverflow.com/a/16973510 + # black magic wrapper around np.unique + return_any = return_index or return_inverse or return_counts + if not return_any: + np.unique(a.view(np.dtype((np.void, a.dtype.itemsize * a.shape[1])))).view(a.dtype).reshape(-1, a.shape[1]) + else: + void_dtype = np.dtype((np.void, a.dtype.itemsize * a.shape[1])) + ret = np.unique(a.view(void_dtype), return_index, return_inverse, return_counts) + return (ret[0].view(a.dtype).reshape(-1, a.shape[1]),) + ret[1:] + + +class ATTTT(): + + def __init__(self, brain): + self.brain = brain + self.score = self._score + + + def _score(self, reply, maxn): + if len(reply) > maxn: + return -999999999 + + #return len(reply) + return 1 + + + def reply(self, item=None, maxn=1000, include_scores=False, attempts=None): + if attempts == None: + # just guess some value that'll take roughly the same amount of time + attempts = int(2**12 / self.brain.order) + lament('attempts:', attempts) + + replies = [] + for i in range(attempts): + reply = "".join(self.brain.reply(item=item, maxn=maxn+1)) + replies += [(reply, self.score(reply, maxn))] + + result = sorted(replies, key=lambda t: t[1], reverse=True)[0] + + if include_scores: + return result + else: + return result[0] + + +class PatternBrain(Brain): + + def __init__(self, *args, **kwargs): + super().__init__(*args, padding='~', **kwargs) + self.tokens = [] + + + def helper(self, v): + return (v,) + + + def resolve_tokens(self, tokens): + # positive values are just unicode characters + if isinstance(tokens, int) or isinstance(tokens, np.int32): + return tokens < 0 and self.tokens[tokens] or chr(tokens) + else: + return [o < 0 and self.tokens[o] or chr(o) for o in tokens] + + + def new_token(self, value): + new_id = -1 - len(self.tokens) + self.tokens[new_id] = value + return new_id + + + @staticmethod + def prepare_items(items, pad=True): + new_items = [] + for item in items: + item = item.strip('\n') + # assert that the number of sequences is a multiple of 2 + # otherwise we can't .reshape() it to be two-dimensional later on + next_biggest = align(len(item) + 1, 2) + # initialize with padding (-1) + new_item = -np.ones(next_biggest, dtype=np.int32) + for i, c in enumerate(item): + new_item[i] = ord(c) + new_items.append(new_item) + + # add an extra padding item to the head and tail + # to make it easier to convert from sequences back to items later on + if pad: + pad = -np.ones(1, dtype=np.int32) + new_items.insert(0, pad) + new_items.append(pad) + + return np.concatenate(new_items) + + + def stat_tokens(self, all_items, skip_normal=False): + unique, counts = np.unique(all_items, return_counts=True) + count_order = np.argsort(counts)[::-1] + counts_descending = counts[count_order] + unique_descending = unique[count_order] + for i, token_id in enumerate(unique_descending): + if token_id == -1: + continue + if skip_normal and token_id >= 0: + continue + token = self.resolve_tokens(token_id) + lament("token id {:5} occurs {:8} times: \"{}\"".format( + token_id, counts_descending[i], token)) + lament("total tokens: {:5}".format(i + 1)) + + + def merge_all(self, all_items, merges, min_count=2): + # set up a 2d array to step through at half the row length; + # this means double redundancy; to acquire all the sequences. + # we could instead .roll it later to get the other half. + # that would require less memory, but memory isn't really a concern. + sequences = all_items.repeat(2)[1:-1].reshape(-1, 2).copy() + + for i in range(merges): + invalid = np.any(sequences == -1, axis=1) + valid_sequences = np.delete(sequences, np.where(invalid), axis=0) + unique, counts = uniq_rows(valid_sequences, return_counts=True) + count = counts.max() + + most_common = (None, 1) + if count > most_common[1]: + seq = unique[counts == count][0] + most_common = (seq, count) + + if most_common[0] is None or most_common[1] <= 1 or most_common[1] < min_count: + lament('no more valid sequences') + break + + token_value = "".join(self.resolve_tokens(most_common[0])) + new_id = self.new_token(token_value) + + # replace the most common two-token sequence + # with one token to represent both + found = np.all(sequences == most_common[0], axis=1) + before = np.roll(found, -1) + after = np.roll(found, 1) + # don't wrap around truth values + before[-1] = False + after[0] = False + # remove the "found" sequences + # and update the previous/next, + # not unlike a doubly-linked list. + befores = sequences[before].T.copy() + befores[1] = new_id + sequences[before] = befores.T + afters = sequences[after].T.copy() + afters[0] = new_id + sequences[after] = afters.T + here = np.where(found) + sequences = np.delete(sequences, here, axis=0) + + lament("new token id {:5} occurs {:8} times: \"{}\"".format( + new_id, len(here[0]), self.tokens[new_id])) + + # reconstruct all_items out of the sequences + all_items = sequences.reshape(-1)[::2][1:].copy() + return all_items + + + def learn_all(self, items, merges=0, stat=True): + min_count = 2 # minimum number of occurences to stop creating tokens at + if merges < 0: + min_count = -merges + merges = 65536 # arbitrary sanity value + + # we'll use numpy matrices so this isn't nearly as disgustingly slow + + self.tokens = {-1: ''} # default with an empty padding token + + all_items = self.prepare_items(items) + + if merges > 0: + all_items = self.merge_all(all_items, merges, min_count) + + # begin the actual learning + self.reset() + np_item = [] + for i in all_items: + if i == -1: + if len(np_item) == 0: + continue + item = tuple() + for i in np_item: + if i < 0: + assert(i != -1) + item += self.helper(self.tokens[i]) + else: + item += self.helper(chr(i)) + #die(np_item, item) + self.learn(item) + np_item = [] + else: + np_item.append(i) + self.update() + + if merges != 0 and stat: + self.stat_tokens(all_items) + + +def run(pname, args, env): + if not 1 <= len(args) <= 2: + lament("usage: {} {{input file}} [savestate file]".format(pname)) + return 1 + + args = dict(enumerate(args)) # just for the .get() method + + fn = args[0] + state_fn = args.get(1, None) + + # the number of lines to output. + count = int(env.get('COUNT', '8')) + # learn and sample using this number of sequential tokens. + order = int(env.get('ORDER', '2')) + # how experimental to be with sampling. + # probably doesn't work properly. + temperature = float(env.get('TEMPERATURE', '0.5')) + # the max character length of output. (not guaranteed) + maxn = int(env.get('MAXN', '240')) + # attempts to maximize scoring + attempts = int(env.get('ATTEMPTS', '-1')) + # if positive, maximum number of tokens to merge. + # if negative, minimum number of occurences to stop at. + merges = int(env.get('MERGES', '0')) + + if attempts <= 0: + attempts = None + + brain = PatternBrain(order=order, temperature=temperature) + tool = ATTTT(brain) + + if state_fn: + lament('# loading') + try: + brain.load(state_fn, raw=False) + except FileNotFoundError: + lament('# no file to load. skipping') + pass + + if brain and brain.new: + lament('# learning') + lines = open(fn).readlines() + brain.learn_all(lines, merges) + + if brain and brain.new and state_fn: + lament('# saving') + brain.save(state_fn, raw=False) + + lament('# replying') + for i in range(count): + #reply = tool.reply(maxn=maxn, raw=True, attempts=attempts) + #print('{:6.1f}\t{}'.format(reply[1], reply[0])) + print(tool.reply(maxn=maxn, attempts=attempts)) + + return 0 + + +if __name__ == '__main__': + import sys + import os + pname = len(sys.argv) > 0 and sys.argv[0] or '' + args = len(sys.argv) > 1 and sys.argv[1:] or [] + sys.exit(run(pname, args, os.environ)) diff --git a/basic.py b/basic.py new file mode 100755 index 0000000..0dc1416 --- /dev/null +++ b/basic.py @@ -0,0 +1,189 @@ +import math +import numpy as np + +from misc import * + + +def normalize(counter): + v = counter.values() + s = float(sum(v)) + m = float(max(v)) + del v + return [(c, cnt/s, cnt/m) for c, cnt in counter.items()] + + +def normalize_sorted(counter): + # if the elements were unsorted, + # we couldn't use our lazy method (subtraction) of selecting tokens + # and temperature would correspond to arbitrary tokens + # instead of more/less common tokens. + return sorted(normalize(counter), key=lambda t: t[1], reverse=True) + + +# http://nbviewer.jupyter.org/gist/yoavg/d76121dfde2618422139 +class Brain: + + def __init__(self, padding, order=1, temperature=0.5): + self.order = order + self.temperature = temperature + self.padding = padding + + self.reset() + + + def reset(self): + import collections as cool + # unnormalized + self._machine = cool.defaultdict(cool.Counter) + # normalized + self.machine = None + + self.type = None + self.dirty = False + self.new = True + + + @property + def temperature(self): + return self._temperature + + + @temperature.setter + def temperature(self, value): + assert(0 < value < 1) + self._temperature = value + + a = 1 - value * 2 + # http://www.mathopenref.com/graphfunctions.html?fx=(a*x-x)/(2*a*x-a-1)&sg=f&sh=f&xh=1&xl=0&yh=1&yl=0&ah=1&al=-1&a=0.5 + tweak = lambda x: (a * x - x) / (2 * a * x - a - 1) + self.random = lambda n: 1 - tweak(np.random.random(n)) + + + def learn_all(self, items): + for item in items: + self.learn(item) + self.update() + + + def learn(self, item): + assert(self.padding) + + if self.type is None and item is not None: + self.type = type(item) + if type(item) is not self.type: + raise Exception("that's no good") + + if self.type == type("string"): + item = item.strip() + + if len(item) == 0: + return + + pad = self.helper(self.padding) * self.order + item = pad + item + pad + + stop = len(item) - self.order + if stop > 0: + for i in range(stop): + history, newitem = item[i:i+self.order], item[i+self.order] + self._machine[history][newitem] += 1 + + self.dirty = True + + + def update(self): + if self.dirty and self._machine: + self.machine = {hist: normalize_sorted(items) + for hist, items in self._machine.items()} + self.dirty = False + + + def next(self, history): + history = history[-self.order:] + + dist = self.machine.get(history, None) + if dist == None: + lament('warning: no value: {}'.format(history)) + return None + + x = self.random(1) + for c, cs, cm in dist: + x = x - cs + if x <= 0: + return c + + + # for overriding in subclasses + # in case the input tokens aren't strings (e.g. tuples) + def helper(self, v): + return v + + + def reply(self, item=None, maxn=1000): + assert(self.padding) + self.update() + + history = self.helper(self.padding) * self.order + + out = [] + for i in range(maxn): + c = self.next(history) + if c.find(self.padding) != -1: + out.append(c.replace(self.padding, '')) + break + history = history[-self.order:] + self.helper(c) + out.append(c) + + return out + + + def load(self, fn, raw=True): + import pickle + if type(fn) == type(''): + f = open(fn, 'rb') + else: + f = fn + + d = pickle.load(f) + + if d['order'] != self.order: + lament('warning: order mismatch. cancelling load.') + return + self.order = d['order'] + + if raw: + if not d.get('_machine'): + lament('warning: no _machine. cancelling load.') + return + self._machine = d['_machine'] + + self.dirty = True + self.update() + else: + if not d.get('machine'): + lament('warning: no machine. cancelling load.') + return + self.machine = d['machine'] + + self.new = False + if f != fn: + f.close() + + + def save(self, fn, raw=True): + import pickle + if type(fn) == type(''): + f = open(fn, 'wb') + else: + f = fn + + d = {} + d['order'] = self.order + if raw: + d['_machine'] = self._machine + else: + d['machine'] = self.machine + pickle.dump(d, f) + + if f != fn: + f.close() diff --git a/misc.py b/misc.py new file mode 100755 index 0000000..36da0d4 --- /dev/null +++ b/misc.py @@ -0,0 +1,11 @@ +import sys +lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs) + + +def die(*args, **kwargs): + # just for ad-hoc debugging really + lament(*args, **kwargs) + sys.exit(1) + + +__all__ = [o for o in locals() if type(o) != 'module' and not o.startswith('_')]