respodns/respodns/db.py
2020-08-29 10:16:06 +02:00

435 lines
14 KiB
Python

import storm.locals as rain
import re
ipv4_pattern = re.compile("(\d+)\.(\d+)\.(\d+)\.(\d+)", re.ASCII)
def addr_to_int(ip):
match = ipv4_pattern.fullmatch(ip)
assert match is not None, row
segs = list(map(int, match.group(1, 2, 3, 4)))
assert all(0 <= seg <= 255 for seg in segs), match.group(0)
numeric = segs[0] << 24 | segs[1] << 16 | segs[2] << 8 | segs[3]
return numeric
create_table_statements = dict(
# TODO: Duration REAL GENERATED ALWAYS AS etc.?
executions="""
CREATE TABLE IF NOT EXISTS Executions (
ExecutionId INTEGER PRIMARY KEY,
StartDate DATE NOT NULL,
FinishDate DATE,
Completed BOOLEAN DEFAULT 0 NOT NULL)
""",
exceptions="""
CREATE TABLE IF NOT EXISTS Exceptions (
ExceptionId INTEGER PRIMARY KEY,
Name TEXT NOT NULL,
Fail BOOLEAN NOT NULL)
""",
ips="""
CREATE TABLE IF NOT EXISTS Ips (
IpId INTEGER PRIMARY KEY,
AsStr TEXT GENERATED ALWAYS AS (
Cast(AsInt >> 24 & 255 AS TEXT) || '.' ||
Cast(AsInt >> 16 & 255 AS TEXT) || '.' ||
Cast(AsInt >> 8 & 255 AS TEXT) || '.' ||
Cast(AsInt & 255 AS TEXT)
) STORED NOT NULL,
AsInt INTEGER UNIQUE CHECK(AsInt >= 0 AND AsInt < 1 << 32) NOT NULL,
China BOOLEAN DEFAULT 0 NOT NULL,
BlockTarget BOOLEAN DEFAULT 0 NOT NULL,
Server BOOLEAN DEFAULT 0 NOT NULL,
RedirectTarget BOOLEAN DEFAULT 0 NOT NULL,
GfwTarget BOOLEAN DEFAULT 0 NOT NULL)
""",
kinds="""
CREATE TABLE IF NOT EXISTS Kinds (
KindId INTEGER PRIMARY KEY,
Name TEXT UNIQUE NOT NULL,
ExpectExceptionId INTEGER,
FOREIGN KEY(ExpectExceptionId) REFERENCES Exceptions(ExceptionId))
""",
domains="""
CREATE TABLE IF NOT EXISTS Domains (
DomainId INTEGER PRIMARY KEY,
Name TEXT UNIQUE NOT NULL,
KindId INTEGER,
FOREIGN KEY(KindId) REFERENCES Kinds(KindId))
""",
# NOTE: that RecordId is *not* the rowid here
# since records can contain multiple IPs,
# and thereby span multiple rows.
# TODO: indexing stuff, cascade deletion stuff.
records="""
CREATE TABLE IF NOT EXISTS Records (
RecordId INTEGER NOT NULL,
IpId INTEGER,
FOREIGN KEY(IpId) REFERENCES Ips(IpId))
""",
messages="""
CREATE TABLE IF NOT EXISTS Messages (
MessageId INTEGER PRIMARY KEY,
ExecutionId INTEGER,
ServerId INTEGER NOT NULL,
DomainId INTEGER NOT NULL,
RecordId INTEGER,
ExceptionId INTEGER,
FOREIGN KEY(ServerId) REFERENCES Ips(IpId),
FOREIGN KEY(ExecutionId) REFERENCES Executions(ExecutionId),
FOREIGN KEY(DomainId) REFERENCES Domains(DomainId),
FOREIGN KEY(ExceptionId) REFERENCES Exceptions(ExceptionId))
""",
# this fails because RecordId is not UNIQUE:
# FOREIGN KEY(RecordId) REFERENCES Records(RecordId)
)
create_view_statements = [
"""
CREATE VIEW Results AS
SELECT
Messages.ExecutionId,
ServerIps.AsStr as Server,
Kinds.Name as Kind,
Domains.Name as Name,
RecordIps.AsStr as Address,
Exceptions.Name as Exception
FROM Messages
LEFT JOIN Domains ON Messages.DomainId = Domains.DomainId
LEFT JOIN Kinds ON Domains.KindId = Kinds.KindId
LEFT JOIN Ips AS ServerIps ON Messages.ServerId = ServerIps.IpId
LEFT JOIN Records ON Messages.RecordId = Records.RecordId
LEFT JOIN Ips as RecordIps ON Records.IpId = RecordIps.IpId
LEFT JOIN Exceptions ON Messages.ExceptionId = Exceptions.ExceptionId
-- GROUP BY Records.IpId
""",
]
table_triggers = dict(
messages=[
# TODO: more triggers. (before update, and also for Records table)
"""
CREATE TRIGGER IF NOT EXISTS RecordExists
BEFORE INSERT
ON Messages
BEGIN
SELECT CASE
WHEN NEW.RecordId NOTNULL AND NOT EXISTS(SELECT 1 FROM Records WHERE Records.RecordID = NEW.RecordId)
THEN raise(FAIL, "RecordId does not exist")
END;
END
""",
])
class Execution:
def __init__(self, db):
self.db = db
self.execution = None
def __enter__(self):
from .util import right_now
self.execution = self.db.start_execution(right_now())
return self.execution
def __exit__(self, exc_type, exc_value, traceback):
from .util import right_now
completed = exc_type is None
self.db.finish_execution(self.execution, right_now(), completed)
class AttrCheck:
"""
Inheriting AttrCheck prevents accidentally setting attributes
that don't already exist.
"""
def __setattr__(self, name, value):
# NOTE: hasattr doesn't do what we want here. dir does.
if name.startswith("_") or name in dir(self):
super().__setattr__(name, value)
else:
raise AttributeError(name)
class TException(rain.Storm, AttrCheck):
__storm_table__ = "Exceptions"
exception_id = rain.Int("ExceptionId", primary=True)
name = rain.Unicode("Name")
fail = rain.Bool("Fail")
class TExecution(rain.Storm, AttrCheck):
__storm_table__ = "Executions"
execution_id = rain.Int("ExecutionId", primary=True)
start_date = rain.DateTime("StartDate")
finish_date = rain.DateTime("FinishDate")
completed = rain.Bool("Completed")
class TAddress(rain.Storm, AttrCheck):
__storm_table__ = "Ips"
address_id = rain.Int("IpId", primary=True)
str = rain.Unicode("AsStr")
ip = rain.Int("AsInt")
china = rain.Bool("China")
block_target = rain.Bool("BlockTarget")
server = rain.Bool("Server")
redirect_target = rain.Bool("RedirectTarget")
gfw_target = rain.Bool("GfwTarget")
class TKind(rain.Storm, AttrCheck):
__storm_table__ = "Kinds"
kind_id = rain.Int("KindId", primary=True)
name = rain.Unicode("Name")
xxid = rain.Int("ExpectExceptionId")
exception = rain.Reference(xxid, "TException.exception_id")
class TDomain(rain.Storm, AttrCheck):
__storm_table__ = "Domains"
domain_id = rain.Int("DomainId", primary=True)
name = rain.Unicode("Name")
kind_id = rain.Int("KindId")
kind = rain.Reference(kind_id, "TKind.kind_id")
class TRecord(rain.Storm, AttrCheck):
__storm_table__ = "Records"
row_id = rain.Int("rowid", primary=True)
record_id = rain.Int("RecordId")
address_id = rain.Int("IpId")
address = rain.Reference(address_id, "TAddress.address_id")
class TMessage(rain.Storm, AttrCheck):
__storm_table__ = "Messages"
message_id = rain.Int("MessageId", primary=True)
execution_id = rain.Int("ExecutionId")
server_id = rain.Int("ServerId")
domain_id = rain.Int("DomainId")
record_id = rain.Int("RecordId")
exception_id = rain.Int("ExceptionId")
execution = rain.Reference(execution_id, "TExecution.execution_id")
server = rain.Reference(server_id, "TAddress.address_id")
domain = rain.Reference(domain_id, "TDomain.domain_id")
#record = rain.Reference(record_id, "TRecord.record_id")
exception = rain.Reference(exception_id, "TException.exception_id")
def apply_properties(obj, d):
from storm.properties import PropertyColumn
for k, v in d.items():
ref = getattr(obj.__class__, k)
assert ref is not None, (type(obj), k)
assert isinstance(ref, PropertyColumn) or isinstance(ref, rain.Reference), \
(type(obj), k)
setattr(obj, k, v)
return obj
class RespoDB:
def __init__(self, uri, setup=False, create=False):
self.uri = uri
db_exists = self._db_exists(self.uri)
self.db = rain.create_database(self.uri)
self._conn = None
if setup or (create and not db_exists):
with self:
self.setup_executions()
self.setup_exceptions()
self.setup_ips()
self.setup_kinds()
self.setup_domains()
self.setup_records()
self.setup_messages()
for q in create_view_statements:
self._conn.execute(q, noresult=True)
assert setup or create or db_exists, "database was never setup"
self.execution = Execution(self)
@staticmethod
def _db_exists(uri):
from os.path import exists
_, _, fp = uri.partition(":")
if fp.startswith("//"):
_, _, fp = fp[2:].partition("/")
return fp and exists(fp)
def __enter__(self):
self._conn = rain.Store(self.db)
return self
def __exit__(self, exc_type, exc_value, traceback):
self.commit()
self._conn.close()
self._conn = None
def find_one(self, cls_spec, *args, **kwargs):
assert self._conn is not None
return self._conn.find(cls_spec, *args, **kwargs).one()
def flush(self):
assert self._conn is not None
self._conn.flush()
def commit(self):
assert self._conn is not None
self._conn.commit()
def new_exception(self, **kwargs):
assert self._conn is not None
return self._conn.add(apply_properties(TException(), kwargs))
def new_kind(self, **kwargs):
assert self._conn is not None
return self._conn.add(apply_properties(TKind(), kwargs))
def new_domain(self, **kwargs):
assert self._conn is not None
return self._conn.add(apply_properties(TDomain(), kwargs))
def new_address(self, **kwargs):
assert self._conn is not None
return self._conn.add(apply_properties(TAddress(), kwargs))
def new_record(self, **kwargs):
assert self._conn is not None
return self._conn.add(apply_properties(TRecord(), kwargs))
def new_message(self, **kwargs):
assert self._conn is not None
return self._conn.add(apply_properties(TMessage(), kwargs))
def setup_executions(self):
self._conn.execute(create_table_statements["executions"], noresult=True)
def setup_exceptions(self):
# careful not to call them "errors" since NXDOMAIN is not an error.
self._conn.execute(create_table_statements["exceptions"], noresult=True)
# TODO: upsert?
self.new_exception(name="NXDOMAIN", fail=False)
self.new_exception(name="NoAnswer", fail=True)
self.new_exception(name="NoNameservers", fail=True)
self.new_exception(name="Timeout", fail=True)
def setup_ips(self):
from .ips import china, blocks
self._conn.execute(create_table_statements["ips"], noresult=True)
# TODO: upsert?
self.new_address(ip=addr_to_int("0.0.0.0"), block_target=True)
self.new_address(ip=addr_to_int("127.0.0.1"), block_target=True)
for ip in china:
self.new_address(ip=addr_to_int(ip), china=True)
for ip in blocks:
self.new_address(ip=addr_to_int(ip), block_target=True)
def setup_kinds(self):
self._conn.execute(create_table_statements["kinds"], noresult=True)
# TODO: upsert?
#NXDOMAIN = self.find_one(TException, TException.name == "NXDOMAIN")
#self.new_kind(name="bad", exception=NXDOMAIN)
#self.new_kind(name="badsub", exception=NXDOMAIN)
def setup_domains(self):
self._conn.execute(create_table_statements["domains"], noresult=True)
def setup_records(self):
self._conn.execute(create_table_statements["records"], noresult=True)
def setup_messages(self):
self._conn.execute(create_table_statements["messages"], noresult=True)
for trig in table_triggers["messages"]:
self._conn.execute(trig)
def start_execution(self, dt):
execution = TExecution()
execution.start_date = dt
self.flush()
return execution
def finish_execution(self, execution, dt, completed):
# TODO: fail if ExecutionId is missing?
execution.finish_date = dt
execution.completed = completed
self.flush()
def next_record_id(self):
from storm.expr import Add, Max, Coalesce
expr = Add(Coalesce(Max(TRecord.record_id), 0), 1)
return self.find_one(expr)
def find_record_id(self, addresses):
address_ids = list(address.address_id for address in addresses)
record_ids = list(self._conn.find(TRecord, TRecord.address_id.is_in(address_ids)).values(TRecord.record_id))
if not record_ids:
return None
unique_ids = sorted(set(record_ids))
for needle in unique_ids:
if sum(1 for id in record_ids if id == needle) == len(addresses):
found = True
return needle
return None
def push_entry(self, entry):
kind = self.find_one(TKind, TKind.name == entry.kind)
if not kind:
kind = self.new_kind(name=entry.kind)
if entry.kind.startswith("bad"):
exception = self.find_one(TException, TException.name == "NXDOMAIN")
assert exception is not None
kind.exception = exception
domain = self.find_one(TDomain, TDomain.name == entry.domain)
if not domain:
domain = self.new_domain(name=entry.domain)
domain.kind = kind
addresses = []
as_ints = sorted(set(map(addr_to_int, entry.addrs)))
for numeric in as_ints:
address = self.find_one(TAddress, TAddress.ip == numeric)
if not address:
address = self.new_address(ip=numeric)
addresses.append(address)
for address in addresses:
if entry.reason == "block":
address.block_target = True
elif entry.reason == "redirect":
address.redirect_target = True
elif entry.reason == "gfw":
address.gfw_target = True
if addresses:
record_id = self.find_record_id(addresses)
if record_id is None:
record_id = self.next_record_id()
for address in addresses:
self.new_record(record_id=record_id, address=address)
else:
record_id = None
numeric = addr_to_int(entry.server)
server = self.find_one(TAddress, TAddress.ip == numeric)
if not server:
server = self.new_address(ip=numeric)
self.flush()
server.server = True
if entry.exception:
exception = self.find_one(TException, TException.name == entry.exception)
assert exception is not None
else:
exception = None
message = self.new_message(
execution=entry.execution,
server=server, domain=domain,
record_id=record_id, exception=exception)
self.flush()