189 lines
4.6 KiB
Python
Executable file
189 lines
4.6 KiB
Python
Executable file
import math
|
|
import numpy as np
|
|
|
|
from misc import *
|
|
|
|
|
|
def normalize(counter):
|
|
v = counter.values()
|
|
s = float(sum(v))
|
|
m = float(max(v))
|
|
del v
|
|
return [(c, cnt/s, cnt/m) for c, cnt in counter.items()]
|
|
|
|
|
|
def normalize_sorted(counter):
|
|
# if the elements were unsorted,
|
|
# we couldn't use our lazy method (subtraction) of selecting tokens
|
|
# and temperature would correspond to arbitrary tokens
|
|
# instead of more/less common tokens.
|
|
return sorted(normalize(counter), key=lambda t: t[1], reverse=True)
|
|
|
|
|
|
# http://nbviewer.jupyter.org/gist/yoavg/d76121dfde2618422139
|
|
class Brain:
|
|
|
|
def __init__(self, padding, order=1, temperature=0.5):
|
|
self.order = order
|
|
self.temperature = temperature
|
|
self.padding = padding
|
|
|
|
self.reset()
|
|
|
|
|
|
def reset(self):
|
|
import collections as cool
|
|
# unnormalized
|
|
self._machine = cool.defaultdict(cool.Counter)
|
|
# normalized
|
|
self.machine = None
|
|
|
|
self.type = None
|
|
self.dirty = False
|
|
self.new = True
|
|
|
|
|
|
@property
|
|
def temperature(self):
|
|
return self._temperature
|
|
|
|
|
|
@temperature.setter
|
|
def temperature(self, value):
|
|
assert(0 < value < 1)
|
|
self._temperature = value
|
|
|
|
a = 1 - value * 2
|
|
# http://www.mathopenref.com/graphfunctions.html?fx=(a*x-x)/(2*a*x-a-1)&sg=f&sh=f&xh=1&xl=0&yh=1&yl=0&ah=1&al=-1&a=0.5
|
|
tweak = lambda x: (a * x - x) / (2 * a * x - a - 1)
|
|
self.random = lambda n: 1 - tweak(np.random.random(n))
|
|
|
|
|
|
def learn_all(self, items):
|
|
for item in items:
|
|
self.learn(item)
|
|
self.update()
|
|
|
|
|
|
def learn(self, item):
|
|
assert(self.padding)
|
|
|
|
if self.type is None and item is not None:
|
|
self.type = type(item)
|
|
if type(item) is not self.type:
|
|
raise Exception("that's no good")
|
|
|
|
if self.type == type("string"):
|
|
item = item.strip()
|
|
|
|
if len(item) == 0:
|
|
return
|
|
|
|
pad = self.helper(self.padding) * self.order
|
|
item = pad + item + pad
|
|
|
|
stop = len(item) - self.order
|
|
if stop > 0:
|
|
for i in range(stop):
|
|
history, newitem = item[i:i+self.order], item[i+self.order]
|
|
self._machine[history][newitem] += 1
|
|
|
|
self.dirty = True
|
|
|
|
|
|
def update(self):
|
|
if self.dirty and self._machine:
|
|
self.machine = {hist: normalize_sorted(items)
|
|
for hist, items in self._machine.items()}
|
|
self.dirty = False
|
|
|
|
|
|
def next(self, history):
|
|
history = history[-self.order:]
|
|
|
|
dist = self.machine.get(history, None)
|
|
if dist == None:
|
|
lament('warning: no value: {}'.format(history))
|
|
return None
|
|
|
|
x = self.random(1)
|
|
for c, cs, cm in dist:
|
|
x = x - cs
|
|
if x <= 0:
|
|
return c
|
|
|
|
|
|
# for overriding in subclasses
|
|
# in case the input tokens aren't strings (e.g. tuples)
|
|
def helper(self, v):
|
|
return v
|
|
|
|
|
|
def reply(self, item=None, maxn=1000):
|
|
assert(self.padding)
|
|
self.update()
|
|
|
|
history = self.helper(self.padding) * self.order
|
|
|
|
out = []
|
|
for i in range(maxn):
|
|
c = self.next(history)
|
|
if c.find(self.padding) != -1:
|
|
out.append(c.replace(self.padding, ''))
|
|
break
|
|
history = history[-self.order:] + self.helper(c)
|
|
out.append(c)
|
|
|
|
return out
|
|
|
|
|
|
def load(self, fn, raw=True):
|
|
import pickle
|
|
if type(fn) == type(''):
|
|
f = open(fn, 'rb')
|
|
else:
|
|
f = fn
|
|
|
|
d = pickle.load(f)
|
|
|
|
if d['order'] != self.order:
|
|
lament('warning: order mismatch. cancelling load.')
|
|
return
|
|
self.order = d['order']
|
|
|
|
if raw:
|
|
if not d.get('_machine'):
|
|
lament('warning: no _machine. cancelling load.')
|
|
return
|
|
self._machine = d['_machine']
|
|
|
|
self.dirty = True
|
|
self.update()
|
|
else:
|
|
if not d.get('machine'):
|
|
lament('warning: no machine. cancelling load.')
|
|
return
|
|
self.machine = d['machine']
|
|
|
|
self.new = False
|
|
if f != fn:
|
|
f.close()
|
|
|
|
|
|
def save(self, fn, raw=True):
|
|
import pickle
|
|
if type(fn) == type(''):
|
|
f = open(fn, 'wb')
|
|
else:
|
|
f = fn
|
|
|
|
d = {}
|
|
d['order'] = self.order
|
|
if raw:
|
|
d['_machine'] = self._machine
|
|
else:
|
|
d['machine'] = self.machine
|
|
pickle.dump(d, f)
|
|
|
|
if f != fn:
|
|
f.close()
|