Skip to content

Commit 29a1a7c

Browse files
authored
Homogenize Python Bindings for DirichletValues + Support for fixing Subspace Basis (#305)
1 parent 4210b89 commit 29a1a7c

File tree

15 files changed

+524
-143
lines changed

15 files changed

+524
-143
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ SPDX-License-Identifier: LGPL-3.0-or-later
5252
([#304](https://github.com/ikarus-project/ikarus/pull/304))
5353
- This can be used, for instance, to apply concentrated forces or to add spring stiffness in a particular direction.
5454
- Furthermore, a helper function to get the global index of a Lagrange node at the given global position is added.
55+
- Rework Python Interface for `DirichletValues` plus adding support to easily fix boundary DOFs of subspacebasis in C++ and Python ([#305](https://github.com/ikarus-project/ikarus/pull/305))
56+
- Rework the Python Interface for `DirichletValues` plus add support to easily fix boundary DOFs of `Subspacebasis` in C++ and Python ([#305](https://github.com/ikarus-project/ikarus/pull/305))
5557

5658
## Release v0.4 (Ganymede)
5759

docs/website/01_framework/dirichletBCs.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,25 @@ The interface of `#!cpp Ikarus::DirichletValues` is represented by the following
2424
Ikarus::DirichletValues dirichletValues2(basis); // (1)!
2525
void fixBoundaryDOFs(f); // (2)!
2626
void fixDOFs(f); // (3)!
27-
void fixIthDOF(i); // (4)!
27+
void setSingleDOF(i, flag); // (4)!
2828
const auto& basis() const; // (5)!
29-
bool isConstrained(std::size_t i) const; // (6)!
29+
bool isConstrained(i) const; // (6)!
3030
auto fixedDOFsize() const; // (7)!
3131
auto size() const ; // (8)!
32+
auto reset(); // (9)!
3233
```
3334
3435
1. Create class by inserting a global basis, [@sander2020dune] Chapter 10.
35-
2. Accepts a functor to fix boundary degrees of freedom. `f` is a functor that will be called with the boolean vector of fixed boundary.
36+
2. Accepts a functor to fix boundary degrees of freedom. `f` is a functor that will be called with the Boolean vector of fixed boundary.
3637
degrees of freedom and the usual arguments of `Dune::Functions::forEachBoundaryDOF`, as defined on page 388 of the Dune
3738
[@sander2020dune] book.
38-
3. A more general version of `fixBoundaryDOFs`. Here, a functor is to be provided that accepts a basis and the corresponding boolean
39-
4. A function that helps to fix the $i$-th degree of freedom
39+
3. A more general version of `fixBoundaryDOFs`. Here, a functor is to be provided that accepts a basis and the corresponding Boolean
40+
4. A function that helps to fix or unfix the $i$-th degree of freedom
4041
vector considering the Dirichlet degrees of freedom.
4142
5. Returns the underlying basis.
42-
6. Indicates whether the degree of freedom `i` is fixed.
43+
6. Indicates whether the degree of freedom $i$ is fixed.
4344
7. Returns the number of fixed degrees of freedom.
44-
8. Returns the number of all dirichlet degrees of freedom.
45+
8. Returns the number of all Dirichlet degrees of freedom.
46+
9. Resets the whole container
4547
4648
\bibliography

ikarus/python/dirichletvalues/dirichletvalues.hh

Lines changed: 131 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,133 @@
88

99
#pragma once
1010

11+
#include <cstdlib>
12+
#include <string>
13+
14+
#include "dune/common/classname.hh"
1115
#include <dune/functions/functionspacebases/lagrangebasis.hh>
1216
#include <dune/functions/functionspacebases/powerbasis.hh>
1317
#include <dune/grid/yaspgrid.hh>
1418
#include <dune/python/common/typeregistry.hh>
1519
#include <dune/python/functions/globalbasis.hh>
20+
#include <dune/python/functions/subspacebasis.hh>
1621
#include <dune/python/pybind11/eigen.h>
1722
#include <dune/python/pybind11/functional.h>
1823
#include <dune/python/pybind11/pybind11.h>
1924
#include <dune/python/pybind11/stl.h>
2025
#include <dune/python/pybind11/stl_bind.h>
2126

2227
#include <ikarus/finiteelements/ferequirements.hh>
28+
2329
// PYBIND11_MAKE_OPAQUE(std::vector<bool>);
2430
namespace Ikarus::Python {
2531

32+
namespace Impl {
33+
using FixBoundaryDOFsWithGlobalIndexFunction = std::function<void(Eigen::Ref<Eigen::VectorX<bool>>, int)>;
34+
35+
template <typename LV>
36+
using FixBoundaryDOFsWithLocalViewFunction = std::function<void(Eigen::Ref<Eigen::VectorX<bool>>, int, LV&)>;
37+
38+
template <typename LV, typename IS>
39+
using FixBoundaryDOFsWithIntersectionFunction =
40+
std::function<void(Eigen::Ref<Eigen::VectorX<bool>>, int, LV&, const IS&)>;
41+
42+
template <typename Basis>
43+
auto registerSubSpaceLocalView() {
44+
pybind11::module scopedf = pybind11::module::import("dune.functions");
45+
using LocalViewWrapper = Dune::Python::LocalViewWrapper<Basis>;
46+
47+
auto includes = Dune::Python::IncludeFiles{"dune/python/functions/globalbasis.hh"};
48+
49+
Dune::Python::insertClass<Basis>(scopedf, "SubspaceBasis_" + Dune::className<typename Basis::PrefixPath>(),
50+
Dune::Python::GenerateTypeName(Dune::className<Basis>()), includes);
51+
52+
auto [lv, isNew] = Dune::Python::insertClass<LocalViewWrapper>(
53+
scopedf, "LocalView_" + Dune::className<typename Basis::PrefixPath>(),
54+
Dune::Python::GenerateTypeName("Dune::Python::LocalViewWrapper", Dune::MetaType<Basis>()), includes);
55+
if (isNew) {
56+
lv.def("bind", &LocalViewWrapper::bind);
57+
lv.def("unbind", &LocalViewWrapper::unbind);
58+
lv.def("index", [](const LocalViewWrapper& localView, int index) { return localView.index(index); });
59+
lv.def("__len__", [](LocalViewWrapper& self) -> int { return self.size(); });
60+
61+
Dune::Python::Functions::registerTree<typename LocalViewWrapper::Tree>(lv);
62+
lv.def("tree", [](const LocalViewWrapper& view) { return view.tree(); });
63+
}
64+
}
65+
} // namespace Impl
66+
67+
template <class DirichletValues>
68+
void forwardCorrectFunction(DirichletValues& dirichletValues, const pybind11::function& functor, auto&& cppFunction) {
69+
using Basis = typename DirichletValues::Basis;
70+
using Intersection = typename Basis::GridView::Intersection;
71+
using BackendType = typename DirichletValues::BackendType;
72+
using MultiIndex = typename Basis::MultiIndex;
73+
74+
// Disambiguate by number of arguments
75+
pybind11::module inspect_module = pybind11::module::import("inspect");
76+
pybind11::object result = inspect_module.attr("signature")(functor).attr("parameters");
77+
size_t numParams = pybind11::len(result);
78+
79+
if (numParams == 2) {
80+
auto function = functor.template cast<const Impl::FixBoundaryDOFsWithGlobalIndexFunction>();
81+
auto lambda = [&](BackendType& vec, const MultiIndex& indexGlobal) { function(vec.vector(), indexGlobal); };
82+
cppFunction(lambda);
83+
84+
} else if (numParams == 3) {
85+
auto lambda = [&](BackendType& vec, int localIndex, auto&& lv) {
86+
using SubSpaceBasis = typename std::remove_cvref_t<decltype(lv)>::GlobalBasis;
87+
Impl::registerSubSpaceLocalView<SubSpaceBasis>();
88+
89+
using SubSpaceLocalViewWrapper = Dune::Python::LocalViewWrapper<SubSpaceBasis>;
90+
auto lvWrapper = SubSpaceLocalViewWrapper(lv);
91+
92+
auto function =
93+
functor.template cast<const Impl::FixBoundaryDOFsWithLocalViewFunction<SubSpaceLocalViewWrapper>>();
94+
function(vec.vector(), localIndex, lvWrapper);
95+
};
96+
cppFunction(lambda);
97+
98+
} else if (numParams == 4) {
99+
auto lambda = [&](BackendType& vec, int localIndex, auto&& lv, const Intersection& intersection) {
100+
using SubSpaceBasis = typename std::remove_cvref_t<decltype(lv)>::GlobalBasis;
101+
Impl::registerSubSpaceLocalView<SubSpaceBasis>();
102+
103+
using SubSpaceLocalViewWrapper = Dune::Python::LocalViewWrapper<SubSpaceBasis>;
104+
auto lvWrapper = SubSpaceLocalViewWrapper(lv);
105+
106+
auto function = functor.template cast<
107+
const Impl::FixBoundaryDOFsWithIntersectionFunction<SubSpaceLocalViewWrapper, Intersection>>();
108+
function(vec.vector(), localIndex, lvWrapper, intersection);
109+
};
110+
cppFunction(lambda);
111+
112+
} else {
113+
DUNE_THROW(Dune::NotImplemented, "fixBoundaryDOFs: A function with this signature is not supported");
114+
}
115+
}
116+
26117
/**
27118
* \brief Register Python bindings for a DirichletValues class.
28119
*
29120
* This function registers Python bindings for a DirichletValues class, allowing it to be used in Python scripts.
30121
* The registered class will have an initializer that takes a `Basis` object. It exposes several member functions to
31122
* Python:
32-
* - `fixBoundaryDOFs(f)`: Fixes boundary degrees of freedom using a user-defined function `f`.
33-
* - `fixBoundaryDOFsUsingLocalView(f)`: Fixes boundary degrees of freedom using a user-defined function `f` with a
34-
* `LocalView` argument.
35-
* - `fixBoundaryDOFsUsingLocalViewAndIntersection(f)`: Fixes boundary degrees of freedom using a user-defined
36-
* function `f` with `LocalView` and `Intersection` arguments.
37-
* - `fixDOFs(f)`: Fixes boundary degrees of freedom using a user-defined function `f` with the boolean vector and
38-
* the basis as arguments.
123+
* - `fixBoundaryDOFs(f)`: Fixes boundary degrees of freedom using a user-defined function `f` than can be called
124+
* with the following arguments:
125+
* - with the boolean vector and the global index.
126+
* - with the boolean vector, the local index and the `LocalView`.
127+
* - with the boolean vector, the local index, the `LocalView` and the `Intersection`.
128+
* - `fixDOFs(f)`: Fixes boundary degrees of freedom using a user-defined function `f` with the basis and the boolean
129+
* vector as arguments.
130+
* - `setSingleDOF(i, flag: bool): Fixes or unfixes DOF with index i
131+
* - `isConstrained(i)`: Checks whether index i is constrained
132+
* - `reset()`: Resets the whole container
133+
*
134+
* The following properties can be accessed:
135+
* - `container`: the underlying container of dirichlet flags (as a const reference)
136+
* - `size`: the size of the underlying basis
137+
* - `fixedDOFsize`: the amount of DOFs currently fixed
39138
*
40139
* \tparam DirichletValues The DirichletValues class to be registered.
41140
* \tparam options Variadic template parameters for additional options when defining the Python class.
@@ -57,60 +156,38 @@ void registerDirichletValues(pybind11::handle scope, pybind11::class_<DirichletV
57156
using Intersection = typename Basis::GridView::Intersection;
58157

59158
pybind11::module scopedf = pybind11::module::import("dune.functions");
60-
typedef Dune::Python::LocalViewWrapper<Basis> LocalViewWrapper;
61-
auto includes = Dune::Python::IncludeFiles{"dune/python/functions/globalbasis.hh"};
62-
auto lv = Dune::Python::insertClass<LocalViewWrapper>(
63-
scopedf, "LocalView",
64-
Dune::Python::GenerateTypeName("Dune::Python::LocalViewWrapper", Dune::MetaType<Basis>()), includes)
65-
.first;
159+
using LocalViewWrapper = Dune::Python::LocalViewWrapper<Basis>;
66160

67-
cls.def(pybind11::init([](const Basis& basis) { return new DirichletValues(basis); }), pybind11::keep_alive<1, 2>());
161+
auto includes = Dune::Python::IncludeFiles{"dune/python/functions/globalbasis.hh"};
162+
auto [lv, isNew] = Dune::Python::insertClass<LocalViewWrapper>(
163+
scopedf, "LocalView", Dune::Python::GenerateTypeName("Dune::Python::LocalViewWrapper", Dune::MetaType<Basis>()),
164+
includes);
68165

69-
// Eigen::Ref needed due to https://pybind11.readthedocs.io/en/stable/advanced/cast/eigen.html#pass-by-reference
70-
cls.def("fixBoundaryDOFs",
71-
[](DirichletValues& self, const std::function<void(Eigen::Ref<Eigen::VectorX<bool>>, int)>& f) {
72-
auto lambda = [&](BackendType& vec, const MultiIndex& indexGlobal) {
73-
// we explicitly only allow flat indices
74-
f(vec.vector(), indexGlobal[0]);
75-
};
76-
self.fixBoundaryDOFs(lambda);
77-
});
166+
if (isNew) {
167+
lv.def("bind", &LocalViewWrapper::bind);
168+
lv.def("unbind", &LocalViewWrapper::unbind);
169+
lv.def("index", [](const LocalViewWrapper& localView, int index) { return localView.index(index); });
170+
lv.def("__len__", [](LocalViewWrapper& self) -> int { return self.size(); });
78171

79-
cls.def("fixBoundaryDOFsUsingLocalView",
80-
[](DirichletValues& self,
81-
const std::function<void(Eigen::Ref<Eigen::VectorX<bool>>, int, LocalViewWrapper&)>& f) {
82-
auto lambda = [&](BackendType& vec, int localIndex, LocalView& lv) {
83-
auto lvWrapper = LocalViewWrapper(lv.globalBasis());
84-
// this can be simplified when
85-
// https://gitlab.dune-project.org/staging/dune-functions/-/merge_requests/418 becomes available
86-
pybind11::object obj = pybind11::cast(lv.element());
87-
lvWrapper.bind(obj);
88-
f(vec.vector(), localIndex, lvWrapper);
89-
};
90-
self.fixBoundaryDOFs(lambda);
91-
});
172+
Dune::Python::Functions::registerTree<typename LocalViewWrapper::Tree>(lv);
173+
lv.def("tree", [](const LocalViewWrapper& view) { return view.tree(); });
174+
}
175+
176+
cls.def(pybind11::init([](const Basis& basis) { return new DirichletValues(basis); }), pybind11::keep_alive<1, 2>());
92177

93-
cls.def(
94-
"fixBoundaryDOFsUsingLocalViewAndIntersection",
95-
[](DirichletValues& self,
96-
const std::function<void(Eigen::Ref<Eigen::VectorX<bool>>, int, LocalViewWrapper&, const Intersection&)>& f) {
97-
auto lambda = [&](BackendType& vec, int localIndex, LocalView& lv, const Intersection& intersection) {
98-
auto lvWrapper = LocalViewWrapper(lv.globalBasis());
99-
// this can be simplified when
100-
// https://gitlab.dune-project.org/staging/dune-functions/-/merge_requests/418 becomes available
101-
pybind11::object obj = pybind11::cast(lv.element());
102-
lvWrapper.bind(obj);
103-
f(vec.vector(), localIndex, lvWrapper, intersection);
104-
};
105-
self.fixBoundaryDOFs(lambda);
106-
});
178+
cls.def_property_readonly("container", &DirichletValues::container);
179+
cls.def_property_readonly("size", &DirichletValues::size);
180+
cls.def("__len__", &DirichletValues::size);
181+
cls.def_property_readonly("fixedDOFsize", &DirichletValues::fixedDOFsize);
182+
cls.def("isConstrained", [](DirichletValues& self, std::size_t i) -> bool { return self.isConstrained(i); });
183+
cls.def("setSingleDOF", [](DirichletValues& self, std::size_t i, bool flag) { self.setSingleDOF(i, flag); });
184+
cls.def("isConstrained", [](DirichletValues& self, MultiIndex i) -> bool { return self.isConstrained(i); });
185+
cls.def("setSingleDOF", [](DirichletValues& self, MultiIndex i, bool flag) { self.setSingleDOF(i, flag); });
186+
cls.def("reset", &DirichletValues::reset);
107187

108188
cls.def("fixDOFs",
109189
[](DirichletValues& self, const std::function<void(const Basis&, Eigen::Ref<Eigen::VectorX<bool>>)>& f) {
110-
auto lambda = [&](const Basis& basis, BackendType& vec) {
111-
// we explicitly only allow flat indices
112-
f(basis, vec.vector());
113-
};
190+
auto lambda = [&](const Basis& basis, BackendType& vec) { f(basis, vec.vector()); };
114191
self.fixDOFs(lambda);
115192
});
116193
}

ikarus/python/finiteelements/fe.hh

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ namespace Ikarus::Python {
4848
* \throws Dune::NotImplemented If the specified resultType is not supported by the finite element.
4949
*/
5050
template <class FE, class... options>
51-
void registerCalculateAt(pybind11::handle scope, pybind11::class_<FE, options...> cls, auto restultTypesTuple) {
51+
void registerCalculateAt(pybind11::handle scope, pybind11::class_<FE, options...> cls, auto resultTypesTuple) {
5252
using Traits = typename FE::Traits;
5353
using FERequirements = typename FE::Requirement;
5454
cls.def(
@@ -57,7 +57,7 @@ void registerCalculateAt(pybind11::handle scope, pybind11::class_<FE, options...
5757
std::string resType) {
5858
Eigen::VectorXd result;
5959
bool success = false;
60-
Dune::Hybrid::forEach(restultTypesTuple, [&]<typename RT>(RT i) {
60+
Dune::Hybrid::forEach(resultTypesTuple, [&]<typename RT>(RT i) {
6161
if (resType == toString(i)) {
6262
success = true;
6363
result = self.template calculateAt<RT::template Rebind>(req, local).asVec();
@@ -113,21 +113,22 @@ void registerFE(pybind11::handle scope, pybind11::class_<FE, options...> cls) {
113113
pybind11::arg("Requirement"), pybind11::arg("MatrixAffordance"), pybind11::arg("elementMatrix").noconvert());
114114

115115
pybind11::module scopedf = pybind11::module::import("dune.functions");
116+
using LocalViewWrapper = Dune::Python::LocalViewWrapper<FlatBasis>;
116117

117-
typedef Dune::Python::LocalViewWrapper<FlatBasis> LocalViewWrapper;
118-
auto includes = Dune::Python::IncludeFiles{"dune/python/functions/globalbasis.hh"};
119-
auto lv = Dune::Python::insertClass<LocalViewWrapper>(
120-
scopedf, "LocalViewWrapper",
121-
Dune::Python::GenerateTypeName("Dune::Python::LocalViewWrapperWrapper", Dune::MetaType<FlatBasis>()),
122-
includes)
123-
.first;
124-
lv.def("bind", &LocalViewWrapper::bind);
125-
lv.def("unbind", &LocalViewWrapper::unbind);
126-
lv.def("index", [](const LocalViewWrapper& localView, int index) { return localView.index(index); });
127-
lv.def("__len__", [](LocalViewWrapper& self) -> int { return self.size(); });
128-
129-
Dune::Python::Functions::registerTree<typename LocalViewWrapper::Tree>(lv);
130-
lv.def("tree", [](const LocalViewWrapper& view) { return view.tree(); });
118+
auto includes = Dune::Python::IncludeFiles{"dune/python/functions/globalbasis.hh"};
119+
auto [lv, isNew] = Dune::Python::insertClass<LocalViewWrapper>(
120+
scopedf, "LocalViewWrapper",
121+
Dune::Python::GenerateTypeName("Dune::Python::LocalViewWrapperWrapper", Dune::MetaType<FlatBasis>()), includes);
122+
123+
if (isNew) {
124+
lv.def("bind", &LocalViewWrapper::bind);
125+
lv.def("unbind", &LocalViewWrapper::unbind);
126+
lv.def("index", [](const LocalViewWrapper& localView, int index) { return localView.index(index); });
127+
lv.def("__len__", [](LocalViewWrapper& self) -> int { return self.size(); });
128+
129+
Dune::Python::Functions::registerTree<typename LocalViewWrapper::Tree>(lv);
130+
lv.def("tree", [](const LocalViewWrapper& view) { return view.tree(); });
131+
}
131132

132133
cls.def(
133134
"localView",

ikarus/python/finiteelements/registerferequirements.hh

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#pragma once
55

6+
#include <dune/python/common/typeregistry.hh>
67
#include <dune/python/pybind11/pybind11.h>
78

89
#include <ikarus/python/finiteelements/scalarwrapper.hh>
@@ -22,26 +23,27 @@ void registerFERequirement(pybind11::handle scope, pybind11::class_<FE, options.
2223
"createRequirement", [](pybind11::object /* self */) { return FERequirements(); },
2324
pybind11::return_value_policy::copy);
2425

25-
auto includes = Dune::Python::IncludeFiles{"ikarus/finiteelements/ferequirements.hh"};
26-
auto [lv, isNotRegistered] = Dune::Python::insertClass<FERequirements>(
26+
auto includes = Dune::Python::IncludeFiles{"ikarus/finiteelements/ferequirements.hh"};
27+
auto [req, isNew] = Dune::Python::insertClass<FERequirements>(
2728
scope, "FERequirements", Dune::Python::GenerateTypeName(Dune::className<FERequirements>()), includes);
28-
if (isNotRegistered) {
29-
lv.def(pybind11::init());
30-
lv.def(pybind11::init<SolutionVectorType&, ParameterType&>());
3129

32-
lv.def(
30+
if (isNew) {
31+
req.def(pybind11::init());
32+
req.def(pybind11::init<SolutionVectorType&, ParameterType&>());
33+
34+
req.def(
3335
"insertGlobalSolution",
3436
[](FERequirements& self, SolutionVectorType solVec) { self.insertGlobalSolution(solVec); },
3537
"solutionVector"_a.noconvert());
36-
lv.def(
38+
req.def(
3739
"globalSolution", [](FERequirements& self) { return self.globalSolution(); },
3840
pybind11::return_value_policy::reference_internal);
39-
lv.def(
41+
req.def(
4042
"insertParameter",
4143
[](FERequirements& self, ScalarWrapper<double>& parVal) { self.insertParameter(parVal.value()); },
4244
pybind11::keep_alive<1, 2>(), "parameterValue"_a.noconvert());
4345

44-
lv.def("parameter", [](const FERequirements& self) { return self.parameter(); });
46+
req.def("parameter", [](const FERequirements& self) { return self.parameter(); });
4547
}
4648
}
4749
} // namespace Ikarus::Python

ikarus/python/test/CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,17 @@ dune_python_add_test(
3434
python
3535
)
3636

37+
dune_python_add_test(
38+
NAME
39+
pydirichletvalues
40+
SCRIPT
41+
dirichletvaluetest.py
42+
WORKING_DIRECTORY
43+
${CMAKE_CURRENT_SOURCE_DIR}
44+
LABELS
45+
python
46+
)
47+
3748
if(HAVE_DUNE_IGA)
3849
dune_python_add_test(
3950
NAME

0 commit comments

Comments
 (0)