add a configurable soft limit to number of DNS connections
yes, i know UDP connections don't technically exist outside of sockets at the OS level, don't @ me.
This commit is contained in:
parent
367329cf8c
commit
55024634f7
3 changed files with 61 additions and 11 deletions
|
@ -33,7 +33,7 @@ def detect_gfw(r, ip, check):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def getaddrs(server, domain, opts):
|
async def getaddrs(server, domain, opts, context=None):
|
||||||
from .ip_util import ipkey
|
from .ip_util import ipkey
|
||||||
from dns.asyncresolver import Resolver
|
from dns.asyncresolver import Resolver
|
||||||
from dns.exception import Timeout
|
from dns.exception import Timeout
|
||||||
|
@ -48,7 +48,11 @@ async def getaddrs(server, domain, opts):
|
||||||
res.lifetime = 9
|
res.lifetime = 9
|
||||||
res.nameservers = [server]
|
res.nameservers = [server]
|
||||||
try:
|
try:
|
||||||
ans = await res.resolve(domain, "A", search=False)
|
if context is not None:
|
||||||
|
async with context:
|
||||||
|
ans = await res.resolve(domain, "A", search=False)
|
||||||
|
else:
|
||||||
|
ans = await res.resolve(domain, "A", search=False)
|
||||||
except NXDOMAIN:
|
except NXDOMAIN:
|
||||||
return ["NXDOMAIN"]
|
return ["NXDOMAIN"]
|
||||||
except NoAnswer:
|
except NoAnswer:
|
||||||
|
@ -116,7 +120,8 @@ def process_result(res, ip, check, opts: Options):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def try_ip(db, server_ip, checks, opts: Options, callback=None):
|
async def try_ip(db, server_ip, checks, context, opts: Options, callback=None):
|
||||||
|
# context can be None.
|
||||||
from .util import make_pooler
|
from .util import make_pooler
|
||||||
from asyncio import sleep, CancelledError
|
from asyncio import sleep, CancelledError
|
||||||
|
|
||||||
|
@ -151,10 +156,10 @@ async def try_ip(db, server_ip, checks, opts: Options, callback=None):
|
||||||
# NOTE: could put right_now() stuff here!
|
# NOTE: could put right_now() stuff here!
|
||||||
# TODO: add duration field given in milliseconds (integer)
|
# TODO: add duration field given in milliseconds (integer)
|
||||||
# by subtracting start and end datetimes.
|
# by subtracting start and end datetimes.
|
||||||
res = await getaddrs(ip, check.domain, opts)
|
res = await getaddrs(ip, check.domain, opts, context)
|
||||||
return res, ip, check
|
return res, ip, check
|
||||||
|
|
||||||
#lament("BEGIN", server_ip)
|
#lament("TESTING", server_ip)
|
||||||
for i, check in enumerate(checks):
|
for i, check in enumerate(checks):
|
||||||
first = i == 0
|
first = i == 0
|
||||||
if not first:
|
if not first:
|
||||||
|
@ -174,7 +179,7 @@ async def try_ip(db, server_ip, checks, opts: Options, callback=None):
|
||||||
db.push_entry(entry)
|
db.push_entry(entry)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
#lament("FINISH", server_ip)
|
#lament("TESTED", server_ip)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
first_failure = None
|
first_failure = None
|
||||||
|
@ -189,7 +194,8 @@ async def try_ip(db, server_ip, checks, opts: Options, callback=None):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def try_all_ips(db, try_me, checks, opts: Options, callback=None):
|
async def try_all_ips(db, try_me, checks, context, opts: Options, callback=None):
|
||||||
|
# context can be None.
|
||||||
from asyncio import create_task, sleep, BoundedSemaphore
|
from asyncio import create_task, sleep, BoundedSemaphore
|
||||||
|
|
||||||
seen, total = 0, None
|
seen, total = 0, None
|
||||||
|
@ -202,7 +208,7 @@ async def try_all_ips(db, try_me, checks, opts: Options, callback=None):
|
||||||
lament(f"#{seen}: {ip}" if total is None else f"#{seen}/{total}: {ip}")
|
lament(f"#{seen}: {ip}" if total is None else f"#{seen}/{total}: {ip}")
|
||||||
stderr.flush()
|
stderr.flush()
|
||||||
|
|
||||||
first_failure = await try_ip(db, ip, checks, opts, callback)
|
first_failure = await try_ip(db, ip, checks, context, opts, callback)
|
||||||
|
|
||||||
if first_failure is None:
|
if first_failure is None:
|
||||||
print(ip) # all tests for this server passed; pass it along to stdout
|
print(ip) # all tests for this server passed; pass it along to stdout
|
||||||
|
@ -243,6 +249,7 @@ def sync_database(db, callback=None):
|
||||||
from .ips import china, blocks
|
from .ips import china, blocks
|
||||||
|
|
||||||
# TODO: handle addresses that were removed from respodns.ips.china.
|
# TODO: handle addresses that were removed from respodns.ips.china.
|
||||||
|
# i could probably just do ip.startswith("- ") and remove those.
|
||||||
for ips, kw in ((china, "china"), (blocks, "block_target")):
|
for ips, kw in ((china, "china"), (blocks, "block_target")):
|
||||||
for ip in ips:
|
for ip in ips:
|
||||||
kwargs = {kw: True}
|
kwargs = {kw: True}
|
||||||
|
@ -297,12 +304,15 @@ async def read_all_ips(filepaths, blocking=False, callback=None):
|
||||||
|
|
||||||
async def main(db, filepaths, checks, ipinfo, opts: Options):
|
async def main(db, filepaths, checks, ipinfo, opts: Options):
|
||||||
# ipinfo can be None.
|
# ipinfo can be None.
|
||||||
|
from .util import LimitPerSecond
|
||||||
from asyncio import Queue, QueueFull, create_task
|
from asyncio import Queue, QueueFull, create_task
|
||||||
from queue import SimpleQueue
|
from queue import SimpleQueue
|
||||||
|
|
||||||
deferred = SimpleQueue()
|
deferred = SimpleQueue()
|
||||||
locate_me = Queue()
|
locate_me = Queue()
|
||||||
try_me = Queue()
|
try_me = Queue()
|
||||||
|
pps = opts.packets_per_second
|
||||||
|
context = LimitPerSecond(pps) if pps > 0 else None
|
||||||
|
|
||||||
def locate_later(ip):
|
def locate_later(ip):
|
||||||
try:
|
try:
|
||||||
|
@ -320,7 +330,7 @@ async def main(db, filepaths, checks, ipinfo, opts: Options):
|
||||||
|
|
||||||
reading = create_task(read_all_ips(filepaths, opts.blocking_file_io,
|
reading = create_task(read_all_ips(filepaths, opts.blocking_file_io,
|
||||||
callback=try_soon))
|
callback=try_soon))
|
||||||
trying = create_task(try_all_ips(db, try_me, checks, opts,
|
trying = create_task(try_all_ips(db, try_me, checks, context, opts,
|
||||||
callback=locate_later))
|
callback=locate_later))
|
||||||
locating = create_task(locate_ips(db, locate_me, ipinfo))
|
locating = create_task(locate_ips(db, locate_me, ipinfo))
|
||||||
|
|
||||||
|
@ -355,3 +365,6 @@ async def main(db, filepaths, checks, ipinfo, opts: Options):
|
||||||
#lament("AWAIT locating")
|
#lament("AWAIT locating")
|
||||||
await locating
|
await locating
|
||||||
#lament("AWAITED locating")
|
#lament("AWAITED locating")
|
||||||
|
|
||||||
|
if context is not None and hasattr(context, "finish"):
|
||||||
|
await context.finish()
|
||||||
|
|
|
@ -7,8 +7,9 @@ class Options:
|
||||||
# TODO: move this out of Options, since it's really not.
|
# TODO: move this out of Options, since it's really not.
|
||||||
execution: object = None
|
execution: object = None
|
||||||
|
|
||||||
ip_simul: int = 15 # how many IPs to connect to at once
|
ip_simul: int = 30 # how many IPs to connect to at once
|
||||||
domain_simul: int = 3 # how many domains per IP to request at once
|
domain_simul: int = 3 # how many domains per IP to request at once
|
||||||
|
packets_per_second: int = 50 # rough limit on all outgoing DNS packets
|
||||||
|
|
||||||
ip_wait: float = 0.05
|
ip_wait: float = 0.05
|
||||||
domain_wait: float = 0.25
|
domain_wait: float = 0.25
|
||||||
|
|
|
@ -88,3 +88,39 @@ class AttrCheck:
|
||||||
super().__setattr__(name, value)
|
super().__setattr__(name, value)
|
||||||
else:
|
else:
|
||||||
raise AttributeError(name)
|
raise AttributeError(name)
|
||||||
|
|
||||||
|
|
||||||
|
async def _release_later(sem, time=1):
|
||||||
|
from asyncio import sleep
|
||||||
|
|
||||||
|
await sleep(time)
|
||||||
|
sem.release()
|
||||||
|
|
||||||
|
|
||||||
|
class LimitPerSecond:
|
||||||
|
def __init__(self, limit):
|
||||||
|
from asyncio import BoundedSemaphore
|
||||||
|
|
||||||
|
if type(limit) is not int:
|
||||||
|
raise ValueError("limit must be int")
|
||||||
|
assert limit > 0, limit
|
||||||
|
|
||||||
|
self.limit = limit
|
||||||
|
self.tasks = []
|
||||||
|
self.sem = BoundedSemaphore(limit)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
#if self.sem.locked:
|
||||||
|
# from sys import stderr
|
||||||
|
# print("THROTTLING", file=stderr)
|
||||||
|
await self.sem.acquire()
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
|
from asyncio import create_task
|
||||||
|
|
||||||
|
task = create_task(_release_later(self.sem))
|
||||||
|
self.tasks.append(task)
|
||||||
|
|
||||||
|
async def finish(self):
|
||||||
|
for task in self.tasks:
|
||||||
|
await task
|
||||||
|
|
Loading…
Add table
Reference in a new issue