backyard/library/plots.py

122 lines
3.8 KiB
Python

import matplotlib.pyplot as plt
class Plot:
defaults = dict(figsize=(6, 4), dpi=128)
def __init__(self, show=True, **kwargs):
self.kwargs = self.defaults.copy()
self.kwargs.update(kwargs)
self.show = show
def __enter__(self):
self.fig = plt.figure(**self.kwargs)
self.ax = self.fig.gca()
self.fig.patch.set_facecolor("white") # hacky
return self.fig, self.ax
def __exit__(self, exc_type, exc_val, exc_tb):
from sys import stdout, stderr
stdout.flush() # doesn't work?
stderr.flush() # doesn't work?
if self.show == "terminal":
from io import BytesIO
from autopsy import termshow
buf = BytesIO()
self.fig.savefig(buf, dpi=self.kwargs.get("dpi", 96), format="png")
termshow(buf)
elif self.show:
plt.show()
plt.close(self.fig)
class SquarePlot(Plot): # TODO: delete me or fix me; not terribly useful.
defaults = dict(figsize=(7, 7), dpi=128)
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __enter__(self):
ret = super().__enter__()
self.ax.set_aspect("equal", "box")
self.ax.grid(True, "major", alpha=0.8)
self.ax.grid(True, "minor", alpha=0.4)
return ret
class CleanPlot(Plot):
def __enter__(self):
super().__enter__()
self.ax.set_axis_off()
self.ax.set_position([0, 0, 1, 1])
return self.fig, self.ax
class CleanPlotUnity(CleanPlot):
def __init__(self, size=1, dim=None, **kwargs):
if "figsize" not in kwargs and size is not None:
kwargs["figsize"] = (size, size)
if "dpi" not in kwargs and dim is not None:
kwargs["dpi"] = dim
super().__init__(**kwargs)
def __enter__(self):
super().__enter__()
self.ax.set_xlim(0, 1)
self.ax.set_ylim(0, 1)
return self.fig, self.ax
class SquareCenterPlot(Plot):
defaults = dict(figsize=(7, 7), dpi=128)
def __init__(self, size=5, subdivisions=4, **kwargs):
self.size = size
self.subdivisions = subdivisions
super().__init__(**kwargs)
def __enter__(self):
ret = super().__enter__()
self.ax.set_aspect("equal", "box")
self.ax.set_xlim(-self.size, self.size)
self.ax.set_ylim(-self.size, self.size)
self.ax.spines["left"].set_position("zero")
self.ax.spines["bottom"].set_position("zero")
self.ax.spines["right"].set_color("none")
self.ax.spines["top"].set_color("none")
self.ax.xaxis.set_ticks_position("bottom")
self.ax.yaxis.set_ticks_position("left")
p = self.subdivisions
ticks = [x for x in range(-self.size * p, self.size * p + 1) if x != 0]
major_ticks = [x / p for x in ticks if x % p == 0]
minor_ticks = [x / p for x in ticks if x % p != 0]
self.ax.set_xticks(major_ticks, minor=False)
self.ax.set_yticks(major_ticks, minor=False)
self.ax.set_xticks(minor_ticks, minor=True)
self.ax.set_yticks(minor_ticks, minor=True)
self.ax.grid(True, "major", alpha=0.8)
self.ax.grid(True, "minor", alpha=0.4)
return ret
class GridPlot:
defaults = dict(figsize=(9, 6), dpi=128)
def __init__(self, rows, cols, show=True, **kwargs):
self.rows = int(rows)
self.cols = int(cols)
self.kwargs = self.defaults.copy()
self.kwargs.update(kwargs)
self.show = bool(show)
def __enter__(self):
self.fig, self.axes = plt.subplots(self.rows, self.cols, **self.kwargs)
self.fig.patch.set_facecolor("white")
return self.fig, self.axes
def __exit__(self, exc_type, exc_val, exc_tb):
self.fig.tight_layout()
if self.show:
plt.show()
plt.close(self.fig)