overhaul the try_ip function's async stuff

This commit is contained in:
Connor Olding 2021-08-13 02:09:01 -07:00
parent 55024634f7
commit 3ca869ccad
2 changed files with 48 additions and 87 deletions

View File

@ -33,14 +33,14 @@ def detect_gfw(r, ip, check):
return False return False
async def getaddrs(server, domain, opts, context=None): async def getaddrs(server, domain, impatient=False, context=None):
from .ip_util import ipkey from .ip_util import ipkey
from dns.asyncresolver import Resolver from dns.asyncresolver import Resolver
from dns.exception import Timeout from dns.exception import Timeout
from dns.resolver import NXDOMAIN, NoAnswer, NoNameservers from dns.resolver import NXDOMAIN, NoAnswer, NoNameservers
res = Resolver(configure=False) res = Resolver(configure=False)
if opts.impatient: if impatient:
res.timeout = 5 res.timeout = 5
res.lifetime = 2 res.lifetime = 2
else: else:
@ -120,59 +120,61 @@ def process_result(res, ip, check, opts: Options):
) )
def find_failure(entries):
assert len(entries) > 0
for entry in entries:
if not entry.success:
return entry
assert False, ("no failures found:", entries)
async def try_ip(db, server_ip, checks, context, opts: Options, callback=None): async def try_ip(db, server_ip, checks, context, opts: Options, callback=None):
# context can be None. # context can be None.
from .util import make_pooler from asyncio import sleep, create_task, CancelledError, BoundedSemaphore
from asyncio import sleep, CancelledError
sem = BoundedSemaphore(opts.domain_simul)
entries = [] entries = []
tasks = []
success = True success = True
def finisher(done, pending): async def _process(check):
nonlocal success nonlocal success
for task in done: res = await getaddrs(server_ip, check.domain, opts.impatient, context)
try: entry = process_result(res, server_ip, check, opts)
res, ip, check = task.result() if callback is not None:
except CancelledError: for addr in entry.addrs:
success = False callback(addr)
break entries.append(entry)
entry = process_result(res, ip, check, opts) if not entry.success:
if callback is not None: if opts.early_stopping and success: # only cancel once
for addr in entry.addrs: for task in tasks:
callback(addr) if not task.done() and not task.cancelled():
entries.append(entry) task.cancel()
if not entry.success: success = False
if opts.early_stopping and success: # only cancel once
for pend in pending:
# FIXME: this can still, somehow,
# cancel the main function.
pend.cancel()
success = False
pooler = make_pooler(opts.domain_simul, finisher) # limit to one connection until the first check completes.
await _process(checks[0])
async def getaddrs_wrapper(ip, check): async def process(check):
# NOTE: could put right_now() stuff here! try:
# TODO: add duration field given in milliseconds (integer) await _process(check)
# by subtracting start and end datetimes. finally:
res = await getaddrs(ip, check.domain, opts, context) sem.release()
return res, ip, check
#lament("TESTING", server_ip) for check in checks[1:]:
for i, check in enumerate(checks): if len(tasks) > 0:
first = i == 0
if not first:
await sleep(opts.domain_wait) await sleep(opts.domain_wait)
await pooler(getaddrs_wrapper(server_ip, check)) # acquire now instead of within the task so
if first: # a ton of tasks aren't created all at once.
# limit to one connection for the first check. await sem.acquire()
await pooler()
if not success: if not success:
if opts.early_stopping or first: break
break #lament("ENTRY", server_ip, check)
else: task = create_task(process(check))
await pooler() tasks.append(task)
for task in tasks:
if not task.cancelled():
await task
if not opts.dry: if not opts.dry:
for entry in entries: for entry in entries:
@ -181,17 +183,7 @@ async def try_ip(db, server_ip, checks, context, opts: Options, callback=None):
#lament("TESTED", server_ip) #lament("TESTED", server_ip)
if not success: return None if success else find_failure(entries)
first_failure = None
assert len(entries) > 0
for entry in entries:
if not entry.success:
first_failure = entry
break
else:
assert 0, ("no failures found:", entries)
return first_failure
return None
async def try_all_ips(db, try_me, checks, context, opts: Options, callback=None): async def try_all_ips(db, try_me, checks, context, opts: Options, callback=None):
@ -341,11 +333,12 @@ async def main(db, filepaths, checks, ipinfo, opts: Options):
await reading await reading
#lament("AWAITED reading") #lament("AWAITED reading")
if seen == 0: if seen == 0 and db is not None:
#lament("UPDATING country codes") #lament("UPDATING country codes")
# no IPs were provided. refresh all the country codes instead. # no IPs were provided. refresh all the country codes instead.
all_ips = db.all_ips() all_ips = db.all_ips()
for i, ip in enumerate(all_ips): for i, ip in enumerate(all_ips):
#lament("UPDATING", ip)
if opts.progress: if opts.progress:
lament(f"#{i + 1}/{len(all_ips)}: {ip}") lament(f"#{i + 1}/{len(all_ips)}: {ip}")
await locate_me.put(ip) await locate_me.put(ip)

View File

@ -45,38 +45,6 @@ def head(n, it):
return res return res
def taskize(item):
from types import CoroutineType
from asyncio import Task, create_task
if isinstance(item, CoroutineType):
assert not isinstance(item, Task) # TODO: paranoid?
item = create_task(item)
return item
def make_pooler(pool_size, finisher=None):
# TODO: write a less confusing interface
# that allows the code to be written more flatly.
# maybe like: async for done in apply(doit, [tuple_of_args]):
from asyncio import wait, FIRST_COMPLETED
pending = set()
async def pooler(item=None):
nonlocal pending
finish = item is None
if not finish:
pending.add(taskize(item))
desired_size = 0 if finish else pool_size - 1
while len(pending) > desired_size:
done, pending = await wait(pending, return_when=FIRST_COMPLETED)
if finisher is not None:
finisher(done, pending)
return pooler
class AttrCheck: class AttrCheck:
""" """
Inheriting AttrCheck prevents accidentally setting attributes Inheriting AttrCheck prevents accidentally setting attributes