from collections import namedtuple from time import time CacheLine = namedtuple("CacheLine", ("time", "code")) one_month = 365.25 / 12 * 24 * 60 * 60 # in seconds def mask_ip(ip): # this assumes IP info is the same for a /24 subnet. return ".".join(ip.split(".")[:3]) + ".0" class IpInfoBase: pass class IpInfoByIpApi(IpInfoBase): def __init__(self, filepath, expiry=one_month): self.filepath = filepath self.expiry = expiry self.cooldown = 0 self.give_up = False self.stored = None self.encoding = "latin-1" self.host = "ip-api.com" self.path = "/csv/{}?fields=2" self.csv_header = ["ip", "time", "code"] def decode(self, txt, mode="ignore"): return txt.decode(self.encoding, mode) def encode(self, txt, mode="ignore"): return txt.encode(self.encoding, mode) def _parse_headers(self, it): # returns an error message as a string, or None. x_cooldown = None x_remaining = None for line in it: if line == b"": break # end of header head, _, tail = line.partition(b":") # do some very basic validation. if tail[0:1] == b" ": tail = tail[1:] else: s = self.decode(line, "replace") return "bad tail: " + s if head in (b"Date", b"Content-Type", b"Content-Length", b"Access-Control-Allow-Origin"): pass elif head == b"X-Ttl": if tail.isdigit(): x_cooldown = int(tail) else: s = self.decode(tail, "replace") return "X-Ttl not an integer: " + s elif head == b"X-Rl": if tail.isdigit(): x_remaining = int(tail) else: s = self.decode(tail, "replace") return "X-Rl not an integer: " + s else: s = self.decode(head, "replace") return "unexpected field: " + s if x_remaining == 0: self.cooldown = time() + x_cooldown self.cooldown += 1.0 # still too frequent according to them return None # no error async def wait(self): from asyncio import sleep # Quote: # Your implementation should always check the value of the X-Rl header, # and if its is 0 you must not send any more requests # for the duration of X-Ttl in seconds. while time() < self.cooldown: wait = self.cooldown - time() wait = max(0, wait) # must not be negative! wait += 0.1 # still too frequent according to them await sleep(wait) async def http_lookup(self, ip): from asyncio import open_connection await self.wait() path = self.path.format(ip) query_lines = ( f"GET {path} HTTP/1.1", f"Host: {self.host}", f"Connection: close", ) query = "\r\n".join(query_lines) + "\r\n\r\n" # TODO: only perform DNS lookup once. # TODO: only open connection once if possible (keep-alive). reader, writer = await open_connection(self.host, 80) writer.write(self.encode(query, "strict")) response = await reader.read() lines = response.splitlines() it = iter(lines) if next(it, None) == b"HTTP/1.1 200 OK": err = self._parse_headers(it) code = next(it, None) if code is None: err = "missing body" if next(it, None) is not None: err = "too many lines" else: err = "not ok" self.cooldown = time() + 30 writer.close() return (code, None) if err is None else (None, err) async def lookup(self, ip, timestamp): from socket import gaierror if self.give_up: return None try: code, err = await self.http_lookup(ip) if err: # retry once in case of rate-limiting code, err = await self.http_lookup(ip) if err: return None except gaierror: give_up = True except OSError: give_up = True code = self.decode(code) if code == "": code = "--" if len(code) != 2: return None info = CacheLine(timestamp, code) return info def prepare(self): from csv import reader, writer from os.path import exists self.stored = dict() if not exists(self.filepath): with open(self.filepath, "w") as f: handle = writer(f) handle.writerow(self.csv_header) return with open(self.filepath, "r", newline="", encoding="utf-8") as f: for i, row in enumerate(reader(f)): if i == 0: assert row == self.csv_header, row continue ip, time, code = row[0], float(row[1]), row[2] info = CacheLine(time, code) masked = mask_ip(ip) self.stored[masked] = info def flush(self): from csv import writer if not self.stored: return with open(self.filepath, "w", newline="", encoding="utf-8") as f: handle = writer(f) handle.writerow(self.csv_header) for ip, info in self.stored.items(): timestr = "{:.2f}".format(info.time) masked = mask_ip(ip) handle.writerow([masked, timestr, info.code]) def cache(self, ip, info=None, timestamp=None): if self.stored is None: self.prepare() now = time() if timestamp is None else timestamp masked = mask_ip(ip) if info is None: cached = self.stored.get(masked, None) if cached is None: return None if now > cached.time + self.expiry: return None return cached else: assert isinstance(info, CacheLine), type(info) self.stored[masked] = info async def find_country(self, ip, db=None): now = time() info = self.cache(ip, timestamp=now) if info is None: info = await self.lookup(ip, now) if info is None: return None self.cache(ip, info) if db is not None: if db.country_code(ip) != info.code: assert info.code is not None db.country_code(ip, info.code) return info.code