.
This commit is contained in:
parent
eeb5d2941e
commit
e61a32c615
2 changed files with 56 additions and 35 deletions
81
atttt.py
81
atttt.py
|
@ -7,6 +7,10 @@ from misc import *
|
|||
from basic import Brain
|
||||
|
||||
|
||||
def align(x, alignment):
|
||||
return (x + alignment // 2) // alignment * alignment
|
||||
|
||||
|
||||
def uniq_rows(a, return_index=False, return_inverse=False, return_counts=False):
|
||||
# via http://stackoverflow.com/a/16973510
|
||||
# black magic wrapper around np.unique
|
||||
|
@ -66,6 +70,9 @@ class PatternBrain(Brain):
|
|||
|
||||
def resolve_tokens(self, tokens):
|
||||
# positive values are just unicode characters
|
||||
if isinstance(tokens, int) or isinstance(tokens, np.int32):
|
||||
return tokens < 0 and self.tokens[tokens] or chr(tokens)
|
||||
else:
|
||||
return [o < 0 and self.tokens[o] or chr(o) for o in tokens]
|
||||
|
||||
|
||||
|
@ -75,6 +82,46 @@ class PatternBrain(Brain):
|
|||
return new_id
|
||||
|
||||
|
||||
@staticmethod
|
||||
def prepare_items(items, pad=True):
|
||||
new_items = []
|
||||
for item in items:
|
||||
item = item.strip('\n')
|
||||
# assert that the number of sequences is a multiple of 2
|
||||
# otherwise we can't .reshape() it to be two-dimensional later on
|
||||
next_biggest = align(len(item) + 1, 2)
|
||||
# initialize with padding (-1)
|
||||
new_item = -np.ones(next_biggest, dtype=np.int32)
|
||||
for i, c in enumerate(item):
|
||||
new_item[i] = ord(c)
|
||||
new_items.append(new_item)
|
||||
|
||||
# add an extra padding item to the head and tail
|
||||
# to make it easier to convert from sequences back to items later on
|
||||
if pad:
|
||||
pad = -np.ones(1, dtype=np.int32)
|
||||
new_items.insert(0, pad)
|
||||
new_items.append(pad)
|
||||
|
||||
return np.concatenate(new_items)
|
||||
|
||||
|
||||
def stat_tokens(self, all_items, skip_normal=False):
|
||||
unique, counts = np.unique(all_items, return_counts=True)
|
||||
count_order = np.argsort(counts)[::-1]
|
||||
counts_descending = counts[count_order]
|
||||
unique_descending = unique[count_order]
|
||||
for i, token_id in enumerate(unique_descending):
|
||||
if token_id == -1:
|
||||
continue
|
||||
if skip_normal and token_id >= 0:
|
||||
continue
|
||||
token = self.resolve_tokens(token_id)
|
||||
lament("token id {:5} occurs {:8} times: \"{}\"".format(
|
||||
token_id, counts_descending[i], token))
|
||||
lament("total tokens: {:5}".format(i + 1))
|
||||
|
||||
|
||||
def merge_all(self, all_items, merges, min_count=2):
|
||||
# set up a 2d array to step through at half the row length;
|
||||
# this means double redundancy; to acquire all the sequences.
|
||||
|
@ -125,16 +172,17 @@ class PatternBrain(Brain):
|
|||
here = np.where(found)
|
||||
sequences = np.delete(sequences, here, axis=0)
|
||||
|
||||
print("new token id {:5} occurs {:8} times: \"{}\"".format(new_id, len(here[0]), self.tokens[new_id]))
|
||||
lament("new token id {:5} occurs {:8} times: \"{}\"".format(
|
||||
new_id, len(here[0]), self.tokens[new_id]))
|
||||
|
||||
# TODO: find unused tokens
|
||||
# TODO: find unused tokens?
|
||||
|
||||
# reconstruct all_items out of the sequences
|
||||
all_items = sequences.reshape(-1)[::2][1:].copy()
|
||||
return all_items
|
||||
|
||||
|
||||
def learn_all(self, items, merges=0):
|
||||
def learn_all(self, items, merges=0, stat=True):
|
||||
min_count = 2 # minimum number of occurences to stop creating tokens at
|
||||
if merges < 0:
|
||||
min_count = -merges
|
||||
|
@ -144,29 +192,7 @@ class PatternBrain(Brain):
|
|||
|
||||
self.tokens = {-1: ''} # default with an empty padding token
|
||||
|
||||
# we need to assert that the number of sequences is a multiple of this
|
||||
# otherwise we can't .reshape() it to be two-dimensional later on
|
||||
alignment = 2
|
||||
align = lambda x: (x + alignment // 2) // alignment * alignment
|
||||
|
||||
new_items = []
|
||||
for item in items:
|
||||
item = item.strip('\n')
|
||||
# assert at least 1 padding character at the end
|
||||
next_biggest = align(len(item) + 1)
|
||||
# initialize with padding (-1)
|
||||
new_item = -np.ones(next_biggest, dtype=np.int32)
|
||||
for i, c in enumerate(item):
|
||||
new_item[i] = ord(c)
|
||||
new_items.append(new_item)
|
||||
|
||||
# add an extra padding item to the head and tail
|
||||
# to make it easier to convert from sequences back to items later on
|
||||
pad = -np.ones(1, dtype=np.int32)
|
||||
new_items.insert(0, pad)
|
||||
new_items.append(pad)
|
||||
|
||||
all_items = np.concatenate(new_items)
|
||||
all_items = self.prepare_items(items)
|
||||
|
||||
if merges > 0:
|
||||
all_items = self.merge_all(all_items, merges, min_count)
|
||||
|
@ -193,6 +219,9 @@ class PatternBrain(Brain):
|
|||
np_item.append(i)
|
||||
self.update()
|
||||
|
||||
if merges != 0 and stat:
|
||||
self.stat_tokens(all_items)
|
||||
|
||||
|
||||
def run(pname, args, env):
|
||||
if not 1 <= len(args) <= 2:
|
||||
|
|
8
misc.py
8
misc.py
|
@ -8,12 +8,4 @@ def die(*args, **kwargs):
|
|||
sys.exit(1)
|
||||
|
||||
|
||||
def easytruncnorm(lower=0, upper=1, loc=0.5, scale=0.25):
|
||||
import scipy.stats as stats
|
||||
a = (lower - loc) / scale
|
||||
b = (upper - loc) / scale
|
||||
return stats.truncnorm(a=a, b=b, loc=loc, scale=scale)
|
||||
|
||||
|
||||
# only make some things visible to "from misc import *"
|
||||
__all__ = [o for o in locals() if type(o) != 'module' and not o.startswith('_')]
|
||||
|
|
Loading…
Reference in a new issue