.
This commit is contained in:
parent
28edd29072
commit
eeb5d2941e
3 changed files with 100 additions and 68 deletions
156
atttt.py
156
atttt.py
|
@ -8,8 +8,8 @@ from basic import Brain
|
||||||
|
|
||||||
|
|
||||||
def uniq_rows(a, return_index=False, return_inverse=False, return_counts=False):
|
def uniq_rows(a, return_index=False, return_inverse=False, return_counts=False):
|
||||||
|
# via http://stackoverflow.com/a/16973510
|
||||||
# black magic wrapper around np.unique
|
# black magic wrapper around np.unique
|
||||||
# via np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
|
|
||||||
return_any = return_index or return_inverse or return_counts
|
return_any = return_index or return_inverse or return_counts
|
||||||
if not return_any:
|
if not return_any:
|
||||||
np.unique(a.view(np.dtype((np.void, a.dtype.itemsize * a.shape[1])))).view(a.dtype).reshape(-1, a.shape[1])
|
np.unique(a.view(np.dtype((np.void, a.dtype.itemsize * a.shape[1])))).view(a.dtype).reshape(-1, a.shape[1])
|
||||||
|
@ -34,8 +34,9 @@ class ATTTT():
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
def reply(self, item=None, maxn=1000, raw=False, attempts=None):
|
def reply(self, item=None, maxn=1000, include_scores=False, attempts=None):
|
||||||
if attempts == None:
|
if attempts == None:
|
||||||
|
# just guess some value that'll take roughly the same amount of time
|
||||||
attempts = int(2**12 / self.brain.order)
|
attempts = int(2**12 / self.brain.order)
|
||||||
lament('attempts:', attempts)
|
lament('attempts:', attempts)
|
||||||
|
|
||||||
|
@ -46,7 +47,7 @@ class ATTTT():
|
||||||
|
|
||||||
result = sorted(replies, key=lambda t: t[1], reverse=True)[0]
|
result = sorted(replies, key=lambda t: t[1], reverse=True)[0]
|
||||||
|
|
||||||
if raw:
|
if include_scores:
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
return result[0]
|
return result[0]
|
||||||
|
@ -63,56 +64,31 @@ class PatternBrain(Brain):
|
||||||
return (v,)
|
return (v,)
|
||||||
|
|
||||||
|
|
||||||
def learn_all(self, items, merges=1):
|
def resolve_tokens(self, tokens):
|
||||||
min_count = 2
|
# positive values are just unicode characters
|
||||||
if merges < 0:
|
return [o < 0 and self.tokens[o] or chr(o) for o in tokens]
|
||||||
min_count = -merges
|
|
||||||
merges = 65536
|
|
||||||
|
|
||||||
# use numpy so this isn't nearly as disgustingly slow
|
|
||||||
|
|
||||||
int32_min = -2**(np.dtype(np.int32).itemsize * 8 - 1)
|
def new_token(self, value):
|
||||||
empty = int32_min
|
new_id = -1 - len(self.tokens)
|
||||||
neg_lookup = {-1: ''} # default with padding
|
self.tokens[new_id] = value
|
||||||
|
return new_id
|
||||||
|
|
||||||
alignment = 2
|
|
||||||
align = lambda x: (x + alignment // 2) // alignment * alignment
|
|
||||||
|
|
||||||
new_items = []
|
def merge_all(self, all_items, merges, min_count=2):
|
||||||
for item in items:
|
# set up a 2d array to step through at half the row length;
|
||||||
item = item.strip('\n')
|
# this means double redundancy; to acquire all the sequences.
|
||||||
# assert at least 1 padding character at the end
|
# we could instead .roll it later to get the other half.
|
||||||
next_biggest = align(len(item) + 1)
|
# that would require less memory, but memory isn't really a concern.
|
||||||
# fill with padding (-1)
|
sequences = all_items.repeat(2)[1:-1].reshape(-1, 2).copy()
|
||||||
new_item = -np.ones(next_biggest, dtype=np.int32)
|
|
||||||
for i, c in enumerate(item):
|
|
||||||
new_item[i] = ord(c)
|
|
||||||
new_items.append(new_item)
|
|
||||||
|
|
||||||
# add an extra padding item to the head and tail
|
|
||||||
# for easier conversion from sequence back to all_items later on
|
|
||||||
pad = -np.ones(1, dtype=np.int32)
|
|
||||||
new_items.insert(0, pad)
|
|
||||||
new_items.append(pad)
|
|
||||||
|
|
||||||
all_items = np.concatenate(new_items)
|
|
||||||
|
|
||||||
if merges > 0:
|
|
||||||
# set up a 2d array to step through at half the row length,
|
|
||||||
# this means double redundancy, to acquire all the sequences.
|
|
||||||
# we don't have to .roll it later to get the other half,
|
|
||||||
# though that would require less memory.
|
|
||||||
sequences = all_items.repeat(2)[1:-1].reshape(-1, 2).copy()
|
|
||||||
|
|
||||||
for i in range(merges):
|
for i in range(merges):
|
||||||
# learn
|
|
||||||
most_common = (None, 1)
|
|
||||||
# TODO: eventually check for empty here too
|
|
||||||
invalid = np.any(sequences == -1, axis=1)
|
invalid = np.any(sequences == -1, axis=1)
|
||||||
valid_sequences = np.delete(sequences, np.where(invalid), axis=0)
|
valid_sequences = np.delete(sequences, np.where(invalid), axis=0)
|
||||||
unique, counts = uniq_rows(valid_sequences, return_counts=True)
|
unique, counts = uniq_rows(valid_sequences, return_counts=True)
|
||||||
count = counts.max()
|
count = counts.max()
|
||||||
|
|
||||||
|
most_common = (None, 1)
|
||||||
if count > most_common[1]:
|
if count > most_common[1]:
|
||||||
seq = unique[counts == count][0]
|
seq = unique[counts == count][0]
|
||||||
most_common = (seq, count)
|
most_common = (seq, count)
|
||||||
|
@ -121,45 +97,85 @@ class PatternBrain(Brain):
|
||||||
lament('no more valid sequences')
|
lament('no more valid sequences')
|
||||||
break
|
break
|
||||||
|
|
||||||
new_id = -1 - len(neg_lookup)
|
token_value = "".join(self.resolve_tokens(most_common[0]))
|
||||||
neg_lookup[new_id] = "".join([o < 0 and neg_lookup[o] or chr(o) for o in most_common[0]])
|
new_id = self.new_token(token_value)
|
||||||
|
|
||||||
if len("".join(neg_lookup.values())) > len(all_items):
|
if len("".join(self.tokens.values())) > len(all_items):
|
||||||
lament('preventing dict from growing larger than source')
|
# this might not ever occur
|
||||||
|
lament('preventing token dictionary from growing larger than source')
|
||||||
break
|
break
|
||||||
|
|
||||||
# replace our most common sequence in the sequences
|
# replace the most common two-token sequence
|
||||||
|
# with one token to represent both
|
||||||
found = np.all(sequences == most_common[0], axis=1)
|
found = np.all(sequences == most_common[0], axis=1)
|
||||||
before = np.roll(found, -1)
|
before = np.roll(found, -1)
|
||||||
after = np.roll(found, 1)
|
after = np.roll(found, 1)
|
||||||
# don't wrap around truth values
|
# don't wrap around truth values
|
||||||
before[-1] = False
|
before[-1] = False
|
||||||
after[0] = False
|
after[0] = False
|
||||||
# or remove padding
|
|
||||||
#before[0] = False
|
|
||||||
#after[-1] = False
|
|
||||||
# remove the "found" sequences
|
# remove the "found" sequences
|
||||||
|
# and update the previous/next,
|
||||||
|
# not unlike a doubly-linked list.
|
||||||
befores = sequences[before].T.copy()
|
befores = sequences[before].T.copy()
|
||||||
befores[1] = new_id
|
befores[1] = new_id
|
||||||
sequences[before] = befores.T
|
sequences[before] = befores.T
|
||||||
afters = sequences[after].T.copy()
|
afters = sequences[after].T.copy()
|
||||||
afters[0] = new_id
|
afters[0] = new_id
|
||||||
sequences[after] = afters.T
|
sequences[after] = afters.T
|
||||||
#sequences[found] = [empty, empty]
|
|
||||||
here = np.where(found)
|
here = np.where(found)
|
||||||
sequences = np.delete(sequences, here, axis=0)
|
sequences = np.delete(sequences, here, axis=0)
|
||||||
|
|
||||||
print("({:8}) new token: {:5} \"{}\"".format(len(here[0]), new_id, neg_lookup[new_id]))
|
print("new token id {:5} occurs {:8} times: \"{}\"".format(new_id, len(here[0]), self.tokens[new_id]))
|
||||||
|
|
||||||
|
# TODO: find unused tokens
|
||||||
|
|
||||||
|
# reconstruct all_items out of the sequences
|
||||||
|
all_items = sequences.reshape(-1)[::2][1:].copy()
|
||||||
|
return all_items
|
||||||
|
|
||||||
|
|
||||||
|
def learn_all(self, items, merges=0):
|
||||||
|
min_count = 2 # minimum number of occurences to stop creating tokens at
|
||||||
|
if merges < 0:
|
||||||
|
min_count = -merges
|
||||||
|
merges = 65536 # arbitrary sanity value
|
||||||
|
|
||||||
|
# we'll use numpy matrices so this isn't nearly as disgustingly slow
|
||||||
|
|
||||||
|
self.tokens = {-1: ''} # default with an empty padding token
|
||||||
|
|
||||||
|
# we need to assert that the number of sequences is a multiple of this
|
||||||
|
# otherwise we can't .reshape() it to be two-dimensional later on
|
||||||
|
alignment = 2
|
||||||
|
align = lambda x: (x + alignment // 2) // alignment * alignment
|
||||||
|
|
||||||
|
new_items = []
|
||||||
|
for item in items:
|
||||||
|
item = item.strip('\n')
|
||||||
|
# assert at least 1 padding character at the end
|
||||||
|
next_biggest = align(len(item) + 1)
|
||||||
|
# initialize with padding (-1)
|
||||||
|
new_item = -np.ones(next_biggest, dtype=np.int32)
|
||||||
|
for i, c in enumerate(item):
|
||||||
|
new_item[i] = ord(c)
|
||||||
|
new_items.append(new_item)
|
||||||
|
|
||||||
|
# add an extra padding item to the head and tail
|
||||||
|
# to make it easier to convert from sequences back to items later on
|
||||||
|
pad = -np.ones(1, dtype=np.int32)
|
||||||
|
new_items.insert(0, pad)
|
||||||
|
new_items.append(pad)
|
||||||
|
|
||||||
|
all_items = np.concatenate(new_items)
|
||||||
|
|
||||||
if merges > 0:
|
if merges > 0:
|
||||||
# reconstruct all_items out of the sequences
|
all_items = self.merge_all(all_items, merges, min_count)
|
||||||
all_items = sequences.reshape(-1)[::2][1:].copy()
|
|
||||||
|
|
||||||
|
# begin the actual learning
|
||||||
self.padding = '~'
|
self.padding = '~'
|
||||||
self.reset()
|
self.reset()
|
||||||
np_item = []
|
np_item = []
|
||||||
for i in all_items:
|
for i in all_items:
|
||||||
#for np_item in np.split(all_items, np.where(all_items == -1)):
|
|
||||||
if i == -1:
|
if i == -1:
|
||||||
if len(np_item) == 0:
|
if len(np_item) == 0:
|
||||||
continue
|
continue
|
||||||
|
@ -167,32 +183,40 @@ class PatternBrain(Brain):
|
||||||
for i in np_item:
|
for i in np_item:
|
||||||
if i < 0:
|
if i < 0:
|
||||||
assert(i != -1)
|
assert(i != -1)
|
||||||
item += self.helper(neg_lookup[i])
|
item += self.helper(self.tokens[i])
|
||||||
else:
|
else:
|
||||||
item += self.helper(chr(i))
|
item += self.helper(chr(i))
|
||||||
#die(np_item, item)
|
#die(np_item, item)
|
||||||
self.learn(item)
|
self.learn(item)
|
||||||
np_item = []
|
np_item = []
|
||||||
elif i != empty:
|
else:
|
||||||
np_item.append(i)
|
np_item.append(i)
|
||||||
self.update()
|
self.update()
|
||||||
|
|
||||||
|
|
||||||
def run(pname, args, env):
|
def run(pname, args, env):
|
||||||
if not 1 <= len(args) <= 2:
|
if not 1 <= len(args) <= 2:
|
||||||
lament("usage: {} {{input file}} [state_fn file]".format(sys.argv[0]))
|
lament("usage: {} {{input file}} [savestate file]".format(pname))
|
||||||
sys.exit(1)
|
return 1
|
||||||
|
|
||||||
args = dict(enumerate(args)) # for .get()
|
args = dict(enumerate(args)) # just for the .get() method
|
||||||
|
|
||||||
fn = args[0]
|
fn = args[0]
|
||||||
state_fn = args.get(1, None)
|
state_fn = args.get(1, None)
|
||||||
|
|
||||||
|
# the number of lines to output.
|
||||||
count = int(env.get('COUNT', '8'))
|
count = int(env.get('COUNT', '8'))
|
||||||
order = int(env.get('ORDER', '3'))
|
# learn and sample using this number of sequential tokens.
|
||||||
temperature = float(env.get('TEMPERATURE', '0'))
|
order = int(env.get('ORDER', '2'))
|
||||||
maxn = int(env.get('MAXN', '1000'))
|
# how experimental to be with sampling.
|
||||||
|
# probably doesn't work properly.
|
||||||
|
temperature = float(env.get('TEMPERATURE', '0.5'))
|
||||||
|
# the max character length of output. (not guaranteed)
|
||||||
|
maxn = int(env.get('MAXN', '240'))
|
||||||
|
# attempts to maximize scoring
|
||||||
attempts = int(env.get('ATTEMPTS', '-1'))
|
attempts = int(env.get('ATTEMPTS', '-1'))
|
||||||
|
# if positive, maximum number of tokens to merge.
|
||||||
|
# if negative, minimum number of occurences to stop at.
|
||||||
merges = int(env.get('MERGES', '0'))
|
merges = int(env.get('MERGES', '0'))
|
||||||
|
|
||||||
if attempts <= 0:
|
if attempts <= 0:
|
||||||
|
@ -201,11 +225,12 @@ def run(pname, args, env):
|
||||||
brain = PatternBrain(order=order, temperature=temperature)
|
brain = PatternBrain(order=order, temperature=temperature)
|
||||||
tool = ATTTT(brain)
|
tool = ATTTT(brain)
|
||||||
|
|
||||||
lament('# loading')
|
|
||||||
if state_fn:
|
if state_fn:
|
||||||
|
lament('# loading')
|
||||||
try:
|
try:
|
||||||
brain.load(state_fn, raw=False)
|
brain.load(state_fn, raw=False)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
|
lament('# no file to load. skipping')
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if brain and brain.new:
|
if brain and brain.new:
|
||||||
|
@ -214,6 +239,7 @@ def run(pname, args, env):
|
||||||
brain.learn_all(lines, merges)
|
brain.learn_all(lines, merges)
|
||||||
|
|
||||||
if brain and brain.new and state_fn:
|
if brain and brain.new and state_fn:
|
||||||
|
lament('# saving')
|
||||||
brain.save(state_fn, raw=False)
|
brain.save(state_fn, raw=False)
|
||||||
|
|
||||||
lament('# replying')
|
lament('# replying')
|
||||||
|
@ -222,6 +248,8 @@ def run(pname, args, env):
|
||||||
#print('{:6.1f}\t{}'.format(reply[1], reply[0]))
|
#print('{:6.1f}\t{}'.format(reply[1], reply[0]))
|
||||||
print(tool.reply(maxn=maxn, attempts=attempts))
|
print(tool.reply(maxn=maxn, attempts=attempts))
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import sys
|
import sys
|
||||||
|
|
11
basic.py
11
basic.py
|
@ -24,12 +24,10 @@ def normalize_sorted(counter):
|
||||||
# http://nbviewer.jupyter.org/gist/yoavg/d76121dfde2618422139
|
# http://nbviewer.jupyter.org/gist/yoavg/d76121dfde2618422139
|
||||||
class Brain:
|
class Brain:
|
||||||
|
|
||||||
# TODO: don't default padding here, but make sure it's set before running
|
def __init__(self, order=1, temperature=0.5):
|
||||||
# the reason is it's the only place that's specific to a string anymore
|
|
||||||
def __init__(self, order=1, temperature=0.5, padding="~"):
|
|
||||||
self.order = order
|
self.order = order
|
||||||
self.padding = padding
|
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
|
self.padding = None
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
@ -77,6 +75,8 @@ class Brain:
|
||||||
|
|
||||||
|
|
||||||
def learn(self, item):
|
def learn(self, item):
|
||||||
|
assert(self.padding)
|
||||||
|
|
||||||
if self.type is None and item is not None:
|
if self.type is None and item is not None:
|
||||||
self.type = type(item)
|
self.type = type(item)
|
||||||
if type(item) is not self.type:
|
if type(item) is not self.type:
|
||||||
|
@ -123,11 +123,14 @@ class Brain:
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
# for overriding in subclasses
|
||||||
|
# in case the input tokens aren't strings (e.g. tuples)
|
||||||
def helper(self, v):
|
def helper(self, v):
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
def reply(self, item=None, maxn=1000):
|
def reply(self, item=None, maxn=1000):
|
||||||
|
assert(self.padding)
|
||||||
self.update()
|
self.update()
|
||||||
|
|
||||||
history = self.helper(self.padding) * self.order
|
history = self.helper(self.padding) * self.order
|
||||||
|
|
1
misc.py
1
misc.py
|
@ -3,6 +3,7 @@ lament = lambda *args, **kwargs: print(*args, file=sys.stderr, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def die(*args, **kwargs):
|
def die(*args, **kwargs):
|
||||||
|
# just for ad-hoc debugging really
|
||||||
lament(*args, **kwargs)
|
lament(*args, **kwargs)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue