gists/basic.py

200 lines
4.8 KiB
Python
Raw Normal View History

2016-05-24 20:15:26 -07:00
import math
import numpy as np
from misc import *
def normalize(counter):
v = counter.values()
s = float(sum(v))
m = float(max(v))
del v
d = {}
for c, cnt in counter.items():
d[c] = (cnt/s, cnt/m)
return d
# return [(c, cnt/s, cnt/m) for c, cnt in counter.items()]
def normalize_sorted(counter):
# mostly just for debugging i guess?
return sorted(normalize(counter), key=lambda t: t[1], reverse=True)
# http://nbviewer.jupyter.org/gist/yoavg/d76121dfde2618422139
class Brain:
2016-05-25 07:31:48 -07:00
def __init__(self, order=1, temperature=0.5):
2016-05-24 20:15:26 -07:00
self.order = order
self.temperature = temperature
2016-05-25 07:31:48 -07:00
self.padding = None
2016-05-24 20:15:26 -07:00
self.reset()
def reset(self):
import collections as cool
# unnormalized
self._machine = cool.defaultdict(cool.Counter)
# normalized
self.machine = None
self.type = None
self.dirty = False
self.new = True
@property
def temperature(self):
return self._temperature
@temperature.setter
def temperature(self, value):
self._temperature = value
if value == 1:
# TODO: proper distribution stuff
self.random = lambda count: np.random.random(count)**2
elif value == 0:
self.random = np.random.random
else:
# +0.25 = -0.0
# +0.50 = +0.5
# +0.75 = +1.0
point75 = 1
const = (point75 * 2 - 1) / math.atanh(0.75 * 2 - 1)
unbound = (math.atanh((1 - value) * 2 - 1) * const + 1) / 2
self.random = easytruncnorm(0, 1, unbound, 0.25).rvs
def learn_all(self, items):
for item in items:
self.learn(item)
self.update()
def learn(self, item):
2016-05-25 07:31:48 -07:00
assert(self.padding)
2016-05-24 20:15:26 -07:00
if self.type is None and item is not None:
self.type = type(item)
if type(item) is not self.type:
raise Exception("that's no good")
if self.type == type("string"):
item = item.strip()
if len(item) == 0:
return
pad = self.helper(self.padding) * self.order
item = pad + item + pad
stop = len(item) - self.order
if stop > 0:
for i in range(stop):
history, newitem = item[i:i+self.order], item[i+self.order]
self._machine[history][newitem] += 1
self.dirty = True
def update(self):
if self.dirty and self._machine:
self.machine = {hist:normalize(items)
for hist, items in self._machine.items()}
self.dirty = False
def next(self, history):
history = history[-self.order:]
dist = self.machine.get(history, None)
if dist == None:
lament('warning: no value: {}'.format(history))
return None
x = self.random(1)
for c, v in dist.items():
# if x <= v: # this is a bad idea
x = x - v[0]
if x <= 0:
return c
2016-05-25 07:31:48 -07:00
# for overriding in subclasses
# in case the input tokens aren't strings (e.g. tuples)
2016-05-24 20:15:26 -07:00
def helper(self, v):
return v
def reply(self, item=None, maxn=1000):
2016-05-25 07:31:48 -07:00
assert(self.padding)
2016-05-24 20:15:26 -07:00
self.update()
history = self.helper(self.padding) * self.order
out = []
for i in range(maxn):
c = self.next(history)
if c.find(self.padding) != -1:
out.append(c.replace(self.padding, ''))
break
history = history[-self.order:] + self.helper(c)
out.append(c)
return out
def load(self, fn, raw=True):
import pickle
if type(fn) == type(''):
f = open(fn, 'rb')
else:
f = fn
d = pickle.load(f)
if d['order'] != self.order:
lament('warning: order mismatch. cancelling load.')
return
self.order = d['order']
if raw:
if not d.get('_machine'):
lament('warning: no _machine. cancelling load.')
return
self._machine = d['_machine']
self.dirty = True
self.update()
else:
if not d.get('machine'):
lament('warning: no machine. cancelling load.')
return
self.machine = d['machine']
self.new = False
if f != fn:
f.close()
def save(self, fn, raw=True):
import pickle
if type(fn) == type(''):
f = open(fn, 'wb')
else:
f = fn
d = {}
d['order'] = self.order
if raw:
d['_machine'] = self._machine
else:
d['machine'] = self.machine
pickle.dump(d, f)
if f != fn:
f.close()