diff --git a/library/plots.py b/library/plots.py new file mode 100644 index 0000000..37fbd84 --- /dev/null +++ b/library/plots.py @@ -0,0 +1,121 @@ +import matplotlib.pyplot as plt + + +class Plot: + defaults = dict(figsize=(6, 4), dpi=128) + + def __init__(self, show=True, **kwargs): + self.kwargs = self.defaults.copy() + self.kwargs.update(kwargs) + self.show = show + + def __enter__(self): + self.fig = plt.figure(**self.kwargs) + self.ax = self.fig.gca() + self.fig.patch.set_facecolor("white") # hacky + return self.fig, self.ax + + def __exit__(self, exc_type, exc_val, exc_tb): + from sys import stdout, stderr + stdout.flush() # doesn't work? + stderr.flush() # doesn't work? + if self.show == "terminal": + from io import BytesIO + from autopsy import termshow + buf = BytesIO() + self.fig.savefig(buf, dpi=self.kwargs.get("dpi", 96), format="png") + termshow(buf) + elif self.show: + plt.show() + plt.close(self.fig) + + +class SquarePlot(Plot): # TODO: delete me or fix me; not terribly useful. + defaults = dict(figsize=(7, 7), dpi=128) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def __enter__(self): + ret = super().__enter__() + self.ax.set_aspect("equal", "box") + self.ax.grid(True, "major", alpha=0.8) + self.ax.grid(True, "minor", alpha=0.4) + return ret + + +class CleanPlot(Plot): + def __enter__(self): + super().__enter__() + self.ax.set_axis_off() + self.ax.set_position([0, 0, 1, 1]) + return self.fig, self.ax + + +class CleanPlotUnity(CleanPlot): + def __init__(self, size=1, dim=None, **kwargs): + if "figsize" not in kwargs and size is not None: + kwargs["figsize"] = (size, size) + if "dpi" not in kwargs and dim is not None: + kwargs["dpi"] = dim + super().__init__(**kwargs) + + def __enter__(self): + super().__enter__() + self.ax.set_xlim(0, 1) + self.ax.set_ylim(0, 1) + return self.fig, self.ax + + +class SquareCenterPlot(Plot): + defaults = dict(figsize=(7, 7), dpi=128) + + def __init__(self, size=5, subdivisions=4, **kwargs): + self.size = size + self.subdivisions = subdivisions + super().__init__(**kwargs) + + def __enter__(self): + ret = super().__enter__() + self.ax.set_aspect("equal", "box") + self.ax.set_xlim(-self.size, self.size) + self.ax.set_ylim(-self.size, self.size) + self.ax.spines["left"].set_position("zero") + self.ax.spines["bottom"].set_position("zero") + self.ax.spines["right"].set_color("none") + self.ax.spines["top"].set_color("none") + self.ax.xaxis.set_ticks_position("bottom") + self.ax.yaxis.set_ticks_position("left") + p = self.subdivisions + ticks = [x for x in range(-self.size * p, self.size * p + 1) if x != 0] + major_ticks = [x / p for x in ticks if x % p == 0] + minor_ticks = [x / p for x in ticks if x % p != 0] + self.ax.set_xticks(major_ticks, minor=False) + self.ax.set_yticks(major_ticks, minor=False) + self.ax.set_xticks(minor_ticks, minor=True) + self.ax.set_yticks(minor_ticks, minor=True) + self.ax.grid(True, "major", alpha=0.8) + self.ax.grid(True, "minor", alpha=0.4) + return ret + + +class GridPlot: + defaults = dict(figsize=(9, 6), dpi=128) + + def __init__(self, rows, cols, show=True, **kwargs): + self.rows = int(rows) + self.cols = int(cols) + self.kwargs = self.defaults.copy() + self.kwargs.update(kwargs) + self.show = bool(show) + + def __enter__(self): + self.fig, self.axes = plt.subplots(self.rows, self.cols, **self.kwargs) + self.fig.patch.set_facecolor("white") + return self.fig, self.axes + + def __exit__(self, exc_type, exc_val, exc_tb): + self.fig.tight_layout() + if self.show: + plt.show() + plt.close(self.fig)