From 0b198cc5c02937fb264d097c237e7a5c08f72611 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Thu, 12 Aug 2021 21:06:00 -0700 Subject: [PATCH] overhaul the main function's async stuff after i confirm i didn't break anything, next up is the underlying try_ip function. --- respodns/dns.py | 205 ++++++++++++++++++++++++++++---------------- respodns/structs.py | 4 +- respodns/ui.py | 6 +- 3 files changed, 136 insertions(+), 79 deletions(-) diff --git a/respodns/dns.py b/respodns/dns.py index b65b4b4..b83cc6a 100644 --- a/respodns/dns.py +++ b/respodns/dns.py @@ -1,4 +1,9 @@ from .structs import Options +from sys import stdin, stdout, stderr + + +def lament(*args, **kwargs): + print(*args, file=stderr, **kwargs) def detect_gfw(r, ip, check): @@ -111,22 +116,14 @@ def process_result(res, ip, check, opts: Options): ) -async def try_ip(db, server_ip, checks, opts: Options): +async def try_ip(db, server_ip, checks, opts: Options, callback=None): from .util import make_pooler from asyncio import sleep, CancelledError entries = [] - deferred = [] success = True - def maybe_put_ip(ip): - from asyncio import QueueFull - try: - opts.ips.put_nowait(ip) - except QueueFull: - deferred.append(ip) - def finisher(done, pending): nonlocal success for task in done: @@ -136,7 +133,9 @@ async def try_ip(db, server_ip, checks, opts: Options): success = False break entry = process_result(res, ip, check, opts) - map(maybe_put_ip, entry.addrs) + if callback is not None: + for addr in entry.addrs: + callback(addr) entries.append(entry) if not entry.success: if opts.early_stopping and success: # only cancel once @@ -155,6 +154,7 @@ async def try_ip(db, server_ip, checks, opts: Options): res = await getaddrs(ip, check.domain, opts) return res, ip, check + #lament("BEGIN", server_ip) for i, check in enumerate(checks): first = i == 0 if not first: @@ -169,14 +169,13 @@ async def try_ip(db, server_ip, checks, opts: Options): else: await pooler() - for ip in deferred: - await opts.ips.put(ip) - if not opts.dry: for entry in entries: db.push_entry(entry) db.commit() + #lament("FINISH", server_ip) + if not success: first_failure = None assert len(entries) > 0 @@ -186,11 +185,52 @@ async def try_ip(db, server_ip, checks, opts: Options): break else: assert 0, ("no failures found:", entries) - return server_ip, first_failure - return server_ip, None + return first_failure + return None -async def sync_database(db, opts: Options): +async def try_all_ips(db, try_me, checks, opts: Options, callback=None): + from asyncio import create_task, sleep, BoundedSemaphore + + seen, total = 0, None + + async def process(ip): + nonlocal seen + first = seen == 0 + seen += 1 + if opts.progress: + lament(f"#{seen}: {ip}" if total is None else f"#{seen}/{total}: {ip}") + stderr.flush() + if not first: + await sleep(opts.ip_wait) + first_failure = await try_ip(db, ip, checks, opts, callback) + if first_failure is None: + print(ip) # all tests for this server passed; pass it along to stdout + elif opts.dry: # don't save the error anywhere; pass it along to stdout + ff = first_failure + if ff.kind in ("shock", "adware"): + # don't print sketchy domains to console in case they're clicked. + print(ip, ff.reason, ff.kind, sep="\t") + else: + print(ip, ff.reason, ff.kind, ff.domain, sep="\t") + + sem = BoundedSemaphore(opts.ip_simul) + tasks = [] + while (res := await try_me.get()) is not None: + ip, total = res + #lament("TRYING", ip) + async with sem: + task = create_task(process(ip)) + tasks.append(task) + for task in tasks: + await task + + #lament("EXITING try_all_ips") + + +def sync_database(db, callback=None): + # NOTE: this no longer takes Options. + # NOTE: this is no longer async. from .ips import china, blocks # TODO: handle addresses that were removed from respodns.ips.china. @@ -199,90 +239,109 @@ async def sync_database(db, opts: Options): kwargs = {kw: True} if db is not None: db.modify_address(ip, **kwargs) - await opts.ips.put(ip) + if callback is not None: + callback(ip) -async def locate_ips(db, opts: Options): +async def locate_ips(db, locate_me, ipinfo): + # NOTE: this no longer takes Options. from time import time seen = set() last_save = time() - while (ip := await opts.ips.get()) is not None: - if opts.ipinfo is not None and ip not in seen: + while (ip := await locate_me.get()) is not None: + if ipinfo is not None and ip not in seen: + #lament("LOCATE", ip) seen.add(ip) - code = await opts.ipinfo.find_country(ip) + code = await ipinfo.find_country(ip) if db is not None: db.modify_address(ip, country_code=code) if time() >= last_save + 10.0: # only flush occasionally - opts.ipinfo.flush() + #lament("FLUSH", time() - last_save) + ipinfo.flush() last_save = time() - opts.ipinfo.flush() + ipinfo.flush() + + #lament("EXITING locate_ips") -async def main(db, filepaths, checks, opts: Options): - from .util import make_pooler - from asyncio import sleep, create_task, Queue - from sys import stdin, stderr - - opts.ips = Queue() - syncing = create_task(sync_database(db, opts)) - geoip = create_task(locate_ips(db, opts)) - - def finisher(done, pending): - for task in done: - ip, first_failure = task.result() - if first_failure is None: - print(ip) - elif opts.dry: - ff = first_failure - if ff.kind in ("shock", "adware"): - print(ip, ff.reason, ff.kind, sep="\t") - else: - print(ip, ff.reason, ff.kind, ff.domain, sep="\t") - - pooler = make_pooler(opts.ip_simul, finisher) - - seen = 0 - - async def process(ip, total=None): - nonlocal seen - first = seen == 0 - seen += 1 - if opts.progress: - if total is None: - print(f"#{seen}: {ip}", file=stderr) - else: - print(f"#{seen}/{total}: {ip}", file=stderr) - stderr.flush() - if not first: - await sleep(opts.ip_wait) - await opts.ips.put(ip) - await pooler(try_ip(db, ip, checks, opts)) - - if opts.blocking_file_io: +async def read_all_ips(filepaths, blocking=False, callback=None): + assert callback is not None, "that doesn't make sense!" + if blocking: from .ip_util import read_ips for filepath in filepaths: f = stdin if filepath == "" else open(filepath, "r") for ip in read_ips(f): - await process(ip) + #lament("READ", ip) + await callback(ip) if f != stdin: f.close() - else: from .ip_util import IpReader fps = [stdin if fp == "" else fp for fp in filepaths] with IpReader(*fps) as reader: async for ip in reader: - await process(ip, reader.total) + #lament("READ", ip) + await callback(ip, reader.total) + + #lament("EXITING read_all_ips") + + +async def main(db, filepaths, checks, ipinfo, opts: Options): + # ipinfo can be None. + from .util import make_pooler + from asyncio import Queue, QueueFull, create_task + from queue import SimpleQueue + + deferred = SimpleQueue() + locate_me = Queue() + try_me = Queue() + + def locate_later(ip): + try: + locate_me.put_nowait(ip) + except QueueFull: + deferred.put(ip) + + seen = 0 + async def try_soon(ip, total=None): + nonlocal seen + seen += 1 + await try_me.put((ip, total)) + + sync_database(db, callback=locate_later) + + reading = create_task(read_all_ips(filepaths, opts.blocking_file_io, + callback=try_soon)) + trying = create_task(try_all_ips(db, try_me, checks, opts, + callback=locate_later)) + locating = create_task(locate_ips(db, locate_me, ipinfo)) + + # these tasks feed each other with queues, so order them as such: + # reading -> trying -> locating + + #lament("AWAIT reading") + await reading + #lament("AWAITED reading") if seen == 0: + #lament("UPDATING country codes") # no IPs were provided. refresh all the country codes instead. all_ips = db.all_ips() for i, ip in enumerate(all_ips): if opts.progress: - print(f"#{i + 1}/{len(all_ips)}: {ip}", file=stderr) - await opts.ips.put(ip) + lament(f"#{i + 1}/{len(all_ips)}: {ip}") + await locate_me.put(ip) - await pooler() - await syncing - await opts.ips.put(None) # end of queue - await geoip + await try_me.put(None) + + #lament("AWAIT trying") + await trying + #lament("AWAITED trying") + + #lament("STOPPING locating") + #done_locating.set() + await locate_me.put(None) + + #lament("AWAIT locating") + await locating + #lament("AWAITED locating") diff --git a/respodns/structs.py b/respodns/structs.py index a2d4ce8..c388ced 100644 --- a/respodns/structs.py +++ b/respodns/structs.py @@ -4,10 +4,8 @@ from dataclasses import dataclass @dataclass class Options: - # TODO: move these out of Options, since they're really not. + # TODO: move this out of Options, since it's really not. execution: object = None - ipinfo: object = None - ips: object = None ip_simul: int = 15 # how many IPs to connect to at once domain_simul: int = 3 # how many domains per IP to request at once diff --git a/respodns/ui.py b/respodns/ui.py index 206a27a..3a25ba7 100644 --- a/respodns/ui.py +++ b/respodns/ui.py @@ -45,7 +45,7 @@ def ui(program, args): opts.early_stopping = opts.dry opts.progress = a.progress - opts.ipinfo = IpInfoByIpApi("ipinfo_cache.csv") + ipinfo = IpInfoByIpApi("ipinfo_cache.csv") if a.database is not None: if a.database.startswith("sqlite:"): @@ -57,9 +57,9 @@ def ui(program, args): if a.debug: import logging logging.basicConfig(level=logging.DEBUG) - run(main(db, a.path, checks, opts), debug=True) + run(main(db, a.path, checks, ipinfo, opts), debug=True) else: - run(main(db, a.path, checks, opts)) + run(main(db, a.path, checks, ipinfo, opts)) if opts.dry: runwrap(None)