diff --git a/.dummy b/.dummy deleted file mode 100644 index 945c9b4..0000000 --- a/.dummy +++ /dev/null @@ -1 +0,0 @@ -. \ No newline at end of file diff --git a/atttt.py b/atttt.py new file mode 100755 index 0000000..4e65b20 --- /dev/null +++ b/atttt.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python3 + +import sys +import numpy as np + +from misc import * +from basic import Brain + + +def uniq_rows(a, return_index=False, return_inverse=False, return_counts=False): + # black magic wrapper around np.unique + # via np.dtype((np.void, a.dtype.itemsize * a.shape[1])) + return_any = return_index or return_inverse or return_counts + if not return_any: + np.unique(a.view(np.dtype((np.void, a.dtype.itemsize * a.shape[1])))).view(a.dtype).reshape(-1, a.shape[1]) + else: + void_dtype = np.dtype((np.void, a.dtype.itemsize * a.shape[1])) + ret = np.unique(a.view(void_dtype), return_index, return_inverse, return_counts) + return (ret[0].view(a.dtype).reshape(-1, a.shape[1]),) + ret[1:] + + +class ATTTT(): + + def __init__(self, brain): + self.brain = brain + self.score = self._score + + + def _score(self, reply, maxn): + if len(reply) > maxn: + return -999999999 + + #return len(reply) + return 1 + + + def reply(self, item=None, maxn=1000, raw=False, attempts=None): + if attempts == None: + attempts = int(2**12 / self.brain.order) + lament('attempts:', attempts) + + replies = [] + for i in range(attempts): + reply = "".join(self.brain.reply(item=item, maxn=maxn+1)) + replies += [(reply, self.score(reply, maxn))] + + result = sorted(replies, key=lambda t: t[1], reverse=True)[0] + + if raw: + return result + else: + return result[0] + + +class PatternBrain(Brain): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tokens = [] + + + def helper(self, v): + return (v,) + + + def learn_all(self, items, merges=1): + min_count = 2 + if merges < 0: + min_count = -merges + merges = 65536 + + # use numpy so this isn't nearly as disgustingly slow + + int32_min = -2**(np.dtype(np.int32).itemsize * 8 - 1) + empty = int32_min + neg_lookup = {-1: ''} # default with padding + + alignment = 2 + align = lambda x: (x + alignment // 2) // alignment * alignment + + new_items = [] + for item in items: + item = item.strip('\n') + # assert at least 1 padding character at the end + next_biggest = align(len(item) + 1) + # fill with padding (-1) + new_item = -np.ones(next_biggest, dtype=np.int32) + for i, c in enumerate(item): + new_item[i] = ord(c) + new_items.append(new_item) + + # add an extra padding item to the head and tail + # for easier conversion from sequence back to all_items later on + pad = -np.ones(1, dtype=np.int32) + new_items.insert(0, pad) + new_items.append(pad) + + all_items = np.concatenate(new_items) + + if merges > 0: + # set up a 2d array to step through at half the row length, + # this means double redundancy, to acquire all the sequences. + # we don't have to .roll it later to get the other half, + # though that would require less memory. + sequences = all_items.repeat(2)[1:-1].reshape(-1, 2).copy() + + for i in range(merges): + # learn + most_common = (None, 1) + # TODO: eventually check for empty here too + invalid = np.any(sequences == -1, axis=1) + valid_sequences = np.delete(sequences, np.where(invalid), axis=0) + unique, counts = uniq_rows(valid_sequences, return_counts=True) + count = counts.max() + + if count > most_common[1]: + seq = unique[counts == count][0] + most_common = (seq, count) + + if most_common[0] is None or most_common[1] <= 1 or most_common[1] < min_count: + lament('no more valid sequences') + break + + new_id = -1 - len(neg_lookup) + neg_lookup[new_id] = "".join([o < 0 and neg_lookup[o] or chr(o) for o in most_common[0]]) + + if len("".join(neg_lookup.values())) > len(all_items): + lament('preventing dict from growing larger than source') + break + + # replace our most common sequence in the sequences + found = np.all(sequences == most_common[0], axis=1) + before = np.roll(found, -1) + after = np.roll(found, 1) + # don't wrap around truth values + before[-1] = False + after[0] = False + # or remove padding + #before[0] = False + #after[-1] = False + # remove the "found" sequences + befores = sequences[before].T.copy() + befores[1] = new_id + sequences[before] = befores.T + afters = sequences[after].T.copy() + afters[0] = new_id + sequences[after] = afters.T + #sequences[found] = [empty, empty] + here = np.where(found) + sequences = np.delete(sequences, here, axis=0) + + print("({:8}) new token: {:5} \"{}\"".format(len(here[0]), new_id, neg_lookup[new_id])) + + if merges > 0: + # reconstruct all_items out of the sequences + all_items = sequences.reshape(-1)[::2][1:].copy() + + self.padding = '~' + self.reset() + np_item = [] + for i in all_items: + #for np_item in np.split(all_items, np.where(all_items == -1)): + if i == -1: + if len(np_item) == 0: + continue + item = tuple() + for i in np_item: + if i < 0: + assert(i != -1) + item += self.helper(neg_lookup[i]) + else: + item += self.helper(chr(i)) + #die(np_item, item) + self.learn(item) + np_item = [] + elif i != empty: + np_item.append(i) + self.update() + + +def run(pname, args, env): + if not 1 <= len(args) <= 2: + lament("usage: {} {{input file}} [state_fn file]".format(sys.argv[0])) + sys.exit(1) + + args = dict(enumerate(args)) # for .get() + + fn = args[0] + state_fn = args.get(1, None) + + count = int(env.get('COUNT', '8')) + order = int(env.get('ORDER', '3')) + temperature = float(env.get('TEMPERATURE', '0')) + maxn = int(env.get('MAXN', '1000')) + attempts = int(env.get('ATTEMPTS', '-1')) + merges = int(env.get('MERGES', '0')) + + if attempts <= 0: + attempts = None + + brain = PatternBrain(order=order, temperature=temperature) + tool = ATTTT(brain) + + lament('# loading') + if state_fn: + try: + brain.load(state_fn, raw=False) + except FileNotFoundError: + pass + + if brain and brain.new: + lament('# learning') + lines = open(fn).readlines() + brain.learn_all(lines, merges) + + if brain and brain.new and state_fn: + brain.save(state_fn, raw=False) + + lament('# replying') + for i in range(count): + #reply = tool.reply(maxn=maxn, raw=True, attempts=attempts) + #print('{:6.1f}\t{}'.format(reply[1], reply[0])) + print(tool.reply(maxn=maxn, attempts=attempts)) + + +if __name__ == '__main__': + import sys + import os + pname = len(sys.argv) > 0 and sys.argv[0] or '' + args = len(sys.argv) > 1 and sys.argv[1:] or [] + sys.exit(run(pname, args, os.environ)) diff --git a/basic.py b/basic.py new file mode 100755 index 0000000..a19e174 --- /dev/null +++ b/basic.py @@ -0,0 +1,196 @@ +import math +import numpy as np + +from misc import * + + +def normalize(counter): + v = counter.values() + s = float(sum(v)) + m = float(max(v)) + del v + d = {} + for c, cnt in counter.items(): + d[c] = (cnt/s, cnt/m) + return d +# return [(c, cnt/s, cnt/m) for c, cnt in counter.items()] + + +def normalize_sorted(counter): + # mostly just for debugging i guess? + return sorted(normalize(counter), key=lambda t: t[1], reverse=True) + + +# http://nbviewer.jupyter.org/gist/yoavg/d76121dfde2618422139 +class Brain: + + # TODO: don't default padding here, but make sure it's set before running + # the reason is it's the only place that's specific to a string anymore + def __init__(self, order=1, temperature=0.5, padding="~"): + self.order = order + self.padding = padding + self.temperature = temperature + + self.reset() + + + def reset(self): + import collections as cool + # unnormalized + self._machine = cool.defaultdict(cool.Counter) + # normalized + self.machine = None + + self.type = None + self.dirty = False + self.new = True + + + @property + def temperature(self): + return self._temperature + + + @temperature.setter + def temperature(self, value): + self._temperature = value + + if value == 1: + # TODO: proper distribution stuff + self.random = lambda count: np.random.random(count)**2 + elif value == 0: + self.random = np.random.random + else: + # +0.25 = -0.0 + # +0.50 = +0.5 + # +0.75 = +1.0 + point75 = 1 + const = (point75 * 2 - 1) / math.atanh(0.75 * 2 - 1) + unbound = (math.atanh((1 - value) * 2 - 1) * const + 1) / 2 + self.random = easytruncnorm(0, 1, unbound, 0.25).rvs + + + def learn_all(self, items): + for item in items: + self.learn(item) + self.update() + + + def learn(self, item): + if self.type is None and item is not None: + self.type = type(item) + if type(item) is not self.type: + raise Exception("that's no good") + + if self.type == type("string"): + item = item.strip() + + if len(item) == 0: + return + + pad = self.helper(self.padding) * self.order + item = pad + item + pad + + stop = len(item) - self.order + if stop > 0: + for i in range(stop): + history, newitem = item[i:i+self.order], item[i+self.order] + self._machine[history][newitem] += 1 + + self.dirty = True + + + def update(self): + if self.dirty and self._machine: + self.machine = {hist:normalize(items) + for hist, items in self._machine.items()} + self.dirty = False + + + def next(self, history): + history = history[-self.order:] + + dist = self.machine.get(history, None) + if dist == None: + lament('warning: no value: {}'.format(history)) + return None + + x = self.random(1) + for c, v in dist.items(): + # if x <= v: # this is a bad idea + x = x - v[0] + if x <= 0: + return c + + + def helper(self, v): + return v + + + def reply(self, item=None, maxn=1000): + self.update() + + history = self.helper(self.padding) * self.order + + out = [] + for i in range(maxn): + c = self.next(history) + if c.find(self.padding) != -1: + out.append(c.replace(self.padding, '')) + break + history = history[-self.order:] + self.helper(c) + out.append(c) + + return out + + + def load(self, fn, raw=True): + import pickle + if type(fn) == type(''): + f = open(fn, 'rb') + else: + f = fn + + d = pickle.load(f) + + if d['order'] != self.order: + lament('warning: order mismatch. cancelling load.') + return + self.order = d['order'] + + if raw: + if not d.get('_machine'): + lament('warning: no _machine. cancelling load.') + return + self._machine = d['_machine'] + + self.dirty = True + self.update() + else: + if not d.get('machine'): + lament('warning: no machine. cancelling load.') + return + self.machine = d['machine'] + + self.new = False + if f != fn: + f.close() + + + def save(self, fn, raw=True): + import pickle + if type(fn) == type(''): + f = open(fn, 'wb') + else: + f = fn + + d = {} + d['order'] = self.order + if raw: + d['_machine'] = self._machine + else: + d['machine'] = self.machine + pickle.dump(d, f) + + if f != fn: + f.close() diff --git a/misc.py b/misc.py new file mode 100755 index 0000000..e2cd2d7 --- /dev/null +++ b/misc.py @@ -0,0 +1,18 @@ +import sys +lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs) + + +def die(*args, **kwargs): + lament(*args, **kwargs) + sys.exit(1) + + +def easytruncnorm(lower=0, upper=1, loc=0.5, scale=0.25): + import scipy.stats as stats + a = (lower - loc) / scale + b = (upper - loc) / scale + return stats.truncnorm(a=a, b=b, loc=loc, scale=scale) + + +# only make some things visible to "from misc import *" +__all__ = [o for o in locals() if type(o) != 'module' and not o.startswith('_')]