gists/atttt/basic.py
2019-03-11 06:50:26 +01:00

189 lines
4.6 KiB
Python
Executable file

import math
import numpy as np
from misc import *
def normalize(counter):
v = counter.values()
s = float(sum(v))
m = float(max(v))
del v
return [(c, cnt/s, cnt/m) for c, cnt in counter.items()]
def normalize_sorted(counter):
# if the elements were unsorted,
# we couldn't use our lazy method (subtraction) of selecting tokens
# and temperature would correspond to arbitrary tokens
# instead of more/less common tokens.
return sorted(normalize(counter), key=lambda t: t[1], reverse=True)
# http://nbviewer.jupyter.org/gist/yoavg/d76121dfde2618422139
class Brain:
def __init__(self, padding, order=1, temperature=0.5):
self.order = order
self.temperature = temperature
self.padding = padding
self.reset()
def reset(self):
import collections as cool
# unnormalized
self._machine = cool.defaultdict(cool.Counter)
# normalized
self.machine = None
self.type = None
self.dirty = False
self.new = True
@property
def temperature(self):
return self._temperature
@temperature.setter
def temperature(self, value):
assert(0 < value < 1)
self._temperature = value
a = 1 - value * 2
# http://www.mathopenref.com/graphfunctions.html?fx=(a*x-x)/(2*a*x-a-1)&sg=f&sh=f&xh=1&xl=0&yh=1&yl=0&ah=1&al=-1&a=0.5
tweak = lambda x: (a * x - x) / (2 * a * x - a - 1)
self.random = lambda n: 1 - tweak(np.random.random(n))
def learn_all(self, items):
for item in items:
self.learn(item)
self.update()
def learn(self, item):
assert(self.padding)
if self.type is None and item is not None:
self.type = type(item)
if type(item) is not self.type:
raise Exception("that's no good")
if self.type == type("string"):
item = item.strip()
if len(item) == 0:
return
pad = self.helper(self.padding) * self.order
item = pad + item + pad
stop = len(item) - self.order
if stop > 0:
for i in range(stop):
history, newitem = item[i:i+self.order], item[i+self.order]
self._machine[history][newitem] += 1
self.dirty = True
def update(self):
if self.dirty and self._machine:
self.machine = {hist: normalize_sorted(items)
for hist, items in self._machine.items()}
self.dirty = False
def next(self, history):
history = history[-self.order:]
dist = self.machine.get(history, None)
if dist == None:
lament('warning: no value: {}'.format(history))
return None
x = self.random(1)
for c, cs, cm in dist:
x = x - cs
if x <= 0:
return c
# for overriding in subclasses
# in case the input tokens aren't strings (e.g. tuples)
def helper(self, v):
return v
def reply(self, item=None, maxn=1000):
assert(self.padding)
self.update()
history = self.helper(self.padding) * self.order
out = []
for i in range(maxn):
c = self.next(history)
if c.find(self.padding) != -1:
out.append(c.replace(self.padding, ''))
break
history = history[-self.order:] + self.helper(c)
out.append(c)
return out
def load(self, fn, raw=True):
import pickle
if type(fn) == type(''):
f = open(fn, 'rb')
else:
f = fn
d = pickle.load(f)
if d['order'] != self.order:
lament('warning: order mismatch. cancelling load.')
return
self.order = d['order']
if raw:
if not d.get('_machine'):
lament('warning: no _machine. cancelling load.')
return
self._machine = d['_machine']
self.dirty = True
self.update()
else:
if not d.get('machine'):
lament('warning: no machine. cancelling load.')
return
self.machine = d['machine']
self.new = False
if f != fn:
f.close()
def save(self, fn, raw=True):
import pickle
if type(fn) == type(''):
f = open(fn, 'wb')
else:
f = fn
d = {}
d['order'] = self.order
if raw:
d['_machine'] = self._machine
else:
d['machine'] = self.machine
pickle.dump(d, f)
if f != fn:
f.close()