Skip to content

Commit

Permalink
Added programmatic differentiation
Browse files Browse the repository at this point in the history
  • Loading branch information
Antonio Kim committed Apr 28, 2020
1 parent db8312b commit 4e2278a
Show file tree
Hide file tree
Showing 22 changed files with 490 additions and 143 deletions.
2 changes: 1 addition & 1 deletion MathEngine/.utils/.templates/BinaryOperatorDirectory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "../../Utils/exceptions.h"
#include "Operators.h"
#include "BinaryOperatorDirectory.h"
#include "OperatorDerivatives/BasicOperators.h"
#include "OperatorDerivatives/OperatorDerivatives.h"
#include "OperatorDirectory/BinaryOperators.h"
#include "OperatorExpressions/Addition.h"
#include "OperatorExpressions/Division.h"
Expand Down
4 changes: 4 additions & 0 deletions MathEngine/.utils/.templates/FunctionDirectory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
#include "MultiFunctions/ExpLogFunctions.h"
#include "MultiFunctions/SpecialFunctions.h"
#include "MultiFunctions/Statistics.h"
#include "MultiFunctionExprs/calculus.h"
#include "MultiFunctionExprs/fft.h"
#include "MultiFunctionExprs/polynomial.h"
#include "MultiFunctionExprs/tuple.h"
#include "UnaryFunctions/BasicFunctions.h"
#include "UnaryFunctions/ExpLogFunctions.h"
#include "UnaryFunctions/SpecialFunctions.h"
#include "UnaryFunctions/TrigFunctions.h"
#include "UnaryFunctionDerivatives/BasicFunctions.h"
#include "UnaryFunctionDerivatives/ExpLogFunctions.h"
#include "UnaryFunctionDerivatives/TrigFunctions.h"
#include "UnaryFunctionExprs/basic.h"
#include "UnaryFunctionExprs/linalg.h"

Expand Down
63 changes: 62 additions & 1 deletion MathEngine/.utils/functions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@ functions:
# Basic
neg:
expression:
derivative:
abs:
derivative:
sqr:
derivative:
sqrt:
derivative:
cb:
derivative:
cbrt:
derivative:
hypot:
nargs: -1
ldexp:
Expand All @@ -31,62 +37,117 @@ functions:
expression:
# Trigonometry
sin:
derivative:
cos:
derivative:
tan:
derivative:
csc:
derivative:
sec:
derivative:
cot:
derivative:
asin:
derivative:
acos:
derivative:
atan:
derivative:
acsc:
derivative:
asec:
derivative:
acot:
derivative:
arcsin:
derivative:
arccos:
derivative:
arctan:
derivative:
arccsc:
derivative:
arcsec:
derivative:
arccot:
derivative:
sinh:
derivative:
cosh:
derivative:
tanh:
derivative:
csch:
derivative:
sech:
derivative:
coth:
derivative:
asinh:
derivative:
acosh:
derivative:
atanh:
derivative:
acsch:
derivative:
asech:
derivative:
acoth:
derivative:
arsinh:
derivative:
arcosh:
derivative:
artanh:
derivative:
arcsch:
derivative:
arsech:
derivative:
arcoth:
derivative:
arcsinh:
derivative:
arccosh:
derivative:
arctanh:
derivative:
arccsch:
derivative:
arcsech:
derivative:
arccoth:
derivative:
# Exponentials
exp:
derivative:
exp_2:
derivative:
expm1:
derivative:
log:
derivative:
log_2:
derivative:
log1p:
derivative:
ln:
derivative:
ln_2:
derivative:
ln1p:
derivative:
logn:
nargs: 2
derivative:
# Calculus
diff:
nargs: -1
expression:
deriv:
nargs: 2
nargs: -1
integral:
nargs: 3
# Statistics
Expand Down
2 changes: 2 additions & 0 deletions MathEngine/Expressions/Expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ struct BinExpression;
struct HexExpression;
struct InvalidExpression;
struct MatrixExpression;
struct BinaryOperatorExpression;
struct TupleExpression;
struct VariableExpression;

Expand Down Expand Up @@ -46,6 +47,7 @@ struct Expression: public std::enable_shared_from_this<Expression> {
virtual const inline HexExpression* hex() const { return nullptr; }
virtual const inline InvalidExpression* invalid() const { return nullptr; }
virtual const inline MatrixExpression* matrix() const { return nullptr; }
virtual const inline BinaryOperatorExpression* binaryOperator() const { return nullptr; }
virtual const inline TupleExpression* tuple() const { return nullptr; }
virtual const inline VariableExpression* variable() const { return nullptr; }

Expand Down
16 changes: 14 additions & 2 deletions MathEngine/Expressions/Expression/ExpressionFunctions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,21 @@

using namespace std;

expression ln(expression expr1){
return UnaryFunctionExpression::construct("ln", expr1);
expression abs(expression expr1){
return UnaryFunctionExpression::construct("abs", expr1);
}
expression sqr(expression expr1){
return UnaryFunctionExpression::construct("sqr", expr1);
}
expression sqrt(expression expr1){
return UnaryFunctionExpression::construct("sqrt", expr1);
}
expression cb(expression expr1){
return UnaryFunctionExpression::construct("cb", expr1);
}
expression cbrt(expression expr1){
return UnaryFunctionExpression::construct("cbrt", expr1);
}
expression ln(expression expr1){
return UnaryFunctionExpression::construct("ln", expr1);
}
99 changes: 97 additions & 2 deletions MathEngine/Expressions/Expression/ExpressionOperations.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

#include <utility>

#include "../ExpressionFunctions.h"
#include "../ExpressionOperations.h"
#include "../FunctionExpression.h"
#include "../NumericalExpression.h"
Expand All @@ -9,12 +10,24 @@
using namespace std;

expression operator+(const expression expr1, const expression expr2) {
if (expr1->evaluable()){
return expr2 + expr1->value();
}
if (expr2->evaluable()){
return expr1 + expr2->value();
}
return BinaryOperatorExpression::construct("+", expr1, expr2);
}
expression operator+(double expr1, const expression expr2) {
return BinaryOperatorExpression::construct("+", NumExpression::construct(expr1), expr2);
return expr2 + expr1;
}
expression operator+(const expression expr1, double expr2) {
if (expr1->evaluable()){
return NumExpression::construct(expr1->value() + expr2);
}
if (expr2 == 0){
return expr1;
}
return BinaryOperatorExpression::construct("+", expr1, NumExpression::construct(expr2));
}
expression operator+(const gsl_complex& expr1, const expression expr2) {
Expand All @@ -25,15 +38,27 @@ expression operator+(const expression expr1, const gsl_complex& expr2) {
}

expression operator-(const expression expr){
if (expr->evaluable()){
return NumExpression::construct(-expr->value());
}
return UnaryFunctionExpression::construct("neg", expr);
}
expression operator-(const expression expr1, const expression expr2) {
if (expr1->evaluable() && expr2->evaluable()){
return NumExpression::construct(expr1->value() - expr2->value());
}
return BinaryOperatorExpression::construct("-", expr1, expr2);
}
expression operator-(double expr1, const expression expr2) {
if (expr2->evaluable()){
return NumExpression::construct(expr1 - expr2->value());
}
return BinaryOperatorExpression::construct("-", NumExpression::construct(expr1), expr2);
}
expression operator-(const expression expr1, double expr2) {
if (expr1->evaluable()){
return NumExpression::construct(expr1->value() - expr2);
}
return BinaryOperatorExpression::construct("-", expr1, NumExpression::construct(expr2));
}
expression operator-(const gsl_complex& expr1, const expression expr2) {
Expand All @@ -44,28 +69,68 @@ expression operator-(const expression expr1, const gsl_complex& expr2) {
}

expression operator*(const expression expr1, const expression expr2) {
if (expr1->evaluable()){
return expr1->value() * expr2;
}
if (expr2->evaluable()){
return expr2->value() * expr1;
}
return BinaryOperatorExpression::construct("*", expr1, expr2);
}
expression operator*(double expr1, const expression expr2) {
if (expr2->evaluable()){
return NumExpression::construct(expr1 * expr2->value());
}
if (expr1 == 1){
return expr2;
}
if (expr1 == 0){
return NumExpression::construct(0);
}
return BinaryOperatorExpression::construct("*", NumExpression::construct(expr1), expr2);
}
expression operator*(const expression expr1, double expr2) {
return BinaryOperatorExpression::construct("*", expr1, NumExpression::construct(expr2));
return expr2 * expr1;
}
expression operator*(const gsl_complex& expr1, const expression expr2) {
if (GSL_IMAG(expr1) == 0){
if (GSL_REAL(expr1) == 1){
return expr2;
}
if (GSL_REAL(expr1) == 0){
return NumExpression::construct(0);
}
}
return BinaryOperatorExpression::construct("*", NumExpression::construct(expr1), expr2);
}
expression operator*(const expression expr1, const gsl_complex& expr2) {
if (GSL_IMAG(expr2) == 0){
if (GSL_REAL(expr2) == 1){
return expr1;
}
if (GSL_REAL(expr2) == 0){
return NumExpression::construct(0);
}
}
return BinaryOperatorExpression::construct("*", expr1, NumExpression::construct(expr2));
}

expression operator/(const expression expr1, const expression expr2) {
if (expr1->evaluable() && expr2->evaluable()){
return NumExpression::construct(expr1->value() / expr2->value());
}
return BinaryOperatorExpression::construct("/", expr1, expr2);
}
expression operator/(double expr1, const expression expr2) {
if (expr2->evaluable()){
return NumExpression::construct(expr1 / expr2->value());
}
return BinaryOperatorExpression::construct("/", NumExpression::construct(expr1), expr2);
}
expression operator/(const expression expr1, double expr2) {
if (expr1->evaluable()){
return NumExpression::construct(expr1->value() / expr2);
}
return BinaryOperatorExpression::construct("/", expr1, NumExpression::construct(expr2));
}
expression operator/(const gsl_complex& expr1, const expression expr2) {
Expand All @@ -76,12 +141,42 @@ expression operator/(const expression expr1, const gsl_complex& expr2) {
}

expression operator^(const expression expr1, const expression expr2) {
if (expr1->evaluable()){
return expr1->value() ^ expr2;
}
else if (expr2->evaluable()){
return expr1 ^ expr2->value();
}
return BinaryOperatorExpression::construct("^", expr1, expr2);
}
expression operator^(double expr1, const expression expr2) {
if (expr2->evaluable()){
return NumExpression::construct(pow(expr1, expr2->value()));
}
if (expr1 == 1){
return NumExpression::construct(1);
}
if (expr1 == 2){
return UnaryFunctionExpression::construct("exp2", expr2);
}
return BinaryOperatorExpression::construct("^", NumExpression::construct(expr1), expr2);
}
expression operator^(const expression expr1, double expr2) {
if (expr1->evaluable()){
return NumExpression::construct(pow(expr1->value(), expr2));
}
if (expr2 == 0){
return NumExpression::construct(1);
}
if (expr2 == 1){
return expr1;
}
if (expr2 == 2){
return sqr(expr1);
}
if (expr2 == 3){
return cb(expr1);
}
return BinaryOperatorExpression::construct("^", expr1, NumExpression::construct(expr2));
}
expression operator^(const gsl_complex& expr1, const expression expr2) {
Expand Down
7 changes: 6 additions & 1 deletion MathEngine/Expressions/ExpressionFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,10 @@

#include "Expression.h"

expression ln(expression expr1);
expression abs(expression expr1);
expression sqr(expression expr1);
expression sqrt(expression expr1);
expression cb(expression expr1);
expression cbrt(expression expr1);

expression ln(expression expr1);
Loading

0 comments on commit 4e2278a

Please sign in to comment.