diff --git a/go_benchmark_functions/go_benchmark.py b/go_benchmark_functions/go_benchmark.py index 0ee1d3c..8fb3802 100644 --- a/go_benchmark_functions/go_benchmark.py +++ b/go_benchmark_functions/go_benchmark.py @@ -2,7 +2,22 @@ import numpy as np from numpy import abs, asarray -from ..common import safe_import + +class safe_import: + def __enter__(self): + self.error = False + return self + + def __exit__(self, type_, value, traceback): + if type_ is not None: + self.error = True + suppress = not ( + os.getenv("SCIPY_ALLOW_BENCH_IMPORT_ERRORS", "1").lower() + in ("0", "false") + or not issubclass(type_, ImportError) + ) + return suppress + with safe_import(): from scipy.special import factorial