Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add Quadrature Data to Sidre #1258

Open
wants to merge 21 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/serac/numerics/functional/quadrature_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ class ArrayView<serac::Empty, 2, MemorySpace::Dynamic> {
} // namespace axom

namespace serac {
namespace detail {
constexpr std::array<mfem::Geometry::Type, 5> qdata_geometries = {mfem::Geometry::SEGMENT, mfem::Geometry::TRIANGLE,
mfem::Geometry::SQUARE, mfem::Geometry::TETRAHEDRON,
mfem::Geometry::CUBE};
constexpr std::array<std::string_view, 5> qdata_geometry_names = {"Segment", "Triangle", "Square", "Tetrahedron",
"Cube"};
} // namespace detail

/**
* @brief A class for storing and access user-defined types at quadrature points
Expand All @@ -122,10 +129,7 @@ struct QuadratureData {
*/
QuadratureData(geom_array_t elements, geom_array_t qpts_per_element, T value = T{})
{
constexpr std::array geometries = {mfem::Geometry::SEGMENT, mfem::Geometry::TRIANGLE, mfem::Geometry::SQUARE,
mfem::Geometry::TETRAHEDRON, mfem::Geometry::CUBE};

for (auto geom : geometries) {
for (auto geom : detail::qdata_geometries) {
if (elements[uint32_t(geom)] > 0) {
data[geom] = axom::Array<T, 2>(elements[uint32_t(geom)], qpts_per_element[uint32_t(geom)]);
data[geom].fill(value);
Expand All @@ -140,7 +144,7 @@ struct QuadratureData {
axom::ArrayView<T, 2> operator[](mfem::Geometry::Type geom) { return axom::ArrayView<T, 2>(data.at(geom)); }

/// @brief a 3D array indexed by (which geometry, which element, which quadrature point)
std::map<mfem::Geometry::Type, axom::Array<T, 2> > data;
std::map<mfem::Geometry::Type, axom::Array<T, 2>> data;
};

/// @cond
Expand Down Expand Up @@ -168,7 +172,7 @@ struct QuadratureData<Empty> {
/// @endcond

/// these values exist to serve as default arguments for materials without material state
extern std::shared_ptr<QuadratureData<Nothing> > NoQData;
extern std::shared_ptr<QuadratureData<Empty> > EmptyQData;
extern std::shared_ptr<QuadratureData<Nothing>> NoQData;
extern std::shared_ptr<QuadratureData<Empty>> EmptyQData;

} // namespace serac
2 changes: 1 addition & 1 deletion src/serac/physics/solid_mechanics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ class SolidMechanics<order, dim, Parameters<parameter_space...>, std::integer_se
qdata_type<T> createQuadratureDataBuffer(T initial_state, const std::optional<Domain>& optional_domain = std::nullopt)
{
Domain domain = (optional_domain) ? *optional_domain : EntireDomain(mesh_);
return StateManager::newQuadratureDataBuffer(domain, order, dim, initial_state);
return StateManager::newQuadratureDataBuffer(mesh_tag_, domain, order, dim, initial_state);
}

/**
Expand Down
5 changes: 5 additions & 0 deletions src/serac/physics/state/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,8 @@ install(TARGETS serac_state
EXPORT serac-targets
DESTINATION lib
)

if(SERAC_ENABLE_TESTS)
add_subdirectory(tests)
endif()

24 changes: 12 additions & 12 deletions src/serac/physics/state/state_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ std::string StateManag
std::unordered_map<std::string, mfem::ParGridFunction*> StateManager::named_states_;
std::unordered_map<std::string, mfem::ParGridFunction*> StateManager::named_duals_;

double StateManager::newDataCollection(const std::string& name, const std::optional<int> cycle_to_load)
double StateManager::newDataCollection(const std::string& mesh_tag, const std::optional<int> cycle_to_load)
{
SLIC_ERROR_ROOT_IF(!ds_, "Cannot construct a DataCollection without a DataStore");
std::string coll_name = name + "_datacoll";
std::string coll_name = getCollectionName(mesh_tag);

auto global_grp = ds_->getRoot()->createGroup(coll_name + "_global");
auto bp_index_grp = global_grp->createGroup("blueprint_index/" + coll_name);
auto domain_grp = ds_->getRoot()->createGroup(coll_name);

// Needs to be configured to own the mesh data so all mesh data is saved to datastore/output file
constexpr bool owns_mesh_data = true;
auto [iter, _] = datacolls_.emplace(std::piecewise_construct, std::forward_as_tuple(name),
auto [iter, _] = datacolls_.emplace(std::piecewise_construct, std::forward_as_tuple(mesh_tag),
std::forward_as_tuple(coll_name, bp_index_grp, domain_grp, owns_mesh_data));
auto& datacoll = iter->second;
datacoll.SetComm(MPI_COMM_WORLD);
Expand All @@ -56,7 +56,7 @@ double StateManager::newDataCollection(const std::string& name, const std::optio
// indicates that the mesh is periodic and the new nodal grid function must also
// be discontinuous.
bool is_discontinuous = false;
auto nodes = mesh(name).GetNodes();
auto nodes = mesh(mesh_tag).GetNodes();
if (nodes) {
is_discontinuous = nodes->FESpace()->FEColl()->GetContType() == mfem::FiniteElementCollection::DISCONTINUOUS;
SLIC_WARNING_ROOT_IF(
Expand All @@ -71,17 +71,17 @@ double StateManager::newDataCollection(const std::string& name, const std::optio
// 2. Uses the existing continuity of the mesh finite element space (periodic meshes are discontinuous)
// 3. Uses the spatial dimension as the mesh dimension (i.e. it is not a lower dimension manifold)
// 4. Uses the ordering set by serac::ordering
mesh(name).SetCurvature(1, is_discontinuous, -1, serac::ordering);
mesh(mesh_tag).SetCurvature(1, is_discontinuous, -1, serac::ordering);

// Sidre will destruct the nodal grid function instead of the mesh
mesh(name).SetNodesOwner(false);
mesh(mesh_tag).SetNodesOwner(false);

// Generate the face neighbor information in the mesh. This is needed by the face restriction
// operators used by Functional
mesh(name).ExchangeFaceNbrData();
mesh(mesh_tag).ExchangeFaceNbrData();

// Construct and store the shape displacement fields and sensitivities associated with this mesh
constructShapeFields(name);
constructShapeFields(mesh_tag);

} else {
datacoll.SetCycle(0); // Iteration counter
Expand All @@ -94,10 +94,10 @@ double StateManager::newDataCollection(const std::string& name, const std::optio
void StateManager::loadCheckpointedStates(int cycle_to_load, std::vector<FiniteElementState*> states_to_load)
{
SERAC_MARK_FUNCTION;
mfem::ParMesh* meshPtr = &(*states_to_load.begin())->mesh();
std::string mesh_name = collectionID(meshPtr);
mfem::ParMesh* meshPtr = &(*states_to_load.begin())->mesh();
std::string mesh_tag = collectionID(meshPtr);

std::string coll_name = mesh_name + "_datacoll";
std::string coll_name = getCollectionName(mesh_tag);

axom::sidre::MFEMSidreDataCollection previous_datacoll(coll_name);

Expand All @@ -107,7 +107,7 @@ void StateManager::loadCheckpointedStates(int cycle_to_load, std::vector<FiniteE

for (auto state : states_to_load) {
meshPtr = &state->mesh();
SLIC_ERROR_ROOT_IF(collectionID(meshPtr) != mesh_name,
SLIC_ERROR_ROOT_IF(collectionID(meshPtr) != mesh_tag,
"Loading FiniteElementStates from two different meshes at one time is not allowed.");
mfem::ParGridFunction* datacoll_owned_grid_function = previous_datacoll.GetParField(state->name());

Expand Down
119 changes: 114 additions & 5 deletions src/serac/physics/state/state_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,119 @@ class StateManager {
*/
static void storeState(FiniteElementState& state);

/**
* @brief Store a pre-constructed Quadrature Data in the state manager
*
* @tparam T the type to be created at each quadrature point
* @param mesh_tag The tag for the stored mesh used to locate the datacollection
* @param qdata The quadrature data to store
*/
template <typename T>
static void storeQuadratureData(const std::string& mesh_tag, std::shared_ptr<QuadratureData<T>> qdata)
{
SLIC_ERROR_ROOT_IF(!ds_, "Serac's data store was not initialized - call StateManager::initialize first");
SLIC_ERROR_ROOT_IF(!hasMesh(mesh_tag),
axom::fmt::format("Serac's state manager does not have a mesh with given tag '{}'", mesh_tag));

constexpr const char* qds_group_name = "quadraturedatas";

// Get Sidre location for quadrature data inside data collection
auto& datacoll = datacolls_.at(mesh_tag);
axom::sidre::Group* bp_group = datacoll.GetBPGroup(); // mesh_datacoll
// For each geometry type, use i to get both type and name from matching arrays
for (std::size_t i = 0; i < detail::qdata_geometries.size(); ++i) {
auto geom_type = detail::qdata_geometries[i];

// Check if geometry type has any data
if ((*qdata).data.find(geom_type) != (*qdata).data.end()) {
auto geom_name = detail::qdata_geometry_names[i];

// Get axom::Array of states in map
auto states = (*qdata)[geom_type];

// Get various size information
auto num_states = static_cast<axom::IndexType>(states.size());
SLIC_ERROR_ROOT_IF(num_states == 0, "Number of States should be more than 0 at this point.");
auto state_size = static_cast<axom::IndexType>(sizeof(*(states.begin())));
auto total_size = num_states * state_size;
// Sidre treats information as an array of uint8s
auto num_uint8s = total_size / static_cast<axom::IndexType>(sizeof(std::uint8_t));

if (!is_restart_) {
axom::sidre::Group* qdatas_group = bp_group->createGroup(qds_group_name);

// Create Sidre group, store basic information, and point Sidre at the array external to Sidre
// Note: Sidre will not own this data.
axom::sidre::Group* geom_group = qdatas_group->createGroup(std::string(geom_name));
geom_group->createViewScalar("num_states", num_states);
geom_group->createViewScalar("state_size", state_size);
geom_group->createViewScalar("total_size", total_size);

// Tell Sidre where the external array is, how large it is (calculated above), and what is in it (faking
// uint8)
axom::sidre::View* states_view = geom_group->createView("states");
states_view->setExternalDataPtr(axom::sidre::UINT8_ID, num_uint8s, states.data());
} else {
// Get Sidre group of where the states were stored.
// Note: this data is not owned by Sidre and the array should have been created at this point but
// the previous data has not been loaded yet into the array.
SLIC_ERROR_ROOT_IF(!bp_group->hasGroup(qds_group_name),
axom::fmt::format("Loaded Sidre Datastore did not have group for Quadrature Datas"));
axom::sidre::Group* qdatas_group = bp_group->getGroup(qds_group_name);
SLIC_ERROR_ROOT_IF(
!qdatas_group->hasGroup(std::string(geom_name)),
axom::fmt::format("Loaded Sidre Datastore did not have group for Quadrature Data geometry type '{}'",
std::string(geom_name)));
axom::sidre::Group* geom_group = qdatas_group->getGroup(std::string(geom_name));

// Verify size correctness
auto verify_size = [](axom::sidre::Group* group, int value, const std::string& view_name,
const std::string& err_msg) {
SLIC_ERROR_IF(
!group->hasView(view_name),
axom::fmt::format("Loaded Sidre Datastore does not have value '{}' for Quadrature Data.", view_name));
auto prev_value = group->getView(view_name)->getData<axom::IndexType>();
SLIC_ERROR_IF(value != prev_value, axom::fmt::format(err_msg, value, prev_value));
};
verify_size(geom_group, num_states, "num_states",
"Current number of Quadrature Data States '{}' does not match value in restart '{}'.");
verify_size(geom_group, state_size, "state_size",
"Current size of Quadrature Data State '{}' does not match value in restart '{}'.");
verify_size(geom_group, total_size, "total_size",
"Current total size of Quadrature Data States '{}' does not match value in restart '{}'.");

// Tell Sidre where the external array is
SLIC_ERROR_ROOT_IF(!geom_group->hasView("states"),
"Loaded Quadrature Data geometry Sidre view did not have 'states'");
axom::sidre::View* states_view = geom_group->getView("states");
states_view->setExternalDataPtr(states.data());
}
}
}

if (is_restart_) {
// NOTE: This call will reload all external buffers from file stored in the DataStore
// TODO: This should be changed to load only the current material quadrature data after
// MFEMSidreDatacollection::LoadExternalData and SPIO is enhanced to allow loading the
// external data piecemeal
datacoll.LoadExternalData();
}
}

/**
* @brief Create a shared ptr to a quadrature data buffer for the given material type
*
* @tparam T the type to be created at each quadrature point
* @param mesh_tag The tag for the stored mesh used to locate the datacollection
* @param domain The spatial domain over which to allocate the quadrature data
* @param order The order of the discretization of the primal fields
* @param dim The spatial dimension of the mesh
* @param initial_state the value to be broadcast to each quadrature point
* @return shared pointer to quadrature data buffer
*/
template <typename T>
static std::shared_ptr<QuadratureData<T>> newQuadratureDataBuffer(const Domain& domain, int order, int dim,
T initial_state)
static std::shared_ptr<QuadratureData<T>> newQuadratureDataBuffer(const std::string& mesh_tag, const Domain& domain,
int order, int dim, T initial_state)
{
int Q = order + 1;

Expand All @@ -125,7 +225,9 @@ class StateManager {
qpts_per_elem[size_t(geom)] = uint32_t(num_quadrature_points(geom, Q));
}

return std::make_shared<QuadratureData<T>>(elems, qpts_per_elem, initial_state);
auto qdata = std::make_shared<QuadratureData<T>>(elems, qpts_per_elem, initial_state);
storeQuadratureData<T>(mesh_tag, qdata);
return qdata;
}

/**
Expand Down Expand Up @@ -336,11 +438,18 @@ class StateManager {
private:
/**
* @brief Creates a new datacollection based on a registered mesh
* @param[in] name The name of the new datacollection
* @param[in] mesh_tag The mesh name used to name the new datacollection
* @param[in] cycle_to_load What cycle to load the DataCollection from, if applicable
* @return The time from specified restart cycle. Otherwise zero.
*/
static double newDataCollection(const std::string& name, const std::optional<int> cycle_to_load = {});
static double newDataCollection(const std::string& mesh_tag, const std::optional<int> cycle_to_load = {});

/**
* @brief Returns the name of the data collection's name for a given mesh_tag
* @param[in] mesh_tag The mesh name used to name the new datacollection
* @return The name of the data collection for the given mesh_tag
*/
static std::string getCollectionName(const std::string& mesh_tag) { return mesh_tag + "_datacoll"; }

/**
* @brief Construct the shape displacement field for the requested mesh
Expand Down
23 changes: 23 additions & 0 deletions src/serac/physics/state/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) 2019-2024, Lawrence Livermore National Security, LLC and
# other Serac Project Developers. See the top-level LICENSE file for
# details.
#
# SPDX-License-Identifier: (BSD-3-Clause)

set(state_manager_tests_depends serac_state serac_physics_materials serac_functional serac_mesh gtest)

set(state_manager_serial_test_sources
state_manager.cpp
)

serac_add_tests(SOURCES ${state_manager_serial_test_sources}
DEPENDS_ON ${state_manager_tests_depends}
NUM_MPI_TASKS 1)

# set(state_manager_parallel_test_sources
# state_manager.cpp
# )

# serac_add_tests(SOURCES ${state_manager_parallel_test_sources}
# DEPENDS_ON ${state_manager_tests_depends}
# NUM_MPI_TASKS 2)
Loading