overhaul the main function's async stuff
after i confirm i didn't break anything, next up is the underlying try_ip function.
This commit is contained in:
parent
788d75a5e9
commit
0b198cc5c0
3 changed files with 136 additions and 79 deletions
205
respodns/dns.py
205
respodns/dns.py
|
@ -1,4 +1,9 @@
|
||||||
from .structs import Options
|
from .structs import Options
|
||||||
|
from sys import stdin, stdout, stderr
|
||||||
|
|
||||||
|
|
||||||
|
def lament(*args, **kwargs):
|
||||||
|
print(*args, file=stderr, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def detect_gfw(r, ip, check):
|
def detect_gfw(r, ip, check):
|
||||||
|
@ -111,22 +116,14 @@ def process_result(res, ip, check, opts: Options):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def try_ip(db, server_ip, checks, opts: Options):
|
async def try_ip(db, server_ip, checks, opts: Options, callback=None):
|
||||||
from .util import make_pooler
|
from .util import make_pooler
|
||||||
from asyncio import sleep, CancelledError
|
from asyncio import sleep, CancelledError
|
||||||
|
|
||||||
entries = []
|
entries = []
|
||||||
deferred = []
|
|
||||||
|
|
||||||
success = True
|
success = True
|
||||||
|
|
||||||
def maybe_put_ip(ip):
|
|
||||||
from asyncio import QueueFull
|
|
||||||
try:
|
|
||||||
opts.ips.put_nowait(ip)
|
|
||||||
except QueueFull:
|
|
||||||
deferred.append(ip)
|
|
||||||
|
|
||||||
def finisher(done, pending):
|
def finisher(done, pending):
|
||||||
nonlocal success
|
nonlocal success
|
||||||
for task in done:
|
for task in done:
|
||||||
|
@ -136,7 +133,9 @@ async def try_ip(db, server_ip, checks, opts: Options):
|
||||||
success = False
|
success = False
|
||||||
break
|
break
|
||||||
entry = process_result(res, ip, check, opts)
|
entry = process_result(res, ip, check, opts)
|
||||||
map(maybe_put_ip, entry.addrs)
|
if callback is not None:
|
||||||
|
for addr in entry.addrs:
|
||||||
|
callback(addr)
|
||||||
entries.append(entry)
|
entries.append(entry)
|
||||||
if not entry.success:
|
if not entry.success:
|
||||||
if opts.early_stopping and success: # only cancel once
|
if opts.early_stopping and success: # only cancel once
|
||||||
|
@ -155,6 +154,7 @@ async def try_ip(db, server_ip, checks, opts: Options):
|
||||||
res = await getaddrs(ip, check.domain, opts)
|
res = await getaddrs(ip, check.domain, opts)
|
||||||
return res, ip, check
|
return res, ip, check
|
||||||
|
|
||||||
|
#lament("BEGIN", server_ip)
|
||||||
for i, check in enumerate(checks):
|
for i, check in enumerate(checks):
|
||||||
first = i == 0
|
first = i == 0
|
||||||
if not first:
|
if not first:
|
||||||
|
@ -169,14 +169,13 @@ async def try_ip(db, server_ip, checks, opts: Options):
|
||||||
else:
|
else:
|
||||||
await pooler()
|
await pooler()
|
||||||
|
|
||||||
for ip in deferred:
|
|
||||||
await opts.ips.put(ip)
|
|
||||||
|
|
||||||
if not opts.dry:
|
if not opts.dry:
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
db.push_entry(entry)
|
db.push_entry(entry)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
#lament("FINISH", server_ip)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
first_failure = None
|
first_failure = None
|
||||||
assert len(entries) > 0
|
assert len(entries) > 0
|
||||||
|
@ -186,11 +185,52 @@ async def try_ip(db, server_ip, checks, opts: Options):
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
assert 0, ("no failures found:", entries)
|
assert 0, ("no failures found:", entries)
|
||||||
return server_ip, first_failure
|
return first_failure
|
||||||
return server_ip, None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def sync_database(db, opts: Options):
|
async def try_all_ips(db, try_me, checks, opts: Options, callback=None):
|
||||||
|
from asyncio import create_task, sleep, BoundedSemaphore
|
||||||
|
|
||||||
|
seen, total = 0, None
|
||||||
|
|
||||||
|
async def process(ip):
|
||||||
|
nonlocal seen
|
||||||
|
first = seen == 0
|
||||||
|
seen += 1
|
||||||
|
if opts.progress:
|
||||||
|
lament(f"#{seen}: {ip}" if total is None else f"#{seen}/{total}: {ip}")
|
||||||
|
stderr.flush()
|
||||||
|
if not first:
|
||||||
|
await sleep(opts.ip_wait)
|
||||||
|
first_failure = await try_ip(db, ip, checks, opts, callback)
|
||||||
|
if first_failure is None:
|
||||||
|
print(ip) # all tests for this server passed; pass it along to stdout
|
||||||
|
elif opts.dry: # don't save the error anywhere; pass it along to stdout
|
||||||
|
ff = first_failure
|
||||||
|
if ff.kind in ("shock", "adware"):
|
||||||
|
# don't print sketchy domains to console in case they're clicked.
|
||||||
|
print(ip, ff.reason, ff.kind, sep="\t")
|
||||||
|
else:
|
||||||
|
print(ip, ff.reason, ff.kind, ff.domain, sep="\t")
|
||||||
|
|
||||||
|
sem = BoundedSemaphore(opts.ip_simul)
|
||||||
|
tasks = []
|
||||||
|
while (res := await try_me.get()) is not None:
|
||||||
|
ip, total = res
|
||||||
|
#lament("TRYING", ip)
|
||||||
|
async with sem:
|
||||||
|
task = create_task(process(ip))
|
||||||
|
tasks.append(task)
|
||||||
|
for task in tasks:
|
||||||
|
await task
|
||||||
|
|
||||||
|
#lament("EXITING try_all_ips")
|
||||||
|
|
||||||
|
|
||||||
|
def sync_database(db, callback=None):
|
||||||
|
# NOTE: this no longer takes Options.
|
||||||
|
# NOTE: this is no longer async.
|
||||||
from .ips import china, blocks
|
from .ips import china, blocks
|
||||||
|
|
||||||
# TODO: handle addresses that were removed from respodns.ips.china.
|
# TODO: handle addresses that were removed from respodns.ips.china.
|
||||||
|
@ -199,90 +239,109 @@ async def sync_database(db, opts: Options):
|
||||||
kwargs = {kw: True}
|
kwargs = {kw: True}
|
||||||
if db is not None:
|
if db is not None:
|
||||||
db.modify_address(ip, **kwargs)
|
db.modify_address(ip, **kwargs)
|
||||||
await opts.ips.put(ip)
|
if callback is not None:
|
||||||
|
callback(ip)
|
||||||
|
|
||||||
|
|
||||||
async def locate_ips(db, opts: Options):
|
async def locate_ips(db, locate_me, ipinfo):
|
||||||
|
# NOTE: this no longer takes Options.
|
||||||
from time import time
|
from time import time
|
||||||
seen = set()
|
seen = set()
|
||||||
last_save = time()
|
last_save = time()
|
||||||
while (ip := await opts.ips.get()) is not None:
|
while (ip := await locate_me.get()) is not None:
|
||||||
if opts.ipinfo is not None and ip not in seen:
|
if ipinfo is not None and ip not in seen:
|
||||||
|
#lament("LOCATE", ip)
|
||||||
seen.add(ip)
|
seen.add(ip)
|
||||||
code = await opts.ipinfo.find_country(ip)
|
code = await ipinfo.find_country(ip)
|
||||||
if db is not None:
|
if db is not None:
|
||||||
db.modify_address(ip, country_code=code)
|
db.modify_address(ip, country_code=code)
|
||||||
if time() >= last_save + 10.0: # only flush occasionally
|
if time() >= last_save + 10.0: # only flush occasionally
|
||||||
opts.ipinfo.flush()
|
#lament("FLUSH", time() - last_save)
|
||||||
|
ipinfo.flush()
|
||||||
last_save = time()
|
last_save = time()
|
||||||
opts.ipinfo.flush()
|
ipinfo.flush()
|
||||||
|
|
||||||
|
#lament("EXITING locate_ips")
|
||||||
|
|
||||||
|
|
||||||
async def main(db, filepaths, checks, opts: Options):
|
async def read_all_ips(filepaths, blocking=False, callback=None):
|
||||||
from .util import make_pooler
|
assert callback is not None, "that doesn't make sense!"
|
||||||
from asyncio import sleep, create_task, Queue
|
if blocking:
|
||||||
from sys import stdin, stderr
|
|
||||||
|
|
||||||
opts.ips = Queue()
|
|
||||||
syncing = create_task(sync_database(db, opts))
|
|
||||||
geoip = create_task(locate_ips(db, opts))
|
|
||||||
|
|
||||||
def finisher(done, pending):
|
|
||||||
for task in done:
|
|
||||||
ip, first_failure = task.result()
|
|
||||||
if first_failure is None:
|
|
||||||
print(ip)
|
|
||||||
elif opts.dry:
|
|
||||||
ff = first_failure
|
|
||||||
if ff.kind in ("shock", "adware"):
|
|
||||||
print(ip, ff.reason, ff.kind, sep="\t")
|
|
||||||
else:
|
|
||||||
print(ip, ff.reason, ff.kind, ff.domain, sep="\t")
|
|
||||||
|
|
||||||
pooler = make_pooler(opts.ip_simul, finisher)
|
|
||||||
|
|
||||||
seen = 0
|
|
||||||
|
|
||||||
async def process(ip, total=None):
|
|
||||||
nonlocal seen
|
|
||||||
first = seen == 0
|
|
||||||
seen += 1
|
|
||||||
if opts.progress:
|
|
||||||
if total is None:
|
|
||||||
print(f"#{seen}: {ip}", file=stderr)
|
|
||||||
else:
|
|
||||||
print(f"#{seen}/{total}: {ip}", file=stderr)
|
|
||||||
stderr.flush()
|
|
||||||
if not first:
|
|
||||||
await sleep(opts.ip_wait)
|
|
||||||
await opts.ips.put(ip)
|
|
||||||
await pooler(try_ip(db, ip, checks, opts))
|
|
||||||
|
|
||||||
if opts.blocking_file_io:
|
|
||||||
from .ip_util import read_ips
|
from .ip_util import read_ips
|
||||||
for filepath in filepaths:
|
for filepath in filepaths:
|
||||||
f = stdin if filepath == "" else open(filepath, "r")
|
f = stdin if filepath == "" else open(filepath, "r")
|
||||||
for ip in read_ips(f):
|
for ip in read_ips(f):
|
||||||
await process(ip)
|
#lament("READ", ip)
|
||||||
|
await callback(ip)
|
||||||
if f != stdin:
|
if f != stdin:
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
from .ip_util import IpReader
|
from .ip_util import IpReader
|
||||||
fps = [stdin if fp == "" else fp for fp in filepaths]
|
fps = [stdin if fp == "" else fp for fp in filepaths]
|
||||||
with IpReader(*fps) as reader:
|
with IpReader(*fps) as reader:
|
||||||
async for ip in reader:
|
async for ip in reader:
|
||||||
await process(ip, reader.total)
|
#lament("READ", ip)
|
||||||
|
await callback(ip, reader.total)
|
||||||
|
|
||||||
|
#lament("EXITING read_all_ips")
|
||||||
|
|
||||||
|
|
||||||
|
async def main(db, filepaths, checks, ipinfo, opts: Options):
|
||||||
|
# ipinfo can be None.
|
||||||
|
from .util import make_pooler
|
||||||
|
from asyncio import Queue, QueueFull, create_task
|
||||||
|
from queue import SimpleQueue
|
||||||
|
|
||||||
|
deferred = SimpleQueue()
|
||||||
|
locate_me = Queue()
|
||||||
|
try_me = Queue()
|
||||||
|
|
||||||
|
def locate_later(ip):
|
||||||
|
try:
|
||||||
|
locate_me.put_nowait(ip)
|
||||||
|
except QueueFull:
|
||||||
|
deferred.put(ip)
|
||||||
|
|
||||||
|
seen = 0
|
||||||
|
async def try_soon(ip, total=None):
|
||||||
|
nonlocal seen
|
||||||
|
seen += 1
|
||||||
|
await try_me.put((ip, total))
|
||||||
|
|
||||||
|
sync_database(db, callback=locate_later)
|
||||||
|
|
||||||
|
reading = create_task(read_all_ips(filepaths, opts.blocking_file_io,
|
||||||
|
callback=try_soon))
|
||||||
|
trying = create_task(try_all_ips(db, try_me, checks, opts,
|
||||||
|
callback=locate_later))
|
||||||
|
locating = create_task(locate_ips(db, locate_me, ipinfo))
|
||||||
|
|
||||||
|
# these tasks feed each other with queues, so order them as such:
|
||||||
|
# reading -> trying -> locating
|
||||||
|
|
||||||
|
#lament("AWAIT reading")
|
||||||
|
await reading
|
||||||
|
#lament("AWAITED reading")
|
||||||
|
|
||||||
if seen == 0:
|
if seen == 0:
|
||||||
|
#lament("UPDATING country codes")
|
||||||
# no IPs were provided. refresh all the country codes instead.
|
# no IPs were provided. refresh all the country codes instead.
|
||||||
all_ips = db.all_ips()
|
all_ips = db.all_ips()
|
||||||
for i, ip in enumerate(all_ips):
|
for i, ip in enumerate(all_ips):
|
||||||
if opts.progress:
|
if opts.progress:
|
||||||
print(f"#{i + 1}/{len(all_ips)}: {ip}", file=stderr)
|
lament(f"#{i + 1}/{len(all_ips)}: {ip}")
|
||||||
await opts.ips.put(ip)
|
await locate_me.put(ip)
|
||||||
|
|
||||||
await pooler()
|
await try_me.put(None)
|
||||||
await syncing
|
|
||||||
await opts.ips.put(None) # end of queue
|
#lament("AWAIT trying")
|
||||||
await geoip
|
await trying
|
||||||
|
#lament("AWAITED trying")
|
||||||
|
|
||||||
|
#lament("STOPPING locating")
|
||||||
|
#done_locating.set()
|
||||||
|
await locate_me.put(None)
|
||||||
|
|
||||||
|
#lament("AWAIT locating")
|
||||||
|
await locating
|
||||||
|
#lament("AWAITED locating")
|
||||||
|
|
|
@ -4,10 +4,8 @@ from dataclasses import dataclass
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Options:
|
class Options:
|
||||||
# TODO: move these out of Options, since they're really not.
|
# TODO: move this out of Options, since it's really not.
|
||||||
execution: object = None
|
execution: object = None
|
||||||
ipinfo: object = None
|
|
||||||
ips: object = None
|
|
||||||
|
|
||||||
ip_simul: int = 15 # how many IPs to connect to at once
|
ip_simul: int = 15 # how many IPs to connect to at once
|
||||||
domain_simul: int = 3 # how many domains per IP to request at once
|
domain_simul: int = 3 # how many domains per IP to request at once
|
||||||
|
|
|
@ -45,7 +45,7 @@ def ui(program, args):
|
||||||
opts.early_stopping = opts.dry
|
opts.early_stopping = opts.dry
|
||||||
opts.progress = a.progress
|
opts.progress = a.progress
|
||||||
|
|
||||||
opts.ipinfo = IpInfoByIpApi("ipinfo_cache.csv")
|
ipinfo = IpInfoByIpApi("ipinfo_cache.csv")
|
||||||
|
|
||||||
if a.database is not None:
|
if a.database is not None:
|
||||||
if a.database.startswith("sqlite:"):
|
if a.database.startswith("sqlite:"):
|
||||||
|
@ -57,9 +57,9 @@ def ui(program, args):
|
||||||
if a.debug:
|
if a.debug:
|
||||||
import logging
|
import logging
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
run(main(db, a.path, checks, opts), debug=True)
|
run(main(db, a.path, checks, ipinfo, opts), debug=True)
|
||||||
else:
|
else:
|
||||||
run(main(db, a.path, checks, opts))
|
run(main(db, a.path, checks, ipinfo, opts))
|
||||||
|
|
||||||
if opts.dry:
|
if opts.dry:
|
||||||
runwrap(None)
|
runwrap(None)
|
||||||
|
|
Loading…
Add table
Reference in a new issue