add a configurable soft limit to number of DNS connections
yes, i know UDP connections don't technically exist outside of sockets at the OS level, don't @ me.
This commit is contained in:
parent
367329cf8c
commit
55024634f7
3 changed files with 61 additions and 11 deletions
|
@ -33,7 +33,7 @@ def detect_gfw(r, ip, check):
|
|||
return False
|
||||
|
||||
|
||||
async def getaddrs(server, domain, opts):
|
||||
async def getaddrs(server, domain, opts, context=None):
|
||||
from .ip_util import ipkey
|
||||
from dns.asyncresolver import Resolver
|
||||
from dns.exception import Timeout
|
||||
|
@ -48,7 +48,11 @@ async def getaddrs(server, domain, opts):
|
|||
res.lifetime = 9
|
||||
res.nameservers = [server]
|
||||
try:
|
||||
ans = await res.resolve(domain, "A", search=False)
|
||||
if context is not None:
|
||||
async with context:
|
||||
ans = await res.resolve(domain, "A", search=False)
|
||||
else:
|
||||
ans = await res.resolve(domain, "A", search=False)
|
||||
except NXDOMAIN:
|
||||
return ["NXDOMAIN"]
|
||||
except NoAnswer:
|
||||
|
@ -116,7 +120,8 @@ def process_result(res, ip, check, opts: Options):
|
|||
)
|
||||
|
||||
|
||||
async def try_ip(db, server_ip, checks, opts: Options, callback=None):
|
||||
async def try_ip(db, server_ip, checks, context, opts: Options, callback=None):
|
||||
# context can be None.
|
||||
from .util import make_pooler
|
||||
from asyncio import sleep, CancelledError
|
||||
|
||||
|
@ -151,10 +156,10 @@ async def try_ip(db, server_ip, checks, opts: Options, callback=None):
|
|||
# NOTE: could put right_now() stuff here!
|
||||
# TODO: add duration field given in milliseconds (integer)
|
||||
# by subtracting start and end datetimes.
|
||||
res = await getaddrs(ip, check.domain, opts)
|
||||
res = await getaddrs(ip, check.domain, opts, context)
|
||||
return res, ip, check
|
||||
|
||||
#lament("BEGIN", server_ip)
|
||||
#lament("TESTING", server_ip)
|
||||
for i, check in enumerate(checks):
|
||||
first = i == 0
|
||||
if not first:
|
||||
|
@ -174,7 +179,7 @@ async def try_ip(db, server_ip, checks, opts: Options, callback=None):
|
|||
db.push_entry(entry)
|
||||
db.commit()
|
||||
|
||||
#lament("FINISH", server_ip)
|
||||
#lament("TESTED", server_ip)
|
||||
|
||||
if not success:
|
||||
first_failure = None
|
||||
|
@ -189,7 +194,8 @@ async def try_ip(db, server_ip, checks, opts: Options, callback=None):
|
|||
return None
|
||||
|
||||
|
||||
async def try_all_ips(db, try_me, checks, opts: Options, callback=None):
|
||||
async def try_all_ips(db, try_me, checks, context, opts: Options, callback=None):
|
||||
# context can be None.
|
||||
from asyncio import create_task, sleep, BoundedSemaphore
|
||||
|
||||
seen, total = 0, None
|
||||
|
@ -202,7 +208,7 @@ async def try_all_ips(db, try_me, checks, opts: Options, callback=None):
|
|||
lament(f"#{seen}: {ip}" if total is None else f"#{seen}/{total}: {ip}")
|
||||
stderr.flush()
|
||||
|
||||
first_failure = await try_ip(db, ip, checks, opts, callback)
|
||||
first_failure = await try_ip(db, ip, checks, context, opts, callback)
|
||||
|
||||
if first_failure is None:
|
||||
print(ip) # all tests for this server passed; pass it along to stdout
|
||||
|
@ -243,6 +249,7 @@ def sync_database(db, callback=None):
|
|||
from .ips import china, blocks
|
||||
|
||||
# TODO: handle addresses that were removed from respodns.ips.china.
|
||||
# i could probably just do ip.startswith("- ") and remove those.
|
||||
for ips, kw in ((china, "china"), (blocks, "block_target")):
|
||||
for ip in ips:
|
||||
kwargs = {kw: True}
|
||||
|
@ -297,12 +304,15 @@ async def read_all_ips(filepaths, blocking=False, callback=None):
|
|||
|
||||
async def main(db, filepaths, checks, ipinfo, opts: Options):
|
||||
# ipinfo can be None.
|
||||
from .util import LimitPerSecond
|
||||
from asyncio import Queue, QueueFull, create_task
|
||||
from queue import SimpleQueue
|
||||
|
||||
deferred = SimpleQueue()
|
||||
locate_me = Queue()
|
||||
try_me = Queue()
|
||||
pps = opts.packets_per_second
|
||||
context = LimitPerSecond(pps) if pps > 0 else None
|
||||
|
||||
def locate_later(ip):
|
||||
try:
|
||||
|
@ -320,7 +330,7 @@ async def main(db, filepaths, checks, ipinfo, opts: Options):
|
|||
|
||||
reading = create_task(read_all_ips(filepaths, opts.blocking_file_io,
|
||||
callback=try_soon))
|
||||
trying = create_task(try_all_ips(db, try_me, checks, opts,
|
||||
trying = create_task(try_all_ips(db, try_me, checks, context, opts,
|
||||
callback=locate_later))
|
||||
locating = create_task(locate_ips(db, locate_me, ipinfo))
|
||||
|
||||
|
@ -355,3 +365,6 @@ async def main(db, filepaths, checks, ipinfo, opts: Options):
|
|||
#lament("AWAIT locating")
|
||||
await locating
|
||||
#lament("AWAITED locating")
|
||||
|
||||
if context is not None and hasattr(context, "finish"):
|
||||
await context.finish()
|
||||
|
|
|
@ -7,8 +7,9 @@ class Options:
|
|||
# TODO: move this out of Options, since it's really not.
|
||||
execution: object = None
|
||||
|
||||
ip_simul: int = 15 # how many IPs to connect to at once
|
||||
domain_simul: int = 3 # how many domains per IP to request at once
|
||||
ip_simul: int = 30 # how many IPs to connect to at once
|
||||
domain_simul: int = 3 # how many domains per IP to request at once
|
||||
packets_per_second: int = 50 # rough limit on all outgoing DNS packets
|
||||
|
||||
ip_wait: float = 0.05
|
||||
domain_wait: float = 0.25
|
||||
|
|
|
@ -88,3 +88,39 @@ class AttrCheck:
|
|||
super().__setattr__(name, value)
|
||||
else:
|
||||
raise AttributeError(name)
|
||||
|
||||
|
||||
async def _release_later(sem, time=1):
|
||||
from asyncio import sleep
|
||||
|
||||
await sleep(time)
|
||||
sem.release()
|
||||
|
||||
|
||||
class LimitPerSecond:
|
||||
def __init__(self, limit):
|
||||
from asyncio import BoundedSemaphore
|
||||
|
||||
if type(limit) is not int:
|
||||
raise ValueError("limit must be int")
|
||||
assert limit > 0, limit
|
||||
|
||||
self.limit = limit
|
||||
self.tasks = []
|
||||
self.sem = BoundedSemaphore(limit)
|
||||
|
||||
async def __aenter__(self):
|
||||
#if self.sem.locked:
|
||||
# from sys import stderr
|
||||
# print("THROTTLING", file=stderr)
|
||||
await self.sem.acquire()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
from asyncio import create_task
|
||||
|
||||
task = create_task(_release_later(self.sem))
|
||||
self.tasks.append(task)
|
||||
|
||||
async def finish(self):
|
||||
for task in self.tasks:
|
||||
await task
|
||||
|
|
Loading…
Reference in a new issue