overhaul the main function's async stuff

after i confirm i didn't break anything,
next up is the underlying try_ip function.
This commit is contained in:
Connor Olding 2021-08-12 21:06:00 -07:00
parent 788d75a5e9
commit 0b198cc5c0
3 changed files with 136 additions and 79 deletions

View file

@ -1,4 +1,9 @@
from .structs import Options
from sys import stdin, stdout, stderr
def lament(*args, **kwargs):
print(*args, file=stderr, **kwargs)
def detect_gfw(r, ip, check):
@ -111,22 +116,14 @@ def process_result(res, ip, check, opts: Options):
)
async def try_ip(db, server_ip, checks, opts: Options):
async def try_ip(db, server_ip, checks, opts: Options, callback=None):
from .util import make_pooler
from asyncio import sleep, CancelledError
entries = []
deferred = []
success = True
def maybe_put_ip(ip):
from asyncio import QueueFull
try:
opts.ips.put_nowait(ip)
except QueueFull:
deferred.append(ip)
def finisher(done, pending):
nonlocal success
for task in done:
@ -136,7 +133,9 @@ async def try_ip(db, server_ip, checks, opts: Options):
success = False
break
entry = process_result(res, ip, check, opts)
map(maybe_put_ip, entry.addrs)
if callback is not None:
for addr in entry.addrs:
callback(addr)
entries.append(entry)
if not entry.success:
if opts.early_stopping and success: # only cancel once
@ -155,6 +154,7 @@ async def try_ip(db, server_ip, checks, opts: Options):
res = await getaddrs(ip, check.domain, opts)
return res, ip, check
#lament("BEGIN", server_ip)
for i, check in enumerate(checks):
first = i == 0
if not first:
@ -169,14 +169,13 @@ async def try_ip(db, server_ip, checks, opts: Options):
else:
await pooler()
for ip in deferred:
await opts.ips.put(ip)
if not opts.dry:
for entry in entries:
db.push_entry(entry)
db.commit()
#lament("FINISH", server_ip)
if not success:
first_failure = None
assert len(entries) > 0
@ -186,11 +185,52 @@ async def try_ip(db, server_ip, checks, opts: Options):
break
else:
assert 0, ("no failures found:", entries)
return server_ip, first_failure
return server_ip, None
return first_failure
return None
async def sync_database(db, opts: Options):
async def try_all_ips(db, try_me, checks, opts: Options, callback=None):
from asyncio import create_task, sleep, BoundedSemaphore
seen, total = 0, None
async def process(ip):
nonlocal seen
first = seen == 0
seen += 1
if opts.progress:
lament(f"#{seen}: {ip}" if total is None else f"#{seen}/{total}: {ip}")
stderr.flush()
if not first:
await sleep(opts.ip_wait)
first_failure = await try_ip(db, ip, checks, opts, callback)
if first_failure is None:
print(ip) # all tests for this server passed; pass it along to stdout
elif opts.dry: # don't save the error anywhere; pass it along to stdout
ff = first_failure
if ff.kind in ("shock", "adware"):
# don't print sketchy domains to console in case they're clicked.
print(ip, ff.reason, ff.kind, sep="\t")
else:
print(ip, ff.reason, ff.kind, ff.domain, sep="\t")
sem = BoundedSemaphore(opts.ip_simul)
tasks = []
while (res := await try_me.get()) is not None:
ip, total = res
#lament("TRYING", ip)
async with sem:
task = create_task(process(ip))
tasks.append(task)
for task in tasks:
await task
#lament("EXITING try_all_ips")
def sync_database(db, callback=None):
# NOTE: this no longer takes Options.
# NOTE: this is no longer async.
from .ips import china, blocks
# TODO: handle addresses that were removed from respodns.ips.china.
@ -199,90 +239,109 @@ async def sync_database(db, opts: Options):
kwargs = {kw: True}
if db is not None:
db.modify_address(ip, **kwargs)
await opts.ips.put(ip)
if callback is not None:
callback(ip)
async def locate_ips(db, opts: Options):
async def locate_ips(db, locate_me, ipinfo):
# NOTE: this no longer takes Options.
from time import time
seen = set()
last_save = time()
while (ip := await opts.ips.get()) is not None:
if opts.ipinfo is not None and ip not in seen:
while (ip := await locate_me.get()) is not None:
if ipinfo is not None and ip not in seen:
#lament("LOCATE", ip)
seen.add(ip)
code = await opts.ipinfo.find_country(ip)
code = await ipinfo.find_country(ip)
if db is not None:
db.modify_address(ip, country_code=code)
if time() >= last_save + 10.0: # only flush occasionally
opts.ipinfo.flush()
#lament("FLUSH", time() - last_save)
ipinfo.flush()
last_save = time()
opts.ipinfo.flush()
ipinfo.flush()
#lament("EXITING locate_ips")
async def main(db, filepaths, checks, opts: Options):
from .util import make_pooler
from asyncio import sleep, create_task, Queue
from sys import stdin, stderr
opts.ips = Queue()
syncing = create_task(sync_database(db, opts))
geoip = create_task(locate_ips(db, opts))
def finisher(done, pending):
for task in done:
ip, first_failure = task.result()
if first_failure is None:
print(ip)
elif opts.dry:
ff = first_failure
if ff.kind in ("shock", "adware"):
print(ip, ff.reason, ff.kind, sep="\t")
else:
print(ip, ff.reason, ff.kind, ff.domain, sep="\t")
pooler = make_pooler(opts.ip_simul, finisher)
seen = 0
async def process(ip, total=None):
nonlocal seen
first = seen == 0
seen += 1
if opts.progress:
if total is None:
print(f"#{seen}: {ip}", file=stderr)
else:
print(f"#{seen}/{total}: {ip}", file=stderr)
stderr.flush()
if not first:
await sleep(opts.ip_wait)
await opts.ips.put(ip)
await pooler(try_ip(db, ip, checks, opts))
if opts.blocking_file_io:
async def read_all_ips(filepaths, blocking=False, callback=None):
assert callback is not None, "that doesn't make sense!"
if blocking:
from .ip_util import read_ips
for filepath in filepaths:
f = stdin if filepath == "" else open(filepath, "r")
for ip in read_ips(f):
await process(ip)
#lament("READ", ip)
await callback(ip)
if f != stdin:
f.close()
else:
from .ip_util import IpReader
fps = [stdin if fp == "" else fp for fp in filepaths]
with IpReader(*fps) as reader:
async for ip in reader:
await process(ip, reader.total)
#lament("READ", ip)
await callback(ip, reader.total)
#lament("EXITING read_all_ips")
async def main(db, filepaths, checks, ipinfo, opts: Options):
# ipinfo can be None.
from .util import make_pooler
from asyncio import Queue, QueueFull, create_task
from queue import SimpleQueue
deferred = SimpleQueue()
locate_me = Queue()
try_me = Queue()
def locate_later(ip):
try:
locate_me.put_nowait(ip)
except QueueFull:
deferred.put(ip)
seen = 0
async def try_soon(ip, total=None):
nonlocal seen
seen += 1
await try_me.put((ip, total))
sync_database(db, callback=locate_later)
reading = create_task(read_all_ips(filepaths, opts.blocking_file_io,
callback=try_soon))
trying = create_task(try_all_ips(db, try_me, checks, opts,
callback=locate_later))
locating = create_task(locate_ips(db, locate_me, ipinfo))
# these tasks feed each other with queues, so order them as such:
# reading -> trying -> locating
#lament("AWAIT reading")
await reading
#lament("AWAITED reading")
if seen == 0:
#lament("UPDATING country codes")
# no IPs were provided. refresh all the country codes instead.
all_ips = db.all_ips()
for i, ip in enumerate(all_ips):
if opts.progress:
print(f"#{i + 1}/{len(all_ips)}: {ip}", file=stderr)
await opts.ips.put(ip)
lament(f"#{i + 1}/{len(all_ips)}: {ip}")
await locate_me.put(ip)
await pooler()
await syncing
await opts.ips.put(None) # end of queue
await geoip
await try_me.put(None)
#lament("AWAIT trying")
await trying
#lament("AWAITED trying")
#lament("STOPPING locating")
#done_locating.set()
await locate_me.put(None)
#lament("AWAIT locating")
await locating
#lament("AWAITED locating")

View file

@ -4,10 +4,8 @@ from dataclasses import dataclass
@dataclass
class Options:
# TODO: move these out of Options, since they're really not.
# TODO: move this out of Options, since it's really not.
execution: object = None
ipinfo: object = None
ips: object = None
ip_simul: int = 15 # how many IPs to connect to at once
domain_simul: int = 3 # how many domains per IP to request at once

View file

@ -45,7 +45,7 @@ def ui(program, args):
opts.early_stopping = opts.dry
opts.progress = a.progress
opts.ipinfo = IpInfoByIpApi("ipinfo_cache.csv")
ipinfo = IpInfoByIpApi("ipinfo_cache.csv")
if a.database is not None:
if a.database.startswith("sqlite:"):
@ -57,9 +57,9 @@ def ui(program, args):
if a.debug:
import logging
logging.basicConfig(level=logging.DEBUG)
run(main(db, a.path, checks, opts), debug=True)
run(main(db, a.path, checks, ipinfo, opts), debug=True)
else:
run(main(db, a.path, checks, opts))
run(main(db, a.path, checks, ipinfo, opts))
if opts.dry:
runwrap(None)