From e6c080bf32423b292bed607573c9f43c42984a81 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Fri, 4 Sep 2020 13:09:49 +0200 Subject: [PATCH] support reading IPs from multiple files at once this also adds the progress flag that was missing previously. --- respodns/dns.py | 43 +++++++++++++++++++++---------- respodns/ip_util.py | 63 +++++++++++++++++++++++++++++++++++++++++++-- respodns/ui.py | 7 +++-- 3 files changed, 96 insertions(+), 17 deletions(-) diff --git a/respodns/dns.py b/respodns/dns.py index 5b8ec5d..ba33095 100644 --- a/respodns/dns.py +++ b/respodns/dns.py @@ -194,8 +194,7 @@ async def sync_database(db, opts: Options): opts.ipinfo.flush() -async def main(db, filepath, checks, opts: Options): - from .ip_util import read_ips +async def main(db, filepaths, checks, opts: Options): from .util import make_pooler from asyncio import sleep, create_task from sys import stdin, stderr @@ -216,17 +215,35 @@ async def main(db, filepath, checks, opts: Options): pooler = make_pooler(opts.ip_simul, finisher) - f = stdin if filepath == "" else open(filepath, "r") - for i, ip in enumerate(read_ips(f)): - first = i == 0 - if opts.progress: - print(f"#{i}: {ip}", file=stderr) - stderr.flush() - if not first: - await sleep(opts.ip_wait) - await pooler(try_ip(db, ip, checks, opts)) - if f != stdin: - f.close() + blocking_file_io = False # TODO: put in Options. + + if blocking_file_io: + from .ip_util import read_ips + + for filepath in filepaths: + f = stdin if filepath == "" else open(filepath, "r") + for i, ip in enumerate(read_ips(f)): + first = i == 0 + if opts.progress: + print(f"#{i}: {ip}", file=stderr) + stderr.flush() + if not first: + await sleep(opts.ip_wait) + await pooler(try_ip(db, ip, checks, opts)) + if f != stdin: + f.close() + else: # blocking_file_io + from .ip_util import IpReader + fps = [stdin if fp == "" else fp for fp in filepaths] + with IpReader(*fps) as reader: + for i, ip in enumerate(reader): + first = i == 0 + if opts.progress: + print(f"#{i}/{reader.total}: {ip}", file=stderr) + stderr.flush() + if not first: + await sleep(opts.ip_wait) + await pooler(try_ip(db, ip, checks, opts)) await pooler() await syncing diff --git a/respodns/ip_util.py b/respodns/ip_util.py index e3c8cef..f56cf81 100644 --- a/respodns/ip_util.py +++ b/respodns/ip_util.py @@ -1,10 +1,10 @@ import re + ipv4_pattern = re.compile(r"(\d+)\.(\d+)\.(\d+)\.(\d+)", re.ASCII) def read_ips(f): - # TODO: make async and more robust. (regex pls) - # TODO: does readlines() block if the pipe is left open i.e. user input? + # TODO: make more robust. (regex pls) for ip in f.readlines(): if "#" in ip: ip, _, _ = ip.partition("#") @@ -27,3 +27,62 @@ def ipkey(ip_string): # this is more lenient than addr_to_int. segs = [int(s) for s in ip_string.replace(":", ".").split(".")] return sum(256**(3 - i) * seg for i, seg in enumerate(segs)) + + +def ip_reader_worker(fp, queue): + from io import IOBase + + needs_closing = not isinstance(fp, IOBase) + f = open(fp, "r") if needs_closing else fp + + try: + for ip in read_ips(f): + queue.put(ip) + finally: + if needs_closing: + f.close() + + +class IpReader: + def __init__(self, *paths_and_handles): + from queue import Queue + + self.fps = paths_and_handles + self.queue = Queue() + self.threads = [] + self.total = 0 + + def running(self): + return any(thread.is_alive() for thread in self.threads) + + def __iter__(self): + return self + + def __next__(self): + # TODO: rewrite such that self.total is useful. (get as many at once) + from queue import Empty + + while self.running() or not self.queue.empty(): + try: + res = self.queue.get(block=True, timeout=1.0) + if res is not None: + self.total += 1 + return res + except Empty: + from sys import stderr + print("blocking on IpReader", file=stderr) + + raise StopIteration + + def __enter__(self): + from threading import Thread + + for fp in self.fps: + thread = Thread(target=ip_reader_worker, args=(fp, self.queue)) + self.threads.append(thread) + thread.start() + + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass diff --git a/respodns/ui.py b/respodns/ui.py index 0bdba75..1b0de2c 100644 --- a/respodns/ui.py +++ b/respodns/ui.py @@ -11,12 +11,14 @@ def ui(program, args): desc = name + ": test and log DNS records" parser = ArgumentParser(name, description=desc) - # TODO: support multiple paths. nargs="+", iterate with pooling? desc = "a path to a file containing IPv4 addresses which host DNS servers" - parser.add_argument("path", metavar="file-path", help=desc) + parser.add_argument("path", metavar="file-path", nargs="+", help=desc) parser.add_argument("--database", help="specify database for logging") + desc = "enable pretty-printing progress to stderr" + parser.add_argument("--progress", action="store_true", help=desc) + a = parser.parse_args(args) checks = [] @@ -26,6 +28,7 @@ def ui(program, args): opts = Options() opts.dry = a.database is None opts.early_stopping = opts.dry + opts.progress = a.progress opts.ipinfo = IpInfoByIpApi("ipinfo_cache.csv")