overhaul the main function's async stuff
after i confirm i didn't break anything, next up is the underlying try_ip function.
This commit is contained in:
parent
788d75a5e9
commit
0b198cc5c0
3 changed files with 136 additions and 79 deletions
205
respodns/dns.py
205
respodns/dns.py
|
@ -1,4 +1,9 @@
|
|||
from .structs import Options
|
||||
from sys import stdin, stdout, stderr
|
||||
|
||||
|
||||
def lament(*args, **kwargs):
|
||||
print(*args, file=stderr, **kwargs)
|
||||
|
||||
|
||||
def detect_gfw(r, ip, check):
|
||||
|
@ -111,22 +116,14 @@ def process_result(res, ip, check, opts: Options):
|
|||
)
|
||||
|
||||
|
||||
async def try_ip(db, server_ip, checks, opts: Options):
|
||||
async def try_ip(db, server_ip, checks, opts: Options, callback=None):
|
||||
from .util import make_pooler
|
||||
from asyncio import sleep, CancelledError
|
||||
|
||||
entries = []
|
||||
deferred = []
|
||||
|
||||
success = True
|
||||
|
||||
def maybe_put_ip(ip):
|
||||
from asyncio import QueueFull
|
||||
try:
|
||||
opts.ips.put_nowait(ip)
|
||||
except QueueFull:
|
||||
deferred.append(ip)
|
||||
|
||||
def finisher(done, pending):
|
||||
nonlocal success
|
||||
for task in done:
|
||||
|
@ -136,7 +133,9 @@ async def try_ip(db, server_ip, checks, opts: Options):
|
|||
success = False
|
||||
break
|
||||
entry = process_result(res, ip, check, opts)
|
||||
map(maybe_put_ip, entry.addrs)
|
||||
if callback is not None:
|
||||
for addr in entry.addrs:
|
||||
callback(addr)
|
||||
entries.append(entry)
|
||||
if not entry.success:
|
||||
if opts.early_stopping and success: # only cancel once
|
||||
|
@ -155,6 +154,7 @@ async def try_ip(db, server_ip, checks, opts: Options):
|
|||
res = await getaddrs(ip, check.domain, opts)
|
||||
return res, ip, check
|
||||
|
||||
#lament("BEGIN", server_ip)
|
||||
for i, check in enumerate(checks):
|
||||
first = i == 0
|
||||
if not first:
|
||||
|
@ -169,14 +169,13 @@ async def try_ip(db, server_ip, checks, opts: Options):
|
|||
else:
|
||||
await pooler()
|
||||
|
||||
for ip in deferred:
|
||||
await opts.ips.put(ip)
|
||||
|
||||
if not opts.dry:
|
||||
for entry in entries:
|
||||
db.push_entry(entry)
|
||||
db.commit()
|
||||
|
||||
#lament("FINISH", server_ip)
|
||||
|
||||
if not success:
|
||||
first_failure = None
|
||||
assert len(entries) > 0
|
||||
|
@ -186,11 +185,52 @@ async def try_ip(db, server_ip, checks, opts: Options):
|
|||
break
|
||||
else:
|
||||
assert 0, ("no failures found:", entries)
|
||||
return server_ip, first_failure
|
||||
return server_ip, None
|
||||
return first_failure
|
||||
return None
|
||||
|
||||
|
||||
async def sync_database(db, opts: Options):
|
||||
async def try_all_ips(db, try_me, checks, opts: Options, callback=None):
|
||||
from asyncio import create_task, sleep, BoundedSemaphore
|
||||
|
||||
seen, total = 0, None
|
||||
|
||||
async def process(ip):
|
||||
nonlocal seen
|
||||
first = seen == 0
|
||||
seen += 1
|
||||
if opts.progress:
|
||||
lament(f"#{seen}: {ip}" if total is None else f"#{seen}/{total}: {ip}")
|
||||
stderr.flush()
|
||||
if not first:
|
||||
await sleep(opts.ip_wait)
|
||||
first_failure = await try_ip(db, ip, checks, opts, callback)
|
||||
if first_failure is None:
|
||||
print(ip) # all tests for this server passed; pass it along to stdout
|
||||
elif opts.dry: # don't save the error anywhere; pass it along to stdout
|
||||
ff = first_failure
|
||||
if ff.kind in ("shock", "adware"):
|
||||
# don't print sketchy domains to console in case they're clicked.
|
||||
print(ip, ff.reason, ff.kind, sep="\t")
|
||||
else:
|
||||
print(ip, ff.reason, ff.kind, ff.domain, sep="\t")
|
||||
|
||||
sem = BoundedSemaphore(opts.ip_simul)
|
||||
tasks = []
|
||||
while (res := await try_me.get()) is not None:
|
||||
ip, total = res
|
||||
#lament("TRYING", ip)
|
||||
async with sem:
|
||||
task = create_task(process(ip))
|
||||
tasks.append(task)
|
||||
for task in tasks:
|
||||
await task
|
||||
|
||||
#lament("EXITING try_all_ips")
|
||||
|
||||
|
||||
def sync_database(db, callback=None):
|
||||
# NOTE: this no longer takes Options.
|
||||
# NOTE: this is no longer async.
|
||||
from .ips import china, blocks
|
||||
|
||||
# TODO: handle addresses that were removed from respodns.ips.china.
|
||||
|
@ -199,90 +239,109 @@ async def sync_database(db, opts: Options):
|
|||
kwargs = {kw: True}
|
||||
if db is not None:
|
||||
db.modify_address(ip, **kwargs)
|
||||
await opts.ips.put(ip)
|
||||
if callback is not None:
|
||||
callback(ip)
|
||||
|
||||
|
||||
async def locate_ips(db, opts: Options):
|
||||
async def locate_ips(db, locate_me, ipinfo):
|
||||
# NOTE: this no longer takes Options.
|
||||
from time import time
|
||||
seen = set()
|
||||
last_save = time()
|
||||
while (ip := await opts.ips.get()) is not None:
|
||||
if opts.ipinfo is not None and ip not in seen:
|
||||
while (ip := await locate_me.get()) is not None:
|
||||
if ipinfo is not None and ip not in seen:
|
||||
#lament("LOCATE", ip)
|
||||
seen.add(ip)
|
||||
code = await opts.ipinfo.find_country(ip)
|
||||
code = await ipinfo.find_country(ip)
|
||||
if db is not None:
|
||||
db.modify_address(ip, country_code=code)
|
||||
if time() >= last_save + 10.0: # only flush occasionally
|
||||
opts.ipinfo.flush()
|
||||
#lament("FLUSH", time() - last_save)
|
||||
ipinfo.flush()
|
||||
last_save = time()
|
||||
opts.ipinfo.flush()
|
||||
ipinfo.flush()
|
||||
|
||||
#lament("EXITING locate_ips")
|
||||
|
||||
|
||||
async def main(db, filepaths, checks, opts: Options):
|
||||
from .util import make_pooler
|
||||
from asyncio import sleep, create_task, Queue
|
||||
from sys import stdin, stderr
|
||||
|
||||
opts.ips = Queue()
|
||||
syncing = create_task(sync_database(db, opts))
|
||||
geoip = create_task(locate_ips(db, opts))
|
||||
|
||||
def finisher(done, pending):
|
||||
for task in done:
|
||||
ip, first_failure = task.result()
|
||||
if first_failure is None:
|
||||
print(ip)
|
||||
elif opts.dry:
|
||||
ff = first_failure
|
||||
if ff.kind in ("shock", "adware"):
|
||||
print(ip, ff.reason, ff.kind, sep="\t")
|
||||
else:
|
||||
print(ip, ff.reason, ff.kind, ff.domain, sep="\t")
|
||||
|
||||
pooler = make_pooler(opts.ip_simul, finisher)
|
||||
|
||||
seen = 0
|
||||
|
||||
async def process(ip, total=None):
|
||||
nonlocal seen
|
||||
first = seen == 0
|
||||
seen += 1
|
||||
if opts.progress:
|
||||
if total is None:
|
||||
print(f"#{seen}: {ip}", file=stderr)
|
||||
else:
|
||||
print(f"#{seen}/{total}: {ip}", file=stderr)
|
||||
stderr.flush()
|
||||
if not first:
|
||||
await sleep(opts.ip_wait)
|
||||
await opts.ips.put(ip)
|
||||
await pooler(try_ip(db, ip, checks, opts))
|
||||
|
||||
if opts.blocking_file_io:
|
||||
async def read_all_ips(filepaths, blocking=False, callback=None):
|
||||
assert callback is not None, "that doesn't make sense!"
|
||||
if blocking:
|
||||
from .ip_util import read_ips
|
||||
for filepath in filepaths:
|
||||
f = stdin if filepath == "" else open(filepath, "r")
|
||||
for ip in read_ips(f):
|
||||
await process(ip)
|
||||
#lament("READ", ip)
|
||||
await callback(ip)
|
||||
if f != stdin:
|
||||
f.close()
|
||||
|
||||
else:
|
||||
from .ip_util import IpReader
|
||||
fps = [stdin if fp == "" else fp for fp in filepaths]
|
||||
with IpReader(*fps) as reader:
|
||||
async for ip in reader:
|
||||
await process(ip, reader.total)
|
||||
#lament("READ", ip)
|
||||
await callback(ip, reader.total)
|
||||
|
||||
#lament("EXITING read_all_ips")
|
||||
|
||||
|
||||
async def main(db, filepaths, checks, ipinfo, opts: Options):
|
||||
# ipinfo can be None.
|
||||
from .util import make_pooler
|
||||
from asyncio import Queue, QueueFull, create_task
|
||||
from queue import SimpleQueue
|
||||
|
||||
deferred = SimpleQueue()
|
||||
locate_me = Queue()
|
||||
try_me = Queue()
|
||||
|
||||
def locate_later(ip):
|
||||
try:
|
||||
locate_me.put_nowait(ip)
|
||||
except QueueFull:
|
||||
deferred.put(ip)
|
||||
|
||||
seen = 0
|
||||
async def try_soon(ip, total=None):
|
||||
nonlocal seen
|
||||
seen += 1
|
||||
await try_me.put((ip, total))
|
||||
|
||||
sync_database(db, callback=locate_later)
|
||||
|
||||
reading = create_task(read_all_ips(filepaths, opts.blocking_file_io,
|
||||
callback=try_soon))
|
||||
trying = create_task(try_all_ips(db, try_me, checks, opts,
|
||||
callback=locate_later))
|
||||
locating = create_task(locate_ips(db, locate_me, ipinfo))
|
||||
|
||||
# these tasks feed each other with queues, so order them as such:
|
||||
# reading -> trying -> locating
|
||||
|
||||
#lament("AWAIT reading")
|
||||
await reading
|
||||
#lament("AWAITED reading")
|
||||
|
||||
if seen == 0:
|
||||
#lament("UPDATING country codes")
|
||||
# no IPs were provided. refresh all the country codes instead.
|
||||
all_ips = db.all_ips()
|
||||
for i, ip in enumerate(all_ips):
|
||||
if opts.progress:
|
||||
print(f"#{i + 1}/{len(all_ips)}: {ip}", file=stderr)
|
||||
await opts.ips.put(ip)
|
||||
lament(f"#{i + 1}/{len(all_ips)}: {ip}")
|
||||
await locate_me.put(ip)
|
||||
|
||||
await pooler()
|
||||
await syncing
|
||||
await opts.ips.put(None) # end of queue
|
||||
await geoip
|
||||
await try_me.put(None)
|
||||
|
||||
#lament("AWAIT trying")
|
||||
await trying
|
||||
#lament("AWAITED trying")
|
||||
|
||||
#lament("STOPPING locating")
|
||||
#done_locating.set()
|
||||
await locate_me.put(None)
|
||||
|
||||
#lament("AWAIT locating")
|
||||
await locating
|
||||
#lament("AWAITED locating")
|
||||
|
|
|
@ -4,10 +4,8 @@ from dataclasses import dataclass
|
|||
|
||||
@dataclass
|
||||
class Options:
|
||||
# TODO: move these out of Options, since they're really not.
|
||||
# TODO: move this out of Options, since it's really not.
|
||||
execution: object = None
|
||||
ipinfo: object = None
|
||||
ips: object = None
|
||||
|
||||
ip_simul: int = 15 # how many IPs to connect to at once
|
||||
domain_simul: int = 3 # how many domains per IP to request at once
|
||||
|
|
|
@ -45,7 +45,7 @@ def ui(program, args):
|
|||
opts.early_stopping = opts.dry
|
||||
opts.progress = a.progress
|
||||
|
||||
opts.ipinfo = IpInfoByIpApi("ipinfo_cache.csv")
|
||||
ipinfo = IpInfoByIpApi("ipinfo_cache.csv")
|
||||
|
||||
if a.database is not None:
|
||||
if a.database.startswith("sqlite:"):
|
||||
|
@ -57,9 +57,9 @@ def ui(program, args):
|
|||
if a.debug:
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
run(main(db, a.path, checks, opts), debug=True)
|
||||
run(main(db, a.path, checks, ipinfo, opts), debug=True)
|
||||
else:
|
||||
run(main(db, a.path, checks, opts))
|
||||
run(main(db, a.path, checks, ipinfo, opts))
|
||||
|
||||
if opts.dry:
|
||||
runwrap(None)
|
||||
|
|
Loading…
Reference in a new issue