add a configurable soft limit to number of DNS connections

yes, i know UDP connections don't technically exist
outside of sockets at the OS level, don't @ me.
This commit is contained in:
Connor Olding 2021-08-13 00:52:10 -07:00
parent 367329cf8c
commit 55024634f7
3 changed files with 61 additions and 11 deletions

View file

@ -33,7 +33,7 @@ def detect_gfw(r, ip, check):
return False
async def getaddrs(server, domain, opts):
async def getaddrs(server, domain, opts, context=None):
from .ip_util import ipkey
from dns.asyncresolver import Resolver
from dns.exception import Timeout
@ -48,7 +48,11 @@ async def getaddrs(server, domain, opts):
res.lifetime = 9
res.nameservers = [server]
try:
ans = await res.resolve(domain, "A", search=False)
if context is not None:
async with context:
ans = await res.resolve(domain, "A", search=False)
else:
ans = await res.resolve(domain, "A", search=False)
except NXDOMAIN:
return ["NXDOMAIN"]
except NoAnswer:
@ -116,7 +120,8 @@ def process_result(res, ip, check, opts: Options):
)
async def try_ip(db, server_ip, checks, opts: Options, callback=None):
async def try_ip(db, server_ip, checks, context, opts: Options, callback=None):
# context can be None.
from .util import make_pooler
from asyncio import sleep, CancelledError
@ -151,10 +156,10 @@ async def try_ip(db, server_ip, checks, opts: Options, callback=None):
# NOTE: could put right_now() stuff here!
# TODO: add duration field given in milliseconds (integer)
# by subtracting start and end datetimes.
res = await getaddrs(ip, check.domain, opts)
res = await getaddrs(ip, check.domain, opts, context)
return res, ip, check
#lament("BEGIN", server_ip)
#lament("TESTING", server_ip)
for i, check in enumerate(checks):
first = i == 0
if not first:
@ -174,7 +179,7 @@ async def try_ip(db, server_ip, checks, opts: Options, callback=None):
db.push_entry(entry)
db.commit()
#lament("FINISH", server_ip)
#lament("TESTED", server_ip)
if not success:
first_failure = None
@ -189,7 +194,8 @@ async def try_ip(db, server_ip, checks, opts: Options, callback=None):
return None
async def try_all_ips(db, try_me, checks, opts: Options, callback=None):
async def try_all_ips(db, try_me, checks, context, opts: Options, callback=None):
# context can be None.
from asyncio import create_task, sleep, BoundedSemaphore
seen, total = 0, None
@ -202,7 +208,7 @@ async def try_all_ips(db, try_me, checks, opts: Options, callback=None):
lament(f"#{seen}: {ip}" if total is None else f"#{seen}/{total}: {ip}")
stderr.flush()
first_failure = await try_ip(db, ip, checks, opts, callback)
first_failure = await try_ip(db, ip, checks, context, opts, callback)
if first_failure is None:
print(ip) # all tests for this server passed; pass it along to stdout
@ -243,6 +249,7 @@ def sync_database(db, callback=None):
from .ips import china, blocks
# TODO: handle addresses that were removed from respodns.ips.china.
# i could probably just do ip.startswith("- ") and remove those.
for ips, kw in ((china, "china"), (blocks, "block_target")):
for ip in ips:
kwargs = {kw: True}
@ -297,12 +304,15 @@ async def read_all_ips(filepaths, blocking=False, callback=None):
async def main(db, filepaths, checks, ipinfo, opts: Options):
# ipinfo can be None.
from .util import LimitPerSecond
from asyncio import Queue, QueueFull, create_task
from queue import SimpleQueue
deferred = SimpleQueue()
locate_me = Queue()
try_me = Queue()
pps = opts.packets_per_second
context = LimitPerSecond(pps) if pps > 0 else None
def locate_later(ip):
try:
@ -320,7 +330,7 @@ async def main(db, filepaths, checks, ipinfo, opts: Options):
reading = create_task(read_all_ips(filepaths, opts.blocking_file_io,
callback=try_soon))
trying = create_task(try_all_ips(db, try_me, checks, opts,
trying = create_task(try_all_ips(db, try_me, checks, context, opts,
callback=locate_later))
locating = create_task(locate_ips(db, locate_me, ipinfo))
@ -355,3 +365,6 @@ async def main(db, filepaths, checks, ipinfo, opts: Options):
#lament("AWAIT locating")
await locating
#lament("AWAITED locating")
if context is not None and hasattr(context, "finish"):
await context.finish()

View file

@ -7,8 +7,9 @@ class Options:
# TODO: move this out of Options, since it's really not.
execution: object = None
ip_simul: int = 15 # how many IPs to connect to at once
domain_simul: int = 3 # how many domains per IP to request at once
ip_simul: int = 30 # how many IPs to connect to at once
domain_simul: int = 3 # how many domains per IP to request at once
packets_per_second: int = 50 # rough limit on all outgoing DNS packets
ip_wait: float = 0.05
domain_wait: float = 0.25

View file

@ -88,3 +88,39 @@ class AttrCheck:
super().__setattr__(name, value)
else:
raise AttributeError(name)
async def _release_later(sem, time=1):
from asyncio import sleep
await sleep(time)
sem.release()
class LimitPerSecond:
def __init__(self, limit):
from asyncio import BoundedSemaphore
if type(limit) is not int:
raise ValueError("limit must be int")
assert limit > 0, limit
self.limit = limit
self.tasks = []
self.sem = BoundedSemaphore(limit)
async def __aenter__(self):
#if self.sem.locked:
# from sys import stderr
# print("THROTTLING", file=stderr)
await self.sem.acquire()
async def __aexit__(self, exc_type, exc_value, traceback):
from asyncio import create_task
task = create_task(_release_later(self.sem))
self.tasks.append(task)
async def finish(self):
for task in self.tasks:
await task