rewrite ipinfo into a class
This commit is contained in:
parent
d0d29bd0d4
commit
0603081c49
5 changed files with 197 additions and 174 deletions
|
@ -1,5 +1,4 @@
|
||||||
from .structs import Options
|
from .structs import Options
|
||||||
from .ip_info import find_country, flush
|
|
||||||
|
|
||||||
|
|
||||||
def detect_gfw(r, ip, check):
|
def detect_gfw(r, ip, check):
|
||||||
|
@ -149,11 +148,12 @@ async def try_ip(db, server_ip, checks, opts: Options):
|
||||||
else:
|
else:
|
||||||
await pooler()
|
await pooler()
|
||||||
|
|
||||||
await find_country(server_ip, db)
|
if opts.ipinfo is not None:
|
||||||
|
await opts.ipinfo.find_country(server_ip, db)
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
for addr in entry.addrs:
|
for addr in entry.addrs:
|
||||||
await find_country(addr, db)
|
await opts.ipinfo.find_country(addr, db)
|
||||||
flush()
|
opts.ipinfo.flush()
|
||||||
|
|
||||||
if not opts.dry:
|
if not opts.dry:
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
|
@ -173,16 +173,18 @@ async def try_ip(db, server_ip, checks, opts: Options):
|
||||||
return server_ip, None
|
return server_ip, None
|
||||||
|
|
||||||
|
|
||||||
async def sync_database(db):
|
async def sync_database(db, opts: Options):
|
||||||
# TODO: handle addresses that were removed from respodns.ips.china.
|
# TODO: handle addresses that were removed from respodns.ips.china.
|
||||||
from .ips import china, blocks
|
from .ips import china, blocks
|
||||||
for ip in china:
|
for ips, kw in ((china, "china"), (blocks, "block_target")):
|
||||||
code = await find_country(ip)
|
for ip in ips:
|
||||||
db.modify_address(ip, china=True, country_code=code)
|
kwargs = dict()
|
||||||
for ip in blocks:
|
kwargs[kw] = True
|
||||||
code = await find_country(ip)
|
if opts.ipinfo is not None:
|
||||||
db.modify_address(ip, block_target=True, country_code=code)
|
kwargs["country_code"] = await opts.ipinfo.find_country(ip)
|
||||||
flush()
|
db.modify_address(ip, **kwargs)
|
||||||
|
if opts.ipinfo is not None:
|
||||||
|
opts.ipinfo.flush()
|
||||||
|
|
||||||
|
|
||||||
async def main(db, filepath, checks, opts: Options):
|
async def main(db, filepath, checks, opts: Options):
|
||||||
|
@ -192,7 +194,7 @@ async def main(db, filepath, checks, opts: Options):
|
||||||
from sys import stdin
|
from sys import stdin
|
||||||
|
|
||||||
if db is not None:
|
if db is not None:
|
||||||
await sync_database(db)
|
await sync_database(db, opts)
|
||||||
|
|
||||||
def finisher(done, pending):
|
def finisher(done, pending):
|
||||||
for task in done:
|
for task in done:
|
||||||
|
|
|
@ -1,66 +1,41 @@
|
||||||
from asyncio import open_connection, sleep
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from socket import gaierror
|
|
||||||
from sys import stderr
|
from sys import stderr
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
CacheLine = namedtuple("CacheLine", ("time", "code"))
|
CacheLine = namedtuple("CacheLine", ("time", "code"))
|
||||||
header = ["ip", "time", "code"]
|
|
||||||
|
|
||||||
one_month = 365.25 / 12 * 24 * 60 * 60 # in seconds
|
one_month = 365.25 / 12 * 24 * 60 * 60 # in seconds
|
||||||
encoding = "latin-1"
|
|
||||||
|
|
||||||
cache_filepath = "ipinfo_cache.csv"
|
|
||||||
|
|
||||||
http_cooldown = 0
|
|
||||||
give_up = False
|
|
||||||
prepared = False
|
|
||||||
stored = None
|
|
||||||
|
|
||||||
|
|
||||||
async def http_lookup(ip):
|
class IpInfoBase:
|
||||||
global http_cooldown
|
pass
|
||||||
|
|
||||||
host = "ip-api.com"
|
|
||||||
path = f"/csv/{ip}?fields=2"
|
|
||||||
|
|
||||||
err = None
|
class IpInfoByIpApi(IpInfoBase):
|
||||||
|
def __init__(self, filepath, expiry=one_month):
|
||||||
|
self.filepath = filepath
|
||||||
|
self.expiry = expiry
|
||||||
|
self.cooldown = 0
|
||||||
|
self.give_up = False
|
||||||
|
self.prepared = False
|
||||||
|
self.stored = None
|
||||||
|
|
||||||
# Quote:
|
self.encoding = "latin-1"
|
||||||
# Your implementation should always check the value of the X-Rl header,
|
self.host = "ip-api.com"
|
||||||
# and if its is 0 you must not send any more requests
|
self.path = "/csv/{}?fields=2"
|
||||||
# for the duration of X-Ttl in seconds.
|
self.csv_header = ["ip", "time", "code"]
|
||||||
while time() < http_cooldown:
|
|
||||||
wait = http_cooldown - time()
|
|
||||||
wait = max(wait, 0.1) # wait at least a little bit
|
|
||||||
await sleep(wait)
|
|
||||||
|
|
||||||
query_lines = (
|
def decode(self, txt, mode="ignore"):
|
||||||
f"GET {path} HTTP/1.1",
|
return txt.decode(self.encoding, mode)
|
||||||
f"Host: {host}",
|
|
||||||
f"Connection: close",
|
|
||||||
)
|
|
||||||
query = "\r\n".join(query_lines) + "\r\n\r\n"
|
|
||||||
|
|
||||||
reader, writer = await open_connection(host, 80)
|
def encode(self, txt, mode="ignore"):
|
||||||
|
return txt.encode(self.encoding, mode)
|
||||||
writer.write(query.encode(encoding, "strict"))
|
|
||||||
response = await reader.read()
|
|
||||||
lines = response.splitlines()
|
|
||||||
|
|
||||||
it = iter(lines)
|
|
||||||
line = next(it)
|
|
||||||
if line != b"HTTP/1.1 200 OK":
|
|
||||||
http_cooldown = time() + 60
|
|
||||||
err = "not ok"
|
|
||||||
it = () # exhaust iterator (not really)
|
|
||||||
|
|
||||||
x_cooldown = None
|
|
||||||
x_remaining = None
|
|
||||||
|
|
||||||
|
def _parse_headers(self, it):
|
||||||
|
# returns an error message as a string, or None.
|
||||||
for line in it:
|
for line in it:
|
||||||
if line == b"":
|
if line == b"":
|
||||||
break
|
break # end of header
|
||||||
|
|
||||||
head, _, tail = line.partition(b":")
|
head, _, tail = line.partition(b":")
|
||||||
|
|
||||||
|
@ -68,8 +43,8 @@ async def http_lookup(ip):
|
||||||
if tail[0:1] == b" ":
|
if tail[0:1] == b" ":
|
||||||
tail = tail[1:]
|
tail = tail[1:]
|
||||||
else:
|
else:
|
||||||
err = "bad tail"
|
s = self.decode(line, "replace")
|
||||||
break
|
return "bad tail: " + s
|
||||||
if head in (b"Date", b"Content-Type", b"Content-Length",
|
if head in (b"Date", b"Content-Type", b"Content-Length",
|
||||||
b"Access-Control-Allow-Origin"):
|
b"Access-Control-Allow-Origin"):
|
||||||
pass
|
pass
|
||||||
|
@ -77,14 +52,69 @@ async def http_lookup(ip):
|
||||||
if tail.isdigit():
|
if tail.isdigit():
|
||||||
x_cooldown = int(tail)
|
x_cooldown = int(tail)
|
||||||
else:
|
else:
|
||||||
err = "X-Ttl not integer"
|
s = self.decode(tail, "replace")
|
||||||
break
|
return "X-Ttl not an integer: " + s
|
||||||
elif head == b"X-Rl":
|
elif head == b"X-Rl":
|
||||||
if tail.isdigit():
|
if tail.isdigit():
|
||||||
x_remaining = int(tail)
|
x_remaining = int(tail)
|
||||||
else:
|
else:
|
||||||
err = "X-Rl not integer"
|
s = self.decode(tail, "replace")
|
||||||
break
|
return "X-Rl not an integer: " + s
|
||||||
|
else:
|
||||||
|
s = self.decode(head, "replace")
|
||||||
|
return "unexpected field: " + s
|
||||||
|
|
||||||
|
if x_remaining == 0:
|
||||||
|
self.cooldown = time() + x_cooldown
|
||||||
|
self.cooldown += 1.0 # still too frequent according to them
|
||||||
|
|
||||||
|
return None # no error
|
||||||
|
|
||||||
|
async def wait(self):
|
||||||
|
from asyncio import sleep
|
||||||
|
|
||||||
|
# Quote:
|
||||||
|
# Your implementation should always check the value of the X-Rl header,
|
||||||
|
# and if its is 0 you must not send any more requests
|
||||||
|
# for the duration of X-Ttl in seconds.
|
||||||
|
while time() < self.cooldown:
|
||||||
|
wait = self.cooldown - time()
|
||||||
|
wait = max(wait, 0.1) # wait at least a little bit
|
||||||
|
await sleep(wait)
|
||||||
|
|
||||||
|
async def http_lookup(self, ip):
|
||||||
|
from asyncio import open_connection
|
||||||
|
|
||||||
|
await self.wait()
|
||||||
|
|
||||||
|
path = self.path.format(ip)
|
||||||
|
query_lines = (
|
||||||
|
f"GET {path} HTTP/1.1",
|
||||||
|
f"Host: {self.host}",
|
||||||
|
f"Connection: close",
|
||||||
|
)
|
||||||
|
query = "\r\n".join(query_lines) + "\r\n\r\n"
|
||||||
|
|
||||||
|
reader, writer = await open_connection(self.host, 80)
|
||||||
|
|
||||||
|
writer.write(self.encode(query, "strict"))
|
||||||
|
response = await reader.read()
|
||||||
|
lines = response.splitlines()
|
||||||
|
|
||||||
|
err = None
|
||||||
|
it = iter(lines)
|
||||||
|
|
||||||
|
line = next(it)
|
||||||
|
if line != b"HTTP/1.1 200 OK":
|
||||||
|
self.cooldown = time() + 60
|
||||||
|
err = "not ok"
|
||||||
|
it = () # exhaust iterator (not really)
|
||||||
|
|
||||||
|
x_cooldown = None
|
||||||
|
x_remaining = None
|
||||||
|
|
||||||
|
if err is None:
|
||||||
|
err = self._parse_headers(it)
|
||||||
|
|
||||||
for i, line in enumerate(it):
|
for i, line in enumerate(it):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
|
@ -95,26 +125,19 @@ async def http_lookup(ip):
|
||||||
|
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
if x_remaining == 0:
|
return (code, None) if err is None else (None, err)
|
||||||
http_cooldown = time() + x_cooldown
|
|
||||||
http_cooldown += 1.0 # still too frequent according to them
|
|
||||||
|
|
||||||
if err:
|
async def lookup(self, ip, timestamp):
|
||||||
return None, err
|
from socket import gaierror
|
||||||
else:
|
|
||||||
return code, None
|
|
||||||
|
|
||||||
|
if self.give_up:
|
||||||
async def lookup(ip, timestamp):
|
|
||||||
global give_up
|
|
||||||
if give_up:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
code, err = await http_lookup(ip)
|
code, err = await self.http_lookup(ip)
|
||||||
if err:
|
if err:
|
||||||
# retry once in case of rate-limiting
|
# retry once in case of rate-limiting
|
||||||
code, err = await http_lookup(ip)
|
code, err = await self.http_lookup(ip)
|
||||||
if err:
|
if err:
|
||||||
return None
|
return None
|
||||||
except gaierror:
|
except gaierror:
|
||||||
|
@ -122,7 +145,7 @@ async def lookup(ip, timestamp):
|
||||||
except OSError:
|
except OSError:
|
||||||
give_up = True
|
give_up = True
|
||||||
|
|
||||||
code = code.decode(encoding, "ignore")
|
code = self.decode(code)
|
||||||
if code == "":
|
if code == "":
|
||||||
code = "--"
|
code = "--"
|
||||||
if len(code) != 2:
|
if len(code) != 2:
|
||||||
|
@ -130,69 +153,63 @@ async def lookup(ip, timestamp):
|
||||||
info = CacheLine(timestamp, code)
|
info = CacheLine(timestamp, code)
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
def prepare(self):
|
||||||
def prepare():
|
|
||||||
from csv import reader, writer
|
from csv import reader, writer
|
||||||
from os.path import exists
|
from os.path import exists
|
||||||
|
|
||||||
global stored
|
self.stored = dict()
|
||||||
stored = dict()
|
|
||||||
|
|
||||||
if not exists(cache_filepath):
|
if not exists(self.filepath):
|
||||||
with open(cache_filepath, "w") as f:
|
with open(self.filepath, "w") as f:
|
||||||
handle = writer(f)
|
handle = writer(f)
|
||||||
handle.writerow(header)
|
handle.writerow(self.csv_header)
|
||||||
return
|
return
|
||||||
|
|
||||||
with open(cache_filepath, "r", newline="", encoding="utf-8") as f:
|
with open(self.filepath, "r", newline="", encoding="utf-8") as f:
|
||||||
for i, row in enumerate(reader(f)):
|
for i, row in enumerate(reader(f)):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
assert row == header, row
|
assert row == self.csv_header, row
|
||||||
continue
|
continue
|
||||||
ip, time, code = row[0], float(row[1]), row[2]
|
ip, time, code = row[0], float(row[1]), row[2]
|
||||||
info = CacheLine(time, code)
|
info = CacheLine(time, code)
|
||||||
stored[ip] = info
|
self.stored[ip] = info
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
def flush():
|
|
||||||
from csv import writer
|
from csv import writer
|
||||||
if not stored:
|
if not self.stored:
|
||||||
return
|
return
|
||||||
with open(cache_filepath, "w", newline="", encoding="utf-8") as f:
|
with open(self.filepath, "w", newline="", encoding="utf-8") as f:
|
||||||
handle = writer(f)
|
handle = writer(f)
|
||||||
handle.writerow(header)
|
handle.writerow(self.csv_header)
|
||||||
for ip, info in stored.items():
|
for ip, info in self.stored.items():
|
||||||
timestr = "{:.2f}".format(info.time)
|
timestr = "{:.2f}".format(info.time)
|
||||||
handle.writerow([ip, timestr, info.code])
|
handle.writerow([ip, timestr, info.code])
|
||||||
|
|
||||||
|
def cache(self, ip, info=None, timestamp=None):
|
||||||
def cache(ip, info=None, timestamp=None, expiry=one_month):
|
if self.stored is None:
|
||||||
global stored
|
self.prepare()
|
||||||
if stored is None:
|
|
||||||
prepare()
|
|
||||||
|
|
||||||
now = time() if timestamp is None else timestamp
|
now = time() if timestamp is None else timestamp
|
||||||
|
|
||||||
if info is None:
|
if info is None:
|
||||||
cached = stored.get(ip, None)
|
cached = self.stored.get(ip, None)
|
||||||
if cached is None:
|
if cached is None:
|
||||||
return None
|
return None
|
||||||
if now > cached.time + expiry:
|
if now > cached.time + self.expiry:
|
||||||
return None
|
return None
|
||||||
return cached
|
return cached
|
||||||
else:
|
else:
|
||||||
assert isinstance(info, CacheLine)
|
assert isinstance(info, CacheLine), type(info)
|
||||||
stored[ip] = info
|
self.stored[ip] = info
|
||||||
|
|
||||||
|
async def find_country(self, ip, db=None):
|
||||||
async def find_country(ip, db=None):
|
|
||||||
now = time()
|
now = time()
|
||||||
info = cache(ip, timestamp=now)
|
info = self.cache(ip, timestamp=now)
|
||||||
if info is None:
|
if info is None:
|
||||||
info = await lookup(ip, now)
|
info = await self.lookup(ip, now)
|
||||||
if info is None:
|
if info is None:
|
||||||
return None
|
return None
|
||||||
cache(ip, info)
|
self.cache(ip, info)
|
||||||
if db is not None:
|
if db is not None:
|
||||||
if db.country_code(ip) != info.code:
|
if db.country_code(ip) != info.code:
|
||||||
assert info.code is not None
|
assert info.code is not None
|
||||||
|
|
|
@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||||
@dataclass
|
@dataclass
|
||||||
class Options:
|
class Options:
|
||||||
execution: object = None
|
execution: object = None
|
||||||
|
ipinfo: object = None
|
||||||
|
|
||||||
ip_simul: int = 10 # how many IPs to connect to at once
|
ip_simul: int = 10 # how many IPs to connect to at once
|
||||||
domain_simul: int = 3 # how many domains per IP to request at once
|
domain_simul: int = 3 # how many domains per IP to request at once
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
def ui(program, args):
|
def ui(program, args):
|
||||||
from .db import RespoDB
|
from .db import RespoDB
|
||||||
from .dns import main
|
from .dns import main
|
||||||
|
from .ip_info import IpInfoByIpApi
|
||||||
from .structs import Options
|
from .structs import Options
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from asyncio import run
|
from asyncio import run
|
||||||
|
@ -26,6 +27,8 @@ def ui(program, args):
|
||||||
opts.dry = a.database is None
|
opts.dry = a.database is None
|
||||||
opts.early_stopping = opts.dry
|
opts.early_stopping = opts.dry
|
||||||
|
|
||||||
|
opts.ipinfo = IpInfoByIpApi("ipinfo_cache.csv")
|
||||||
|
|
||||||
if a.database is not None:
|
if a.database is not None:
|
||||||
if a.database.startswith("sqlite:"):
|
if a.database.startswith("sqlite:"):
|
||||||
uri = a.database
|
uri = a.database
|
||||||
|
|
Loading…
Reference in a new issue