overhaul the try_ip function's async stuff

This commit is contained in:
Connor Olding 2021-08-13 02:09:01 -07:00
parent 55024634f7
commit 3ca869ccad
2 changed files with 48 additions and 87 deletions

View file

@ -33,14 +33,14 @@ def detect_gfw(r, ip, check):
return False
async def getaddrs(server, domain, opts, context=None):
async def getaddrs(server, domain, impatient=False, context=None):
from .ip_util import ipkey
from dns.asyncresolver import Resolver
from dns.exception import Timeout
from dns.resolver import NXDOMAIN, NoAnswer, NoNameservers
res = Resolver(configure=False)
if opts.impatient:
if impatient:
res.timeout = 5
res.lifetime = 2
else:
@ -120,59 +120,61 @@ def process_result(res, ip, check, opts: Options):
)
def find_failure(entries):
assert len(entries) > 0
for entry in entries:
if not entry.success:
return entry
assert False, ("no failures found:", entries)
async def try_ip(db, server_ip, checks, context, opts: Options, callback=None):
# context can be None.
from .util import make_pooler
from asyncio import sleep, CancelledError
from asyncio import sleep, create_task, CancelledError, BoundedSemaphore
sem = BoundedSemaphore(opts.domain_simul)
entries = []
tasks = []
success = True
def finisher(done, pending):
async def _process(check):
nonlocal success
for task in done:
try:
res, ip, check = task.result()
except CancelledError:
success = False
break
entry = process_result(res, ip, check, opts)
if callback is not None:
for addr in entry.addrs:
callback(addr)
entries.append(entry)
if not entry.success:
if opts.early_stopping and success: # only cancel once
for pend in pending:
# FIXME: this can still, somehow,
# cancel the main function.
pend.cancel()
success = False
res = await getaddrs(server_ip, check.domain, opts.impatient, context)
entry = process_result(res, server_ip, check, opts)
if callback is not None:
for addr in entry.addrs:
callback(addr)
entries.append(entry)
if not entry.success:
if opts.early_stopping and success: # only cancel once
for task in tasks:
if not task.done() and not task.cancelled():
task.cancel()
success = False
pooler = make_pooler(opts.domain_simul, finisher)
# limit to one connection until the first check completes.
await _process(checks[0])
async def getaddrs_wrapper(ip, check):
# NOTE: could put right_now() stuff here!
# TODO: add duration field given in milliseconds (integer)
# by subtracting start and end datetimes.
res = await getaddrs(ip, check.domain, opts, context)
return res, ip, check
async def process(check):
try:
await _process(check)
finally:
sem.release()
#lament("TESTING", server_ip)
for i, check in enumerate(checks):
first = i == 0
if not first:
for check in checks[1:]:
if len(tasks) > 0:
await sleep(opts.domain_wait)
await pooler(getaddrs_wrapper(server_ip, check))
if first:
# limit to one connection for the first check.
await pooler()
# acquire now instead of within the task so
# a ton of tasks aren't created all at once.
await sem.acquire()
if not success:
if opts.early_stopping or first:
break
else:
await pooler()
break
#lament("ENTRY", server_ip, check)
task = create_task(process(check))
tasks.append(task)
for task in tasks:
if not task.cancelled():
await task
if not opts.dry:
for entry in entries:
@ -181,17 +183,7 @@ async def try_ip(db, server_ip, checks, context, opts: Options, callback=None):
#lament("TESTED", server_ip)
if not success:
first_failure = None
assert len(entries) > 0
for entry in entries:
if not entry.success:
first_failure = entry
break
else:
assert 0, ("no failures found:", entries)
return first_failure
return None
return None if success else find_failure(entries)
async def try_all_ips(db, try_me, checks, context, opts: Options, callback=None):
@ -341,11 +333,12 @@ async def main(db, filepaths, checks, ipinfo, opts: Options):
await reading
#lament("AWAITED reading")
if seen == 0:
if seen == 0 and db is not None:
#lament("UPDATING country codes")
# no IPs were provided. refresh all the country codes instead.
all_ips = db.all_ips()
for i, ip in enumerate(all_ips):
#lament("UPDATING", ip)
if opts.progress:
lament(f"#{i + 1}/{len(all_ips)}: {ip}")
await locate_me.put(ip)

View file

@ -45,38 +45,6 @@ def head(n, it):
return res
def taskize(item):
from types import CoroutineType
from asyncio import Task, create_task
if isinstance(item, CoroutineType):
assert not isinstance(item, Task) # TODO: paranoid?
item = create_task(item)
return item
def make_pooler(pool_size, finisher=None):
# TODO: write a less confusing interface
# that allows the code to be written more flatly.
# maybe like: async for done in apply(doit, [tuple_of_args]):
from asyncio import wait, FIRST_COMPLETED
pending = set()
async def pooler(item=None):
nonlocal pending
finish = item is None
if not finish:
pending.add(taskize(item))
desired_size = 0 if finish else pool_size - 1
while len(pending) > desired_size:
done, pending = await wait(pending, return_when=FIRST_COMPLETED)
if finisher is not None:
finisher(done, pending)
return pooler
class AttrCheck:
"""
Inheriting AttrCheck prevents accidentally setting attributes