diff --git a/respodns/dns.py b/respodns/dns.py index 7977799..38db406 100644 --- a/respodns/dns.py +++ b/respodns/dns.py @@ -33,14 +33,14 @@ def detect_gfw(r, ip, check): return False -async def getaddrs(server, domain, opts, context=None): +async def getaddrs(server, domain, impatient=False, context=None): from .ip_util import ipkey from dns.asyncresolver import Resolver from dns.exception import Timeout from dns.resolver import NXDOMAIN, NoAnswer, NoNameservers res = Resolver(configure=False) - if opts.impatient: + if impatient: res.timeout = 5 res.lifetime = 2 else: @@ -120,59 +120,61 @@ def process_result(res, ip, check, opts: Options): ) +def find_failure(entries): + assert len(entries) > 0 + for entry in entries: + if not entry.success: + return entry + assert False, ("no failures found:", entries) + + async def try_ip(db, server_ip, checks, context, opts: Options, callback=None): # context can be None. - from .util import make_pooler - from asyncio import sleep, CancelledError + from asyncio import sleep, create_task, CancelledError, BoundedSemaphore + sem = BoundedSemaphore(opts.domain_simul) entries = [] - + tasks = [] success = True - def finisher(done, pending): + async def _process(check): nonlocal success - for task in done: - try: - res, ip, check = task.result() - except CancelledError: - success = False - break - entry = process_result(res, ip, check, opts) - if callback is not None: - for addr in entry.addrs: - callback(addr) - entries.append(entry) - if not entry.success: - if opts.early_stopping and success: # only cancel once - for pend in pending: - # FIXME: this can still, somehow, - # cancel the main function. - pend.cancel() - success = False + res = await getaddrs(server_ip, check.domain, opts.impatient, context) + entry = process_result(res, server_ip, check, opts) + if callback is not None: + for addr in entry.addrs: + callback(addr) + entries.append(entry) + if not entry.success: + if opts.early_stopping and success: # only cancel once + for task in tasks: + if not task.done() and not task.cancelled(): + task.cancel() + success = False - pooler = make_pooler(opts.domain_simul, finisher) + # limit to one connection until the first check completes. + await _process(checks[0]) - async def getaddrs_wrapper(ip, check): - # NOTE: could put right_now() stuff here! - # TODO: add duration field given in milliseconds (integer) - # by subtracting start and end datetimes. - res = await getaddrs(ip, check.domain, opts, context) - return res, ip, check + async def process(check): + try: + await _process(check) + finally: + sem.release() - #lament("TESTING", server_ip) - for i, check in enumerate(checks): - first = i == 0 - if not first: + for check in checks[1:]: + if len(tasks) > 0: await sleep(opts.domain_wait) - await pooler(getaddrs_wrapper(server_ip, check)) - if first: - # limit to one connection for the first check. - await pooler() + # acquire now instead of within the task so + # a ton of tasks aren't created all at once. + await sem.acquire() if not success: - if opts.early_stopping or first: - break - else: - await pooler() + break + #lament("ENTRY", server_ip, check) + task = create_task(process(check)) + tasks.append(task) + for task in tasks: + if not task.cancelled(): + await task if not opts.dry: for entry in entries: @@ -181,17 +183,7 @@ async def try_ip(db, server_ip, checks, context, opts: Options, callback=None): #lament("TESTED", server_ip) - if not success: - first_failure = None - assert len(entries) > 0 - for entry in entries: - if not entry.success: - first_failure = entry - break - else: - assert 0, ("no failures found:", entries) - return first_failure - return None + return None if success else find_failure(entries) async def try_all_ips(db, try_me, checks, context, opts: Options, callback=None): @@ -341,11 +333,12 @@ async def main(db, filepaths, checks, ipinfo, opts: Options): await reading #lament("AWAITED reading") - if seen == 0: + if seen == 0 and db is not None: #lament("UPDATING country codes") # no IPs were provided. refresh all the country codes instead. all_ips = db.all_ips() for i, ip in enumerate(all_ips): + #lament("UPDATING", ip) if opts.progress: lament(f"#{i + 1}/{len(all_ips)}: {ip}") await locate_me.put(ip) diff --git a/respodns/util.py b/respodns/util.py index 486577f..4137ec8 100644 --- a/respodns/util.py +++ b/respodns/util.py @@ -45,38 +45,6 @@ def head(n, it): return res -def taskize(item): - from types import CoroutineType - from asyncio import Task, create_task - - if isinstance(item, CoroutineType): - assert not isinstance(item, Task) # TODO: paranoid? - item = create_task(item) - return item - - -def make_pooler(pool_size, finisher=None): - # TODO: write a less confusing interface - # that allows the code to be written more flatly. - # maybe like: async for done in apply(doit, [tuple_of_args]): - from asyncio import wait, FIRST_COMPLETED - - pending = set() - - async def pooler(item=None): - nonlocal pending - finish = item is None - if not finish: - pending.add(taskize(item)) - desired_size = 0 if finish else pool_size - 1 - while len(pending) > desired_size: - done, pending = await wait(pending, return_when=FIRST_COMPLETED) - if finisher is not None: - finisher(done, pending) - - return pooler - - class AttrCheck: """ Inheriting AttrCheck prevents accidentally setting attributes