From 50e0311051d9b28262c5ffb395899cb7c7653315 Mon Sep 17 00:00:00 2001 From: Connor Olding Date: Tue, 5 Feb 2019 22:16:46 +0100 Subject: [PATCH] allow passing model through Ritual init --- onn/ritual_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onn/ritual_base.py b/onn/ritual_base.py index 80c6366..d5c51fa 100644 --- a/onn/ritual_base.py +++ b/onn/ritual_base.py @@ -9,9 +9,11 @@ Losses = namedtuple("Losses", ["avg_loss", "avg_mloss", "losses", "mlosses"]) class Ritual: # i'm just making up names at this point. - def __init__(self, learner=None): + def __init__(self, learner=None, model=None): self.learner = learner if learner is not None else Learner(Optimizer()) self.model = None + if model is not None: + self.prepare(model) def reset(self): self.learner.reset(optim=True)