switch from storm to sqlalchemy

This commit is contained in:
Connor Olding 2021-08-06 18:14:17 -07:00
parent 1395e7d6f6
commit 9defbc8da0
4 changed files with 188 additions and 221 deletions

View File

@ -1,2 +1,2 @@
dnspython >= 2.1.0 dnspython >= 2.1.0
storm sqlalchemy >= 1.3.18

View File

@ -1,8 +1,30 @@
from .sql import create_table_statements, create_view_statements
from .sql import table_triggers
from .tables import TException, TExecution, TAddress
from .tables import TKind, TDomain, TRecord, TMessage
from .ip_util import addr_to_int from .ip_util import addr_to_int
from .tables import Base, TException, TExecution, TAddress
from .tables import TKind, TDomain, TRecord, TMessage
from sqlalchemy.orm import sessionmaker
Session = sessionmaker(autoflush=False)
create_view_statements = [
"""
CREATE VIEW Results AS
SELECT
Messages.ExecutionId,
ServerIps.AsStr as Server,
Kinds.Name as Kind,
Domains.Name as Name,
RecordIps.AsStr as Address,
Exceptions.Name as Exception,
Messages.Failed as Failed
FROM Messages
LEFT JOIN Domains ON Messages.DomainId = Domains.DomainId
LEFT JOIN Kinds ON Domains.KindId = Kinds.KindId
LEFT JOIN Ips AS ServerIps ON Messages.ServerId = ServerIps.IpId
LEFT JOIN Records ON Messages.RecordId = Records.RecordId
LEFT JOIN Ips as RecordIps ON Records.IpId = RecordIps.IpId
LEFT JOIN Exceptions ON Messages.ExceptionId = Exceptions.ExceptionId
""",
]
class Execution: class Execution:
@ -22,9 +44,8 @@ class Execution:
def is_column(ref): def is_column(ref):
from storm.properties import PropertyColumn from sqlalchemy.orm.attributes import InstrumentedAttribute
from storm.references import Reference return isinstance(ref, InstrumentedAttribute)
return isinstance(ref, PropertyColumn) or isinstance(ref, Reference)
def apply_properties(obj, d): def apply_properties(obj, d):
@ -38,15 +59,17 @@ def apply_properties(obj, d):
class RespoDB: class RespoDB:
def __init__(self, uri, setup=False, create=False): def __init__(self, uri, setup=False, create=False):
from storm.database import create_database from sqlalchemy import create_engine
self.uri = uri self.uri = uri
db_exists = self._db_exists(self.uri) db_exists = self._db_exists(self.uri)
self.db = create_database(self.uri) self.db = create_engine(self.uri)
Session.configure(bind=self.db)
self._conn = None self._conn = None
if setup or (create and not db_exists): if setup or (create and not db_exists):
with self: with self:
Base.metadata.create_all(self.db)
self.setup_executions() self.setup_executions()
self.setup_exceptions() self.setup_exceptions()
self.setup_ips() self.setup_ips()
@ -70,8 +93,7 @@ class RespoDB:
return fp and exists(fp) return fp and exists(fp)
def __enter__(self): def __enter__(self):
from storm.store import Store self._conn = Session()
self._conn = Store(self.db)
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
@ -79,9 +101,11 @@ class RespoDB:
self._conn.close() self._conn.close()
self._conn = None self._conn = None
def find_one(self, cls_spec, *args, **kwargs): def find_one(self, cls_spec, **filters):
assert self._conn is not None if len(filters) > 0:
return self._conn.find(cls_spec, *args, **kwargs).one() return self._conn.query(cls_spec).filter_by(**filters).first()
else:
return self._conn.query(cls_spec).first()
def flush(self): def flush(self):
assert self._conn is not None assert self._conn is not None
@ -93,65 +117,70 @@ class RespoDB:
def new_exception(self, **kwargs): def new_exception(self, **kwargs):
assert self._conn is not None assert self._conn is not None
return self._conn.add(apply_properties(TException(), kwargs)) res = TException(**kwargs)
self._conn.add(res)
return res
def new_kind(self, **kwargs): def new_kind(self, **kwargs):
assert self._conn is not None assert self._conn is not None
return self._conn.add(apply_properties(TKind(), kwargs)) res = TKind(**kwargs)
self._conn.add(res)
return res
def new_domain(self, **kwargs): def new_domain(self, **kwargs):
assert self._conn is not None assert self._conn is not None
return self._conn.add(apply_properties(TDomain(), kwargs)) res = TDomain(**kwargs)
self._conn.add(res)
return res
def new_address(self, **kwargs): def new_address(self, **kwargs):
assert self._conn is not None assert self._conn is not None
return self._conn.add(apply_properties(TAddress(), kwargs)) res = TAddress(**kwargs)
self._conn.add(res)
return res
def new_record(self, **kwargs): def new_record(self, **kwargs):
assert self._conn is not None assert self._conn is not None
return self._conn.add(apply_properties(TRecord(), kwargs)) res = TRecord(**kwargs)
self._conn.add(res)
return res
def new_message(self, **kwargs): def new_message(self, **kwargs):
assert self._conn is not None assert self._conn is not None
return self._conn.add(apply_properties(TMessage(), kwargs)) res = TMessage(**kwargs)
self._conn.add(res)
return res
def _fire(self, statement): def _fire(self, statement):
assert self._conn is not None assert self._conn is not None
self._conn.execute(statement, noresult=True) self._conn.execute(statement).close()
def setup_executions(self): def setup_executions(self):
self._fire(create_table_statements["executions"]) pass
def setup_exceptions(self): def setup_exceptions(self):
# careful not to call them "errors" since NXDOMAIN is not an error. # careful not to call them "errors" since NXDOMAIN is not an error.
self._fire(create_table_statements["exceptions"])
# TODO: upsert? # TODO: upsert?
self.new_exception(name="NXDOMAIN", fail=False) self.new_exception(name="NXDOMAIN", fail=False)
self.new_exception(name="NoAnswer", fail=True) self.new_exception(name="NoAnswer", fail=True)
self.new_exception(name="NoNameservers", fail=True) self.new_exception(name="NoNameservers", fail=True)
self.new_exception(name="Timeout", fail=True) self.new_exception(name="Timeout", fail=True)
def setup_ips(self): def setup_ips(self):
self._fire(create_table_statements["ips"])
self.modify_address("0.0.0.0", block_target=True) self.modify_address("0.0.0.0", block_target=True)
self.modify_address("127.0.0.1", block_target=True) self.modify_address("127.0.0.1", block_target=True)
def setup_kinds(self): def setup_kinds(self):
self._fire(create_table_statements["kinds"]) pass
def setup_domains(self): def setup_domains(self):
self._fire(create_table_statements["domains"]) pass
def setup_records(self): def setup_records(self):
self._fire(create_table_statements["records"]) pass
def setup_messages(self): def setup_messages(self):
self._fire(create_table_statements["messages"]) pass
for trig in table_triggers["messages"]:
self._conn.execute(trig)
def start_execution(self, dt): def start_execution(self, dt):
execution = TExecution() execution = TExecution()
@ -167,18 +196,19 @@ class RespoDB:
def all_ips(self): def all_ips(self):
assert self._conn is not None assert self._conn is not None
addresses = self._conn.find(TAddress) addresses = self._conn.query(TAddress).fetchall()
return [addr.str for addr in addresses] return [addr.str for addr in addresses]
def next_record_id(self): def next_record_id(self):
from storm.expr import Add, Max, Coalesce from sqlalchemy.sql.expression import func
expr = Add(Coalesce(Max(TRecord.record_id), 0), 1) expr = func.coalesce(func.max(TRecord.record_id), 0) + 1
return self.find_one(expr) return self.find_one(expr)[0]
def find_record_id(self, addresses): def find_record_id(self, addresses):
address_ids = list(address.address_id for address in addresses) address_ids = list(address.address_id for address in addresses)
temp = self._conn.find(TRecord, TRecord.address_id.is_in(address_ids)) temp = self._conn.query(TRecord).filter(TRecord.address_id.in_(address_ids))
record_ids = list(temp.values(TRecord.record_id)) record_ids = [t[0] for t in temp.values(TRecord.record_id)]
# TODO: why are record_ids even tuples to begin with?
if not record_ids: if not record_ids:
return None return None
unique_ids = sorted(set(record_ids)) unique_ids = sorted(set(record_ids))
@ -189,16 +219,15 @@ class RespoDB:
return None return None
def push_entry(self, entry): def push_entry(self, entry):
kind = self.find_one(TKind, TKind.name == entry.kind) kind = self.find_one(TKind, name=entry.kind)
if not kind: if not kind:
kind = self.new_kind(name=entry.kind) kind = self.new_kind(name=entry.kind)
if entry.kind.startswith("bad"): if entry.kind.startswith("bad"):
exception = self.find_one(TException, exception = self.find_one(TException, name="NXDOMAIN")
TException.name == "NXDOMAIN")
assert exception is not None assert exception is not None
kind.exception = exception kind.exception = exception
domain = self.find_one(TDomain, TDomain.name == entry.domain) domain = self.find_one(TDomain, name=entry.domain)
if not domain: if not domain:
domain = self.new_domain(name=entry.domain) domain = self.new_domain(name=entry.domain)
domain.kind = kind domain.kind = kind
@ -206,7 +235,7 @@ class RespoDB:
addresses = [] addresses = []
as_ints = sorted(set(map(addr_to_int, entry.addrs))) as_ints = sorted(set(map(addr_to_int, entry.addrs)))
for numeric in as_ints: for numeric in as_ints:
address = self.find_one(TAddress, TAddress.ip == numeric) address = self.find_one(TAddress, ip=numeric)
if not address: if not address:
address = self.new_address(ip=numeric) address = self.new_address(ip=numeric)
addresses.append(address) addresses.append(address)
@ -229,15 +258,14 @@ class RespoDB:
record_id = None record_id = None
numeric = addr_to_int(entry.server) numeric = addr_to_int(entry.server)
server = self.find_one(TAddress, TAddress.ip == numeric) server = self.find_one(TAddress, ip=numeric)
if not server: if not server:
server = self.new_address(ip=numeric) server = self.new_address(ip=numeric)
self.flush() self.flush()
server.server = True server.server = True
if entry.exception: if entry.exception:
exception = self.find_one(TException, exception = self.find_one(TException, name=entry.exception)
TException.name == entry.exception)
assert exception is not None assert exception is not None
else: else:
exception = None exception = None
@ -253,7 +281,7 @@ class RespoDB:
def country_code(self, ip, code=None): def country_code(self, ip, code=None):
numeric = addr_to_int(ip) numeric = addr_to_int(ip)
address = self.find_one(TAddress, TAddress.ip == numeric) address = self.find_one(TAddress, ip=numeric)
if code is None: if code is None:
if address is not None: if address is not None:
return address.country_code return address.country_code
@ -268,7 +296,7 @@ class RespoDB:
def modify_address(self, ip, **kwargs): def modify_address(self, ip, **kwargs):
numeric = addr_to_int(ip) numeric = addr_to_int(ip)
address = self.find_one(TAddress, TAddress.ip == numeric) address = self.find_one(TAddress, ip=numeric)
if address is None: if address is None:
self.new_address(ip=numeric, **kwargs) self.new_address(ip=numeric, **kwargs)
else: else:

View File

@ -1,117 +0,0 @@
create_table_statements = dict(
# TODO: Duration REAL GENERATED ALWAYS AS etc.?
executions="""
CREATE TABLE IF NOT EXISTS Executions (
ExecutionId INTEGER PRIMARY KEY,
StartDate DATE NOT NULL,
FinishDate DATE,
Completed BOOLEAN DEFAULT 0 NOT NULL)
""",
exceptions="""
CREATE TABLE IF NOT EXISTS Exceptions (
ExceptionId INTEGER PRIMARY KEY,
Name TEXT NOT NULL,
Fail BOOLEAN NOT NULL)
""",
ips="""
CREATE TABLE IF NOT EXISTS Ips (
IpId INTEGER PRIMARY KEY,
AsStr TEXT GENERATED ALWAYS AS (
Cast(AsInt >> 24 & 255 AS TEXT) || '.' ||
Cast(AsInt >> 16 & 255 AS TEXT) || '.' ||
Cast(AsInt >> 8 & 255 AS TEXT) || '.' ||
Cast(AsInt & 255 AS TEXT)
) STORED NOT NULL,
AsInt INTEGER UNIQUE CHECK(AsInt >= 0 AND AsInt < 1 << 32) NOT NULL,
China BOOLEAN DEFAULT 0 NOT NULL,
BlockTarget BOOLEAN DEFAULT 0 NOT NULL,
Server BOOLEAN DEFAULT 0 NOT NULL,
RedirectTarget BOOLEAN DEFAULT 0 NOT NULL,
GfwTarget BOOLEAN DEFAULT 0 NOT NULL,
CountryCode TEXT)
""",
kinds="""
CREATE TABLE IF NOT EXISTS Kinds (
KindId INTEGER PRIMARY KEY,
Name TEXT UNIQUE NOT NULL,
ExpectExceptionId INTEGER,
FOREIGN KEY(ExpectExceptionId) REFERENCES Exceptions(ExceptionId))
""",
domains="""
CREATE TABLE IF NOT EXISTS Domains (
DomainId INTEGER PRIMARY KEY,
Name TEXT UNIQUE NOT NULL,
KindId INTEGER,
FOREIGN KEY(KindId) REFERENCES Kinds(KindId))
""",
# NOTE: that RecordId is *not* the rowid here
# since records can contain multiple IPs,
# and thereby span multiple rows.
# TODO: indexing stuff, cascade deletion stuff.
records="""
CREATE TABLE IF NOT EXISTS Records (
RecordId INTEGER NOT NULL,
IpId INTEGER,
FOREIGN KEY(IpId) REFERENCES Ips(IpId))
""",
messages="""
CREATE TABLE IF NOT EXISTS Messages (
MessageId INTEGER PRIMARY KEY,
ExecutionId INTEGER,
ServerId INTEGER NOT NULL,
DomainId INTEGER NOT NULL,
RecordId INTEGER,
ExceptionId INTEGER,
Failed BOOLEAN DEFAULT 0 NOT NULL,
FOREIGN KEY(ServerId) REFERENCES Ips(IpId),
FOREIGN KEY(ExecutionId) REFERENCES Executions(ExecutionId),
FOREIGN KEY(DomainId) REFERENCES Domains(DomainId),
FOREIGN KEY(ExceptionId) REFERENCES Exceptions(ExceptionId))
""",
# this fails because RecordId is not UNIQUE:
# FOREIGN KEY(RecordId) REFERENCES Records(RecordId)
)
create_view_statements = [
"""
CREATE VIEW Results AS
SELECT
Messages.ExecutionId,
ServerIps.AsStr as Server,
Kinds.Name as Kind,
Domains.Name as Name,
RecordIps.AsStr as Address,
Exceptions.Name as Exception,
Messages.Failed as Failed
FROM Messages
LEFT JOIN Domains ON Messages.DomainId = Domains.DomainId
LEFT JOIN Kinds ON Domains.KindId = Kinds.KindId
LEFT JOIN Ips AS ServerIps ON Messages.ServerId = ServerIps.IpId
LEFT JOIN Records ON Messages.RecordId = Records.RecordId
LEFT JOIN Ips as RecordIps ON Records.IpId = RecordIps.IpId
LEFT JOIN Exceptions ON Messages.ExceptionId = Exceptions.ExceptionId
""",
]
table_triggers = dict(
messages=[
# TODO: more triggers. (before update, and also for Records table)
"""
CREATE TRIGGER IF NOT EXISTS RecordExists
BEFORE INSERT
ON Messages
BEGIN
SELECT CASE
WHEN NEW.RecordId NOTNULL AND NOT EXISTS(
SELECT 1 FROM Records WHERE Records.RecordID = NEW.RecordId)
THEN raise(FAIL, "RecordId does not exist")
END;
END
""",
])

View File

@ -1,69 +1,125 @@
from .util import AttrCheck from .util import AttrCheck
import storm.locals as rain from sqlalchemy import Column, String, Integer, Computed, Boolean, DateTime
from sqlalchemy import ForeignKey, UniqueConstraint, CheckConstraint
from sqlalchemy import column, cast
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
Base = declarative_base()
class TException(rain.Storm, AttrCheck): class TException(Base, AttrCheck):
__storm_table__ = "Exceptions" __tablename__ = "Exceptions"
exception_id = rain.Int("ExceptionId", primary=True) exception_id = Column("ExceptionId", Integer, primary_key=True)
name = rain.Unicode("Name") name = Column("Name", String, nullable=False)
fail = rain.Bool("Fail") fail = Column("Fail", Boolean, nullable=False)
# children:
kinds = relationship("TKind", back_populates="exception")
messages = relationship("TMessage", back_populates="exception")
class TExecution(rain.Storm, AttrCheck): class TExecution(Base, AttrCheck):
__storm_table__ = "Executions" __tablename__ = "Executions"
execution_id = rain.Int("ExecutionId", primary=True) execution_id = Column("ExecutionId", Integer, primary_key=True)
start_date = rain.DateTime("StartDate") start_date = Column("StartDate", DateTime, nullable=False)
finish_date = rain.DateTime("FinishDate") finish_date = Column("FinishDate", DateTime)
completed = rain.Bool("Completed") completed = Column("Completed", Boolean, default=0, nullable=False)
# children:
messages = relationship("TMessage", back_populates="execution")
class TAddress(rain.Storm, AttrCheck): as_int = column("AsInt")
__storm_table__ = "Ips"
address_id = rain.Int("IpId", primary=True)
str = rain.Unicode("AsStr")
ip = rain.Int("AsInt")
china = rain.Bool("China")
block_target = rain.Bool("BlockTarget")
server = rain.Bool("Server")
redirect_target = rain.Bool("RedirectTarget")
gfw_target = rain.Bool("GfwTarget")
country_code = rain.Unicode("CountryCode")
class TKind(rain.Storm, AttrCheck): class TAddress(Base, AttrCheck):
__storm_table__ = "Kinds" __tablename__ = "Ips"
kind_id = rain.Int("KindId", primary=True) __table_args__ = (
name = rain.Unicode("Name") UniqueConstraint("AsInt"),
xxid = rain.Int("ExpectExceptionId") CheckConstraint((as_int >= 0) & (as_int < 2**32)),
exception = rain.Reference(xxid, "TException.exception_id") )
address_id = Column("IpId", Integer, primary_key=True)
str = Column("AsStr", String, Computed(
cast(as_int.op(">>")(24).op("&")(255), String) + "." +
cast(as_int.op(">>")(16).op("&")(255), String) + "." +
cast(as_int.op(">>")(8).op("&")(255), String) + "." +
cast(as_int.op("&")(255), String)
), nullable=False)
ip = Column("AsInt", Integer, nullable=False)
china = Column("China", Boolean, default=0, nullable=False)
block_target = Column("BlockTarget", Boolean, default=0, nullable=False)
server = Column("Server", Boolean, default=0, nullable=False)
redirect_target = Column("RedirectTarget", Boolean, default=0, nullable=False)
gfw_target = Column("GfwTarget", Boolean, default=0, nullable=False)
country_code = Column("CountryCode", String)
# children:
messages = relationship("TMessage", back_populates="server")
records = relationship("TRecord", back_populates="address")
class TDomain(rain.Storm, AttrCheck): class TKind(Base, AttrCheck):
__storm_table__ = "Domains" __tablename__ = "Kinds"
domain_id = rain.Int("DomainId", primary=True) __table_args__ = (
name = rain.Unicode("Name") UniqueConstraint("Name"),
kind_id = rain.Int("KindId") )
kind = rain.Reference(kind_id, "TKind.kind_id") kind_id = Column("KindId", Integer, primary_key=True)
name = Column("Name", String, nullable=False)
xxid = Column("ExpectExceptionId", Integer, ForeignKey("Exceptions.ExceptionId"))
# parents:
exception = relationship("TException", back_populates="kinds")
#exception = Column(xxid, Reference, "TException.exception_id")
# children:
domains = relationship("TDomain", back_populates="kind")
class TRecord(rain.Storm, AttrCheck): class TDomain(Base, AttrCheck):
__storm_table__ = "Records" __tablename__ = "Domains"
row_id = rain.Int("rowid", primary=True) __table_args__ = (
record_id = rain.Int("RecordId") UniqueConstraint("Name"),
address_id = rain.Int("IpId") )
address = rain.Reference(address_id, "TAddress.address_id") domain_id = Column("DomainId", Integer, primary_key=True)
name = Column("Name", String, nullable=False)
kind_id = Column("KindId", Integer, ForeignKey("Kinds.KindId"))
# parents:
kind = relationship("TKind", back_populates="domains")
#kind = Column(kind_id, Reference, "TKind.kind_id")
# children:
messages = relationship("TMessage", back_populates="domain")
class TMessage(rain.Storm, AttrCheck): class TRecord(Base, AttrCheck):
__storm_table__ = "Messages" __tablename__ = "Records"
message_id = rain.Int("MessageId", primary=True) row_id = Column("rowid", Integer, primary_key=True)
execution_id = rain.Int("ExecutionId") record_id = Column("RecordId", Integer, nullable=False)
server_id = rain.Int("ServerId") address_id = Column("IpId", Integer, ForeignKey("Ips.IpId"))
domain_id = rain.Int("DomainId")
record_id = rain.Int("RecordId") # parents:
exception_id = rain.Int("ExceptionId") address = relationship("TAddress", back_populates="records")
failed = rain.Bool("Failed") #address = Column(address_id, Reference, "TAddress.address_id")
execution = rain.Reference(execution_id, "TExecution.execution_id")
server = rain.Reference(server_id, "TAddress.address_id")
domain = rain.Reference(domain_id, "TDomain.domain_id") class TMessage(Base, AttrCheck):
exception = rain.Reference(exception_id, "TException.exception_id") __tablename__ = "Messages"
message_id = Column("MessageId", Integer, primary_key=True)
execution_id = Column("ExecutionId", Integer, ForeignKey("Executions.ExecutionId"))
server_id = Column("ServerId", Integer, ForeignKey("Ips.IpId"), nullable=False)
domain_id = Column("DomainId", Integer, ForeignKey("Domains.DomainId"), nullable=False)
record_id = Column("RecordId", Integer)
exception_id = Column("ExceptionId", Integer, ForeignKey("Exceptions.ExceptionId"))
failed = Column("Failed", Boolean, nullable=False)
# parents:
execution = relationship("TExecution", back_populates="messages")
server = relationship("TAddress", back_populates="messages")
domain = relationship("TDomain", back_populates="messages")
exception = relationship("TException", back_populates="messages")
#execution = Column(execution_id, Reference, "TExecution.execution_id")
#server = Column(server_id, Reference, "TAddress.address_id")
#domain = Column(domain_id, Reference, "TDomain.domain_id")
#exception = Column(exception_id, Reference, "TException.exception_id")