This commit is contained in:
Connor Olding 2016-05-25 11:30:07 -07:00
parent eeb5d2941e
commit e61a32c615
2 changed files with 56 additions and 35 deletions

View File

@ -7,6 +7,10 @@ from misc import *
from basic import Brain
def align(x, alignment):
return (x + alignment // 2) // alignment * alignment
def uniq_rows(a, return_index=False, return_inverse=False, return_counts=False):
# via http://stackoverflow.com/a/16973510
# black magic wrapper around np.unique
@ -66,7 +70,10 @@ class PatternBrain(Brain):
def resolve_tokens(self, tokens):
# positive values are just unicode characters
return [o < 0 and self.tokens[o] or chr(o) for o in tokens]
if isinstance(tokens, int) or isinstance(tokens, np.int32):
return tokens < 0 and self.tokens[tokens] or chr(tokens)
else:
return [o < 0 and self.tokens[o] or chr(o) for o in tokens]
def new_token(self, value):
@ -75,6 +82,46 @@ class PatternBrain(Brain):
return new_id
@staticmethod
def prepare_items(items, pad=True):
new_items = []
for item in items:
item = item.strip('\n')
# assert that the number of sequences is a multiple of 2
# otherwise we can't .reshape() it to be two-dimensional later on
next_biggest = align(len(item) + 1, 2)
# initialize with padding (-1)
new_item = -np.ones(next_biggest, dtype=np.int32)
for i, c in enumerate(item):
new_item[i] = ord(c)
new_items.append(new_item)
# add an extra padding item to the head and tail
# to make it easier to convert from sequences back to items later on
if pad:
pad = -np.ones(1, dtype=np.int32)
new_items.insert(0, pad)
new_items.append(pad)
return np.concatenate(new_items)
def stat_tokens(self, all_items, skip_normal=False):
unique, counts = np.unique(all_items, return_counts=True)
count_order = np.argsort(counts)[::-1]
counts_descending = counts[count_order]
unique_descending = unique[count_order]
for i, token_id in enumerate(unique_descending):
if token_id == -1:
continue
if skip_normal and token_id >= 0:
continue
token = self.resolve_tokens(token_id)
lament("token id {:5} occurs {:8} times: \"{}\"".format(
token_id, counts_descending[i], token))
lament("total tokens: {:5}".format(i + 1))
def merge_all(self, all_items, merges, min_count=2):
# set up a 2d array to step through at half the row length;
# this means double redundancy; to acquire all the sequences.
@ -125,16 +172,17 @@ class PatternBrain(Brain):
here = np.where(found)
sequences = np.delete(sequences, here, axis=0)
print("new token id {:5} occurs {:8} times: \"{}\"".format(new_id, len(here[0]), self.tokens[new_id]))
lament("new token id {:5} occurs {:8} times: \"{}\"".format(
new_id, len(here[0]), self.tokens[new_id]))
# TODO: find unused tokens
# TODO: find unused tokens?
# reconstruct all_items out of the sequences
all_items = sequences.reshape(-1)[::2][1:].copy()
return all_items
def learn_all(self, items, merges=0):
def learn_all(self, items, merges=0, stat=True):
min_count = 2 # minimum number of occurences to stop creating tokens at
if merges < 0:
min_count = -merges
@ -144,29 +192,7 @@ class PatternBrain(Brain):
self.tokens = {-1: ''} # default with an empty padding token
# we need to assert that the number of sequences is a multiple of this
# otherwise we can't .reshape() it to be two-dimensional later on
alignment = 2
align = lambda x: (x + alignment // 2) // alignment * alignment
new_items = []
for item in items:
item = item.strip('\n')
# assert at least 1 padding character at the end
next_biggest = align(len(item) + 1)
# initialize with padding (-1)
new_item = -np.ones(next_biggest, dtype=np.int32)
for i, c in enumerate(item):
new_item[i] = ord(c)
new_items.append(new_item)
# add an extra padding item to the head and tail
# to make it easier to convert from sequences back to items later on
pad = -np.ones(1, dtype=np.int32)
new_items.insert(0, pad)
new_items.append(pad)
all_items = np.concatenate(new_items)
all_items = self.prepare_items(items)
if merges > 0:
all_items = self.merge_all(all_items, merges, min_count)
@ -193,6 +219,9 @@ class PatternBrain(Brain):
np_item.append(i)
self.update()
if merges != 0 and stat:
self.stat_tokens(all_items)
def run(pname, args, env):
if not 1 <= len(args) <= 2:

View File

@ -8,12 +8,4 @@ def die(*args, **kwargs):
sys.exit(1)
def easytruncnorm(lower=0, upper=1, loc=0.5, scale=0.25):
import scipy.stats as stats
a = (lower - loc) / scale
b = (upper - loc) / scale
return stats.truncnorm(a=a, b=b, loc=loc, scale=scale)
# only make some things visible to "from misc import *"
__all__ = [o for o in locals() if type(o) != 'module' and not o.startswith('_')]