overhaul the try_ip function's async stuff
This commit is contained in:
parent
55024634f7
commit
3ca869ccad
2 changed files with 48 additions and 87 deletions
103
respodns/dns.py
103
respodns/dns.py
|
@ -33,14 +33,14 @@ def detect_gfw(r, ip, check):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def getaddrs(server, domain, opts, context=None):
|
async def getaddrs(server, domain, impatient=False, context=None):
|
||||||
from .ip_util import ipkey
|
from .ip_util import ipkey
|
||||||
from dns.asyncresolver import Resolver
|
from dns.asyncresolver import Resolver
|
||||||
from dns.exception import Timeout
|
from dns.exception import Timeout
|
||||||
from dns.resolver import NXDOMAIN, NoAnswer, NoNameservers
|
from dns.resolver import NXDOMAIN, NoAnswer, NoNameservers
|
||||||
|
|
||||||
res = Resolver(configure=False)
|
res = Resolver(configure=False)
|
||||||
if opts.impatient:
|
if impatient:
|
||||||
res.timeout = 5
|
res.timeout = 5
|
||||||
res.lifetime = 2
|
res.lifetime = 2
|
||||||
else:
|
else:
|
||||||
|
@ -120,59 +120,61 @@ def process_result(res, ip, check, opts: Options):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def find_failure(entries):
|
||||||
|
assert len(entries) > 0
|
||||||
|
for entry in entries:
|
||||||
|
if not entry.success:
|
||||||
|
return entry
|
||||||
|
assert False, ("no failures found:", entries)
|
||||||
|
|
||||||
|
|
||||||
async def try_ip(db, server_ip, checks, context, opts: Options, callback=None):
|
async def try_ip(db, server_ip, checks, context, opts: Options, callback=None):
|
||||||
# context can be None.
|
# context can be None.
|
||||||
from .util import make_pooler
|
from asyncio import sleep, create_task, CancelledError, BoundedSemaphore
|
||||||
from asyncio import sleep, CancelledError
|
|
||||||
|
|
||||||
|
sem = BoundedSemaphore(opts.domain_simul)
|
||||||
entries = []
|
entries = []
|
||||||
|
tasks = []
|
||||||
success = True
|
success = True
|
||||||
|
|
||||||
def finisher(done, pending):
|
async def _process(check):
|
||||||
nonlocal success
|
nonlocal success
|
||||||
for task in done:
|
res = await getaddrs(server_ip, check.domain, opts.impatient, context)
|
||||||
try:
|
entry = process_result(res, server_ip, check, opts)
|
||||||
res, ip, check = task.result()
|
if callback is not None:
|
||||||
except CancelledError:
|
for addr in entry.addrs:
|
||||||
success = False
|
callback(addr)
|
||||||
break
|
entries.append(entry)
|
||||||
entry = process_result(res, ip, check, opts)
|
if not entry.success:
|
||||||
if callback is not None:
|
if opts.early_stopping and success: # only cancel once
|
||||||
for addr in entry.addrs:
|
for task in tasks:
|
||||||
callback(addr)
|
if not task.done() and not task.cancelled():
|
||||||
entries.append(entry)
|
task.cancel()
|
||||||
if not entry.success:
|
success = False
|
||||||
if opts.early_stopping and success: # only cancel once
|
|
||||||
for pend in pending:
|
|
||||||
# FIXME: this can still, somehow,
|
|
||||||
# cancel the main function.
|
|
||||||
pend.cancel()
|
|
||||||
success = False
|
|
||||||
|
|
||||||
pooler = make_pooler(opts.domain_simul, finisher)
|
# limit to one connection until the first check completes.
|
||||||
|
await _process(checks[0])
|
||||||
|
|
||||||
async def getaddrs_wrapper(ip, check):
|
async def process(check):
|
||||||
# NOTE: could put right_now() stuff here!
|
try:
|
||||||
# TODO: add duration field given in milliseconds (integer)
|
await _process(check)
|
||||||
# by subtracting start and end datetimes.
|
finally:
|
||||||
res = await getaddrs(ip, check.domain, opts, context)
|
sem.release()
|
||||||
return res, ip, check
|
|
||||||
|
|
||||||
#lament("TESTING", server_ip)
|
for check in checks[1:]:
|
||||||
for i, check in enumerate(checks):
|
if len(tasks) > 0:
|
||||||
first = i == 0
|
|
||||||
if not first:
|
|
||||||
await sleep(opts.domain_wait)
|
await sleep(opts.domain_wait)
|
||||||
await pooler(getaddrs_wrapper(server_ip, check))
|
# acquire now instead of within the task so
|
||||||
if first:
|
# a ton of tasks aren't created all at once.
|
||||||
# limit to one connection for the first check.
|
await sem.acquire()
|
||||||
await pooler()
|
|
||||||
if not success:
|
if not success:
|
||||||
if opts.early_stopping or first:
|
break
|
||||||
break
|
#lament("ENTRY", server_ip, check)
|
||||||
else:
|
task = create_task(process(check))
|
||||||
await pooler()
|
tasks.append(task)
|
||||||
|
for task in tasks:
|
||||||
|
if not task.cancelled():
|
||||||
|
await task
|
||||||
|
|
||||||
if not opts.dry:
|
if not opts.dry:
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
|
@ -181,17 +183,7 @@ async def try_ip(db, server_ip, checks, context, opts: Options, callback=None):
|
||||||
|
|
||||||
#lament("TESTED", server_ip)
|
#lament("TESTED", server_ip)
|
||||||
|
|
||||||
if not success:
|
return None if success else find_failure(entries)
|
||||||
first_failure = None
|
|
||||||
assert len(entries) > 0
|
|
||||||
for entry in entries:
|
|
||||||
if not entry.success:
|
|
||||||
first_failure = entry
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
assert 0, ("no failures found:", entries)
|
|
||||||
return first_failure
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def try_all_ips(db, try_me, checks, context, opts: Options, callback=None):
|
async def try_all_ips(db, try_me, checks, context, opts: Options, callback=None):
|
||||||
|
@ -341,11 +333,12 @@ async def main(db, filepaths, checks, ipinfo, opts: Options):
|
||||||
await reading
|
await reading
|
||||||
#lament("AWAITED reading")
|
#lament("AWAITED reading")
|
||||||
|
|
||||||
if seen == 0:
|
if seen == 0 and db is not None:
|
||||||
#lament("UPDATING country codes")
|
#lament("UPDATING country codes")
|
||||||
# no IPs were provided. refresh all the country codes instead.
|
# no IPs were provided. refresh all the country codes instead.
|
||||||
all_ips = db.all_ips()
|
all_ips = db.all_ips()
|
||||||
for i, ip in enumerate(all_ips):
|
for i, ip in enumerate(all_ips):
|
||||||
|
#lament("UPDATING", ip)
|
||||||
if opts.progress:
|
if opts.progress:
|
||||||
lament(f"#{i + 1}/{len(all_ips)}: {ip}")
|
lament(f"#{i + 1}/{len(all_ips)}: {ip}")
|
||||||
await locate_me.put(ip)
|
await locate_me.put(ip)
|
||||||
|
|
|
@ -45,38 +45,6 @@ def head(n, it):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def taskize(item):
|
|
||||||
from types import CoroutineType
|
|
||||||
from asyncio import Task, create_task
|
|
||||||
|
|
||||||
if isinstance(item, CoroutineType):
|
|
||||||
assert not isinstance(item, Task) # TODO: paranoid?
|
|
||||||
item = create_task(item)
|
|
||||||
return item
|
|
||||||
|
|
||||||
|
|
||||||
def make_pooler(pool_size, finisher=None):
|
|
||||||
# TODO: write a less confusing interface
|
|
||||||
# that allows the code to be written more flatly.
|
|
||||||
# maybe like: async for done in apply(doit, [tuple_of_args]):
|
|
||||||
from asyncio import wait, FIRST_COMPLETED
|
|
||||||
|
|
||||||
pending = set()
|
|
||||||
|
|
||||||
async def pooler(item=None):
|
|
||||||
nonlocal pending
|
|
||||||
finish = item is None
|
|
||||||
if not finish:
|
|
||||||
pending.add(taskize(item))
|
|
||||||
desired_size = 0 if finish else pool_size - 1
|
|
||||||
while len(pending) > desired_size:
|
|
||||||
done, pending = await wait(pending, return_when=FIRST_COMPLETED)
|
|
||||||
if finisher is not None:
|
|
||||||
finisher(done, pending)
|
|
||||||
|
|
||||||
return pooler
|
|
||||||
|
|
||||||
|
|
||||||
class AttrCheck:
|
class AttrCheck:
|
||||||
"""
|
"""
|
||||||
Inheriting AttrCheck prevents accidentally setting attributes
|
Inheriting AttrCheck prevents accidentally setting attributes
|
||||||
|
|
Loading…
Add table
Reference in a new issue