2016-05-24 20:15:26 -07:00
|
|
|
import math
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from misc import *
|
|
|
|
|
|
|
|
|
|
|
|
def normalize(counter):
|
|
|
|
v = counter.values()
|
|
|
|
s = float(sum(v))
|
|
|
|
m = float(max(v))
|
|
|
|
del v
|
|
|
|
d = {}
|
|
|
|
for c, cnt in counter.items():
|
|
|
|
d[c] = (cnt/s, cnt/m)
|
|
|
|
return d
|
|
|
|
# return [(c, cnt/s, cnt/m) for c, cnt in counter.items()]
|
|
|
|
|
|
|
|
|
|
|
|
def normalize_sorted(counter):
|
|
|
|
# mostly just for debugging i guess?
|
|
|
|
return sorted(normalize(counter), key=lambda t: t[1], reverse=True)
|
|
|
|
|
|
|
|
|
|
|
|
# http://nbviewer.jupyter.org/gist/yoavg/d76121dfde2618422139
|
|
|
|
class Brain:
|
|
|
|
|
2016-05-25 07:31:48 -07:00
|
|
|
def __init__(self, order=1, temperature=0.5):
|
2016-05-24 20:15:26 -07:00
|
|
|
self.order = order
|
|
|
|
self.temperature = temperature
|
2016-05-25 07:31:48 -07:00
|
|
|
self.padding = None
|
2016-05-24 20:15:26 -07:00
|
|
|
|
|
|
|
self.reset()
|
|
|
|
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
import collections as cool
|
|
|
|
# unnormalized
|
|
|
|
self._machine = cool.defaultdict(cool.Counter)
|
|
|
|
# normalized
|
|
|
|
self.machine = None
|
|
|
|
|
|
|
|
self.type = None
|
|
|
|
self.dirty = False
|
|
|
|
self.new = True
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
def temperature(self):
|
|
|
|
return self._temperature
|
|
|
|
|
|
|
|
|
|
|
|
@temperature.setter
|
|
|
|
def temperature(self, value):
|
|
|
|
self._temperature = value
|
|
|
|
|
|
|
|
if value == 1:
|
|
|
|
# TODO: proper distribution stuff
|
|
|
|
self.random = lambda count: np.random.random(count)**2
|
|
|
|
elif value == 0:
|
|
|
|
self.random = np.random.random
|
|
|
|
else:
|
|
|
|
# +0.25 = -0.0
|
|
|
|
# +0.50 = +0.5
|
|
|
|
# +0.75 = +1.0
|
|
|
|
point75 = 1
|
|
|
|
const = (point75 * 2 - 1) / math.atanh(0.75 * 2 - 1)
|
|
|
|
unbound = (math.atanh((1 - value) * 2 - 1) * const + 1) / 2
|
|
|
|
self.random = easytruncnorm(0, 1, unbound, 0.25).rvs
|
|
|
|
|
|
|
|
|
|
|
|
def learn_all(self, items):
|
|
|
|
for item in items:
|
|
|
|
self.learn(item)
|
|
|
|
self.update()
|
|
|
|
|
|
|
|
|
|
|
|
def learn(self, item):
|
2016-05-25 07:31:48 -07:00
|
|
|
assert(self.padding)
|
|
|
|
|
2016-05-24 20:15:26 -07:00
|
|
|
if self.type is None and item is not None:
|
|
|
|
self.type = type(item)
|
|
|
|
if type(item) is not self.type:
|
|
|
|
raise Exception("that's no good")
|
|
|
|
|
|
|
|
if self.type == type("string"):
|
|
|
|
item = item.strip()
|
|
|
|
|
|
|
|
if len(item) == 0:
|
|
|
|
return
|
|
|
|
|
|
|
|
pad = self.helper(self.padding) * self.order
|
|
|
|
item = pad + item + pad
|
|
|
|
|
|
|
|
stop = len(item) - self.order
|
|
|
|
if stop > 0:
|
|
|
|
for i in range(stop):
|
|
|
|
history, newitem = item[i:i+self.order], item[i+self.order]
|
|
|
|
self._machine[history][newitem] += 1
|
|
|
|
|
|
|
|
self.dirty = True
|
|
|
|
|
|
|
|
|
|
|
|
def update(self):
|
|
|
|
if self.dirty and self._machine:
|
|
|
|
self.machine = {hist:normalize(items)
|
|
|
|
for hist, items in self._machine.items()}
|
|
|
|
self.dirty = False
|
|
|
|
|
|
|
|
|
|
|
|
def next(self, history):
|
|
|
|
history = history[-self.order:]
|
|
|
|
|
|
|
|
dist = self.machine.get(history, None)
|
|
|
|
if dist == None:
|
|
|
|
lament('warning: no value: {}'.format(history))
|
|
|
|
return None
|
|
|
|
|
|
|
|
x = self.random(1)
|
|
|
|
for c, v in dist.items():
|
|
|
|
# if x <= v: # this is a bad idea
|
|
|
|
x = x - v[0]
|
|
|
|
if x <= 0:
|
|
|
|
return c
|
|
|
|
|
|
|
|
|
2016-05-25 07:31:48 -07:00
|
|
|
# for overriding in subclasses
|
|
|
|
# in case the input tokens aren't strings (e.g. tuples)
|
2016-05-24 20:15:26 -07:00
|
|
|
def helper(self, v):
|
|
|
|
return v
|
|
|
|
|
|
|
|
|
|
|
|
def reply(self, item=None, maxn=1000):
|
2016-05-25 07:31:48 -07:00
|
|
|
assert(self.padding)
|
2016-05-24 20:15:26 -07:00
|
|
|
self.update()
|
|
|
|
|
|
|
|
history = self.helper(self.padding) * self.order
|
|
|
|
|
|
|
|
out = []
|
|
|
|
for i in range(maxn):
|
|
|
|
c = self.next(history)
|
|
|
|
if c.find(self.padding) != -1:
|
|
|
|
out.append(c.replace(self.padding, ''))
|
|
|
|
break
|
|
|
|
history = history[-self.order:] + self.helper(c)
|
|
|
|
out.append(c)
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
def load(self, fn, raw=True):
|
|
|
|
import pickle
|
|
|
|
if type(fn) == type(''):
|
|
|
|
f = open(fn, 'rb')
|
|
|
|
else:
|
|
|
|
f = fn
|
|
|
|
|
|
|
|
d = pickle.load(f)
|
|
|
|
|
|
|
|
if d['order'] != self.order:
|
|
|
|
lament('warning: order mismatch. cancelling load.')
|
|
|
|
return
|
|
|
|
self.order = d['order']
|
|
|
|
|
|
|
|
if raw:
|
|
|
|
if not d.get('_machine'):
|
|
|
|
lament('warning: no _machine. cancelling load.')
|
|
|
|
return
|
|
|
|
self._machine = d['_machine']
|
|
|
|
|
|
|
|
self.dirty = True
|
|
|
|
self.update()
|
|
|
|
else:
|
|
|
|
if not d.get('machine'):
|
|
|
|
lament('warning: no machine. cancelling load.')
|
|
|
|
return
|
|
|
|
self.machine = d['machine']
|
|
|
|
|
|
|
|
self.new = False
|
|
|
|
if f != fn:
|
|
|
|
f.close()
|
|
|
|
|
|
|
|
|
|
|
|
def save(self, fn, raw=True):
|
|
|
|
import pickle
|
|
|
|
if type(fn) == type(''):
|
|
|
|
f = open(fn, 'wb')
|
|
|
|
else:
|
|
|
|
f = fn
|
|
|
|
|
|
|
|
d = {}
|
|
|
|
d['order'] = self.order
|
|
|
|
if raw:
|
|
|
|
d['_machine'] = self._machine
|
|
|
|
else:
|
|
|
|
d['machine'] = self.machine
|
|
|
|
pickle.dump(d, f)
|
|
|
|
|
|
|
|
if f != fn:
|
|
|
|
f.close()
|