From 70308230e61792c97e2ac047bacb1c0493eddcc6 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Fri, 4 Sep 2020 14:15:21 +0200 Subject: [PATCH] write and use async iterator for IpReader also fixes the counter in the blocking version when progress is enabled --- respodns/dns.py | 17 +++++++++++------ respodns/ip_util.py | 22 +++++++++++++++++++++- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/respodns/dns.py b/respodns/dns.py index 8f518db..bb0ace6 100644 --- a/respodns/dns.py +++ b/respodns/dns.py @@ -220,26 +220,31 @@ async def main(db, filepaths, checks, opts: Options): if blocking_file_io: from .ip_util import read_ips + seen = 0 for filepath in filepaths: f = stdin if filepath == "" else open(filepath, "r") - for i, ip in enumerate(read_ips(f)): - first = i == 0 + for ip in read_ips(f): + first = seen == 0 + seen += 1 if opts.progress: - print(f"#{i + 1}: {ip}", file=stderr) + print(f"#{seen}: {ip}", file=stderr) stderr.flush() if not first: await sleep(opts.ip_wait) await pooler(try_ip(db, ip, checks, opts)) if f != stdin: f.close() + else: # blocking_file_io from .ip_util import IpReader fps = [stdin if fp == "" else fp for fp in filepaths] with IpReader(*fps) as reader: - for i, ip in enumerate(reader): - first = i == 0 + seen = 0 + async for ip in reader: + first = seen == 0 + seen += 1 if opts.progress: - print(f"#{i + 1}/{reader.total}: {ip}", file=stderr) + print(f"#{seen}/{reader.total}: {ip}", file=stderr) stderr.flush() if not first: await sleep(opts.ip_wait) diff --git a/respodns/ip_util.py b/respodns/ip_util.py index 33b05ee..0e1d106 100644 --- a/respodns/ip_util.py +++ b/respodns/ip_util.py @@ -67,7 +67,27 @@ class IpReader: results.put(self.queue.get(timeout=1.0)) self.total += 1 except Empty: - print("blocking on IpReader", file=stderr) + pass + while not self.queue.empty(): + results.put(self.queue.get()) + self.total += 1 + if not results.empty(): + yield results.get() + while not results.empty(): + yield results.get() + + return _next() + + def __aiter__(self): + from asyncio import sleep + from queue import SimpleQueue + from sys import stderr + + async def _next(): + results = SimpleQueue() + while self.is_running() or not self.queue.empty(): + if self.queue.empty() and results.empty(): + await sleep(0.1) # this incurs some latency, but alas... while not self.queue.empty(): results.put(self.queue.get()) self.total += 1