Коммит 36fb7aae создал по автору Dealga McArdle's avatar Dealga McArdle
Просмотр файлов

does this introduce slight overhead?

владелец aaaedfc2
...@@ -59,6 +59,21 @@ TAU = PI * 2 ...@@ -59,6 +59,21 @@ TAU = PI * 2
TWO_PI = TAU TWO_PI = TAU
N = identity_matrix N = identity_matrix
# ---- testing numba speed ups, so expect to relocate this
local_numba_storage = {}
def get_fastest_implementation(numba, function_to_compile):
if numba:
function_name = function_to_compile.__name__
if function_name not in local_numba_storage:
local_numba_storage[function_name] = numba.njit(function_to_compile)
return local_numba_storage[function_name]
else:
return function_to_compile
# ----------------- vectorize wrapper --------------- # ----------------- vectorize wrapper ---------------
...@@ -191,8 +206,6 @@ class Spline(object): ...@@ -191,8 +206,6 @@ class Spline(object):
self._single_eval_cache[t] = result self._single_eval_cache[t] = result
return result return result
local_numba_storage = {}
class CubicSpline(Spline): class CubicSpline(Spline):
def __init__(self, vertices, tknots = None, metric = None, is_cyclic = False): def __init__(self, vertices, tknots = None, metric = None, is_cyclic = False):
""" """
...@@ -235,8 +248,7 @@ class CubicSpline(Spline): ...@@ -235,8 +248,7 @@ class CubicSpline(Spline):
if n < 2: if n < 2:
raise Exception("Cubic spline can't be built from less than 3 vertices") raise Exception("Cubic spline can't be built from less than 3 vertices")
# a = locs def calc_cubic_splines(tknots, n, locs):
def perform_stage(tknots, n, locs):
""" """
returns splines returns splines
""" """
...@@ -282,12 +294,8 @@ class CubicSpline(Spline): ...@@ -282,12 +294,8 @@ class CubicSpline(Spline):
splines[:, 4] = tknots[:-1].reshape((-1, 1)) splines[:, 4] = tknots[:-1].reshape((-1, 1))
return splines return splines
if numba: calc_cubic_splines = get_fastest_implementation(numba, calc_cubic_splines)
if 'perform_stage' not in local_numba_storage: self.splines = calc_cubic_splines(tknots, n, locs)
local_numba_storage['perform_stage'] = numba.njit(perform_stage)
perform_stage = local_numba_storage['perform_stage']
self.splines = perform_stage(tknots, n, locs)
def eval(self, t_in, tknots = None): def eval(self, t_in, tknots = None):
""" """
......
Поддерживает Markdown
0% или .
You are about to add 0 people to the discussion. Proceed with caution.
Сначала завершите редактирование этого сообщения!
Пожалуйста, зарегистрируйтесь или чтобы прокомментировать