write and use async iterator for IpReader

also fixes the counter in the blocking version when progress is enabled
This commit is contained in:
Connor Olding 2020-09-04 14:15:21 +02:00
parent 853d25c51a
commit 70308230e6
2 changed files with 32 additions and 7 deletions

View file

@ -220,26 +220,31 @@ async def main(db, filepaths, checks, opts: Options):
if blocking_file_io: if blocking_file_io:
from .ip_util import read_ips from .ip_util import read_ips
seen = 0
for filepath in filepaths: for filepath in filepaths:
f = stdin if filepath == "" else open(filepath, "r") f = stdin if filepath == "" else open(filepath, "r")
for i, ip in enumerate(read_ips(f)): for ip in read_ips(f):
first = i == 0 first = seen == 0
seen += 1
if opts.progress: if opts.progress:
print(f"#{i + 1}: {ip}", file=stderr) print(f"#{seen}: {ip}", file=stderr)
stderr.flush() stderr.flush()
if not first: if not first:
await sleep(opts.ip_wait) await sleep(opts.ip_wait)
await pooler(try_ip(db, ip, checks, opts)) await pooler(try_ip(db, ip, checks, opts))
if f != stdin: if f != stdin:
f.close() f.close()
else: # blocking_file_io else: # blocking_file_io
from .ip_util import IpReader from .ip_util import IpReader
fps = [stdin if fp == "" else fp for fp in filepaths] fps = [stdin if fp == "" else fp for fp in filepaths]
with IpReader(*fps) as reader: with IpReader(*fps) as reader:
for i, ip in enumerate(reader): seen = 0
first = i == 0 async for ip in reader:
first = seen == 0
seen += 1
if opts.progress: if opts.progress:
print(f"#{i + 1}/{reader.total}: {ip}", file=stderr) print(f"#{seen}/{reader.total}: {ip}", file=stderr)
stderr.flush() stderr.flush()
if not first: if not first:
await sleep(opts.ip_wait) await sleep(opts.ip_wait)

View file

@ -67,7 +67,27 @@ class IpReader:
results.put(self.queue.get(timeout=1.0)) results.put(self.queue.get(timeout=1.0))
self.total += 1 self.total += 1
except Empty: except Empty:
print("blocking on IpReader", file=stderr) pass
while not self.queue.empty():
results.put(self.queue.get())
self.total += 1
if not results.empty():
yield results.get()
while not results.empty():
yield results.get()
return _next()
def __aiter__(self):
from asyncio import sleep
from queue import SimpleQueue
from sys import stderr
async def _next():
results = SimpleQueue()
while self.is_running() or not self.queue.empty():
if self.queue.empty() and results.empty():
await sleep(0.1) # this incurs some latency, but alas...
while not self.queue.empty(): while not self.queue.empty():
results.put(self.queue.get()) results.put(self.queue.get())
self.total += 1 self.total += 1