From 9defbc8da041fe11c74eaf88f20d5111c2aeda88 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Fri, 6 Aug 2021 18:14:17 -0700 Subject: [PATCH] switch from storm to sqlalchemy --- requirements.txt | 2 +- respodns/db.py | 126 ++++++++++++++++++++-------------- respodns/sql.py | 117 -------------------------------- respodns/tables.py | 164 ++++++++++++++++++++++++++++++--------------- 4 files changed, 188 insertions(+), 221 deletions(-) delete mode 100644 respodns/sql.py diff --git a/requirements.txt b/requirements.txt index b41e259..52ace01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ dnspython >= 2.1.0 -storm +sqlalchemy >= 1.3.18 diff --git a/respodns/db.py b/respodns/db.py index 1e91cdb..7fcac9b 100644 --- a/respodns/db.py +++ b/respodns/db.py @@ -1,8 +1,30 @@ -from .sql import create_table_statements, create_view_statements -from .sql import table_triggers -from .tables import TException, TExecution, TAddress -from .tables import TKind, TDomain, TRecord, TMessage from .ip_util import addr_to_int +from .tables import Base, TException, TExecution, TAddress +from .tables import TKind, TDomain, TRecord, TMessage +from sqlalchemy.orm import sessionmaker + +Session = sessionmaker(autoflush=False) + +create_view_statements = [ + """ +CREATE VIEW Results AS +SELECT + Messages.ExecutionId, + ServerIps.AsStr as Server, + Kinds.Name as Kind, + Domains.Name as Name, + RecordIps.AsStr as Address, + Exceptions.Name as Exception, + Messages.Failed as Failed +FROM Messages +LEFT JOIN Domains ON Messages.DomainId = Domains.DomainId +LEFT JOIN Kinds ON Domains.KindId = Kinds.KindId +LEFT JOIN Ips AS ServerIps ON Messages.ServerId = ServerIps.IpId +LEFT JOIN Records ON Messages.RecordId = Records.RecordId +LEFT JOIN Ips as RecordIps ON Records.IpId = RecordIps.IpId +LEFT JOIN Exceptions ON Messages.ExceptionId = Exceptions.ExceptionId + """, +] class Execution: @@ -22,9 +44,8 @@ class Execution: def is_column(ref): - from storm.properties import PropertyColumn - from storm.references import Reference - return isinstance(ref, PropertyColumn) or isinstance(ref, Reference) + from sqlalchemy.orm.attributes import InstrumentedAttribute + return isinstance(ref, InstrumentedAttribute) def apply_properties(obj, d): @@ -38,15 +59,17 @@ def apply_properties(obj, d): class RespoDB: def __init__(self, uri, setup=False, create=False): - from storm.database import create_database + from sqlalchemy import create_engine self.uri = uri db_exists = self._db_exists(self.uri) - self.db = create_database(self.uri) + self.db = create_engine(self.uri) + Session.configure(bind=self.db) self._conn = None if setup or (create and not db_exists): with self: + Base.metadata.create_all(self.db) self.setup_executions() self.setup_exceptions() self.setup_ips() @@ -70,8 +93,7 @@ class RespoDB: return fp and exists(fp) def __enter__(self): - from storm.store import Store - self._conn = Store(self.db) + self._conn = Session() return self def __exit__(self, exc_type, exc_value, traceback): @@ -79,9 +101,11 @@ class RespoDB: self._conn.close() self._conn = None - def find_one(self, cls_spec, *args, **kwargs): - assert self._conn is not None - return self._conn.find(cls_spec, *args, **kwargs).one() + def find_one(self, cls_spec, **filters): + if len(filters) > 0: + return self._conn.query(cls_spec).filter_by(**filters).first() + else: + return self._conn.query(cls_spec).first() def flush(self): assert self._conn is not None @@ -93,65 +117,70 @@ class RespoDB: def new_exception(self, **kwargs): assert self._conn is not None - return self._conn.add(apply_properties(TException(), kwargs)) + res = TException(**kwargs) + self._conn.add(res) + return res def new_kind(self, **kwargs): assert self._conn is not None - return self._conn.add(apply_properties(TKind(), kwargs)) + res = TKind(**kwargs) + self._conn.add(res) + return res def new_domain(self, **kwargs): assert self._conn is not None - return self._conn.add(apply_properties(TDomain(), kwargs)) + res = TDomain(**kwargs) + self._conn.add(res) + return res def new_address(self, **kwargs): assert self._conn is not None - return self._conn.add(apply_properties(TAddress(), kwargs)) + res = TAddress(**kwargs) + self._conn.add(res) + return res def new_record(self, **kwargs): assert self._conn is not None - return self._conn.add(apply_properties(TRecord(), kwargs)) + res = TRecord(**kwargs) + self._conn.add(res) + return res def new_message(self, **kwargs): assert self._conn is not None - return self._conn.add(apply_properties(TMessage(), kwargs)) + res = TMessage(**kwargs) + self._conn.add(res) + return res def _fire(self, statement): assert self._conn is not None - self._conn.execute(statement, noresult=True) + self._conn.execute(statement).close() def setup_executions(self): - self._fire(create_table_statements["executions"]) + pass def setup_exceptions(self): # careful not to call them "errors" since NXDOMAIN is not an error. - self._fire(create_table_statements["exceptions"]) - # TODO: upsert? - self.new_exception(name="NXDOMAIN", fail=False) self.new_exception(name="NoAnswer", fail=True) self.new_exception(name="NoNameservers", fail=True) self.new_exception(name="Timeout", fail=True) def setup_ips(self): - self._fire(create_table_statements["ips"]) - self.modify_address("0.0.0.0", block_target=True) self.modify_address("127.0.0.1", block_target=True) def setup_kinds(self): - self._fire(create_table_statements["kinds"]) + pass def setup_domains(self): - self._fire(create_table_statements["domains"]) + pass def setup_records(self): - self._fire(create_table_statements["records"]) + pass def setup_messages(self): - self._fire(create_table_statements["messages"]) - for trig in table_triggers["messages"]: - self._conn.execute(trig) + pass def start_execution(self, dt): execution = TExecution() @@ -167,18 +196,19 @@ class RespoDB: def all_ips(self): assert self._conn is not None - addresses = self._conn.find(TAddress) + addresses = self._conn.query(TAddress).fetchall() return [addr.str for addr in addresses] def next_record_id(self): - from storm.expr import Add, Max, Coalesce - expr = Add(Coalesce(Max(TRecord.record_id), 0), 1) - return self.find_one(expr) + from sqlalchemy.sql.expression import func + expr = func.coalesce(func.max(TRecord.record_id), 0) + 1 + return self.find_one(expr)[0] def find_record_id(self, addresses): address_ids = list(address.address_id for address in addresses) - temp = self._conn.find(TRecord, TRecord.address_id.is_in(address_ids)) - record_ids = list(temp.values(TRecord.record_id)) + temp = self._conn.query(TRecord).filter(TRecord.address_id.in_(address_ids)) + record_ids = [t[0] for t in temp.values(TRecord.record_id)] + # TODO: why are record_ids even tuples to begin with? if not record_ids: return None unique_ids = sorted(set(record_ids)) @@ -189,16 +219,15 @@ class RespoDB: return None def push_entry(self, entry): - kind = self.find_one(TKind, TKind.name == entry.kind) + kind = self.find_one(TKind, name=entry.kind) if not kind: kind = self.new_kind(name=entry.kind) if entry.kind.startswith("bad"): - exception = self.find_one(TException, - TException.name == "NXDOMAIN") + exception = self.find_one(TException, name="NXDOMAIN") assert exception is not None kind.exception = exception - domain = self.find_one(TDomain, TDomain.name == entry.domain) + domain = self.find_one(TDomain, name=entry.domain) if not domain: domain = self.new_domain(name=entry.domain) domain.kind = kind @@ -206,7 +235,7 @@ class RespoDB: addresses = [] as_ints = sorted(set(map(addr_to_int, entry.addrs))) for numeric in as_ints: - address = self.find_one(TAddress, TAddress.ip == numeric) + address = self.find_one(TAddress, ip=numeric) if not address: address = self.new_address(ip=numeric) addresses.append(address) @@ -229,15 +258,14 @@ class RespoDB: record_id = None numeric = addr_to_int(entry.server) - server = self.find_one(TAddress, TAddress.ip == numeric) + server = self.find_one(TAddress, ip=numeric) if not server: server = self.new_address(ip=numeric) self.flush() server.server = True if entry.exception: - exception = self.find_one(TException, - TException.name == entry.exception) + exception = self.find_one(TException, name=entry.exception) assert exception is not None else: exception = None @@ -253,7 +281,7 @@ class RespoDB: def country_code(self, ip, code=None): numeric = addr_to_int(ip) - address = self.find_one(TAddress, TAddress.ip == numeric) + address = self.find_one(TAddress, ip=numeric) if code is None: if address is not None: return address.country_code @@ -268,7 +296,7 @@ class RespoDB: def modify_address(self, ip, **kwargs): numeric = addr_to_int(ip) - address = self.find_one(TAddress, TAddress.ip == numeric) + address = self.find_one(TAddress, ip=numeric) if address is None: self.new_address(ip=numeric, **kwargs) else: diff --git a/respodns/sql.py b/respodns/sql.py deleted file mode 100644 index 3f75b75..0000000 --- a/respodns/sql.py +++ /dev/null @@ -1,117 +0,0 @@ -create_table_statements = dict( - # TODO: Duration REAL GENERATED ALWAYS AS etc.? - executions=""" -CREATE TABLE IF NOT EXISTS Executions ( - ExecutionId INTEGER PRIMARY KEY, - StartDate DATE NOT NULL, - FinishDate DATE, - Completed BOOLEAN DEFAULT 0 NOT NULL) - """, - - exceptions=""" -CREATE TABLE IF NOT EXISTS Exceptions ( - ExceptionId INTEGER PRIMARY KEY, - Name TEXT NOT NULL, - Fail BOOLEAN NOT NULL) - """, - - ips=""" -CREATE TABLE IF NOT EXISTS Ips ( - IpId INTEGER PRIMARY KEY, - AsStr TEXT GENERATED ALWAYS AS ( - Cast(AsInt >> 24 & 255 AS TEXT) || '.' || - Cast(AsInt >> 16 & 255 AS TEXT) || '.' || - Cast(AsInt >> 8 & 255 AS TEXT) || '.' || - Cast(AsInt & 255 AS TEXT) - ) STORED NOT NULL, - AsInt INTEGER UNIQUE CHECK(AsInt >= 0 AND AsInt < 1 << 32) NOT NULL, - China BOOLEAN DEFAULT 0 NOT NULL, - BlockTarget BOOLEAN DEFAULT 0 NOT NULL, - Server BOOLEAN DEFAULT 0 NOT NULL, - RedirectTarget BOOLEAN DEFAULT 0 NOT NULL, - GfwTarget BOOLEAN DEFAULT 0 NOT NULL, - CountryCode TEXT) - """, - - kinds=""" -CREATE TABLE IF NOT EXISTS Kinds ( - KindId INTEGER PRIMARY KEY, - Name TEXT UNIQUE NOT NULL, - ExpectExceptionId INTEGER, - FOREIGN KEY(ExpectExceptionId) REFERENCES Exceptions(ExceptionId)) - """, - - domains=""" -CREATE TABLE IF NOT EXISTS Domains ( - DomainId INTEGER PRIMARY KEY, - Name TEXT UNIQUE NOT NULL, - KindId INTEGER, - FOREIGN KEY(KindId) REFERENCES Kinds(KindId)) - """, - - # NOTE: that RecordId is *not* the rowid here - # since records can contain multiple IPs, - # and thereby span multiple rows. - # TODO: indexing stuff, cascade deletion stuff. - records=""" -CREATE TABLE IF NOT EXISTS Records ( - RecordId INTEGER NOT NULL, - IpId INTEGER, - FOREIGN KEY(IpId) REFERENCES Ips(IpId)) - """, - - messages=""" -CREATE TABLE IF NOT EXISTS Messages ( - MessageId INTEGER PRIMARY KEY, - ExecutionId INTEGER, - ServerId INTEGER NOT NULL, - DomainId INTEGER NOT NULL, - RecordId INTEGER, - ExceptionId INTEGER, - Failed BOOLEAN DEFAULT 0 NOT NULL, - FOREIGN KEY(ServerId) REFERENCES Ips(IpId), - FOREIGN KEY(ExecutionId) REFERENCES Executions(ExecutionId), - FOREIGN KEY(DomainId) REFERENCES Domains(DomainId), - FOREIGN KEY(ExceptionId) REFERENCES Exceptions(ExceptionId)) - """, - # this fails because RecordId is not UNIQUE: - # FOREIGN KEY(RecordId) REFERENCES Records(RecordId) -) - -create_view_statements = [ - """ -CREATE VIEW Results AS -SELECT - Messages.ExecutionId, - ServerIps.AsStr as Server, - Kinds.Name as Kind, - Domains.Name as Name, - RecordIps.AsStr as Address, - Exceptions.Name as Exception, - Messages.Failed as Failed -FROM Messages -LEFT JOIN Domains ON Messages.DomainId = Domains.DomainId -LEFT JOIN Kinds ON Domains.KindId = Kinds.KindId -LEFT JOIN Ips AS ServerIps ON Messages.ServerId = ServerIps.IpId -LEFT JOIN Records ON Messages.RecordId = Records.RecordId -LEFT JOIN Ips as RecordIps ON Records.IpId = RecordIps.IpId -LEFT JOIN Exceptions ON Messages.ExceptionId = Exceptions.ExceptionId - """, -] - -table_triggers = dict( - messages=[ - # TODO: more triggers. (before update, and also for Records table) - """ -CREATE TRIGGER IF NOT EXISTS RecordExists -BEFORE INSERT -ON Messages -BEGIN - SELECT CASE - WHEN NEW.RecordId NOTNULL AND NOT EXISTS( - SELECT 1 FROM Records WHERE Records.RecordID = NEW.RecordId) - THEN raise(FAIL, "RecordId does not exist") - END; -END - """, - ]) diff --git a/respodns/tables.py b/respodns/tables.py index 22fda74..e3d2fc5 100644 --- a/respodns/tables.py +++ b/respodns/tables.py @@ -1,69 +1,125 @@ from .util import AttrCheck -import storm.locals as rain +from sqlalchemy import Column, String, Integer, Computed, Boolean, DateTime +from sqlalchemy import ForeignKey, UniqueConstraint, CheckConstraint +from sqlalchemy import column, cast +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship + +Base = declarative_base() -class TException(rain.Storm, AttrCheck): - __storm_table__ = "Exceptions" - exception_id = rain.Int("ExceptionId", primary=True) - name = rain.Unicode("Name") - fail = rain.Bool("Fail") +class TException(Base, AttrCheck): + __tablename__ = "Exceptions" + exception_id = Column("ExceptionId", Integer, primary_key=True) + name = Column("Name", String, nullable=False) + fail = Column("Fail", Boolean, nullable=False) + + # children: + kinds = relationship("TKind", back_populates="exception") + messages = relationship("TMessage", back_populates="exception") -class TExecution(rain.Storm, AttrCheck): - __storm_table__ = "Executions" - execution_id = rain.Int("ExecutionId", primary=True) - start_date = rain.DateTime("StartDate") - finish_date = rain.DateTime("FinishDate") - completed = rain.Bool("Completed") +class TExecution(Base, AttrCheck): + __tablename__ = "Executions" + execution_id = Column("ExecutionId", Integer, primary_key=True) + start_date = Column("StartDate", DateTime, nullable=False) + finish_date = Column("FinishDate", DateTime) + completed = Column("Completed", Boolean, default=0, nullable=False) + + # children: + messages = relationship("TMessage", back_populates="execution") -class TAddress(rain.Storm, AttrCheck): - __storm_table__ = "Ips" - address_id = rain.Int("IpId", primary=True) - str = rain.Unicode("AsStr") - ip = rain.Int("AsInt") - china = rain.Bool("China") - block_target = rain.Bool("BlockTarget") - server = rain.Bool("Server") - redirect_target = rain.Bool("RedirectTarget") - gfw_target = rain.Bool("GfwTarget") - country_code = rain.Unicode("CountryCode") +as_int = column("AsInt") -class TKind(rain.Storm, AttrCheck): - __storm_table__ = "Kinds" - kind_id = rain.Int("KindId", primary=True) - name = rain.Unicode("Name") - xxid = rain.Int("ExpectExceptionId") - exception = rain.Reference(xxid, "TException.exception_id") +class TAddress(Base, AttrCheck): + __tablename__ = "Ips" + __table_args__ = ( + UniqueConstraint("AsInt"), + CheckConstraint((as_int >= 0) & (as_int < 2**32)), + ) + address_id = Column("IpId", Integer, primary_key=True) + str = Column("AsStr", String, Computed( + cast(as_int.op(">>")(24).op("&")(255), String) + "." + + cast(as_int.op(">>")(16).op("&")(255), String) + "." + + cast(as_int.op(">>")(8).op("&")(255), String) + "." + + cast(as_int.op("&")(255), String) + ), nullable=False) + ip = Column("AsInt", Integer, nullable=False) + china = Column("China", Boolean, default=0, nullable=False) + block_target = Column("BlockTarget", Boolean, default=0, nullable=False) + server = Column("Server", Boolean, default=0, nullable=False) + redirect_target = Column("RedirectTarget", Boolean, default=0, nullable=False) + gfw_target = Column("GfwTarget", Boolean, default=0, nullable=False) + country_code = Column("CountryCode", String) + + # children: + messages = relationship("TMessage", back_populates="server") + records = relationship("TRecord", back_populates="address") -class TDomain(rain.Storm, AttrCheck): - __storm_table__ = "Domains" - domain_id = rain.Int("DomainId", primary=True) - name = rain.Unicode("Name") - kind_id = rain.Int("KindId") - kind = rain.Reference(kind_id, "TKind.kind_id") +class TKind(Base, AttrCheck): + __tablename__ = "Kinds" + __table_args__ = ( + UniqueConstraint("Name"), + ) + kind_id = Column("KindId", Integer, primary_key=True) + name = Column("Name", String, nullable=False) + xxid = Column("ExpectExceptionId", Integer, ForeignKey("Exceptions.ExceptionId")) + + # parents: + exception = relationship("TException", back_populates="kinds") + #exception = Column(xxid, Reference, "TException.exception_id") + + # children: + domains = relationship("TDomain", back_populates="kind") -class TRecord(rain.Storm, AttrCheck): - __storm_table__ = "Records" - row_id = rain.Int("rowid", primary=True) - record_id = rain.Int("RecordId") - address_id = rain.Int("IpId") - address = rain.Reference(address_id, "TAddress.address_id") +class TDomain(Base, AttrCheck): + __tablename__ = "Domains" + __table_args__ = ( + UniqueConstraint("Name"), + ) + domain_id = Column("DomainId", Integer, primary_key=True) + name = Column("Name", String, nullable=False) + kind_id = Column("KindId", Integer, ForeignKey("Kinds.KindId")) + + # parents: + kind = relationship("TKind", back_populates="domains") + #kind = Column(kind_id, Reference, "TKind.kind_id") + + # children: + messages = relationship("TMessage", back_populates="domain") -class TMessage(rain.Storm, AttrCheck): - __storm_table__ = "Messages" - message_id = rain.Int("MessageId", primary=True) - execution_id = rain.Int("ExecutionId") - server_id = rain.Int("ServerId") - domain_id = rain.Int("DomainId") - record_id = rain.Int("RecordId") - exception_id = rain.Int("ExceptionId") - failed = rain.Bool("Failed") - execution = rain.Reference(execution_id, "TExecution.execution_id") - server = rain.Reference(server_id, "TAddress.address_id") - domain = rain.Reference(domain_id, "TDomain.domain_id") - exception = rain.Reference(exception_id, "TException.exception_id") +class TRecord(Base, AttrCheck): + __tablename__ = "Records" + row_id = Column("rowid", Integer, primary_key=True) + record_id = Column("RecordId", Integer, nullable=False) + address_id = Column("IpId", Integer, ForeignKey("Ips.IpId")) + + # parents: + address = relationship("TAddress", back_populates="records") + #address = Column(address_id, Reference, "TAddress.address_id") + + +class TMessage(Base, AttrCheck): + __tablename__ = "Messages" + message_id = Column("MessageId", Integer, primary_key=True) + execution_id = Column("ExecutionId", Integer, ForeignKey("Executions.ExecutionId")) + server_id = Column("ServerId", Integer, ForeignKey("Ips.IpId"), nullable=False) + domain_id = Column("DomainId", Integer, ForeignKey("Domains.DomainId"), nullable=False) + record_id = Column("RecordId", Integer) + exception_id = Column("ExceptionId", Integer, ForeignKey("Exceptions.ExceptionId")) + failed = Column("Failed", Boolean, nullable=False) + + # parents: + execution = relationship("TExecution", back_populates="messages") + server = relationship("TAddress", back_populates="messages") + domain = relationship("TDomain", back_populates="messages") + exception = relationship("TException", back_populates="messages") + #execution = Column(execution_id, Reference, "TExecution.execution_id") + #server = Column(server_id, Reference, "TAddress.address_id") + #domain = Column(domain_id, Reference, "TDomain.domain_id") + #exception = Column(exception_id, Reference, "TException.exception_id")