.
This commit is contained in:
parent
db3171ac29
commit
28edd29072
4 changed files with 445 additions and 1 deletions
1
.dummy
1
.dummy
|
@ -1 +0,0 @@
|
||||||
.
|
|
231
atttt.py
Executable file
231
atttt.py
Executable file
|
@ -0,0 +1,231 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from misc import *
|
||||||
|
from basic import Brain
|
||||||
|
|
||||||
|
|
||||||
|
def uniq_rows(a, return_index=False, return_inverse=False, return_counts=False):
|
||||||
|
# black magic wrapper around np.unique
|
||||||
|
# via np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
|
||||||
|
return_any = return_index or return_inverse or return_counts
|
||||||
|
if not return_any:
|
||||||
|
np.unique(a.view(np.dtype((np.void, a.dtype.itemsize * a.shape[1])))).view(a.dtype).reshape(-1, a.shape[1])
|
||||||
|
else:
|
||||||
|
void_dtype = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
|
||||||
|
ret = np.unique(a.view(void_dtype), return_index, return_inverse, return_counts)
|
||||||
|
return (ret[0].view(a.dtype).reshape(-1, a.shape[1]),) + ret[1:]
|
||||||
|
|
||||||
|
|
||||||
|
class ATTTT():
|
||||||
|
|
||||||
|
def __init__(self, brain):
|
||||||
|
self.brain = brain
|
||||||
|
self.score = self._score
|
||||||
|
|
||||||
|
|
||||||
|
def _score(self, reply, maxn):
|
||||||
|
if len(reply) > maxn:
|
||||||
|
return -999999999
|
||||||
|
|
||||||
|
#return len(reply)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def reply(self, item=None, maxn=1000, raw=False, attempts=None):
|
||||||
|
if attempts == None:
|
||||||
|
attempts = int(2**12 / self.brain.order)
|
||||||
|
lament('attempts:', attempts)
|
||||||
|
|
||||||
|
replies = []
|
||||||
|
for i in range(attempts):
|
||||||
|
reply = "".join(self.brain.reply(item=item, maxn=maxn+1))
|
||||||
|
replies += [(reply, self.score(reply, maxn))]
|
||||||
|
|
||||||
|
result = sorted(replies, key=lambda t: t[1], reverse=True)[0]
|
||||||
|
|
||||||
|
if raw:
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
return result[0]
|
||||||
|
|
||||||
|
|
||||||
|
class PatternBrain(Brain):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.tokens = []
|
||||||
|
|
||||||
|
|
||||||
|
def helper(self, v):
|
||||||
|
return (v,)
|
||||||
|
|
||||||
|
|
||||||
|
def learn_all(self, items, merges=1):
|
||||||
|
min_count = 2
|
||||||
|
if merges < 0:
|
||||||
|
min_count = -merges
|
||||||
|
merges = 65536
|
||||||
|
|
||||||
|
# use numpy so this isn't nearly as disgustingly slow
|
||||||
|
|
||||||
|
int32_min = -2**(np.dtype(np.int32).itemsize * 8 - 1)
|
||||||
|
empty = int32_min
|
||||||
|
neg_lookup = {-1: ''} # default with padding
|
||||||
|
|
||||||
|
alignment = 2
|
||||||
|
align = lambda x: (x + alignment // 2) // alignment * alignment
|
||||||
|
|
||||||
|
new_items = []
|
||||||
|
for item in items:
|
||||||
|
item = item.strip('\n')
|
||||||
|
# assert at least 1 padding character at the end
|
||||||
|
next_biggest = align(len(item) + 1)
|
||||||
|
# fill with padding (-1)
|
||||||
|
new_item = -np.ones(next_biggest, dtype=np.int32)
|
||||||
|
for i, c in enumerate(item):
|
||||||
|
new_item[i] = ord(c)
|
||||||
|
new_items.append(new_item)
|
||||||
|
|
||||||
|
# add an extra padding item to the head and tail
|
||||||
|
# for easier conversion from sequence back to all_items later on
|
||||||
|
pad = -np.ones(1, dtype=np.int32)
|
||||||
|
new_items.insert(0, pad)
|
||||||
|
new_items.append(pad)
|
||||||
|
|
||||||
|
all_items = np.concatenate(new_items)
|
||||||
|
|
||||||
|
if merges > 0:
|
||||||
|
# set up a 2d array to step through at half the row length,
|
||||||
|
# this means double redundancy, to acquire all the sequences.
|
||||||
|
# we don't have to .roll it later to get the other half,
|
||||||
|
# though that would require less memory.
|
||||||
|
sequences = all_items.repeat(2)[1:-1].reshape(-1, 2).copy()
|
||||||
|
|
||||||
|
for i in range(merges):
|
||||||
|
# learn
|
||||||
|
most_common = (None, 1)
|
||||||
|
# TODO: eventually check for empty here too
|
||||||
|
invalid = np.any(sequences == -1, axis=1)
|
||||||
|
valid_sequences = np.delete(sequences, np.where(invalid), axis=0)
|
||||||
|
unique, counts = uniq_rows(valid_sequences, return_counts=True)
|
||||||
|
count = counts.max()
|
||||||
|
|
||||||
|
if count > most_common[1]:
|
||||||
|
seq = unique[counts == count][0]
|
||||||
|
most_common = (seq, count)
|
||||||
|
|
||||||
|
if most_common[0] is None or most_common[1] <= 1 or most_common[1] < min_count:
|
||||||
|
lament('no more valid sequences')
|
||||||
|
break
|
||||||
|
|
||||||
|
new_id = -1 - len(neg_lookup)
|
||||||
|
neg_lookup[new_id] = "".join([o < 0 and neg_lookup[o] or chr(o) for o in most_common[0]])
|
||||||
|
|
||||||
|
if len("".join(neg_lookup.values())) > len(all_items):
|
||||||
|
lament('preventing dict from growing larger than source')
|
||||||
|
break
|
||||||
|
|
||||||
|
# replace our most common sequence in the sequences
|
||||||
|
found = np.all(sequences == most_common[0], axis=1)
|
||||||
|
before = np.roll(found, -1)
|
||||||
|
after = np.roll(found, 1)
|
||||||
|
# don't wrap around truth values
|
||||||
|
before[-1] = False
|
||||||
|
after[0] = False
|
||||||
|
# or remove padding
|
||||||
|
#before[0] = False
|
||||||
|
#after[-1] = False
|
||||||
|
# remove the "found" sequences
|
||||||
|
befores = sequences[before].T.copy()
|
||||||
|
befores[1] = new_id
|
||||||
|
sequences[before] = befores.T
|
||||||
|
afters = sequences[after].T.copy()
|
||||||
|
afters[0] = new_id
|
||||||
|
sequences[after] = afters.T
|
||||||
|
#sequences[found] = [empty, empty]
|
||||||
|
here = np.where(found)
|
||||||
|
sequences = np.delete(sequences, here, axis=0)
|
||||||
|
|
||||||
|
print("({:8}) new token: {:5} \"{}\"".format(len(here[0]), new_id, neg_lookup[new_id]))
|
||||||
|
|
||||||
|
if merges > 0:
|
||||||
|
# reconstruct all_items out of the sequences
|
||||||
|
all_items = sequences.reshape(-1)[::2][1:].copy()
|
||||||
|
|
||||||
|
self.padding = '~'
|
||||||
|
self.reset()
|
||||||
|
np_item = []
|
||||||
|
for i in all_items:
|
||||||
|
#for np_item in np.split(all_items, np.where(all_items == -1)):
|
||||||
|
if i == -1:
|
||||||
|
if len(np_item) == 0:
|
||||||
|
continue
|
||||||
|
item = tuple()
|
||||||
|
for i in np_item:
|
||||||
|
if i < 0:
|
||||||
|
assert(i != -1)
|
||||||
|
item += self.helper(neg_lookup[i])
|
||||||
|
else:
|
||||||
|
item += self.helper(chr(i))
|
||||||
|
#die(np_item, item)
|
||||||
|
self.learn(item)
|
||||||
|
np_item = []
|
||||||
|
elif i != empty:
|
||||||
|
np_item.append(i)
|
||||||
|
self.update()
|
||||||
|
|
||||||
|
|
||||||
|
def run(pname, args, env):
|
||||||
|
if not 1 <= len(args) <= 2:
|
||||||
|
lament("usage: {} {{input file}} [state_fn file]".format(sys.argv[0]))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
args = dict(enumerate(args)) # for .get()
|
||||||
|
|
||||||
|
fn = args[0]
|
||||||
|
state_fn = args.get(1, None)
|
||||||
|
|
||||||
|
count = int(env.get('COUNT', '8'))
|
||||||
|
order = int(env.get('ORDER', '3'))
|
||||||
|
temperature = float(env.get('TEMPERATURE', '0'))
|
||||||
|
maxn = int(env.get('MAXN', '1000'))
|
||||||
|
attempts = int(env.get('ATTEMPTS', '-1'))
|
||||||
|
merges = int(env.get('MERGES', '0'))
|
||||||
|
|
||||||
|
if attempts <= 0:
|
||||||
|
attempts = None
|
||||||
|
|
||||||
|
brain = PatternBrain(order=order, temperature=temperature)
|
||||||
|
tool = ATTTT(brain)
|
||||||
|
|
||||||
|
lament('# loading')
|
||||||
|
if state_fn:
|
||||||
|
try:
|
||||||
|
brain.load(state_fn, raw=False)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if brain and brain.new:
|
||||||
|
lament('# learning')
|
||||||
|
lines = open(fn).readlines()
|
||||||
|
brain.learn_all(lines, merges)
|
||||||
|
|
||||||
|
if brain and brain.new and state_fn:
|
||||||
|
brain.save(state_fn, raw=False)
|
||||||
|
|
||||||
|
lament('# replying')
|
||||||
|
for i in range(count):
|
||||||
|
#reply = tool.reply(maxn=maxn, raw=True, attempts=attempts)
|
||||||
|
#print('{:6.1f}\t{}'.format(reply[1], reply[0]))
|
||||||
|
print(tool.reply(maxn=maxn, attempts=attempts))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
pname = len(sys.argv) > 0 and sys.argv[0] or ''
|
||||||
|
args = len(sys.argv) > 1 and sys.argv[1:] or []
|
||||||
|
sys.exit(run(pname, args, os.environ))
|
196
basic.py
Executable file
196
basic.py
Executable file
|
@ -0,0 +1,196 @@
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from misc import *
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(counter):
|
||||||
|
v = counter.values()
|
||||||
|
s = float(sum(v))
|
||||||
|
m = float(max(v))
|
||||||
|
del v
|
||||||
|
d = {}
|
||||||
|
for c, cnt in counter.items():
|
||||||
|
d[c] = (cnt/s, cnt/m)
|
||||||
|
return d
|
||||||
|
# return [(c, cnt/s, cnt/m) for c, cnt in counter.items()]
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_sorted(counter):
|
||||||
|
# mostly just for debugging i guess?
|
||||||
|
return sorted(normalize(counter), key=lambda t: t[1], reverse=True)
|
||||||
|
|
||||||
|
|
||||||
|
# http://nbviewer.jupyter.org/gist/yoavg/d76121dfde2618422139
|
||||||
|
class Brain:
|
||||||
|
|
||||||
|
# TODO: don't default padding here, but make sure it's set before running
|
||||||
|
# the reason is it's the only place that's specific to a string anymore
|
||||||
|
def __init__(self, order=1, temperature=0.5, padding="~"):
|
||||||
|
self.order = order
|
||||||
|
self.padding = padding
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
import collections as cool
|
||||||
|
# unnormalized
|
||||||
|
self._machine = cool.defaultdict(cool.Counter)
|
||||||
|
# normalized
|
||||||
|
self.machine = None
|
||||||
|
|
||||||
|
self.type = None
|
||||||
|
self.dirty = False
|
||||||
|
self.new = True
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def temperature(self):
|
||||||
|
return self._temperature
|
||||||
|
|
||||||
|
|
||||||
|
@temperature.setter
|
||||||
|
def temperature(self, value):
|
||||||
|
self._temperature = value
|
||||||
|
|
||||||
|
if value == 1:
|
||||||
|
# TODO: proper distribution stuff
|
||||||
|
self.random = lambda count: np.random.random(count)**2
|
||||||
|
elif value == 0:
|
||||||
|
self.random = np.random.random
|
||||||
|
else:
|
||||||
|
# +0.25 = -0.0
|
||||||
|
# +0.50 = +0.5
|
||||||
|
# +0.75 = +1.0
|
||||||
|
point75 = 1
|
||||||
|
const = (point75 * 2 - 1) / math.atanh(0.75 * 2 - 1)
|
||||||
|
unbound = (math.atanh((1 - value) * 2 - 1) * const + 1) / 2
|
||||||
|
self.random = easytruncnorm(0, 1, unbound, 0.25).rvs
|
||||||
|
|
||||||
|
|
||||||
|
def learn_all(self, items):
|
||||||
|
for item in items:
|
||||||
|
self.learn(item)
|
||||||
|
self.update()
|
||||||
|
|
||||||
|
|
||||||
|
def learn(self, item):
|
||||||
|
if self.type is None and item is not None:
|
||||||
|
self.type = type(item)
|
||||||
|
if type(item) is not self.type:
|
||||||
|
raise Exception("that's no good")
|
||||||
|
|
||||||
|
if self.type == type("string"):
|
||||||
|
item = item.strip()
|
||||||
|
|
||||||
|
if len(item) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
pad = self.helper(self.padding) * self.order
|
||||||
|
item = pad + item + pad
|
||||||
|
|
||||||
|
stop = len(item) - self.order
|
||||||
|
if stop > 0:
|
||||||
|
for i in range(stop):
|
||||||
|
history, newitem = item[i:i+self.order], item[i+self.order]
|
||||||
|
self._machine[history][newitem] += 1
|
||||||
|
|
||||||
|
self.dirty = True
|
||||||
|
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
if self.dirty and self._machine:
|
||||||
|
self.machine = {hist:normalize(items)
|
||||||
|
for hist, items in self._machine.items()}
|
||||||
|
self.dirty = False
|
||||||
|
|
||||||
|
|
||||||
|
def next(self, history):
|
||||||
|
history = history[-self.order:]
|
||||||
|
|
||||||
|
dist = self.machine.get(history, None)
|
||||||
|
if dist == None:
|
||||||
|
lament('warning: no value: {}'.format(history))
|
||||||
|
return None
|
||||||
|
|
||||||
|
x = self.random(1)
|
||||||
|
for c, v in dist.items():
|
||||||
|
# if x <= v: # this is a bad idea
|
||||||
|
x = x - v[0]
|
||||||
|
if x <= 0:
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
def helper(self, v):
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def reply(self, item=None, maxn=1000):
|
||||||
|
self.update()
|
||||||
|
|
||||||
|
history = self.helper(self.padding) * self.order
|
||||||
|
|
||||||
|
out = []
|
||||||
|
for i in range(maxn):
|
||||||
|
c = self.next(history)
|
||||||
|
if c.find(self.padding) != -1:
|
||||||
|
out.append(c.replace(self.padding, ''))
|
||||||
|
break
|
||||||
|
history = history[-self.order:] + self.helper(c)
|
||||||
|
out.append(c)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, fn, raw=True):
|
||||||
|
import pickle
|
||||||
|
if type(fn) == type(''):
|
||||||
|
f = open(fn, 'rb')
|
||||||
|
else:
|
||||||
|
f = fn
|
||||||
|
|
||||||
|
d = pickle.load(f)
|
||||||
|
|
||||||
|
if d['order'] != self.order:
|
||||||
|
lament('warning: order mismatch. cancelling load.')
|
||||||
|
return
|
||||||
|
self.order = d['order']
|
||||||
|
|
||||||
|
if raw:
|
||||||
|
if not d.get('_machine'):
|
||||||
|
lament('warning: no _machine. cancelling load.')
|
||||||
|
return
|
||||||
|
self._machine = d['_machine']
|
||||||
|
|
||||||
|
self.dirty = True
|
||||||
|
self.update()
|
||||||
|
else:
|
||||||
|
if not d.get('machine'):
|
||||||
|
lament('warning: no machine. cancelling load.')
|
||||||
|
return
|
||||||
|
self.machine = d['machine']
|
||||||
|
|
||||||
|
self.new = False
|
||||||
|
if f != fn:
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
|
||||||
|
def save(self, fn, raw=True):
|
||||||
|
import pickle
|
||||||
|
if type(fn) == type(''):
|
||||||
|
f = open(fn, 'wb')
|
||||||
|
else:
|
||||||
|
f = fn
|
||||||
|
|
||||||
|
d = {}
|
||||||
|
d['order'] = self.order
|
||||||
|
if raw:
|
||||||
|
d['_machine'] = self._machine
|
||||||
|
else:
|
||||||
|
d['machine'] = self.machine
|
||||||
|
pickle.dump(d, f)
|
||||||
|
|
||||||
|
if f != fn:
|
||||||
|
f.close()
|
18
misc.py
Executable file
18
misc.py
Executable file
|
@ -0,0 +1,18 @@
|
||||||
|
import sys
|
||||||
|
lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def die(*args, **kwargs):
|
||||||
|
lament(*args, **kwargs)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def easytruncnorm(lower=0, upper=1, loc=0.5, scale=0.25):
|
||||||
|
import scipy.stats as stats
|
||||||
|
a = (lower - loc) / scale
|
||||||
|
b = (upper - loc) / scale
|
||||||
|
return stats.truncnorm(a=a, b=b, loc=loc, scale=scale)
|
||||||
|
|
||||||
|
|
||||||
|
# only make some things visible to "from misc import *"
|
||||||
|
__all__ = [o for o in locals() if type(o) != 'module' and not o.startswith('_')]
|
Loading…
Reference in a new issue