optim/onn/math.py

16 lines
540 B
Python
Raw Normal View History

2018-01-21 14:04:25 -08:00
import numpy as np
2018-01-22 11:40:36 -08:00
2018-01-21 14:04:25 -08:00
def rolling(a, window):
# http://stackoverflow.com/a/4924433
shape = (a.size - window + 1, window)
strides = (a.itemsize, a.itemsize)
return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
2018-01-22 11:40:36 -08:00
2018-01-21 14:04:25 -08:00
def rolling_batch(a, window):
# same as rolling, but acts on each batch (axis 0).
shape = (a.shape[0], a.shape[-1] - window + 1, window)
strides = (np.prod(a.shape[1:]) * a.itemsize, a.itemsize, a.itemsize)
return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)