8
8
9
9
#pragma once
10
10
11
+ #include < cstdlib>
12
+ #include < string>
13
+
14
+ #include " dune/common/classname.hh"
11
15
#include < dune/functions/functionspacebases/lagrangebasis.hh>
12
16
#include < dune/functions/functionspacebases/powerbasis.hh>
13
17
#include < dune/grid/yaspgrid.hh>
14
18
#include < dune/python/common/typeregistry.hh>
15
19
#include < dune/python/functions/globalbasis.hh>
20
+ #include < dune/python/functions/subspacebasis.hh>
16
21
#include < dune/python/pybind11/eigen.h>
17
22
#include < dune/python/pybind11/functional.h>
18
23
#include < dune/python/pybind11/pybind11.h>
19
24
#include < dune/python/pybind11/stl.h>
20
25
#include < dune/python/pybind11/stl_bind.h>
21
26
22
27
#include < ikarus/finiteelements/ferequirements.hh>
28
+
23
29
// PYBIND11_MAKE_OPAQUE(std::vector<bool>);
24
30
namespace Ikarus ::Python {
25
31
32
+ namespace Impl {
33
+ using FixBoundaryDOFsWithGlobalIndexFunction = std::function<void (Eigen::Ref<Eigen::VectorX<bool >>, int )>;
34
+
35
+ template <typename LV>
36
+ using FixBoundaryDOFsWithLocalViewFunction = std::function<void (Eigen::Ref<Eigen::VectorX<bool >>, int , LV&)>;
37
+
38
+ template <typename LV, typename IS>
39
+ using FixBoundaryDOFsWithIntersectionFunction =
40
+ std::function<void (Eigen::Ref<Eigen::VectorX<bool >>, int , LV&, const IS&)>;
41
+
42
+ template <typename Basis>
43
+ auto registerSubSpaceLocalView () {
44
+ pybind11::module scopedf = pybind11::module::import (" dune.functions" );
45
+ using LocalViewWrapper = Dune::Python::LocalViewWrapper<Basis>;
46
+
47
+ auto includes = Dune::Python::IncludeFiles{" dune/python/functions/globalbasis.hh" };
48
+
49
+ Dune::Python::insertClass<Basis>(scopedf, " SubspaceBasis_" + Dune::className<typename Basis::PrefixPath>(),
50
+ Dune::Python::GenerateTypeName (Dune::className<Basis>()), includes);
51
+
52
+ auto [lv, isNew] = Dune::Python::insertClass<LocalViewWrapper>(
53
+ scopedf, " LocalView_" + Dune::className<typename Basis::PrefixPath>(),
54
+ Dune::Python::GenerateTypeName (" Dune::Python::LocalViewWrapper" , Dune::MetaType<Basis>()), includes);
55
+ if (isNew) {
56
+ lv.def (" bind" , &LocalViewWrapper::bind);
57
+ lv.def (" unbind" , &LocalViewWrapper::unbind);
58
+ lv.def (" index" , [](const LocalViewWrapper& localView, int index) { return localView.index (index); });
59
+ lv.def (" __len__" , [](LocalViewWrapper& self) -> int { return self.size (); });
60
+
61
+ Dune::Python::Functions::registerTree<typename LocalViewWrapper::Tree>(lv);
62
+ lv.def (" tree" , [](const LocalViewWrapper& view) { return view.tree (); });
63
+ }
64
+ }
65
+ } // namespace Impl
66
+
67
+ template <class DirichletValues >
68
+ void forwardCorrectFunction (DirichletValues& dirichletValues, const pybind11::function& functor, auto && cppFunction) {
69
+ using Basis = typename DirichletValues::Basis;
70
+ using Intersection = typename Basis::GridView::Intersection;
71
+ using BackendType = typename DirichletValues::BackendType;
72
+ using MultiIndex = typename Basis::MultiIndex;
73
+
74
+ // Disambiguate by number of arguments
75
+ pybind11::module inspect_module = pybind11::module::import (" inspect" );
76
+ pybind11::object result = inspect_module.attr (" signature" )(functor).attr (" parameters" );
77
+ size_t numParams = pybind11::len (result);
78
+
79
+ if (numParams == 2 ) {
80
+ auto function = functor.template cast <const Impl::FixBoundaryDOFsWithGlobalIndexFunction>();
81
+ auto lambda = [&](BackendType& vec, const MultiIndex& indexGlobal) { function (vec.vector (), indexGlobal); };
82
+ cppFunction (lambda);
83
+
84
+ } else if (numParams == 3 ) {
85
+ auto lambda = [&](BackendType& vec, int localIndex, auto && lv) {
86
+ using SubSpaceBasis = typename std::remove_cvref_t <decltype (lv)>::GlobalBasis;
87
+ Impl::registerSubSpaceLocalView<SubSpaceBasis>();
88
+
89
+ using SubSpaceLocalViewWrapper = Dune::Python::LocalViewWrapper<SubSpaceBasis>;
90
+ auto lvWrapper = SubSpaceLocalViewWrapper (lv);
91
+
92
+ auto function =
93
+ functor.template cast <const Impl::FixBoundaryDOFsWithLocalViewFunction<SubSpaceLocalViewWrapper>>();
94
+ function (vec.vector (), localIndex, lvWrapper);
95
+ };
96
+ cppFunction (lambda);
97
+
98
+ } else if (numParams == 4 ) {
99
+ auto lambda = [&](BackendType& vec, int localIndex, auto && lv, const Intersection& intersection) {
100
+ using SubSpaceBasis = typename std::remove_cvref_t <decltype (lv)>::GlobalBasis;
101
+ Impl::registerSubSpaceLocalView<SubSpaceBasis>();
102
+
103
+ using SubSpaceLocalViewWrapper = Dune::Python::LocalViewWrapper<SubSpaceBasis>;
104
+ auto lvWrapper = SubSpaceLocalViewWrapper (lv);
105
+
106
+ auto function = functor.template cast <
107
+ const Impl::FixBoundaryDOFsWithIntersectionFunction<SubSpaceLocalViewWrapper, Intersection>>();
108
+ function (vec.vector (), localIndex, lvWrapper, intersection);
109
+ };
110
+ cppFunction (lambda);
111
+
112
+ } else {
113
+ DUNE_THROW (Dune::NotImplemented, " fixBoundaryDOFs: A function with this signature is not supported" );
114
+ }
115
+ }
116
+
26
117
/* *
27
118
* \brief Register Python bindings for a DirichletValues class.
28
119
*
29
120
* This function registers Python bindings for a DirichletValues class, allowing it to be used in Python scripts.
30
121
* The registered class will have an initializer that takes a `Basis` object. It exposes several member functions to
31
122
* Python:
32
- * - `fixBoundaryDOFs(f)`: Fixes boundary degrees of freedom using a user-defined function `f`.
33
- * - `fixBoundaryDOFsUsingLocalView(f)`: Fixes boundary degrees of freedom using a user-defined function `f` with a
34
- * `LocalView` argument.
35
- * - `fixBoundaryDOFsUsingLocalViewAndIntersection(f)`: Fixes boundary degrees of freedom using a user-defined
36
- * function `f` with `LocalView` and `Intersection` arguments.
37
- * - `fixDOFs(f)`: Fixes boundary degrees of freedom using a user-defined function `f` with the boolean vector and
38
- * the basis as arguments.
123
+ * - `fixBoundaryDOFs(f)`: Fixes boundary degrees of freedom using a user-defined function `f` than can be called
124
+ * with the following arguments:
125
+ * - with the boolean vector and the global index.
126
+ * - with the boolean vector, the local index and the `LocalView`.
127
+ * - with the boolean vector, the local index, the `LocalView` and the `Intersection`.
128
+ * - `fixDOFs(f)`: Fixes boundary degrees of freedom using a user-defined function `f` with the basis and the boolean
129
+ * vector as arguments.
130
+ * - `setSingleDOF(i, flag: bool): Fixes or unfixes DOF with index i
131
+ * - `isConstrained(i)`: Checks whether index i is constrained
132
+ * - `reset()`: Resets the whole container
133
+ *
134
+ * The following properties can be accessed:
135
+ * - `container`: the underlying container of dirichlet flags (as a const reference)
136
+ * - `size`: the size of the underlying basis
137
+ * - `fixedDOFsize`: the amount of DOFs currently fixed
39
138
*
40
139
* \tparam DirichletValues The DirichletValues class to be registered.
41
140
* \tparam options Variadic template parameters for additional options when defining the Python class.
@@ -57,60 +156,38 @@ void registerDirichletValues(pybind11::handle scope, pybind11::class_<DirichletV
57
156
using Intersection = typename Basis::GridView::Intersection;
58
157
59
158
pybind11::module scopedf = pybind11::module::import (" dune.functions" );
60
- typedef Dune::Python::LocalViewWrapper<Basis> LocalViewWrapper;
61
- auto includes = Dune::Python::IncludeFiles{" dune/python/functions/globalbasis.hh" };
62
- auto lv = Dune::Python::insertClass<LocalViewWrapper>(
63
- scopedf, " LocalView" ,
64
- Dune::Python::GenerateTypeName (" Dune::Python::LocalViewWrapper" , Dune::MetaType<Basis>()), includes)
65
- .first ;
159
+ using LocalViewWrapper = Dune::Python::LocalViewWrapper<Basis>;
66
160
67
- cls.def (pybind11::init ([](const Basis& basis) { return new DirichletValues (basis); }), pybind11::keep_alive<1 , 2 >());
161
+ auto includes = Dune::Python::IncludeFiles{" dune/python/functions/globalbasis.hh" };
162
+ auto [lv, isNew] = Dune::Python::insertClass<LocalViewWrapper>(
163
+ scopedf, " LocalView" , Dune::Python::GenerateTypeName (" Dune::Python::LocalViewWrapper" , Dune::MetaType<Basis>()),
164
+ includes);
68
165
69
- // Eigen::Ref needed due to https://pybind11.readthedocs.io/en/stable/advanced/cast/eigen.html#pass-by-reference
70
- cls.def (" fixBoundaryDOFs" ,
71
- [](DirichletValues& self, const std::function<void (Eigen::Ref<Eigen::VectorX<bool >>, int )>& f) {
72
- auto lambda = [&](BackendType& vec, const MultiIndex& indexGlobal) {
73
- // we explicitly only allow flat indices
74
- f (vec.vector (), indexGlobal[0 ]);
75
- };
76
- self.fixBoundaryDOFs (lambda);
77
- });
166
+ if (isNew) {
167
+ lv.def (" bind" , &LocalViewWrapper::bind);
168
+ lv.def (" unbind" , &LocalViewWrapper::unbind);
169
+ lv.def (" index" , [](const LocalViewWrapper& localView, int index) { return localView.index (index); });
170
+ lv.def (" __len__" , [](LocalViewWrapper& self) -> int { return self.size (); });
78
171
79
- cls.def (" fixBoundaryDOFsUsingLocalView" ,
80
- [](DirichletValues& self,
81
- const std::function<void (Eigen::Ref<Eigen::VectorX<bool >>, int , LocalViewWrapper&)>& f) {
82
- auto lambda = [&](BackendType& vec, int localIndex, LocalView& lv) {
83
- auto lvWrapper = LocalViewWrapper (lv.globalBasis ());
84
- // this can be simplified when
85
- // https://gitlab.dune-project.org/staging/dune-functions/-/merge_requests/418 becomes available
86
- pybind11::object obj = pybind11::cast (lv.element ());
87
- lvWrapper.bind (obj);
88
- f (vec.vector (), localIndex, lvWrapper);
89
- };
90
- self.fixBoundaryDOFs (lambda);
91
- });
172
+ Dune::Python::Functions::registerTree<typename LocalViewWrapper::Tree>(lv);
173
+ lv.def (" tree" , [](const LocalViewWrapper& view) { return view.tree (); });
174
+ }
175
+
176
+ cls.def (pybind11::init ([](const Basis& basis) { return new DirichletValues (basis); }), pybind11::keep_alive<1 , 2 >());
92
177
93
- cls.def (
94
- " fixBoundaryDOFsUsingLocalViewAndIntersection" ,
95
- [](DirichletValues& self,
96
- const std::function<void (Eigen::Ref<Eigen::VectorX<bool >>, int , LocalViewWrapper&, const Intersection&)>& f) {
97
- auto lambda = [&](BackendType& vec, int localIndex, LocalView& lv, const Intersection& intersection) {
98
- auto lvWrapper = LocalViewWrapper (lv.globalBasis ());
99
- // this can be simplified when
100
- // https://gitlab.dune-project.org/staging/dune-functions/-/merge_requests/418 becomes available
101
- pybind11::object obj = pybind11::cast (lv.element ());
102
- lvWrapper.bind (obj);
103
- f (vec.vector (), localIndex, lvWrapper, intersection);
104
- };
105
- self.fixBoundaryDOFs (lambda);
106
- });
178
+ cls.def_property_readonly (" container" , &DirichletValues::container);
179
+ cls.def_property_readonly (" size" , &DirichletValues::size);
180
+ cls.def (" __len__" , &DirichletValues::size);
181
+ cls.def_property_readonly (" fixedDOFsize" , &DirichletValues::fixedDOFsize);
182
+ cls.def (" isConstrained" , [](DirichletValues& self, std::size_t i) -> bool { return self.isConstrained (i); });
183
+ cls.def (" setSingleDOF" , [](DirichletValues& self, std::size_t i, bool flag) { self.setSingleDOF (i, flag); });
184
+ cls.def (" isConstrained" , [](DirichletValues& self, MultiIndex i) -> bool { return self.isConstrained (i); });
185
+ cls.def (" setSingleDOF" , [](DirichletValues& self, MultiIndex i, bool flag) { self.setSingleDOF (i, flag); });
186
+ cls.def (" reset" , &DirichletValues::reset);
107
187
108
188
cls.def (" fixDOFs" ,
109
189
[](DirichletValues& self, const std::function<void (const Basis&, Eigen::Ref<Eigen::VectorX<bool >>)>& f) {
110
- auto lambda = [&](const Basis& basis, BackendType& vec) {
111
- // we explicitly only allow flat indices
112
- f (basis, vec.vector ());
113
- };
190
+ auto lambda = [&](const Basis& basis, BackendType& vec) { f (basis, vec.vector ()); };
114
191
self.fixDOFs (lambda);
115
192
});
116
193
}
0 commit comments