Skip to content

Commit d537876

Browse files
authored
Merge pull request #33 from EconForge/albop/extrap
FIX: interp extrapolates with closest value on the grid.
2 parents 53a1c6a + 931335f commit d537876

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

interpolation/multilinear/fungen.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
def clamp(x,a,b):
2121
return min(b,max(a,x))
2222

23+
24+
25+
2326
# returns the index of a 1d point along a 1d dimension
2427
@generated_jit(nopython=True)
2528
def get_index(gc, x):
@@ -195,3 +198,20 @@ def extract_row(a, n, tup):
195198
s = "def extract_row(a, n, tup): return ({},)".format(str.join(', ', [f"a[n,{i}]" for i in range(d)]))
196199
eval(compile(ast.parse(s),'<string>','exec'), dd)
197200
return dd['extract_row']
201+
202+
203+
204+
# find closest point inside the grid domain
205+
@generated_jit
206+
def project(grid, point):
207+
s = "def __project(grid, point):\n"
208+
d = len(grid.types)
209+
for i in range(d):
210+
if isinstance(grid.types[i], numba.types.Array):
211+
s += f" x_{i} = min(max(point[{i}], grid[{i}][0]), grid[{i}][grid[{i}].shape[0]-1])\n"
212+
else:
213+
s += f" x_{i} = min(max(point[{i}], grid[{i}][0]), grid[{i}][1])\n"
214+
s += f" return ({str.join(', ', ['x_{}'.format(i) for i in range(d)])},)"
215+
d = {}
216+
eval(compile(ast.parse(s),'<string>','exec'), d)
217+
return d['__project']

interpolation/multilinear/mlinterp.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# Actual interpolation function #
1919
#################################
2020

21-
from .fungen import fmap, funzip, get_coeffs, tensor_reduction, get_index, extract_row
21+
from .fungen import fmap, funzip, get_coeffs, tensor_reduction, get_index, extract_row, project
2222

2323
from numba import njit
2424
from typing import Tuple
@@ -136,7 +136,8 @@ def {funname}(*args):
136136
grid = {grid_s}
137137
C = args[{it.d}]
138138
point = {point_s}
139-
res = mlinterp(grid, C, point)
139+
ppoint = project(grid, point)
140+
res = mlinterp(grid, C, ppoint)
140141
return res
141142
"""
142143
return source
@@ -152,7 +153,8 @@ def {funname}(*args):
152153
res = zeros(N)
153154
# return res
154155
for n in range(N):
155-
res[n] = mlinterp(grid, C, {p_s})
156+
ppoint = project(grid, {p_s})
157+
res[n] = mlinterp(grid, C, ppoint)
156158
return res
157159
"""
158160
return source
@@ -168,7 +170,8 @@ def {funname}(*args):
168170
N = points_x.shape[0]
169171
res = zeros(N)
170172
for n in range(N):
171-
res[n] = mlinterp(grid, C, (points_x[n],))
173+
ppoint = project(grid,(points_x[n],))
174+
res[n] = mlinterp(grid, C, ppoint)
172175
return res
173176
"""
174177
elif it.d==2:
@@ -184,7 +187,8 @@ def {funname}(*args):
184187
res = zeros((N,M))
185188
for n in range(N):
186189
for m in range(M):
187-
res[n,m] = mlinterp(grid, C, (points_x[n], points_y[m]))
190+
ppoint = project(grid,(points_x[n], points_y[m]))
191+
res[n,m] = mlinterp(grid, C, ppoint)
188192
return res
189193
"""
190194
else:

0 commit comments

Comments
 (0)