rewrite ipinfo into a class

This commit is contained in:
Connor Olding 2020-09-03 12:41:47 +02:00
parent d0d29bd0d4
commit 0603081c49
5 changed files with 197 additions and 174 deletions

View file

@ -1,5 +1,4 @@
from .structs import Options from .structs import Options
from .ip_info import find_country, flush
def detect_gfw(r, ip, check): def detect_gfw(r, ip, check):
@ -149,11 +148,12 @@ async def try_ip(db, server_ip, checks, opts: Options):
else: else:
await pooler() await pooler()
await find_country(server_ip, db) if opts.ipinfo is not None:
await opts.ipinfo.find_country(server_ip, db)
for entry in entries: for entry in entries:
for addr in entry.addrs: for addr in entry.addrs:
await find_country(addr, db) await opts.ipinfo.find_country(addr, db)
flush() opts.ipinfo.flush()
if not opts.dry: if not opts.dry:
for entry in entries: for entry in entries:
@ -173,16 +173,18 @@ async def try_ip(db, server_ip, checks, opts: Options):
return server_ip, None return server_ip, None
async def sync_database(db): async def sync_database(db, opts: Options):
# TODO: handle addresses that were removed from respodns.ips.china. # TODO: handle addresses that were removed from respodns.ips.china.
from .ips import china, blocks from .ips import china, blocks
for ip in china: for ips, kw in ((china, "china"), (blocks, "block_target")):
code = await find_country(ip) for ip in ips:
db.modify_address(ip, china=True, country_code=code) kwargs = dict()
for ip in blocks: kwargs[kw] = True
code = await find_country(ip) if opts.ipinfo is not None:
db.modify_address(ip, block_target=True, country_code=code) kwargs["country_code"] = await opts.ipinfo.find_country(ip)
flush() db.modify_address(ip, **kwargs)
if opts.ipinfo is not None:
opts.ipinfo.flush()
async def main(db, filepath, checks, opts: Options): async def main(db, filepath, checks, opts: Options):
@ -192,7 +194,7 @@ async def main(db, filepath, checks, opts: Options):
from sys import stdin from sys import stdin
if db is not None: if db is not None:
await sync_database(db) await sync_database(db, opts)
def finisher(done, pending): def finisher(done, pending):
for task in done: for task in done:

View file

@ -1,66 +1,41 @@
from asyncio import open_connection, sleep
from collections import namedtuple from collections import namedtuple
from socket import gaierror
from sys import stderr from sys import stderr
from time import time from time import time
CacheLine = namedtuple("CacheLine", ("time", "code")) CacheLine = namedtuple("CacheLine", ("time", "code"))
header = ["ip", "time", "code"]
one_month = 365.25 / 12 * 24 * 60 * 60 # in seconds one_month = 365.25 / 12 * 24 * 60 * 60 # in seconds
encoding = "latin-1"
cache_filepath = "ipinfo_cache.csv"
http_cooldown = 0
give_up = False
prepared = False
stored = None
async def http_lookup(ip): class IpInfoBase:
global http_cooldown pass
host = "ip-api.com"
path = f"/csv/{ip}?fields=2"
err = None class IpInfoByIpApi(IpInfoBase):
def __init__(self, filepath, expiry=one_month):
self.filepath = filepath
self.expiry = expiry
self.cooldown = 0
self.give_up = False
self.prepared = False
self.stored = None
# Quote: self.encoding = "latin-1"
# Your implementation should always check the value of the X-Rl header, self.host = "ip-api.com"
# and if its is 0 you must not send any more requests self.path = "/csv/{}?fields=2"
# for the duration of X-Ttl in seconds. self.csv_header = ["ip", "time", "code"]
while time() < http_cooldown:
wait = http_cooldown - time()
wait = max(wait, 0.1) # wait at least a little bit
await sleep(wait)
query_lines = ( def decode(self, txt, mode="ignore"):
f"GET {path} HTTP/1.1", return txt.decode(self.encoding, mode)
f"Host: {host}",
f"Connection: close",
)
query = "\r\n".join(query_lines) + "\r\n\r\n"
reader, writer = await open_connection(host, 80) def encode(self, txt, mode="ignore"):
return txt.encode(self.encoding, mode)
writer.write(query.encode(encoding, "strict"))
response = await reader.read()
lines = response.splitlines()
it = iter(lines)
line = next(it)
if line != b"HTTP/1.1 200 OK":
http_cooldown = time() + 60
err = "not ok"
it = () # exhaust iterator (not really)
x_cooldown = None
x_remaining = None
def _parse_headers(self, it):
# returns an error message as a string, or None.
for line in it: for line in it:
if line == b"": if line == b"":
break break # end of header
head, _, tail = line.partition(b":") head, _, tail = line.partition(b":")
@ -68,8 +43,8 @@ async def http_lookup(ip):
if tail[0:1] == b" ": if tail[0:1] == b" ":
tail = tail[1:] tail = tail[1:]
else: else:
err = "bad tail" s = self.decode(line, "replace")
break return "bad tail: " + s
if head in (b"Date", b"Content-Type", b"Content-Length", if head in (b"Date", b"Content-Type", b"Content-Length",
b"Access-Control-Allow-Origin"): b"Access-Control-Allow-Origin"):
pass pass
@ -77,14 +52,69 @@ async def http_lookup(ip):
if tail.isdigit(): if tail.isdigit():
x_cooldown = int(tail) x_cooldown = int(tail)
else: else:
err = "X-Ttl not integer" s = self.decode(tail, "replace")
break return "X-Ttl not an integer: " + s
elif head == b"X-Rl": elif head == b"X-Rl":
if tail.isdigit(): if tail.isdigit():
x_remaining = int(tail) x_remaining = int(tail)
else: else:
err = "X-Rl not integer" s = self.decode(tail, "replace")
break return "X-Rl not an integer: " + s
else:
s = self.decode(head, "replace")
return "unexpected field: " + s
if x_remaining == 0:
self.cooldown = time() + x_cooldown
self.cooldown += 1.0 # still too frequent according to them
return None # no error
async def wait(self):
from asyncio import sleep
# Quote:
# Your implementation should always check the value of the X-Rl header,
# and if its is 0 you must not send any more requests
# for the duration of X-Ttl in seconds.
while time() < self.cooldown:
wait = self.cooldown - time()
wait = max(wait, 0.1) # wait at least a little bit
await sleep(wait)
async def http_lookup(self, ip):
from asyncio import open_connection
await self.wait()
path = self.path.format(ip)
query_lines = (
f"GET {path} HTTP/1.1",
f"Host: {self.host}",
f"Connection: close",
)
query = "\r\n".join(query_lines) + "\r\n\r\n"
reader, writer = await open_connection(self.host, 80)
writer.write(self.encode(query, "strict"))
response = await reader.read()
lines = response.splitlines()
err = None
it = iter(lines)
line = next(it)
if line != b"HTTP/1.1 200 OK":
self.cooldown = time() + 60
err = "not ok"
it = () # exhaust iterator (not really)
x_cooldown = None
x_remaining = None
if err is None:
err = self._parse_headers(it)
for i, line in enumerate(it): for i, line in enumerate(it):
if i == 0: if i == 0:
@ -95,26 +125,19 @@ async def http_lookup(ip):
writer.close() writer.close()
if x_remaining == 0: return (code, None) if err is None else (None, err)
http_cooldown = time() + x_cooldown
http_cooldown += 1.0 # still too frequent according to them
if err: async def lookup(self, ip, timestamp):
return None, err from socket import gaierror
else:
return code, None
if self.give_up:
async def lookup(ip, timestamp):
global give_up
if give_up:
return None return None
try: try:
code, err = await http_lookup(ip) code, err = await self.http_lookup(ip)
if err: if err:
# retry once in case of rate-limiting # retry once in case of rate-limiting
code, err = await http_lookup(ip) code, err = await self.http_lookup(ip)
if err: if err:
return None return None
except gaierror: except gaierror:
@ -122,7 +145,7 @@ async def lookup(ip, timestamp):
except OSError: except OSError:
give_up = True give_up = True
code = code.decode(encoding, "ignore") code = self.decode(code)
if code == "": if code == "":
code = "--" code = "--"
if len(code) != 2: if len(code) != 2:
@ -130,69 +153,63 @@ async def lookup(ip, timestamp):
info = CacheLine(timestamp, code) info = CacheLine(timestamp, code)
return info return info
def prepare(self):
def prepare():
from csv import reader, writer from csv import reader, writer
from os.path import exists from os.path import exists
global stored self.stored = dict()
stored = dict()
if not exists(cache_filepath): if not exists(self.filepath):
with open(cache_filepath, "w") as f: with open(self.filepath, "w") as f:
handle = writer(f) handle = writer(f)
handle.writerow(header) handle.writerow(self.csv_header)
return return
with open(cache_filepath, "r", newline="", encoding="utf-8") as f: with open(self.filepath, "r", newline="", encoding="utf-8") as f:
for i, row in enumerate(reader(f)): for i, row in enumerate(reader(f)):
if i == 0: if i == 0:
assert row == header, row assert row == self.csv_header, row
continue continue
ip, time, code = row[0], float(row[1]), row[2] ip, time, code = row[0], float(row[1]), row[2]
info = CacheLine(time, code) info = CacheLine(time, code)
stored[ip] = info self.stored[ip] = info
def flush(self):
def flush():
from csv import writer from csv import writer
if not stored: if not self.stored:
return return
with open(cache_filepath, "w", newline="", encoding="utf-8") as f: with open(self.filepath, "w", newline="", encoding="utf-8") as f:
handle = writer(f) handle = writer(f)
handle.writerow(header) handle.writerow(self.csv_header)
for ip, info in stored.items(): for ip, info in self.stored.items():
timestr = "{:.2f}".format(info.time) timestr = "{:.2f}".format(info.time)
handle.writerow([ip, timestr, info.code]) handle.writerow([ip, timestr, info.code])
def cache(self, ip, info=None, timestamp=None):
def cache(ip, info=None, timestamp=None, expiry=one_month): if self.stored is None:
global stored self.prepare()
if stored is None:
prepare()
now = time() if timestamp is None else timestamp now = time() if timestamp is None else timestamp
if info is None: if info is None:
cached = stored.get(ip, None) cached = self.stored.get(ip, None)
if cached is None: if cached is None:
return None return None
if now > cached.time + expiry: if now > cached.time + self.expiry:
return None return None
return cached return cached
else: else:
assert isinstance(info, CacheLine) assert isinstance(info, CacheLine), type(info)
stored[ip] = info self.stored[ip] = info
async def find_country(self, ip, db=None):
async def find_country(ip, db=None):
now = time() now = time()
info = cache(ip, timestamp=now) info = self.cache(ip, timestamp=now)
if info is None: if info is None:
info = await lookup(ip, now) info = await self.lookup(ip, now)
if info is None: if info is None:
return None return None
cache(ip, info) self.cache(ip, info)
if db is not None: if db is not None:
if db.country_code(ip) != info.code: if db.country_code(ip) != info.code:
assert info.code is not None assert info.code is not None

View file

@ -5,6 +5,7 @@ from dataclasses import dataclass
@dataclass @dataclass
class Options: class Options:
execution: object = None execution: object = None
ipinfo: object = None
ip_simul: int = 10 # how many IPs to connect to at once ip_simul: int = 10 # how many IPs to connect to at once
domain_simul: int = 3 # how many domains per IP to request at once domain_simul: int = 3 # how many domains per IP to request at once

View file

@ -1,6 +1,7 @@
def ui(program, args): def ui(program, args):
from .db import RespoDB from .db import RespoDB
from .dns import main from .dns import main
from .ip_info import IpInfoByIpApi
from .structs import Options from .structs import Options
from argparse import ArgumentParser from argparse import ArgumentParser
from asyncio import run from asyncio import run
@ -26,6 +27,8 @@ def ui(program, args):
opts.dry = a.database is None opts.dry = a.database is None
opts.early_stopping = opts.dry opts.early_stopping = opts.dry
opts.ipinfo = IpInfoByIpApi("ipinfo_cache.csv")
if a.database is not None: if a.database is not None:
if a.database.startswith("sqlite:"): if a.database.startswith("sqlite:"):
uri = a.database uri = a.database