overhaul the try_ip function's async stuff
This commit is contained in:
parent
55024634f7
commit
3ca869ccad
2 changed files with 48 additions and 87 deletions
103
respodns/dns.py
103
respodns/dns.py
|
@ -33,14 +33,14 @@ def detect_gfw(r, ip, check):
|
|||
return False
|
||||
|
||||
|
||||
async def getaddrs(server, domain, opts, context=None):
|
||||
async def getaddrs(server, domain, impatient=False, context=None):
|
||||
from .ip_util import ipkey
|
||||
from dns.asyncresolver import Resolver
|
||||
from dns.exception import Timeout
|
||||
from dns.resolver import NXDOMAIN, NoAnswer, NoNameservers
|
||||
|
||||
res = Resolver(configure=False)
|
||||
if opts.impatient:
|
||||
if impatient:
|
||||
res.timeout = 5
|
||||
res.lifetime = 2
|
||||
else:
|
||||
|
@ -120,59 +120,61 @@ def process_result(res, ip, check, opts: Options):
|
|||
)
|
||||
|
||||
|
||||
def find_failure(entries):
|
||||
assert len(entries) > 0
|
||||
for entry in entries:
|
||||
if not entry.success:
|
||||
return entry
|
||||
assert False, ("no failures found:", entries)
|
||||
|
||||
|
||||
async def try_ip(db, server_ip, checks, context, opts: Options, callback=None):
|
||||
# context can be None.
|
||||
from .util import make_pooler
|
||||
from asyncio import sleep, CancelledError
|
||||
from asyncio import sleep, create_task, CancelledError, BoundedSemaphore
|
||||
|
||||
sem = BoundedSemaphore(opts.domain_simul)
|
||||
entries = []
|
||||
|
||||
tasks = []
|
||||
success = True
|
||||
|
||||
def finisher(done, pending):
|
||||
async def _process(check):
|
||||
nonlocal success
|
||||
for task in done:
|
||||
try:
|
||||
res, ip, check = task.result()
|
||||
except CancelledError:
|
||||
success = False
|
||||
break
|
||||
entry = process_result(res, ip, check, opts)
|
||||
if callback is not None:
|
||||
for addr in entry.addrs:
|
||||
callback(addr)
|
||||
entries.append(entry)
|
||||
if not entry.success:
|
||||
if opts.early_stopping and success: # only cancel once
|
||||
for pend in pending:
|
||||
# FIXME: this can still, somehow,
|
||||
# cancel the main function.
|
||||
pend.cancel()
|
||||
success = False
|
||||
res = await getaddrs(server_ip, check.domain, opts.impatient, context)
|
||||
entry = process_result(res, server_ip, check, opts)
|
||||
if callback is not None:
|
||||
for addr in entry.addrs:
|
||||
callback(addr)
|
||||
entries.append(entry)
|
||||
if not entry.success:
|
||||
if opts.early_stopping and success: # only cancel once
|
||||
for task in tasks:
|
||||
if not task.done() and not task.cancelled():
|
||||
task.cancel()
|
||||
success = False
|
||||
|
||||
pooler = make_pooler(opts.domain_simul, finisher)
|
||||
# limit to one connection until the first check completes.
|
||||
await _process(checks[0])
|
||||
|
||||
async def getaddrs_wrapper(ip, check):
|
||||
# NOTE: could put right_now() stuff here!
|
||||
# TODO: add duration field given in milliseconds (integer)
|
||||
# by subtracting start and end datetimes.
|
||||
res = await getaddrs(ip, check.domain, opts, context)
|
||||
return res, ip, check
|
||||
async def process(check):
|
||||
try:
|
||||
await _process(check)
|
||||
finally:
|
||||
sem.release()
|
||||
|
||||
#lament("TESTING", server_ip)
|
||||
for i, check in enumerate(checks):
|
||||
first = i == 0
|
||||
if not first:
|
||||
for check in checks[1:]:
|
||||
if len(tasks) > 0:
|
||||
await sleep(opts.domain_wait)
|
||||
await pooler(getaddrs_wrapper(server_ip, check))
|
||||
if first:
|
||||
# limit to one connection for the first check.
|
||||
await pooler()
|
||||
# acquire now instead of within the task so
|
||||
# a ton of tasks aren't created all at once.
|
||||
await sem.acquire()
|
||||
if not success:
|
||||
if opts.early_stopping or first:
|
||||
break
|
||||
else:
|
||||
await pooler()
|
||||
break
|
||||
#lament("ENTRY", server_ip, check)
|
||||
task = create_task(process(check))
|
||||
tasks.append(task)
|
||||
for task in tasks:
|
||||
if not task.cancelled():
|
||||
await task
|
||||
|
||||
if not opts.dry:
|
||||
for entry in entries:
|
||||
|
@ -181,17 +183,7 @@ async def try_ip(db, server_ip, checks, context, opts: Options, callback=None):
|
|||
|
||||
#lament("TESTED", server_ip)
|
||||
|
||||
if not success:
|
||||
first_failure = None
|
||||
assert len(entries) > 0
|
||||
for entry in entries:
|
||||
if not entry.success:
|
||||
first_failure = entry
|
||||
break
|
||||
else:
|
||||
assert 0, ("no failures found:", entries)
|
||||
return first_failure
|
||||
return None
|
||||
return None if success else find_failure(entries)
|
||||
|
||||
|
||||
async def try_all_ips(db, try_me, checks, context, opts: Options, callback=None):
|
||||
|
@ -341,11 +333,12 @@ async def main(db, filepaths, checks, ipinfo, opts: Options):
|
|||
await reading
|
||||
#lament("AWAITED reading")
|
||||
|
||||
if seen == 0:
|
||||
if seen == 0 and db is not None:
|
||||
#lament("UPDATING country codes")
|
||||
# no IPs were provided. refresh all the country codes instead.
|
||||
all_ips = db.all_ips()
|
||||
for i, ip in enumerate(all_ips):
|
||||
#lament("UPDATING", ip)
|
||||
if opts.progress:
|
||||
lament(f"#{i + 1}/{len(all_ips)}: {ip}")
|
||||
await locate_me.put(ip)
|
||||
|
|
|
@ -45,38 +45,6 @@ def head(n, it):
|
|||
return res
|
||||
|
||||
|
||||
def taskize(item):
|
||||
from types import CoroutineType
|
||||
from asyncio import Task, create_task
|
||||
|
||||
if isinstance(item, CoroutineType):
|
||||
assert not isinstance(item, Task) # TODO: paranoid?
|
||||
item = create_task(item)
|
||||
return item
|
||||
|
||||
|
||||
def make_pooler(pool_size, finisher=None):
|
||||
# TODO: write a less confusing interface
|
||||
# that allows the code to be written more flatly.
|
||||
# maybe like: async for done in apply(doit, [tuple_of_args]):
|
||||
from asyncio import wait, FIRST_COMPLETED
|
||||
|
||||
pending = set()
|
||||
|
||||
async def pooler(item=None):
|
||||
nonlocal pending
|
||||
finish = item is None
|
||||
if not finish:
|
||||
pending.add(taskize(item))
|
||||
desired_size = 0 if finish else pool_size - 1
|
||||
while len(pending) > desired_size:
|
||||
done, pending = await wait(pending, return_when=FIRST_COMPLETED)
|
||||
if finisher is not None:
|
||||
finisher(done, pending)
|
||||
|
||||
return pooler
|
||||
|
||||
|
||||
class AttrCheck:
|
||||
"""
|
||||
Inheriting AttrCheck prevents accidentally setting attributes
|
||||
|
|
Loading…
Reference in a new issue