rewrite rate-limiter to not create any tasks
This commit is contained in:
parent
ffbdf50674
commit
6508ee1210
2 changed files with 25 additions and 26 deletions
|
@ -284,7 +284,7 @@ async def read_all_ips(filepaths, blocking=False, callback=None):
|
|||
|
||||
async def main(db, filepaths, checks, ipinfo, opts: Options):
|
||||
# ipinfo can be None.
|
||||
from .util import LimitPerSecond
|
||||
from .util import RateLimiter
|
||||
from asyncio import Queue, QueueFull, create_task
|
||||
from queue import SimpleQueue
|
||||
|
||||
|
@ -292,7 +292,7 @@ async def main(db, filepaths, checks, ipinfo, opts: Options):
|
|||
locate_me = Queue()
|
||||
try_me = Queue()
|
||||
pps = opts.packets_per_second
|
||||
context = LimitPerSecond(pps) if pps > 0 else None
|
||||
context = RateLimiter(pps) if pps > 0 else None
|
||||
|
||||
def locate_later(ip):
|
||||
try:
|
||||
|
@ -335,6 +335,3 @@ async def main(db, filepaths, checks, ipinfo, opts: Options):
|
|||
|
||||
await locate_me.put(None)
|
||||
await locating
|
||||
|
||||
if context is not None and hasattr(context, "finish"):
|
||||
await context.finish()
|
||||
|
|
|
@ -58,37 +58,39 @@ class AttrCheck:
|
|||
raise AttributeError(name)
|
||||
|
||||
|
||||
async def _release_later(sem, time=1):
|
||||
from asyncio import sleep
|
||||
|
||||
await sleep(time)
|
||||
sem.release()
|
||||
def _present():
|
||||
from time import time
|
||||
return time()
|
||||
|
||||
|
||||
class LimitPerSecond:
|
||||
class RateLimiter:
|
||||
def __init__(self, limit):
|
||||
from asyncio import BoundedSemaphore
|
||||
from asyncio import Lock
|
||||
|
||||
if type(limit) is not int:
|
||||
raise ValueError("limit must be int")
|
||||
assert limit > 0, limit
|
||||
|
||||
self.unit = 1.0 # TODO: allow window length to be configured.
|
||||
self.limit = limit
|
||||
self.tasks = []
|
||||
self.sem = BoundedSemaphore(limit)
|
||||
self.times = []
|
||||
self.lock = Lock()
|
||||
self.eps = self.unit * 0.01 # to wait a tiny bit longer than specified
|
||||
|
||||
def eta(self):
|
||||
past = _present() - self.unit
|
||||
self.times = [time for time in self.times if time > past]
|
||||
ind = len(self.times) - self.limit
|
||||
return 0.0 if ind < 0 else self.times[ind] - past
|
||||
|
||||
async def __aenter__(self):
|
||||
#if self.sem.locked:
|
||||
# from sys import stderr
|
||||
# print("THROTTLING", file=stderr)
|
||||
await self.sem.acquire()
|
||||
from asyncio import sleep
|
||||
|
||||
async with self.lock:
|
||||
while (wait := self.eta()) > 0.0:
|
||||
# this is done in a loop in case sleep ends early (it can).
|
||||
await sleep(wait + self.eps)
|
||||
self.times.append(_present())
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
from asyncio import create_task
|
||||
|
||||
task = create_task(_release_later(self.sem))
|
||||
self.tasks.append(task)
|
||||
|
||||
async def finish(self):
|
||||
for task in self.tasks:
|
||||
await task
|
||||
pass
|
||||
|
|
Loading…
Reference in a new issue