From 6508ee12106925654b8050fdd01020c770c75c7b Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Sat, 14 Aug 2021 00:40:38 -0700 Subject: [PATCH] rewrite rate-limiter to not create any tasks --- respodns/dns.py | 7 ++----- respodns/util.py | 44 +++++++++++++++++++++++--------------------- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/respodns/dns.py b/respodns/dns.py index 39a4b82..4debcd5 100644 --- a/respodns/dns.py +++ b/respodns/dns.py @@ -284,7 +284,7 @@ async def read_all_ips(filepaths, blocking=False, callback=None): async def main(db, filepaths, checks, ipinfo, opts: Options): # ipinfo can be None. - from .util import LimitPerSecond + from .util import RateLimiter from asyncio import Queue, QueueFull, create_task from queue import SimpleQueue @@ -292,7 +292,7 @@ async def main(db, filepaths, checks, ipinfo, opts: Options): locate_me = Queue() try_me = Queue() pps = opts.packets_per_second - context = LimitPerSecond(pps) if pps > 0 else None + context = RateLimiter(pps) if pps > 0 else None def locate_later(ip): try: @@ -335,6 +335,3 @@ async def main(db, filepaths, checks, ipinfo, opts: Options): await locate_me.put(None) await locating - - if context is not None and hasattr(context, "finish"): - await context.finish() diff --git a/respodns/util.py b/respodns/util.py index 4137ec8..cece8d1 100644 --- a/respodns/util.py +++ b/respodns/util.py @@ -58,37 +58,39 @@ class AttrCheck: raise AttributeError(name) -async def _release_later(sem, time=1): - from asyncio import sleep - - await sleep(time) - sem.release() +def _present(): + from time import time + return time() -class LimitPerSecond: +class RateLimiter: def __init__(self, limit): - from asyncio import BoundedSemaphore + from asyncio import Lock if type(limit) is not int: raise ValueError("limit must be int") assert limit > 0, limit + self.unit = 1.0 # TODO: allow window length to be configured. self.limit = limit - self.tasks = [] - self.sem = BoundedSemaphore(limit) + self.times = [] + self.lock = Lock() + self.eps = self.unit * 0.01 # to wait a tiny bit longer than specified + + def eta(self): + past = _present() - self.unit + self.times = [time for time in self.times if time > past] + ind = len(self.times) - self.limit + return 0.0 if ind < 0 else self.times[ind] - past async def __aenter__(self): - #if self.sem.locked: - # from sys import stderr - # print("THROTTLING", file=stderr) - await self.sem.acquire() + from asyncio import sleep + + async with self.lock: + while (wait := self.eta()) > 0.0: + # this is done in a loop in case sleep ends early (it can). + await sleep(wait + self.eps) + self.times.append(_present()) async def __aexit__(self, exc_type, exc_value, traceback): - from asyncio import create_task - - task = create_task(_release_later(self.sem)) - self.tasks.append(task) - - async def finish(self): - for task in self.tasks: - await task + pass