diff --git a/respodns/dns.py b/respodns/dns.py index 7f8771c..de2a48c 100644 --- a/respodns/dns.py +++ b/respodns/dns.py @@ -1,5 +1,4 @@ from .structs import Options -from .ip_info import find_country, flush def detect_gfw(r, ip, check): @@ -149,11 +148,12 @@ async def try_ip(db, server_ip, checks, opts: Options): else: await pooler() - await find_country(server_ip, db) - for entry in entries: - for addr in entry.addrs: - await find_country(addr, db) - flush() + if opts.ipinfo is not None: + await opts.ipinfo.find_country(server_ip, db) + for entry in entries: + for addr in entry.addrs: + await opts.ipinfo.find_country(addr, db) + opts.ipinfo.flush() if not opts.dry: for entry in entries: @@ -173,16 +173,18 @@ async def try_ip(db, server_ip, checks, opts: Options): return server_ip, None -async def sync_database(db): +async def sync_database(db, opts: Options): # TODO: handle addresses that were removed from respodns.ips.china. from .ips import china, blocks - for ip in china: - code = await find_country(ip) - db.modify_address(ip, china=True, country_code=code) - for ip in blocks: - code = await find_country(ip) - db.modify_address(ip, block_target=True, country_code=code) - flush() + for ips, kw in ((china, "china"), (blocks, "block_target")): + for ip in ips: + kwargs = dict() + kwargs[kw] = True + if opts.ipinfo is not None: + kwargs["country_code"] = await opts.ipinfo.find_country(ip) + db.modify_address(ip, **kwargs) + if opts.ipinfo is not None: + opts.ipinfo.flush() async def main(db, filepath, checks, opts: Options): @@ -192,7 +194,7 @@ async def main(db, filepath, checks, opts: Options): from sys import stdin if db is not None: - await sync_database(db) + await sync_database(db, opts) def finisher(done, pending): for task in done: diff --git a/respodns/ip_info.py b/respodns/ip_info.py index a1270e4..ec1a380 100644 --- a/respodns/ip_info.py +++ b/respodns/ip_info.py @@ -1,200 +1,217 @@ -from asyncio import open_connection, sleep from collections import namedtuple -from socket import gaierror from sys import stderr from time import time CacheLine = namedtuple("CacheLine", ("time", "code")) -header = ["ip", "time", "code"] one_month = 365.25 / 12 * 24 * 60 * 60 # in seconds -encoding = "latin-1" - -cache_filepath = "ipinfo_cache.csv" - -http_cooldown = 0 -give_up = False -prepared = False -stored = None -async def http_lookup(ip): - global http_cooldown +class IpInfoBase: + pass - host = "ip-api.com" - path = f"/csv/{ip}?fields=2" - err = None +class IpInfoByIpApi(IpInfoBase): + def __init__(self, filepath, expiry=one_month): + self.filepath = filepath + self.expiry = expiry + self.cooldown = 0 + self.give_up = False + self.prepared = False + self.stored = None - # Quote: - # Your implementation should always check the value of the X-Rl header, - # and if its is 0 you must not send any more requests - # for the duration of X-Ttl in seconds. - while time() < http_cooldown: - wait = http_cooldown - time() - wait = max(wait, 0.1) # wait at least a little bit - await sleep(wait) + self.encoding = "latin-1" + self.host = "ip-api.com" + self.path = "/csv/{}?fields=2" + self.csv_header = ["ip", "time", "code"] - query_lines = ( - f"GET {path} HTTP/1.1", - f"Host: {host}", - f"Connection: close", - ) - query = "\r\n".join(query_lines) + "\r\n\r\n" + def decode(self, txt, mode="ignore"): + return txt.decode(self.encoding, mode) - reader, writer = await open_connection(host, 80) + def encode(self, txt, mode="ignore"): + return txt.encode(self.encoding, mode) - writer.write(query.encode(encoding, "strict")) - response = await reader.read() - lines = response.splitlines() + def _parse_headers(self, it): + # returns an error message as a string, or None. + for line in it: + if line == b"": + break # end of header - it = iter(lines) - line = next(it) - if line != b"HTTP/1.1 200 OK": - http_cooldown = time() + 60 - err = "not ok" - it = () # exhaust iterator (not really) + head, _, tail = line.partition(b":") - x_cooldown = None - x_remaining = None - - for line in it: - if line == b"": - break - - head, _, tail = line.partition(b":") - - # do some very basic validation. - if tail[0:1] == b" ": - tail = tail[1:] - else: - err = "bad tail" - break - if head in (b"Date", b"Content-Type", b"Content-Length", - b"Access-Control-Allow-Origin"): - pass - elif head == b"X-Ttl": - if tail.isdigit(): - x_cooldown = int(tail) + # do some very basic validation. + if tail[0:1] == b" ": + tail = tail[1:] else: - err = "X-Ttl not integer" - break - elif head == b"X-Rl": - if tail.isdigit(): - x_remaining = int(tail) + s = self.decode(line, "replace") + return "bad tail: " + s + if head in (b"Date", b"Content-Type", b"Content-Length", + b"Access-Control-Allow-Origin"): + pass + elif head == b"X-Ttl": + if tail.isdigit(): + x_cooldown = int(tail) + else: + s = self.decode(tail, "replace") + return "X-Ttl not an integer: " + s + elif head == b"X-Rl": + if tail.isdigit(): + x_remaining = int(tail) + else: + s = self.decode(tail, "replace") + return "X-Rl not an integer: " + s else: - err = "X-Rl not integer" - break + s = self.decode(head, "replace") + return "unexpected field: " + s - for i, line in enumerate(it): - if i == 0: - code = line - else: - err = "too many lines" - break + if x_remaining == 0: + self.cooldown = time() + x_cooldown + self.cooldown += 1.0 # still too frequent according to them - writer.close() + return None # no error - if x_remaining == 0: - http_cooldown = time() + x_cooldown - http_cooldown += 1.0 # still too frequent according to them + async def wait(self): + from asyncio import sleep - if err: - return None, err - else: - return code, None + # Quote: + # Your implementation should always check the value of the X-Rl header, + # and if its is 0 you must not send any more requests + # for the duration of X-Ttl in seconds. + while time() < self.cooldown: + wait = self.cooldown - time() + wait = max(wait, 0.1) # wait at least a little bit + await sleep(wait) + async def http_lookup(self, ip): + from asyncio import open_connection -async def lookup(ip, timestamp): - global give_up - if give_up: - return None + await self.wait() - try: - code, err = await http_lookup(ip) - if err: - # retry once in case of rate-limiting - code, err = await http_lookup(ip) - if err: - return None - except gaierror: - give_up = True - except OSError: - give_up = True + path = self.path.format(ip) + query_lines = ( + f"GET {path} HTTP/1.1", + f"Host: {self.host}", + f"Connection: close", + ) + query = "\r\n".join(query_lines) + "\r\n\r\n" - code = code.decode(encoding, "ignore") - if code == "": - code = "--" - if len(code) != 2: - return None - info = CacheLine(timestamp, code) - return info + reader, writer = await open_connection(self.host, 80) + writer.write(self.encode(query, "strict")) + response = await reader.read() + lines = response.splitlines() -def prepare(): - from csv import reader, writer - from os.path import exists + err = None + it = iter(lines) - global stored - stored = dict() + line = next(it) + if line != b"HTTP/1.1 200 OK": + self.cooldown = time() + 60 + err = "not ok" + it = () # exhaust iterator (not really) - if not exists(cache_filepath): - with open(cache_filepath, "w") as f: - handle = writer(f) - handle.writerow(header) - return + x_cooldown = None + x_remaining = None - with open(cache_filepath, "r", newline="", encoding="utf-8") as f: - for i, row in enumerate(reader(f)): + if err is None: + err = self._parse_headers(it) + + for i, line in enumerate(it): if i == 0: - assert row == header, row - continue - ip, time, code = row[0], float(row[1]), row[2] - info = CacheLine(time, code) - stored[ip] = info + code = line + else: + err = "too many lines" + break + writer.close() -def flush(): - from csv import writer - if not stored: - return - with open(cache_filepath, "w", newline="", encoding="utf-8") as f: - handle = writer(f) - handle.writerow(header) - for ip, info in stored.items(): - timestr = "{:.2f}".format(info.time) - handle.writerow([ip, timestr, info.code]) + return (code, None) if err is None else (None, err) + async def lookup(self, ip, timestamp): + from socket import gaierror -def cache(ip, info=None, timestamp=None, expiry=one_month): - global stored - if stored is None: - prepare() - - now = time() if timestamp is None else timestamp - - if info is None: - cached = stored.get(ip, None) - if cached is None: + if self.give_up: return None - if now > cached.time + expiry: + + try: + code, err = await self.http_lookup(ip) + if err: + # retry once in case of rate-limiting + code, err = await self.http_lookup(ip) + if err: + return None + except gaierror: + give_up = True + except OSError: + give_up = True + + code = self.decode(code) + if code == "": + code = "--" + if len(code) != 2: return None - return cached - else: - assert isinstance(info, CacheLine) - stored[ip] = info + info = CacheLine(timestamp, code) + return info + def prepare(self): + from csv import reader, writer + from os.path import exists + + self.stored = dict() + + if not exists(self.filepath): + with open(self.filepath, "w") as f: + handle = writer(f) + handle.writerow(self.csv_header) + return + + with open(self.filepath, "r", newline="", encoding="utf-8") as f: + for i, row in enumerate(reader(f)): + if i == 0: + assert row == self.csv_header, row + continue + ip, time, code = row[0], float(row[1]), row[2] + info = CacheLine(time, code) + self.stored[ip] = info + + def flush(self): + from csv import writer + if not self.stored: + return + with open(self.filepath, "w", newline="", encoding="utf-8") as f: + handle = writer(f) + handle.writerow(self.csv_header) + for ip, info in self.stored.items(): + timestr = "{:.2f}".format(info.time) + handle.writerow([ip, timestr, info.code]) + + def cache(self, ip, info=None, timestamp=None): + if self.stored is None: + self.prepare() + + now = time() if timestamp is None else timestamp -async def find_country(ip, db=None): - now = time() - info = cache(ip, timestamp=now) - if info is None: - info = await lookup(ip, now) if info is None: - return None - cache(ip, info) - if db is not None: - if db.country_code(ip) != info.code: - assert info.code is not None - db.country_code(ip, info.code) - return info.code + cached = self.stored.get(ip, None) + if cached is None: + return None + if now > cached.time + self.expiry: + return None + return cached + else: + assert isinstance(info, CacheLine), type(info) + self.stored[ip] = info + + async def find_country(self, ip, db=None): + now = time() + info = self.cache(ip, timestamp=now) + if info is None: + info = await self.lookup(ip, now) + if info is None: + return None + self.cache(ip, info) + if db is not None: + if db.country_code(ip) != info.code: + assert info.code is not None + db.country_code(ip, info.code) + return info.code diff --git a/respodns/ips.py b/respodns/ips.py index a7773d6..fbaa5aa 100644 --- a/respodns/ips.py +++ b/respodns/ips.py @@ -9,7 +9,7 @@ china = { "58.221.250.86", "58.222.226.146", "114.254.201.131", -# "204.13.152.3", # not china but seems to be poisoned + # "204.13.152.3", # not china but seems to be poisoned "218.94.128.126", "218.94.193.170", "218.107.55.108", diff --git a/respodns/structs.py b/respodns/structs.py index c467134..f8d3691 100644 --- a/respodns/structs.py +++ b/respodns/structs.py @@ -5,6 +5,7 @@ from dataclasses import dataclass @dataclass class Options: execution: object = None + ipinfo: object = None ip_simul: int = 10 # how many IPs to connect to at once domain_simul: int = 3 # how many domains per IP to request at once diff --git a/respodns/ui.py b/respodns/ui.py index 64521e9..0bdba75 100644 --- a/respodns/ui.py +++ b/respodns/ui.py @@ -1,6 +1,7 @@ def ui(program, args): from .db import RespoDB from .dns import main + from .ip_info import IpInfoByIpApi from .structs import Options from argparse import ArgumentParser from asyncio import run @@ -26,6 +27,8 @@ def ui(program, args): opts.dry = a.database is None opts.early_stopping = opts.dry + opts.ipinfo = IpInfoByIpApi("ipinfo_cache.csv") + if a.database is not None: if a.database.startswith("sqlite:"): uri = a.database