respodns/respodns/db.py
2021-08-13 02:38:36 -07:00

305 lines
9.2 KiB
Python

from .ip_util import addr_to_int
from .tables import Base, TException, TExecution, TAddress
from .tables import TKind, TDomain, TRecord, TMessage
from sqlalchemy.orm import sessionmaker
Session = sessionmaker(autoflush=False)
create_view_statements = [
"""
CREATE VIEW Results AS
SELECT
Messages.ExecutionId,
ServerIps.AsStr as Server,
Kinds.Name as Kind,
Domains.Name as Name,
RecordIps.AsStr as Address,
Exceptions.Name as Exception,
Messages.Failed as Failed
FROM Messages
LEFT JOIN Domains ON Messages.DomainId = Domains.DomainId
LEFT JOIN Kinds ON Domains.KindId = Kinds.KindId
LEFT JOIN Ips AS ServerIps ON Messages.ServerId = ServerIps.IpId
LEFT JOIN Records ON Messages.RecordId = Records.RecordId
LEFT JOIN Ips as RecordIps ON Records.IpId = RecordIps.IpId
LEFT JOIN Exceptions ON Messages.ExceptionId = Exceptions.ExceptionId
""",
]
class Execution:
def __init__(self, db):
self.db = db
self.execution = None
def __enter__(self):
from .util import right_now
self.execution = self.db.start_execution(right_now())
return self.execution
def __exit__(self, exc_type, exc_value, traceback):
from .util import right_now
completed = exc_type is None
self.db.finish_execution(self.execution, right_now(), completed)
def is_column(ref):
from sqlalchemy.orm.attributes import InstrumentedAttribute
return isinstance(ref, InstrumentedAttribute)
def apply_properties(obj, d):
for k, v in d.items():
ref = getattr(obj.__class__, k)
assert ref is not None, (type(obj), k)
assert is_column(ref), (type(obj), k)
setattr(obj, k, v)
return obj
class RespoDB:
def __init__(self, uri, setup=False, create=False):
from sqlalchemy import create_engine
self.uri = uri
db_exists = self._db_exists(self.uri)
self.db = create_engine(self.uri)
Session.configure(bind=self.db)
self._conn = None
if setup or (create and not db_exists):
with self:
Base.metadata.create_all(self.db)
self.setup_executions()
self.setup_exceptions()
self.setup_ips()
self.setup_kinds()
self.setup_domains()
self.setup_records()
self.setup_messages()
for q in create_view_statements:
self._fire(q)
assert setup or create or db_exists, "database was never setup"
self.execution = Execution(self)
@staticmethod
def _db_exists(uri):
from os.path import exists
_, _, fp = uri.partition(":")
if fp.startswith("//"):
_, _, fp = fp[2:].partition("/")
return fp and exists(fp)
def __enter__(self):
self._conn = Session()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.commit()
self._conn.close()
self._conn = None
def find_one(self, cls_spec, **filters):
if len(filters) > 0:
return self._conn.query(cls_spec).filter_by(**filters).first()
else:
return self._conn.query(cls_spec).first()
def flush(self):
assert self._conn is not None
self._conn.flush()
def commit(self):
assert self._conn is not None
self._conn.commit()
def new_exception(self, **kwargs):
assert self._conn is not None
res = TException(**kwargs)
self._conn.add(res)
return res
def new_kind(self, **kwargs):
assert self._conn is not None
res = TKind(**kwargs)
self._conn.add(res)
return res
def new_domain(self, **kwargs):
assert self._conn is not None
res = TDomain(**kwargs)
self._conn.add(res)
return res
def new_address(self, **kwargs):
assert self._conn is not None
res = TAddress(**kwargs)
self._conn.add(res)
return res
def new_record(self, **kwargs):
assert self._conn is not None
res = TRecord(**kwargs)
self._conn.add(res)
return res
def new_message(self, **kwargs):
assert self._conn is not None
res = TMessage(**kwargs)
self._conn.add(res)
return res
def _fire(self, statement):
assert self._conn is not None
self._conn.execute(statement).close()
def setup_executions(self):
pass
def setup_exceptions(self):
# careful not to call them "errors" since NXDOMAIN is not an error.
# TODO: upsert?
self.new_exception(name="NXDOMAIN", fail=False)
self.new_exception(name="NoAnswer", fail=True)
self.new_exception(name="NoNameservers", fail=True)
self.new_exception(name="Timeout", fail=True)
def setup_ips(self):
self.modify_address("0.0.0.0", block_target=True)
self.modify_address("127.0.0.1", block_target=True)
def setup_kinds(self):
pass
def setup_domains(self):
pass
def setup_records(self):
pass
def setup_messages(self):
pass
def start_execution(self, dt):
execution = TExecution()
execution.start_date = dt
self.flush()
return execution
def finish_execution(self, execution, dt, completed):
# TODO: fail if ExecutionId is missing?
execution.finish_date = dt
execution.completed = completed
self.flush()
def all_ips(self):
assert self._conn is not None
temp = self._conn.query(TAddress).values(TAddress.str)
return [t[0] for t in temp]
def next_record_id(self):
from sqlalchemy.sql.expression import func
expr = func.coalesce(func.max(TRecord.record_id), 0) + 1
return self.find_one(expr)[0]
def find_record_id(self, addresses):
address_ids = list(address.address_id for address in addresses)
temp = self._conn.query(TRecord).filter(TRecord.address_id.in_(address_ids))
record_ids = [t[0] for t in temp.values(TRecord.record_id)]
# TODO: why are record_ids even tuples to begin with?
if not record_ids:
return None
unique_ids = sorted(set(record_ids))
for needle in unique_ids:
if sum(1 for id in record_ids if id == needle) == len(addresses):
found = True
return needle
return None
def push_entry(self, entry):
kind = self.find_one(TKind, name=entry.kind)
if not kind:
kind = self.new_kind(name=entry.kind)
if entry.kind.startswith("bad"):
exception = self.find_one(TException, name="NXDOMAIN")
assert exception is not None
kind.exception = exception
domain = self.find_one(TDomain, name=entry.domain)
if not domain:
domain = self.new_domain(name=entry.domain)
domain.kind = kind
addresses = []
as_ints = sorted(set(map(addr_to_int, entry.addrs)))
for numeric in as_ints:
address = self.find_one(TAddress, ip=numeric)
if not address:
address = self.new_address(ip=numeric)
addresses.append(address)
for address in addresses:
if entry.reason == "block":
address.block_target = True
elif entry.reason == "redirect":
address.redirect_target = True
elif entry.reason == "gfw":
address.gfw_target = True
if addresses:
record_id = self.find_record_id(addresses)
if record_id is None:
record_id = self.next_record_id()
for address in addresses:
self.new_record(record_id=record_id, address=address)
else:
record_id = None
numeric = addr_to_int(entry.server)
server = self.find_one(TAddress, ip=numeric)
if not server:
server = self.new_address(ip=numeric)
self.flush()
server.server = True
if entry.exception:
exception = self.find_one(TException, name=entry.exception)
assert exception is not None
else:
exception = None
failed = not entry.success
message = self.new_message(
execution=entry.execution,
server=server, domain=domain,
record_id=record_id, exception=exception,
failed=failed)
self.flush()
def country_code(self, ip, code=None):
numeric = addr_to_int(ip)
address = self.find_one(TAddress, ip=numeric)
if code is None:
if address is not None:
return address.country_code
else:
# NOTE: can't set code to null here since None was ruled out.
if address is None:
self.new_address(ip=numeric, country_code=code)
else:
address.country_code = code
self.flush()
return None
def modify_address(self, ip, **kwargs):
numeric = addr_to_int(ip)
address = self.find_one(TAddress, ip=numeric)
if address is None:
self.new_address(ip=numeric, **kwargs)
else:
apply_properties(address, kwargs)
self.flush()