Skip to content

Commit

Permalink
IncrementalSVDFastUpdate with faster updates + bug fix in Incremental…
Browse files Browse the repository at this point in the history
…SVD (#227)

* Update IncrementalSVD for faster updates
* created a separate class for new features
* minor fix in CMakeLists.txt
* another minor fix in CMakeLists
* add unit test for IncrementalSVDBrand
* fixed mpi bug in test_IncrementalSVDBrand
* minor styling fix
* update args order in unit test

---------

Co-authored-by: swsuh28 <[email protected]>
  • Loading branch information
swsuh28 and swsuh28 authored Jul 21, 2023
1 parent 8505c23 commit d3ef2cd
Show file tree
Hide file tree
Showing 9 changed files with 886 additions and 7 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/run_tests/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ runs:
mpirun -n 3 --oversubscribe tests/test_RandomizedSVD
./tests/test_StaticSVD
mpirun -n 3 --oversubscribe tests/test_StaticSVD
./tests/test_IncrementalSVDBrand
mpirun -n 3 --oversubscribe tests/test_IncrementalSVDBrand
shell: bash

- name: Run regression tests
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ if(GTEST_FOUND)
StaticSVD
RandomizedSVD
IncrementalSVD
IncrementalSVDBrand
GreedyCustomSampler)
foreach(stem IN LISTS unit_test_stems)
add_executable(test_${stem} unit_tests/test_${stem}.cpp)
Expand Down
1 change: 1 addition & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ set(module_list
linalg/svd/IncrementalSVD
linalg/svd/IncrementalSVDFastUpdate
linalg/svd/IncrementalSVDStandard
linalg/svd/IncrementalSVDBrand
linalg/svd/RandomizedSVD
linalg/svd/SVD
linalg/svd/StaticSVD
Expand Down
9 changes: 8 additions & 1 deletion lib/linalg/BasisGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "svd/RandomizedSVD.h"
#include "svd/IncrementalSVDStandard.h"
#include "svd/IncrementalSVDFastUpdate.h"
#include "svd/IncrementalSVDBrand.h"

namespace CAROM {

Expand Down Expand Up @@ -65,7 +66,13 @@ BasisGenerator::BasisGenerator(
d_dt = options.initial_dt;
d_next_sample_time = 0.0;

if (options.fast_update) {
if (options.fast_update_brand) {
d_svd.reset(
new IncrementalSVDBrand(
options,
basis_file_name));
}
else if (options.fast_update) {
d_svd.reset(
new IncrementalSVDFastUpdate(
options,
Expand Down
8 changes: 8 additions & 0 deletions lib/linalg/Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class Options
double sampling_tol_,
double max_time_between_samples_,
bool fast_update_ = false,
bool fast_update_brand_ = false,
bool skip_linearly_dependent_ = false
)
{
Expand All @@ -154,6 +155,7 @@ class Options
sampling_tol = sampling_tol_;
max_time_between_samples = max_time_between_samples_;
fast_update = fast_update_;
fast_update_brand = fast_update_brand_;
skip_linearly_dependent = skip_linearly_dependent_;
return *this;
}
Expand Down Expand Up @@ -299,6 +301,12 @@ class Options
*/
bool fast_update = false;

/**
* @brief If true use the exact Brand's algorithm for the
* incremental SVD.
*/
bool fast_update_brand = false;

/**
* @brief If true skip linearly dependent samples of the
* incremental SVD algorithm.
Expand Down
13 changes: 7 additions & 6 deletions lib/linalg/svd/IncrementalSVD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,12 +301,13 @@ IncrementalSVD::buildIncrementalSVD(
// basisl = basis * l
Vector* basisl = d_basis->mult(l);

// Compute k = sqrt(u.u - 2.0*l.l + basisl.basisl) which is ||u -
// basisl||_{2}. This is the error in the projection of u into the
// reduced order space and subsequent lifting back to the full
// order space.
double k = u_vec.inner_product(u_vec) - 2.0*l->inner_product(l) +
basisl->inner_product(basisl);
// Computing as k = sqrt(u.u - 2.0*l.l + basisl.basisl)
// results in catastrophic cancellation, and must be avoided.
// Instead we compute as k = sqrt((u-basisl).(u-basisl)).
Vector* e_proj = u_vec.minus(basisl);
double k = e_proj->inner_product(e_proj);
delete e_proj;

if (k <= 0) {
if(d_rank == 0) printf("linearly dependent sample!\n");
k = 0;
Expand Down
Loading

0 comments on commit d3ef2cd

Please sign in to comment.