Skip to content

Commit a814144

Browse files
committed
Fix subspacelocalview
1 parent b57a5d8 commit a814144

File tree

4 files changed

+63
-55
lines changed

4 files changed

+63
-55
lines changed

ikarus/python/dirichletvalues/dirichletvalues.hh

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -39,39 +39,27 @@ namespace Impl {
3939
using FixBoundaryDOFsWithIntersectionFunction =
4040
std::function<void(Eigen::Ref<Eigen::VectorX<bool>>, int, LV&, const IS&)>;
4141

42-
template <typename Basis, bool registerBasis = false>
43-
auto registerLocalView() {
42+
template <typename Basis>
43+
auto registerSubSpaceLocalView() {
4444
pybind11::module scopedf = pybind11::module::import("dune.functions");
45-
using LocalView = Dune::Python::LocalViewWrapper<Basis>;
45+
using LocalViewWrapper = Dune::Python::LocalViewWrapper<Basis>;
4646

4747
auto includes = Dune::Python::IncludeFiles{"dune/python/functions/globalbasis.hh"};
4848

49-
// also register subspace basis
50-
if constexpr (registerBasis) {
51-
auto construct = [](const Basis& basis) { return new Basis(basis); };
52-
53-
// This if statement does absolutly nothing
54-
if (Dune::Python::findInTypeRegistry<Basis>().second) {
55-
auto [basisCls, isNotRegistered] = Dune::Python::insertClass<Basis>(
56-
scopedf, "SubspaceBasis", Dune::Python::GenerateTypeName(Dune::className<Basis>()), includes);
57-
if (isNotRegistered)
58-
Dune::Python::registerSubspaceBasis(scopedf, basisCls);
59-
}
60-
// Dune::Python::registerBasisType(scopedf, basisCls, construct, std::false_type{});
61-
} else {
62-
// auto [lv, isNotRegistered] = Dune::Python::insertClass<LocalView>(
63-
// scopedf, "LocalView",
64-
// Dune::Python::GenerateTypeName("Dune::Python::LocalViewWrapper", Dune::MetaType<Basis>()), includes);
65-
66-
// if (isNotRegistered) {
67-
// lv.def("bind", &LocalView::bind);
68-
// lv.def("unbind", &LocalView::unbind);
69-
// lv.def("index", [](const LocalView& localView, int index) { return localView.index(index); });
70-
// lv.def("__len__", [](LocalView& self) -> int { return self.size(); });
71-
72-
// Dune::Python::Functions::registerTree<typename LocalView::Tree>(lv);
73-
// lv.def("tree", [](const LocalView& view) { return view.tree(); });
74-
// }
49+
Dune::Python::insertClass<Basis>(scopedf, "SubspaceBasis_" + Dune::className<typename Basis::PrefixPath>(),
50+
Dune::Python::GenerateTypeName(Dune::className<Basis>()), includes);
51+
52+
auto [lv, isNew] = Dune::Python::insertClass<LocalViewWrapper>(
53+
scopedf, "LocalView_" + Dune::className<typename Basis::PrefixPath>(),
54+
Dune::Python::GenerateTypeName("Dune::Python::LocalViewWrapper", Dune::MetaType<Basis>()), includes);
55+
if (isNew) {
56+
lv.def("bind", &LocalViewWrapper::bind);
57+
lv.def("unbind", &LocalViewWrapper::unbind);
58+
lv.def("index", [](const LocalViewWrapper& localView, int index) { return localView.index(index); });
59+
lv.def("__len__", [](LocalViewWrapper& self) -> int { return self.size(); });
60+
61+
Dune::Python::Functions::registerTree<typename LocalViewWrapper::Tree>(lv);
62+
lv.def("tree", [](const LocalViewWrapper& view) { return view.tree(); });
7563
}
7664
}
7765
} // namespace Impl
@@ -98,7 +86,7 @@ void forwardCorrectFunction(DirichletValues& dirichletValues, const pybind11::fu
9886
} else if (numParams == 3) {
9987
auto lambda = [&](BackendType& vec, int localIndex, auto&& lv) {
10088
using SubSpaceBasis = typename std::remove_cvref_t<decltype(lv)>::GlobalBasis;
101-
Impl::registerLocalView<SubSpaceBasis, true>();
89+
Impl::registerSubSpaceLocalView<SubSpaceBasis>();
10290

10391
using SubSpaceLocalViewWrapper = Dune::Python::LocalViewWrapper<SubSpaceBasis>;
10492
auto lvWrapper = SubSpaceLocalViewWrapper(lv);
@@ -112,7 +100,7 @@ void forwardCorrectFunction(DirichletValues& dirichletValues, const pybind11::fu
112100
} else if (numParams == 4) {
113101
auto lambda = [&](BackendType& vec, int localIndex, auto&& lv, const Intersection& intersection) {
114102
using SubSpaceBasis = typename std::remove_cvref_t<decltype(lv)>::GlobalBasis;
115-
Impl::registerLocalView<SubSpaceBasis, true>();
103+
Impl::registerSubSpaceLocalView<SubSpaceBasis>();
116104

117105
using SubSpaceLocalViewWrapper = Dune::Python::LocalViewWrapper<SubSpaceBasis>;
118106
auto lvWrapper = SubSpaceLocalViewWrapper(lv);
@@ -169,7 +157,23 @@ void registerDirichletValues(pybind11::handle scope, pybind11::class_<DirichletV
169157
using LocalView = typename Basis::LocalView;
170158
using Intersection = typename Basis::GridView::Intersection;
171159

172-
Impl::registerLocalView<Basis>();
160+
pybind11::module scopedf = pybind11::module::import("dune.functions");
161+
using LocalViewWrapper = Dune::Python::LocalViewWrapper<Basis>;
162+
163+
auto includes = Dune::Python::IncludeFiles{"dune/python/functions/globalbasis.hh"};
164+
auto [lv, isNew] = Dune::Python::insertClass<LocalViewWrapper>(
165+
scopedf, "LocalView", Dune::Python::GenerateTypeName("Dune::Python::LocalViewWrapper", Dune::MetaType<Basis>()),
166+
includes);
167+
168+
if (isNew) {
169+
lv.def("bind", &LocalViewWrapper::bind);
170+
lv.def("unbind", &LocalViewWrapper::unbind);
171+
lv.def("index", [](const LocalViewWrapper& localView, int index) { return localView.index(index); });
172+
lv.def("__len__", [](LocalViewWrapper& self) -> int { return self.size(); });
173+
174+
Dune::Python::Functions::registerTree<typename LocalViewWrapper::Tree>(lv);
175+
lv.def("tree", [](const LocalViewWrapper& view) { return view.tree(); });
176+
}
173177

174178
cls.def(pybind11::init([](const Basis& basis) { return new DirichletValues(basis); }), pybind11::keep_alive<1, 2>());
175179

ikarus/python/finiteelements/fe.hh

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -113,21 +113,22 @@ void registerFE(pybind11::handle scope, pybind11::class_<FE, options...> cls) {
113113
pybind11::arg("Requirement"), pybind11::arg("MatrixAffordance"), pybind11::arg("elementMatrix").noconvert());
114114

115115
pybind11::module scopedf = pybind11::module::import("dune.functions");
116+
using LocalViewWrapper = Dune::Python::LocalViewWrapper<FlatBasis>;
116117

117-
typedef Dune::Python::LocalViewWrapper<FlatBasis> LocalViewWrapper;
118-
auto includes = Dune::Python::IncludeFiles{"dune/python/functions/globalbasis.hh"};
119-
auto lv = Dune::Python::insertClass<LocalViewWrapper>(
120-
scopedf, "LocalViewWrapper",
121-
Dune::Python::GenerateTypeName("Dune::Python::LocalViewWrapperWrapper", Dune::MetaType<FlatBasis>()),
122-
includes)
123-
.first;
124-
lv.def("bind", &LocalViewWrapper::bind);
125-
lv.def("unbind", &LocalViewWrapper::unbind);
126-
lv.def("index", [](const LocalViewWrapper& localView, int index) { return localView.index(index); });
127-
lv.def("__len__", [](LocalViewWrapper& self) -> int { return self.size(); });
128-
129-
Dune::Python::Functions::registerTree<typename LocalViewWrapper::Tree>(lv);
130-
lv.def("tree", [](const LocalViewWrapper& view) { return view.tree(); });
118+
auto includes = Dune::Python::IncludeFiles{"dune/python/functions/globalbasis.hh"};
119+
auto [lv, isNew] = Dune::Python::insertClass<LocalViewWrapper>(
120+
scopedf, "LocalViewWrapper",
121+
Dune::Python::GenerateTypeName("Dune::Python::LocalViewWrapperWrapper", Dune::MetaType<FlatBasis>()), includes);
122+
123+
if (isNew) {
124+
lv.def("bind", &LocalViewWrapper::bind);
125+
lv.def("unbind", &LocalViewWrapper::unbind);
126+
lv.def("index", [](const LocalViewWrapper& localView, int index) { return localView.index(index); });
127+
lv.def("__len__", [](LocalViewWrapper& self) -> int { return self.size(); });
128+
129+
Dune::Python::Functions::registerTree<typename LocalViewWrapper::Tree>(lv);
130+
lv.def("tree", [](const LocalViewWrapper& view) { return view.tree(); });
131+
}
131132

132133
cls.def(
133134
"localView",

ikarus/python/finiteelements/registerferequirements.hh

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#pragma once
55

6+
#include <dune/python/common/typeregistry.hh>
67
#include <dune/python/pybind11/pybind11.h>
78

89
#include <ikarus/python/finiteelements/valuewrapper.hh>
@@ -22,25 +23,26 @@ void registerFERequirement(pybind11::handle scope, pybind11::class_<FE, options.
2223
"createRequirement", [](pybind11::object /* self */) { return FERequirements(); },
2324
pybind11::return_value_policy::copy);
2425

25-
auto includes = Dune::Python::IncludeFiles{"ikarus/finiteelements/ferequirements.hh"};
26-
auto [lv, isNotRegistered] = Dune::Python::insertClass<FERequirements>(
26+
auto includes = Dune::Python::IncludeFiles{"ikarus/finiteelements/ferequirements.hh"};
27+
auto [req, isNew] = Dune::Python::insertClass<FERequirements>(
2728
scope, "FERequirements", Dune::Python::GenerateTypeName(Dune::className<FERequirements>()), includes);
28-
if (isNotRegistered) {
29-
lv.def(pybind11::init());
30-
lv.def(pybind11::init<SolutionVectorType&, ParameterType&>());
3129

32-
lv.def(
30+
if (isNew) {
31+
req.def(pybind11::init());
32+
req.def(pybind11::init<SolutionVectorType&, ParameterType&>());
33+
34+
req.def(
3335
"insertGlobalSolution",
3436
[](FERequirements& self, SolutionVectorType solVec) { self.insertGlobalSolution(solVec); },
3537
"solutionVector"_a.noconvert());
36-
lv.def(
38+
req.def(
3739
"globalSolution", [](FERequirements& self) { return self.globalSolution(); },
3840
pybind11::return_value_policy::reference_internal);
39-
lv.def(
41+
req.def(
4042
"insertParameter", [](FERequirements& self, ValueWrapper<double>& parVal) { self.insertParameter(parVal.val); },
4143
pybind11::keep_alive<1, 2>(), "parameterValue"_a.noconvert());
4244

43-
lv.def("parameter", [](const FERequirements& self) { return self.parameter(); });
45+
req.def("parameter", [](const FERequirements& self) { return self.parameter(); });
4446
}
4547
}
4648
} // namespace Ikarus::Python

ikarus/python/test/testdirichletvalues.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,5 +100,6 @@ def fixTopSide(vec, localIndex, localView, intersection):
100100
assert dirichletValues2.fixedDOFsize == 0
101101
assert sum(dirichletValues2.container) == 0
102102

103+
103104
if __name__ == "__main__":
104105
testDirichletValues()

0 commit comments

Comments
 (0)