.
This commit is contained in:
parent
e61a32c615
commit
b028ee53d9
2 changed files with 16 additions and 34 deletions
10
atttt.py
10
atttt.py
|
@ -60,7 +60,7 @@ class ATTTT():
|
||||||
class PatternBrain(Brain):
|
class PatternBrain(Brain):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, padding='~', **kwargs)
|
||||||
self.tokens = []
|
self.tokens = []
|
||||||
|
|
||||||
|
|
||||||
|
@ -147,11 +147,6 @@ class PatternBrain(Brain):
|
||||||
token_value = "".join(self.resolve_tokens(most_common[0]))
|
token_value = "".join(self.resolve_tokens(most_common[0]))
|
||||||
new_id = self.new_token(token_value)
|
new_id = self.new_token(token_value)
|
||||||
|
|
||||||
if len("".join(self.tokens.values())) > len(all_items):
|
|
||||||
# this might not ever occur
|
|
||||||
lament('preventing token dictionary from growing larger than source')
|
|
||||||
break
|
|
||||||
|
|
||||||
# replace the most common two-token sequence
|
# replace the most common two-token sequence
|
||||||
# with one token to represent both
|
# with one token to represent both
|
||||||
found = np.all(sequences == most_common[0], axis=1)
|
found = np.all(sequences == most_common[0], axis=1)
|
||||||
|
@ -175,8 +170,6 @@ class PatternBrain(Brain):
|
||||||
lament("new token id {:5} occurs {:8} times: \"{}\"".format(
|
lament("new token id {:5} occurs {:8} times: \"{}\"".format(
|
||||||
new_id, len(here[0]), self.tokens[new_id]))
|
new_id, len(here[0]), self.tokens[new_id]))
|
||||||
|
|
||||||
# TODO: find unused tokens?
|
|
||||||
|
|
||||||
# reconstruct all_items out of the sequences
|
# reconstruct all_items out of the sequences
|
||||||
all_items = sequences.reshape(-1)[::2][1:].copy()
|
all_items = sequences.reshape(-1)[::2][1:].copy()
|
||||||
return all_items
|
return all_items
|
||||||
|
@ -198,7 +191,6 @@ class PatternBrain(Brain):
|
||||||
all_items = self.merge_all(all_items, merges, min_count)
|
all_items = self.merge_all(all_items, merges, min_count)
|
||||||
|
|
||||||
# begin the actual learning
|
# begin the actual learning
|
||||||
self.padding = '~'
|
|
||||||
self.reset()
|
self.reset()
|
||||||
np_item = []
|
np_item = []
|
||||||
for i in all_items:
|
for i in all_items:
|
||||||
|
|
40
basic.py
40
basic.py
|
@ -9,25 +9,24 @@ def normalize(counter):
|
||||||
s = float(sum(v))
|
s = float(sum(v))
|
||||||
m = float(max(v))
|
m = float(max(v))
|
||||||
del v
|
del v
|
||||||
d = {}
|
return [(c, cnt/s, cnt/m) for c, cnt in counter.items()]
|
||||||
for c, cnt in counter.items():
|
|
||||||
d[c] = (cnt/s, cnt/m)
|
|
||||||
return d
|
|
||||||
# return [(c, cnt/s, cnt/m) for c, cnt in counter.items()]
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_sorted(counter):
|
def normalize_sorted(counter):
|
||||||
# mostly just for debugging i guess?
|
# if the elements were unsorted,
|
||||||
|
# we couldn't use our lazy method (subtraction) of selecting tokens
|
||||||
|
# and temperature would correspond to arbitrary tokens
|
||||||
|
# instead of more/less common tokens.
|
||||||
return sorted(normalize(counter), key=lambda t: t[1], reverse=True)
|
return sorted(normalize(counter), key=lambda t: t[1], reverse=True)
|
||||||
|
|
||||||
|
|
||||||
# http://nbviewer.jupyter.org/gist/yoavg/d76121dfde2618422139
|
# http://nbviewer.jupyter.org/gist/yoavg/d76121dfde2618422139
|
||||||
class Brain:
|
class Brain:
|
||||||
|
|
||||||
def __init__(self, order=1, temperature=0.5):
|
def __init__(self, padding, order=1, temperature=0.5):
|
||||||
self.order = order
|
self.order = order
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.padding = None
|
self.padding = padding
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
@ -51,21 +50,13 @@ class Brain:
|
||||||
|
|
||||||
@temperature.setter
|
@temperature.setter
|
||||||
def temperature(self, value):
|
def temperature(self, value):
|
||||||
|
assert(0 < value < 1)
|
||||||
self._temperature = value
|
self._temperature = value
|
||||||
|
|
||||||
if value == 1:
|
a = 1 - value * 2
|
||||||
# TODO: proper distribution stuff
|
# http://www.mathopenref.com/graphfunctions.html?fx=(a*x-x)/(2*a*x-a-1)&sg=f&sh=f&xh=1&xl=0&yh=1&yl=0&ah=1&al=-1&a=0.5
|
||||||
self.random = lambda count: np.random.random(count)**2
|
tweak = lambda x: (a * x - x) / (2 * a * x - a - 1)
|
||||||
elif value == 0:
|
self.random = lambda n: 1 - tweak(np.random.random(n))
|
||||||
self.random = np.random.random
|
|
||||||
else:
|
|
||||||
# +0.25 = -0.0
|
|
||||||
# +0.50 = +0.5
|
|
||||||
# +0.75 = +1.0
|
|
||||||
point75 = 1
|
|
||||||
const = (point75 * 2 - 1) / math.atanh(0.75 * 2 - 1)
|
|
||||||
unbound = (math.atanh((1 - value) * 2 - 1) * const + 1) / 2
|
|
||||||
self.random = easytruncnorm(0, 1, unbound, 0.25).rvs
|
|
||||||
|
|
||||||
|
|
||||||
def learn_all(self, items):
|
def learn_all(self, items):
|
||||||
|
@ -102,7 +93,7 @@ class Brain:
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
if self.dirty and self._machine:
|
if self.dirty and self._machine:
|
||||||
self.machine = {hist:normalize(items)
|
self.machine = {hist: normalize_sorted(items)
|
||||||
for hist, items in self._machine.items()}
|
for hist, items in self._machine.items()}
|
||||||
self.dirty = False
|
self.dirty = False
|
||||||
|
|
||||||
|
@ -116,9 +107,8 @@ class Brain:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
x = self.random(1)
|
x = self.random(1)
|
||||||
for c, v in dist.items():
|
for c, cs, cm in dist:
|
||||||
# if x <= v: # this is a bad idea
|
x = x - cs
|
||||||
x = x - v[0]
|
|
||||||
if x <= 0:
|
if x <= 0:
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue