diff --git a/respodns/dns.py b/respodns/dns.py index 03f5eff..7977799 100644 --- a/respodns/dns.py +++ b/respodns/dns.py @@ -33,7 +33,7 @@ def detect_gfw(r, ip, check): return False -async def getaddrs(server, domain, opts): +async def getaddrs(server, domain, opts, context=None): from .ip_util import ipkey from dns.asyncresolver import Resolver from dns.exception import Timeout @@ -48,7 +48,11 @@ async def getaddrs(server, domain, opts): res.lifetime = 9 res.nameservers = [server] try: - ans = await res.resolve(domain, "A", search=False) + if context is not None: + async with context: + ans = await res.resolve(domain, "A", search=False) + else: + ans = await res.resolve(domain, "A", search=False) except NXDOMAIN: return ["NXDOMAIN"] except NoAnswer: @@ -116,7 +120,8 @@ def process_result(res, ip, check, opts: Options): ) -async def try_ip(db, server_ip, checks, opts: Options, callback=None): +async def try_ip(db, server_ip, checks, context, opts: Options, callback=None): + # context can be None. from .util import make_pooler from asyncio import sleep, CancelledError @@ -151,10 +156,10 @@ async def try_ip(db, server_ip, checks, opts: Options, callback=None): # NOTE: could put right_now() stuff here! # TODO: add duration field given in milliseconds (integer) # by subtracting start and end datetimes. - res = await getaddrs(ip, check.domain, opts) + res = await getaddrs(ip, check.domain, opts, context) return res, ip, check - #lament("BEGIN", server_ip) + #lament("TESTING", server_ip) for i, check in enumerate(checks): first = i == 0 if not first: @@ -174,7 +179,7 @@ async def try_ip(db, server_ip, checks, opts: Options, callback=None): db.push_entry(entry) db.commit() - #lament("FINISH", server_ip) + #lament("TESTED", server_ip) if not success: first_failure = None @@ -189,7 +194,8 @@ async def try_ip(db, server_ip, checks, opts: Options, callback=None): return None -async def try_all_ips(db, try_me, checks, opts: Options, callback=None): +async def try_all_ips(db, try_me, checks, context, opts: Options, callback=None): + # context can be None. from asyncio import create_task, sleep, BoundedSemaphore seen, total = 0, None @@ -202,7 +208,7 @@ async def try_all_ips(db, try_me, checks, opts: Options, callback=None): lament(f"#{seen}: {ip}" if total is None else f"#{seen}/{total}: {ip}") stderr.flush() - first_failure = await try_ip(db, ip, checks, opts, callback) + first_failure = await try_ip(db, ip, checks, context, opts, callback) if first_failure is None: print(ip) # all tests for this server passed; pass it along to stdout @@ -243,6 +249,7 @@ def sync_database(db, callback=None): from .ips import china, blocks # TODO: handle addresses that were removed from respodns.ips.china. + # i could probably just do ip.startswith("- ") and remove those. for ips, kw in ((china, "china"), (blocks, "block_target")): for ip in ips: kwargs = {kw: True} @@ -297,12 +304,15 @@ async def read_all_ips(filepaths, blocking=False, callback=None): async def main(db, filepaths, checks, ipinfo, opts: Options): # ipinfo can be None. + from .util import LimitPerSecond from asyncio import Queue, QueueFull, create_task from queue import SimpleQueue deferred = SimpleQueue() locate_me = Queue() try_me = Queue() + pps = opts.packets_per_second + context = LimitPerSecond(pps) if pps > 0 else None def locate_later(ip): try: @@ -320,7 +330,7 @@ async def main(db, filepaths, checks, ipinfo, opts: Options): reading = create_task(read_all_ips(filepaths, opts.blocking_file_io, callback=try_soon)) - trying = create_task(try_all_ips(db, try_me, checks, opts, + trying = create_task(try_all_ips(db, try_me, checks, context, opts, callback=locate_later)) locating = create_task(locate_ips(db, locate_me, ipinfo)) @@ -355,3 +365,6 @@ async def main(db, filepaths, checks, ipinfo, opts: Options): #lament("AWAIT locating") await locating #lament("AWAITED locating") + + if context is not None and hasattr(context, "finish"): + await context.finish() diff --git a/respodns/structs.py b/respodns/structs.py index c388ced..f861b4e 100644 --- a/respodns/structs.py +++ b/respodns/structs.py @@ -7,8 +7,9 @@ class Options: # TODO: move this out of Options, since it's really not. execution: object = None - ip_simul: int = 15 # how many IPs to connect to at once - domain_simul: int = 3 # how many domains per IP to request at once + ip_simul: int = 30 # how many IPs to connect to at once + domain_simul: int = 3 # how many domains per IP to request at once + packets_per_second: int = 50 # rough limit on all outgoing DNS packets ip_wait: float = 0.05 domain_wait: float = 0.25 diff --git a/respodns/util.py b/respodns/util.py index afff64e..486577f 100644 --- a/respodns/util.py +++ b/respodns/util.py @@ -88,3 +88,39 @@ class AttrCheck: super().__setattr__(name, value) else: raise AttributeError(name) + + +async def _release_later(sem, time=1): + from asyncio import sleep + + await sleep(time) + sem.release() + + +class LimitPerSecond: + def __init__(self, limit): + from asyncio import BoundedSemaphore + + if type(limit) is not int: + raise ValueError("limit must be int") + assert limit > 0, limit + + self.limit = limit + self.tasks = [] + self.sem = BoundedSemaphore(limit) + + async def __aenter__(self): + #if self.sem.locked: + # from sys import stderr + # print("THROTTLING", file=stderr) + await self.sem.acquire() + + async def __aexit__(self, exc_type, exc_value, traceback): + from asyncio import create_task + + task = create_task(_release_later(self.sem)) + self.tasks.append(task) + + async def finish(self): + for task in self.tasks: + await task