.
This commit is contained in:
parent
eeb5d2941e
commit
e61a32c615
2 changed files with 56 additions and 35 deletions
83
atttt.py
83
atttt.py
|
@ -7,6 +7,10 @@ from misc import *
|
||||||
from basic import Brain
|
from basic import Brain
|
||||||
|
|
||||||
|
|
||||||
|
def align(x, alignment):
|
||||||
|
return (x + alignment // 2) // alignment * alignment
|
||||||
|
|
||||||
|
|
||||||
def uniq_rows(a, return_index=False, return_inverse=False, return_counts=False):
|
def uniq_rows(a, return_index=False, return_inverse=False, return_counts=False):
|
||||||
# via http://stackoverflow.com/a/16973510
|
# via http://stackoverflow.com/a/16973510
|
||||||
# black magic wrapper around np.unique
|
# black magic wrapper around np.unique
|
||||||
|
@ -66,7 +70,10 @@ class PatternBrain(Brain):
|
||||||
|
|
||||||
def resolve_tokens(self, tokens):
|
def resolve_tokens(self, tokens):
|
||||||
# positive values are just unicode characters
|
# positive values are just unicode characters
|
||||||
return [o < 0 and self.tokens[o] or chr(o) for o in tokens]
|
if isinstance(tokens, int) or isinstance(tokens, np.int32):
|
||||||
|
return tokens < 0 and self.tokens[tokens] or chr(tokens)
|
||||||
|
else:
|
||||||
|
return [o < 0 and self.tokens[o] or chr(o) for o in tokens]
|
||||||
|
|
||||||
|
|
||||||
def new_token(self, value):
|
def new_token(self, value):
|
||||||
|
@ -75,6 +82,46 @@ class PatternBrain(Brain):
|
||||||
return new_id
|
return new_id
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prepare_items(items, pad=True):
|
||||||
|
new_items = []
|
||||||
|
for item in items:
|
||||||
|
item = item.strip('\n')
|
||||||
|
# assert that the number of sequences is a multiple of 2
|
||||||
|
# otherwise we can't .reshape() it to be two-dimensional later on
|
||||||
|
next_biggest = align(len(item) + 1, 2)
|
||||||
|
# initialize with padding (-1)
|
||||||
|
new_item = -np.ones(next_biggest, dtype=np.int32)
|
||||||
|
for i, c in enumerate(item):
|
||||||
|
new_item[i] = ord(c)
|
||||||
|
new_items.append(new_item)
|
||||||
|
|
||||||
|
# add an extra padding item to the head and tail
|
||||||
|
# to make it easier to convert from sequences back to items later on
|
||||||
|
if pad:
|
||||||
|
pad = -np.ones(1, dtype=np.int32)
|
||||||
|
new_items.insert(0, pad)
|
||||||
|
new_items.append(pad)
|
||||||
|
|
||||||
|
return np.concatenate(new_items)
|
||||||
|
|
||||||
|
|
||||||
|
def stat_tokens(self, all_items, skip_normal=False):
|
||||||
|
unique, counts = np.unique(all_items, return_counts=True)
|
||||||
|
count_order = np.argsort(counts)[::-1]
|
||||||
|
counts_descending = counts[count_order]
|
||||||
|
unique_descending = unique[count_order]
|
||||||
|
for i, token_id in enumerate(unique_descending):
|
||||||
|
if token_id == -1:
|
||||||
|
continue
|
||||||
|
if skip_normal and token_id >= 0:
|
||||||
|
continue
|
||||||
|
token = self.resolve_tokens(token_id)
|
||||||
|
lament("token id {:5} occurs {:8} times: \"{}\"".format(
|
||||||
|
token_id, counts_descending[i], token))
|
||||||
|
lament("total tokens: {:5}".format(i + 1))
|
||||||
|
|
||||||
|
|
||||||
def merge_all(self, all_items, merges, min_count=2):
|
def merge_all(self, all_items, merges, min_count=2):
|
||||||
# set up a 2d array to step through at half the row length;
|
# set up a 2d array to step through at half the row length;
|
||||||
# this means double redundancy; to acquire all the sequences.
|
# this means double redundancy; to acquire all the sequences.
|
||||||
|
@ -125,16 +172,17 @@ class PatternBrain(Brain):
|
||||||
here = np.where(found)
|
here = np.where(found)
|
||||||
sequences = np.delete(sequences, here, axis=0)
|
sequences = np.delete(sequences, here, axis=0)
|
||||||
|
|
||||||
print("new token id {:5} occurs {:8} times: \"{}\"".format(new_id, len(here[0]), self.tokens[new_id]))
|
lament("new token id {:5} occurs {:8} times: \"{}\"".format(
|
||||||
|
new_id, len(here[0]), self.tokens[new_id]))
|
||||||
|
|
||||||
# TODO: find unused tokens
|
# TODO: find unused tokens?
|
||||||
|
|
||||||
# reconstruct all_items out of the sequences
|
# reconstruct all_items out of the sequences
|
||||||
all_items = sequences.reshape(-1)[::2][1:].copy()
|
all_items = sequences.reshape(-1)[::2][1:].copy()
|
||||||
return all_items
|
return all_items
|
||||||
|
|
||||||
|
|
||||||
def learn_all(self, items, merges=0):
|
def learn_all(self, items, merges=0, stat=True):
|
||||||
min_count = 2 # minimum number of occurences to stop creating tokens at
|
min_count = 2 # minimum number of occurences to stop creating tokens at
|
||||||
if merges < 0:
|
if merges < 0:
|
||||||
min_count = -merges
|
min_count = -merges
|
||||||
|
@ -144,29 +192,7 @@ class PatternBrain(Brain):
|
||||||
|
|
||||||
self.tokens = {-1: ''} # default with an empty padding token
|
self.tokens = {-1: ''} # default with an empty padding token
|
||||||
|
|
||||||
# we need to assert that the number of sequences is a multiple of this
|
all_items = self.prepare_items(items)
|
||||||
# otherwise we can't .reshape() it to be two-dimensional later on
|
|
||||||
alignment = 2
|
|
||||||
align = lambda x: (x + alignment // 2) // alignment * alignment
|
|
||||||
|
|
||||||
new_items = []
|
|
||||||
for item in items:
|
|
||||||
item = item.strip('\n')
|
|
||||||
# assert at least 1 padding character at the end
|
|
||||||
next_biggest = align(len(item) + 1)
|
|
||||||
# initialize with padding (-1)
|
|
||||||
new_item = -np.ones(next_biggest, dtype=np.int32)
|
|
||||||
for i, c in enumerate(item):
|
|
||||||
new_item[i] = ord(c)
|
|
||||||
new_items.append(new_item)
|
|
||||||
|
|
||||||
# add an extra padding item to the head and tail
|
|
||||||
# to make it easier to convert from sequences back to items later on
|
|
||||||
pad = -np.ones(1, dtype=np.int32)
|
|
||||||
new_items.insert(0, pad)
|
|
||||||
new_items.append(pad)
|
|
||||||
|
|
||||||
all_items = np.concatenate(new_items)
|
|
||||||
|
|
||||||
if merges > 0:
|
if merges > 0:
|
||||||
all_items = self.merge_all(all_items, merges, min_count)
|
all_items = self.merge_all(all_items, merges, min_count)
|
||||||
|
@ -193,6 +219,9 @@ class PatternBrain(Brain):
|
||||||
np_item.append(i)
|
np_item.append(i)
|
||||||
self.update()
|
self.update()
|
||||||
|
|
||||||
|
if merges != 0 and stat:
|
||||||
|
self.stat_tokens(all_items)
|
||||||
|
|
||||||
|
|
||||||
def run(pname, args, env):
|
def run(pname, args, env):
|
||||||
if not 1 <= len(args) <= 2:
|
if not 1 <= len(args) <= 2:
|
||||||
|
|
8
misc.py
8
misc.py
|
@ -8,12 +8,4 @@ def die(*args, **kwargs):
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def easytruncnorm(lower=0, upper=1, loc=0.5, scale=0.25):
|
|
||||||
import scipy.stats as stats
|
|
||||||
a = (lower - loc) / scale
|
|
||||||
b = (upper - loc) / scale
|
|
||||||
return stats.truncnorm(a=a, b=b, loc=loc, scale=scale)
|
|
||||||
|
|
||||||
|
|
||||||
# only make some things visible to "from misc import *"
|
|
||||||
__all__ = [o for o in locals() if type(o) != 'module' and not o.startswith('_')]
|
__all__ = [o for o in locals() if type(o) != 'module' and not o.startswith('_')]
|
||||||
|
|
Loading…
Reference in a new issue