respodns/respodns/ip_info.py

221 lines
6.7 KiB
Python

from collections import namedtuple
from time import time
CacheLine = namedtuple("CacheLine", ("time", "code"))
one_month = 365.25 / 12 * 24 * 60 * 60 # in seconds
def mask_ip(ip):
# this assumes IP info is the same for a /24 subnet.
return ".".join(ip.split(".")[:3]) + ".0"
class IpInfoBase:
pass
class IpInfoByIpApi(IpInfoBase):
def __init__(self, filepath, expiry=one_month):
self.filepath = filepath
self.expiry = expiry
self.cooldown = 0
self.give_up = False
self.stored = None
self.encoding = "latin-1"
self.host = "ip-api.com"
self.path = "/csv/{}?fields=2"
self.csv_header = ["ip", "time", "code"]
def decode(self, txt, mode="ignore"):
return txt.decode(self.encoding, mode)
def encode(self, txt, mode="ignore"):
return txt.encode(self.encoding, mode)
def _parse_headers(self, it):
# returns an error message as a string, or None.
x_cooldown = None
x_remaining = None
for line in it:
if line == b"":
break # end of header
head, _, tail = line.partition(b":")
# do some very basic validation.
if tail[0:1] == b" ":
tail = tail[1:]
else:
s = self.decode(line, "replace")
return "bad tail: " + s
if head in (b"Date", b"Content-Type", b"Content-Length",
b"Access-Control-Allow-Origin"):
pass
elif head == b"X-Ttl":
if tail.isdigit():
x_cooldown = int(tail)
else:
s = self.decode(tail, "replace")
return "X-Ttl not an integer: " + s
elif head == b"X-Rl":
if tail.isdigit():
x_remaining = int(tail)
else:
s = self.decode(tail, "replace")
return "X-Rl not an integer: " + s
else:
s = self.decode(head, "replace")
return "unexpected field: " + s
if x_remaining == 0:
self.cooldown = time() + x_cooldown
self.cooldown += 1.0 # still too frequent according to them
return None # no error
async def wait(self):
from asyncio import sleep
# Quote:
# Your implementation should always check the value of the X-Rl header,
# and if its is 0 you must not send any more requests
# for the duration of X-Ttl in seconds.
while time() < self.cooldown:
wait = self.cooldown - time()
wait = max(0, wait) # must not be negative!
wait += 0.1 # still too frequent according to them
await sleep(wait)
async def http_lookup(self, ip):
from asyncio import open_connection
await self.wait()
path = self.path.format(ip)
query_lines = (
f"GET {path} HTTP/1.1",
f"Host: {self.host}",
f"Connection: close",
)
query = "\r\n".join(query_lines) + "\r\n\r\n"
# TODO: only perform DNS lookup once.
# TODO: only open connection once if possible (keep-alive).
reader, writer = await open_connection(self.host, 80)
writer.write(self.encode(query, "strict"))
response = await reader.read()
lines = response.splitlines()
it = iter(lines)
if next(it, None) == b"HTTP/1.1 200 OK":
err = self._parse_headers(it)
code = next(it, None)
if code is None:
err = "missing body"
if next(it, None) is not None:
err = "too many lines"
else:
err = "not ok"
self.cooldown = time() + 30
writer.close()
return (code, None) if err is None else (None, err)
async def lookup(self, ip, timestamp):
from socket import gaierror
if self.give_up:
return None
try:
code, err = await self.http_lookup(ip)
if err:
# retry once in case of rate-limiting
code, err = await self.http_lookup(ip)
if err:
return None
except gaierror:
give_up = True
except OSError:
give_up = True
code = self.decode(code)
if code == "":
code = "--"
if len(code) != 2:
return None
info = CacheLine(timestamp, code)
return info
def prepare(self):
from csv import reader, writer
from os.path import exists
self.stored = dict()
if not exists(self.filepath):
with open(self.filepath, "w") as f:
handle = writer(f)
handle.writerow(self.csv_header)
return
with open(self.filepath, "r", newline="", encoding="utf-8") as f:
for i, row in enumerate(reader(f)):
if i == 0:
assert row == self.csv_header, row
continue
ip, time, code = row[0], float(row[1]), row[2]
info = CacheLine(time, code)
masked = mask_ip(ip)
self.stored[masked] = info
def flush(self):
from csv import writer
if not self.stored:
return
with open(self.filepath, "w", newline="", encoding="utf-8") as f:
handle = writer(f)
handle.writerow(self.csv_header)
for ip, info in self.stored.items():
timestr = "{:.2f}".format(info.time)
masked = mask_ip(ip)
handle.writerow([masked, timestr, info.code])
def cache(self, ip, info=None, timestamp=None):
if self.stored is None:
self.prepare()
now = time() if timestamp is None else timestamp
masked = mask_ip(ip)
if info is None:
cached = self.stored.get(masked, None)
if cached is None:
return None
if now > cached.time + self.expiry:
return None
return cached
else:
assert isinstance(info, CacheLine), type(info)
self.stored[masked] = info
async def find_country(self, ip, db=None):
now = time()
info = self.cache(ip, timestamp=now)
if info is None:
info = await self.lookup(ip, now)
if info is None:
return None
self.cache(ip, info)
if db is not None:
if db.country_code(ip) != info.code:
assert info.code is not None
db.country_code(ip, info.code)
return info.code