rewrite rate-limiter to not create any tasks
This commit is contained in:
parent
ffbdf50674
commit
6508ee1210
2 changed files with 25 additions and 26 deletions
|
@ -284,7 +284,7 @@ async def read_all_ips(filepaths, blocking=False, callback=None):
|
||||||
|
|
||||||
async def main(db, filepaths, checks, ipinfo, opts: Options):
|
async def main(db, filepaths, checks, ipinfo, opts: Options):
|
||||||
# ipinfo can be None.
|
# ipinfo can be None.
|
||||||
from .util import LimitPerSecond
|
from .util import RateLimiter
|
||||||
from asyncio import Queue, QueueFull, create_task
|
from asyncio import Queue, QueueFull, create_task
|
||||||
from queue import SimpleQueue
|
from queue import SimpleQueue
|
||||||
|
|
||||||
|
@ -292,7 +292,7 @@ async def main(db, filepaths, checks, ipinfo, opts: Options):
|
||||||
locate_me = Queue()
|
locate_me = Queue()
|
||||||
try_me = Queue()
|
try_me = Queue()
|
||||||
pps = opts.packets_per_second
|
pps = opts.packets_per_second
|
||||||
context = LimitPerSecond(pps) if pps > 0 else None
|
context = RateLimiter(pps) if pps > 0 else None
|
||||||
|
|
||||||
def locate_later(ip):
|
def locate_later(ip):
|
||||||
try:
|
try:
|
||||||
|
@ -335,6 +335,3 @@ async def main(db, filepaths, checks, ipinfo, opts: Options):
|
||||||
|
|
||||||
await locate_me.put(None)
|
await locate_me.put(None)
|
||||||
await locating
|
await locating
|
||||||
|
|
||||||
if context is not None and hasattr(context, "finish"):
|
|
||||||
await context.finish()
|
|
||||||
|
|
|
@ -58,37 +58,39 @@ class AttrCheck:
|
||||||
raise AttributeError(name)
|
raise AttributeError(name)
|
||||||
|
|
||||||
|
|
||||||
async def _release_later(sem, time=1):
|
def _present():
|
||||||
from asyncio import sleep
|
from time import time
|
||||||
|
return time()
|
||||||
await sleep(time)
|
|
||||||
sem.release()
|
|
||||||
|
|
||||||
|
|
||||||
class LimitPerSecond:
|
class RateLimiter:
|
||||||
def __init__(self, limit):
|
def __init__(self, limit):
|
||||||
from asyncio import BoundedSemaphore
|
from asyncio import Lock
|
||||||
|
|
||||||
if type(limit) is not int:
|
if type(limit) is not int:
|
||||||
raise ValueError("limit must be int")
|
raise ValueError("limit must be int")
|
||||||
assert limit > 0, limit
|
assert limit > 0, limit
|
||||||
|
|
||||||
|
self.unit = 1.0 # TODO: allow window length to be configured.
|
||||||
self.limit = limit
|
self.limit = limit
|
||||||
self.tasks = []
|
self.times = []
|
||||||
self.sem = BoundedSemaphore(limit)
|
self.lock = Lock()
|
||||||
|
self.eps = self.unit * 0.01 # to wait a tiny bit longer than specified
|
||||||
|
|
||||||
|
def eta(self):
|
||||||
|
past = _present() - self.unit
|
||||||
|
self.times = [time for time in self.times if time > past]
|
||||||
|
ind = len(self.times) - self.limit
|
||||||
|
return 0.0 if ind < 0 else self.times[ind] - past
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
#if self.sem.locked:
|
from asyncio import sleep
|
||||||
# from sys import stderr
|
|
||||||
# print("THROTTLING", file=stderr)
|
async with self.lock:
|
||||||
await self.sem.acquire()
|
while (wait := self.eta()) > 0.0:
|
||||||
|
# this is done in a loop in case sleep ends early (it can).
|
||||||
|
await sleep(wait + self.eps)
|
||||||
|
self.times.append(_present())
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
from asyncio import create_task
|
pass
|
||||||
|
|
||||||
task = create_task(_release_later(self.sem))
|
|
||||||
self.tasks.append(task)
|
|
||||||
|
|
||||||
async def finish(self):
|
|
||||||
for task in self.tasks:
|
|
||||||
await task
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue