import math import numpy as np from misc import * def normalize(counter): v = counter.values() s = float(sum(v)) m = float(max(v)) del v return [(c, cnt/s, cnt/m) for c, cnt in counter.items()] def normalize_sorted(counter): # if the elements were unsorted, # we couldn't use our lazy method (subtraction) of selecting tokens # and temperature would correspond to arbitrary tokens # instead of more/less common tokens. return sorted(normalize(counter), key=lambda t: t[1], reverse=True) # http://nbviewer.jupyter.org/gist/yoavg/d76121dfde2618422139 class Brain: def __init__(self, padding, order=1, temperature=0.5): self.order = order self.temperature = temperature self.padding = padding self.reset() def reset(self): import collections as cool # unnormalized self._machine = cool.defaultdict(cool.Counter) # normalized self.machine = None self.type = None self.dirty = False self.new = True @property def temperature(self): return self._temperature @temperature.setter def temperature(self, value): assert(0 < value < 1) self._temperature = value a = 1 - value * 2 # http://www.mathopenref.com/graphfunctions.html?fx=(a*x-x)/(2*a*x-a-1)&sg=f&sh=f&xh=1&xl=0&yh=1&yl=0&ah=1&al=-1&a=0.5 tweak = lambda x: (a * x - x) / (2 * a * x - a - 1) self.random = lambda n: 1 - tweak(np.random.random(n)) def learn_all(self, items): for item in items: self.learn(item) self.update() def learn(self, item): assert(self.padding) if self.type is None and item is not None: self.type = type(item) if type(item) is not self.type: raise Exception("that's no good") if self.type == type("string"): item = item.strip() if len(item) == 0: return pad = self.helper(self.padding) * self.order item = pad + item + pad stop = len(item) - self.order if stop > 0: for i in range(stop): history, newitem = item[i:i+self.order], item[i+self.order] self._machine[history][newitem] += 1 self.dirty = True def update(self): if self.dirty and self._machine: self.machine = {hist: normalize_sorted(items) for hist, items in self._machine.items()} self.dirty = False def next(self, history): history = history[-self.order:] dist = self.machine.get(history, None) if dist == None: lament('warning: no value: {}'.format(history)) return None x = self.random(1) for c, cs, cm in dist: x = x - cs if x <= 0: return c # for overriding in subclasses # in case the input tokens aren't strings (e.g. tuples) def helper(self, v): return v def reply(self, item=None, maxn=1000): assert(self.padding) self.update() history = self.helper(self.padding) * self.order out = [] for i in range(maxn): c = self.next(history) if c.find(self.padding) != -1: out.append(c.replace(self.padding, '')) break history = history[-self.order:] + self.helper(c) out.append(c) return out def load(self, fn, raw=True): import pickle if type(fn) == type(''): f = open(fn, 'rb') else: f = fn d = pickle.load(f) if d['order'] != self.order: lament('warning: order mismatch. cancelling load.') return self.order = d['order'] if raw: if not d.get('_machine'): lament('warning: no _machine. cancelling load.') return self._machine = d['_machine'] self.dirty = True self.update() else: if not d.get('machine'): lament('warning: no machine. cancelling load.') return self.machine = d['machine'] self.new = False if f != fn: f.close() def save(self, fn, raw=True): import pickle if type(fn) == type(''): f = open(fn, 'wb') else: f = fn d = {} d['order'] = self.order if raw: d['_machine'] = self._machine else: d['machine'] = self.machine pickle.dump(d, f) if f != fn: f.close()