#!/usr/bin/env python3 import sys import numpy as np from misc import * from basic import Brain def align(x, alignment): return (x + alignment // 2) // alignment * alignment def uniq_rows(a, return_index=False, return_inverse=False, return_counts=False): # via http://stackoverflow.com/a/16973510 # black magic wrapper around np.unique return_any = return_index or return_inverse or return_counts if not return_any: np.unique(a.view(np.dtype((np.void, a.dtype.itemsize * a.shape[1])))).view(a.dtype).reshape(-1, a.shape[1]) else: void_dtype = np.dtype((np.void, a.dtype.itemsize * a.shape[1])) ret = np.unique(a.view(void_dtype), return_index, return_inverse, return_counts) return (ret[0].view(a.dtype).reshape(-1, a.shape[1]),) + ret[1:] class ATTTT(): def __init__(self, brain): self.brain = brain self.score = self._score def _score(self, reply, maxn): if len(reply) > maxn: return -999999999 #return len(reply) return 1 def reply(self, item=None, maxn=1000, include_scores=False, attempts=None): if attempts == None: # just guess some value that'll take roughly the same amount of time attempts = int(2**12 / self.brain.order) lament('attempts:', attempts) replies = [] for i in range(attempts): reply = "".join(self.brain.reply(item=item, maxn=maxn+1)) replies += [(reply, self.score(reply, maxn))] result = sorted(replies, key=lambda t: t[1], reverse=True)[0] if include_scores: return result else: return result[0] class PatternBrain(Brain): def __init__(self, *args, **kwargs): super().__init__(*args, padding='~', **kwargs) self.tokens = [] def helper(self, v): return (v,) def resolve_tokens(self, tokens): # positive values are just unicode characters if isinstance(tokens, int) or isinstance(tokens, np.int32): return tokens < 0 and self.tokens[tokens] or chr(tokens) else: return [o < 0 and self.tokens[o] or chr(o) for o in tokens] def new_token(self, value): new_id = -1 - len(self.tokens) self.tokens[new_id] = value return new_id @staticmethod def prepare_items(items, pad=True): new_items = [] for item in items: item = item.strip('\n') # assert that the number of sequences is a multiple of 2 # otherwise we can't .reshape() it to be two-dimensional later on next_biggest = align(len(item) + 1, 2) # initialize with padding (-1) new_item = -np.ones(next_biggest, dtype=np.int32) for i, c in enumerate(item): new_item[i] = ord(c) new_items.append(new_item) # add an extra padding item to the head and tail # to make it easier to convert from sequences back to items later on if pad: pad = -np.ones(1, dtype=np.int32) new_items.insert(0, pad) new_items.append(pad) return np.concatenate(new_items) def stat_tokens(self, all_items, skip_normal=False): unique, counts = np.unique(all_items, return_counts=True) count_order = np.argsort(counts)[::-1] counts_descending = counts[count_order] unique_descending = unique[count_order] for i, token_id in enumerate(unique_descending): if token_id == -1: continue if skip_normal and token_id >= 0: continue token = self.resolve_tokens(token_id) lament("token id {:5} occurs {:8} times: \"{}\"".format( token_id, counts_descending[i], token)) lament("total tokens: {:5}".format(i + 1)) def merge_all(self, all_items, merges, min_count=2): # set up a 2d array to step through at half the row length; # this means double redundancy; to acquire all the sequences. # we could instead .roll it later to get the other half. # that would require less memory, but memory isn't really a concern. sequences = all_items.repeat(2)[1:-1].reshape(-1, 2).copy() for i in range(merges): invalid = np.any(sequences == -1, axis=1) valid_sequences = np.delete(sequences, np.where(invalid), axis=0) unique, counts = uniq_rows(valid_sequences, return_counts=True) count = counts.max() most_common = (None, 1) if count > most_common[1]: seq = unique[counts == count][0] most_common = (seq, count) if most_common[0] is None or most_common[1] <= 1 or most_common[1] < min_count: lament('no more valid sequences') break token_value = "".join(self.resolve_tokens(most_common[0])) new_id = self.new_token(token_value) # replace the most common two-token sequence # with one token to represent both found = np.all(sequences == most_common[0], axis=1) before = np.roll(found, -1) after = np.roll(found, 1) # don't wrap around truth values before[-1] = False after[0] = False # remove the "found" sequences # and update the previous/next, # not unlike a doubly-linked list. befores = sequences[before].T.copy() befores[1] = new_id sequences[before] = befores.T afters = sequences[after].T.copy() afters[0] = new_id sequences[after] = afters.T here = np.where(found) sequences = np.delete(sequences, here, axis=0) lament("new token id {:5} occurs {:8} times: \"{}\"".format( new_id, len(here[0]), self.tokens[new_id])) # reconstruct all_items out of the sequences all_items = sequences.reshape(-1)[::2][1:].copy() return all_items def learn_all(self, items, merges=0, stat=True): min_count = 2 # minimum number of occurences to stop creating tokens at if merges < 0: min_count = -merges merges = 65536 # arbitrary sanity value # we'll use numpy matrices so this isn't nearly as disgustingly slow self.tokens = {-1: ''} # default with an empty padding token all_items = self.prepare_items(items) if merges > 0: all_items = self.merge_all(all_items, merges, min_count) # begin the actual learning self.reset() np_item = [] for i in all_items: if i == -1: if len(np_item) == 0: continue item = tuple() for i in np_item: if i < 0: assert(i != -1) item += self.helper(self.tokens[i]) else: item += self.helper(chr(i)) #die(np_item, item) self.learn(item) np_item = [] else: np_item.append(i) self.update() if merges != 0 and stat: self.stat_tokens(all_items) def run(pname, args, env): if not 1 <= len(args) <= 2: lament("usage: {} {{input file}} [savestate file]".format(pname)) return 1 args = dict(enumerate(args)) # just for the .get() method fn = args[0] state_fn = args.get(1, None) # the number of lines to output. count = int(env.get('COUNT', '8')) # learn and sample using this number of sequential tokens. order = int(env.get('ORDER', '2')) # how experimental to be with sampling. # probably doesn't work properly. temperature = float(env.get('TEMPERATURE', '0.5')) # the max character length of output. (not guaranteed) maxn = int(env.get('MAXN', '240')) # attempts to maximize scoring attempts = int(env.get('ATTEMPTS', '-1')) # if positive, maximum number of tokens to merge. # if negative, minimum number of occurences to stop at. merges = int(env.get('MERGES', '0')) if attempts <= 0: attempts = None brain = PatternBrain(order=order, temperature=temperature) tool = ATTTT(brain) if state_fn: lament('# loading') try: brain.load(state_fn, raw=False) except FileNotFoundError: lament('# no file to load. skipping') pass if brain and brain.new: lament('# learning') lines = open(fn).readlines() brain.learn_all(lines, merges) if brain and brain.new and state_fn: lament('# saving') brain.save(state_fn, raw=False) lament('# replying') for i in range(count): #reply = tool.reply(maxn=maxn, raw=True, attempts=attempts) #print('{:6.1f}\t{}'.format(reply[1], reply[0])) print(tool.reply(maxn=maxn, attempts=attempts)) return 0 if __name__ == '__main__': import sys import os pname = len(sys.argv) > 0 and sys.argv[0] or '' args = len(sys.argv) > 1 and sys.argv[1:] or [] sys.exit(run(pname, args, os.environ))