Skip to content

Commit 1716d14

Browse files
committed
Parametrise CrossMeshInterpolator tests
1 parent 122a61a commit 1716d14

File tree

1 file changed

+44
-48
lines changed

1 file changed

+44
-48
lines changed

tests/firedrake/regression/test_interpolate_cross_mesh.py

Lines changed: 44 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -277,26 +277,6 @@ def parameters(request):
277277
return m_src, m_dest, coords, expr_src, expr_dest, expected, V_src, V_dest, V_dest_2
278278

279279

280-
def test_interpolate_cross_mesh(parameters):
281-
(
282-
m_src,
283-
m_dest,
284-
coords,
285-
expr_src,
286-
expr_dest,
287-
expected,
288-
V_src,
289-
V_dest,
290-
V_dest_2,
291-
) = parameters
292-
get_expected_values(
293-
m_src, m_dest, V_src, V_dest, coords, expected, expr_src, expr_dest
294-
)
295-
get_expected_values(
296-
m_src, m_dest, V_src, V_dest_2, coords, expected, expr_src, expr_dest
297-
)
298-
299-
300280
def test_interpolate_unitsquare_mixed():
301281
# this has to be in its own test because UFL expressions on mixed function
302282
# spaces are not supported.
@@ -461,28 +441,10 @@ def test_interpolate_cross_mesh_not_point_eval():
461441
)
462442

463443

464-
def get_expected_values(
465-
m_src, m_dest, V_src, V_dest, coords, expected, expr_src, expr_dest
466-
):
467-
if m_src.name == "src_sphere" and m_dest.name == "dest_sphere":
468-
# Between immersed manifolds we will often be doing projection so we
469-
# need a higher tolerance for our tests
470-
atol = 1e-3
471-
else:
472-
atol = 1e-8 # default
473-
interpolate_function(
474-
m_src, m_dest, V_src, V_dest, coords, expected, expr_src, expr_dest, atol
475-
)
476-
interpolate_expression(
477-
m_src, m_dest, V_src, V_dest, coords, expected, expr_src, expr_dest, atol
478-
)
479-
480-
481444
def interpolate_function(
482445
m_src, m_dest, V_src, V_dest, coords, expected, expr_src, expr_dest, atol
483446
):
484-
f_src = Function(V_src).interpolate(expr_src)
485-
f_dest = assemble(interpolate(f_src, V_dest))
447+
f_dest = Function(V_dest).interpolate(expr_src)
486448
assert extract_unique_domain(f_dest) is m_dest
487449
got = f_dest.at(coords)
488450
assert np.allclose(got, expected, atol=atol)
@@ -496,17 +458,17 @@ def interpolate_function(
496458
f_dest_2 = Function(V_dest).interpolate(expr_dest)
497459
assert np.allclose(f_dest.dat.data_ro, f_dest_2.dat.data_ro, atol=atol)
498460

499-
# works with Function interpolate method
461+
# test Function.interpolate(...)
500462
f_dest = Function(V_dest)
501463
f_dest.interpolate(f_src)
502464
assert extract_unique_domain(f_dest) is m_dest
503465
got = f_dest.at(coords)
504466
assert np.allclose(got, expected, atol=atol)
505467
assert np.allclose(f_dest.dat.data_ro, f_dest_2.dat.data_ro, atol=atol)
506468

507-
# output argument works
469+
# test assemble(interpolate(Function, ...), tensor=...)
508470
f_dest = Function(V_dest)
509-
assemble(Interpolate(f_src, V_dest), tensor=f_dest)
471+
assemble(interpolate(f_src, V_dest), tensor=f_dest)
510472
assert extract_unique_domain(f_dest) is m_dest
511473
got = f_dest.at(coords)
512474
assert np.allclose(got, expected, atol=atol)
@@ -516,27 +478,35 @@ def interpolate_function(
516478
def interpolate_expression(
517479
m_src, m_dest, V_src, V_dest, coords, expected, expr_src, expr_dest, atol
518480
):
519-
f_src = Function(V_src).interpolate(expr_src)
520-
521-
f_dest = assemble(interpolate(expr_src, V_dest))
481+
f_dest = Function(V_dest).interpolate(expr_src)
522482
assert extract_unique_domain(f_dest) is m_dest
483+
484+
# test Function.interpolate(...)
523485
got = f_dest.at(coords)
524486
assert np.allclose(got, expected, atol=atol)
525487
f_dest_2 = Function(V_dest).interpolate(expr_dest)
526488
assert np.allclose(f_dest.dat.data_ro, f_dest_2.dat.data_ro, atol=atol)
527489

528-
# output argument works for expressions
490+
# test assemble(interpolate(expr, ...), tensor=...)
529491
f_dest = Function(V_dest)
530-
assemble(Interpolate(expr_src, V_dest), tensor=f_dest)
492+
assemble(interpolate(expr_src, V_dest), tensor=f_dest)
531493
assert extract_unique_domain(f_dest) is m_dest
532494
got = f_dest.at(coords)
533495
assert np.allclose(got, expected, atol=atol)
534496
assert np.allclose(f_dest.dat.data_ro, f_dest_2.dat.data_ro, atol=atol)
535497

536-
# adjoint
498+
499+
def interpolate_cofunction(
500+
m_src, m_dest, V_src, V_dest, coords, expected, expr_src, expr_dest, atol
501+
):
502+
f_dest = Function(V_dest).interpolate(expr_src)
503+
assert extract_unique_domain(f_dest) is m_dest
504+
537505
cofunction_dest = assemble(inner(f_dest, TestFunction(V_dest)) * dx)
538506
cofunction_dest_on_src = assemble(interpolate(TestFunction(V_src), cofunction_dest))
539507
assert cofunction_dest_on_src.function_space().mesh() is m_src
508+
509+
f_src = Function(V_src).interpolate(expr_src)
540510
assert np.isclose(
541511
assemble(action(cofunction_dest_on_src, f_src)),
542512
assemble(action(cofunction_dest, f_dest)), atol=atol
@@ -552,6 +522,32 @@ def interpolate_expression(
552522
)
553523

554524

525+
@pytest.mark.parametrize("space", [0, 1])
526+
@pytest.mark.parametrize("run_test", [interpolate_expression, interpolate_function, interpolate_cofunction])
527+
def test_interpolate_cross_mesh(run_test, space, parameters):
528+
(
529+
m_src,
530+
m_dest,
531+
coords,
532+
expr_src,
533+
expr_dest,
534+
expected,
535+
V_src,
536+
V_dest,
537+
V_dest_2,
538+
) = parameters
539+
V_dest = (V_dest, V_dest_2)[space]
540+
if m_src.name == "src_sphere" and m_dest.name == "dest_sphere":
541+
# Between immersed manifolds we will often be doing projection so we
542+
# need a higher tolerance for our tests
543+
atol = 1e-3
544+
else:
545+
atol = 1e-8 # default
546+
run_test(
547+
m_src, m_dest, V_src, V_dest, coords, expected, expr_src, expr_dest, atol
548+
)
549+
550+
555551
def test_missing_dofs():
556552
m_src = UnitSquareMesh(2, 3)
557553
m_dest = UnitSquareMesh(4, 5)

0 commit comments

Comments
 (0)