Skip to content

Commit ea9d244

Browse files
Implement GetOptimalObjectiveValue function to evaluate objective at solution point (#26)
* Initial plan * Implement GetOptimalObjectiveValue function with comprehensive tests Co-authored-by: kwesiRutledge <[email protected]> * Change FindValueOfExpression signature to return symbolic.Expression (#27) * Initial plan * Change FindValueOfExpression signature to return symbolic.Expression Co-authored-by: kwesiRutledge <[email protected]> * Add helper function to reduce code duplication in tests Co-authored-by: kwesiRutledge <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: kwesiRutledge <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: kwesiRutledge <[email protected]>
1 parent f515767 commit ea9d244

File tree

2 files changed

+219
-7
lines changed

2 files changed

+219
-7
lines changed

solution/solution.go

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ func ExtractValueOfVariable(s Solution, v symbolic.Variable) (float64, error) {
4444

4545
// FindValueOfExpression evaluates a symbolic expression using the values from a solution.
4646
// It substitutes all variables in the expression with their values from the solution
47-
// and returns the resulting scalar value.
48-
func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error) {
47+
// and returns the resulting symbolic expression (typically a constant).
48+
func FindValueOfExpression(s Solution, expr symbolic.Expression) (symbolic.Expression, error) {
4949
// Get all variables in the expression
5050
vars := expr.Variables()
5151

@@ -54,7 +54,7 @@ func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error
5454
for _, v := range vars {
5555
val, err := ExtractValueOfVariable(s, v)
5656
if err != nil {
57-
return 0.0, fmt.Errorf(
57+
return nil, fmt.Errorf(
5858
"failed to extract value for variable %v: %w",
5959
v.ID,
6060
err,
@@ -66,6 +66,31 @@ func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error
6666
// Substitute all variables with their values
6767
resultExpr := expr.SubstituteAccordingTo(subMap)
6868

69+
return resultExpr, nil
70+
}
71+
72+
// GetOptimalObjectiveValue evaluates the objective function of an optimization problem
73+
// at the solution point. It uses the FindValueOfExpression function to compute the value
74+
// of the objective expression using the variable values from the solution.
75+
func GetOptimalObjectiveValue(sol Solution) (float64, error) {
76+
// Get the problem from the solution
77+
prob := sol.GetProblem()
78+
if prob == nil {
79+
return 0.0, fmt.Errorf("solution does not have an associated problem")
80+
}
81+
82+
// Get the objective expression from the problem
83+
objectiveExpr := prob.Objective.Expression
84+
if objectiveExpr == nil {
85+
return 0.0, fmt.Errorf("problem does not have a defined objective")
86+
}
87+
88+
// Use FindValueOfExpression to evaluate the objective at the solution point
89+
resultExpr, err := FindValueOfExpression(sol, objectiveExpr)
90+
if err != nil {
91+
return 0.0, fmt.Errorf("failed to evaluate objective expression: %w", err)
92+
}
93+
6994
// Type assert to K (constant) to extract the float64 value
7095
resultK, ok := resultExpr.(symbolic.K)
7196
if !ok {

testing/solution/solution_test.go

Lines changed: 191 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ Description:
1717
(This seems like it is highly representative of the Gurobi solver; is there a reason to make it this way?)
1818
*/
1919

20+
// Helper function to convert a symbolic.Expression to float64
21+
func exprToFloat64(t *testing.T, expr symbolic.Expression) float64 {
22+
resultK, ok := expr.(symbolic.K)
23+
if !ok {
24+
t.Fatalf("Expected result to be a constant, got type %T", expr)
25+
}
26+
return float64(resultK)
27+
}
28+
2029
func TestSolution_ToMessage1(t *testing.T) {
2130
// Constants
2231
tempSol := solution.DummySolution{
@@ -161,11 +170,13 @@ func TestSolution_FindValueOfExpression1(t *testing.T) {
161170
expr := v1.Multiply(symbolic.K(2.0)).Plus(v2.Multiply(symbolic.K(3.0)))
162171

163172
// Algorithm
164-
result, err := solution.FindValueOfExpression(&tempSol, expr)
173+
resultExpr, err := solution.FindValueOfExpression(&tempSol, expr)
165174
if err != nil {
166175
t.Errorf("FindValueOfExpression returned an error: %v", err)
167176
}
168177

178+
result := exprToFloat64(t, resultExpr)
179+
169180
expected := 13.0
170181
if result != expected {
171182
t.Errorf(
@@ -194,11 +205,13 @@ func TestSolution_FindValueOfExpression2(t *testing.T) {
194205
expr := symbolic.K(42.0)
195206

196207
// Algorithm
197-
result, err := solution.FindValueOfExpression(&tempSol, expr)
208+
resultExpr, err := solution.FindValueOfExpression(&tempSol, expr)
198209
if err != nil {
199210
t.Errorf("FindValueOfExpression returned an error: %v", err)
200211
}
201212

213+
result := exprToFloat64(t, resultExpr)
214+
202215
expected := 42.0
203216
if result != expected {
204217
t.Errorf(
@@ -231,11 +244,13 @@ func TestSolution_FindValueOfExpression3(t *testing.T) {
231244
expr := v1.Plus(symbolic.K(10.0))
232245

233246
// Algorithm
234-
result, err := solution.FindValueOfExpression(&tempSol, expr)
247+
resultExpr, err := solution.FindValueOfExpression(&tempSol, expr)
235248
if err != nil {
236249
t.Errorf("FindValueOfExpression returned an error: %v", err)
237250
}
238251

252+
result := exprToFloat64(t, resultExpr)
253+
239254
expected := 15.5
240255
if result != expected {
241256
t.Errorf(
@@ -304,11 +319,13 @@ func TestSolution_FindValueOfExpression5(t *testing.T) {
304319
expr := v1.Plus(v2).Multiply(v3).Plus(symbolic.K(5.0))
305320

306321
// Algorithm
307-
result, err := solution.FindValueOfExpression(&tempSol, expr)
322+
resultExpr, err := solution.FindValueOfExpression(&tempSol, expr)
308323
if err != nil {
309324
t.Errorf("FindValueOfExpression returned an error: %v", err)
310325
}
311326

327+
result := exprToFloat64(t, resultExpr)
328+
312329
expected := 14.0
313330
if result != expected {
314331
t.Errorf(
@@ -377,3 +394,173 @@ func TestSolution_GetProblem2(t *testing.T) {
377394
t.Errorf("Expected GetProblem to return nil when no problem is set")
378395
}
379396
}
397+
398+
/*
399+
TestSolution_GetOptimalObjectiveValue1
400+
Description:
401+
402+
This function tests whether we can compute the objective value at the solution point
403+
for a simple linear objective.
404+
*/
405+
func TestSolution_GetOptimalObjectiveValue1(t *testing.T) {
406+
// Constants
407+
p := problem.NewProblem("TestProblem")
408+
v1 := p.AddVariable()
409+
v2 := p.AddVariable()
410+
411+
// Set objective: 2*v1 + 3*v2
412+
objectiveExpr := v1.Multiply(symbolic.K(2.0)).Plus(v2.Multiply(symbolic.K(3.0)))
413+
err := p.SetObjective(objectiveExpr, problem.SenseMinimize)
414+
if err != nil {
415+
t.Errorf("Failed to set objective: %v", err)
416+
}
417+
418+
// Create solution with v1=2.0, v2=3.0
419+
// Expected objective value: 2*2.0 + 3*3.0 = 4.0 + 9.0 = 13.0
420+
tempSol := solution.DummySolution{
421+
Values: map[uint64]float64{
422+
v1.ID: 2.0,
423+
v2.ID: 3.0,
424+
},
425+
Objective: 13.0,
426+
Status: solution_status.OPTIMAL,
427+
Problem: p,
428+
}
429+
430+
// Algorithm
431+
objectiveValue, err := solution.GetOptimalObjectiveValue(&tempSol)
432+
if err != nil {
433+
t.Errorf("GetOptimalObjectiveValue returned an error: %v", err)
434+
}
435+
436+
expected := 13.0
437+
if objectiveValue != expected {
438+
t.Errorf(
439+
"Expected objective value to be %v; received %v",
440+
expected,
441+
objectiveValue,
442+
)
443+
}
444+
}
445+
446+
/*
447+
TestSolution_GetOptimalObjectiveValue2
448+
Description:
449+
450+
This function tests whether GetOptimalObjectiveValue returns an error
451+
when the solution has no associated problem.
452+
*/
453+
func TestSolution_GetOptimalObjectiveValue2(t *testing.T) {
454+
// Constants
455+
v1 := symbolic.NewVariable()
456+
457+
tempSol := solution.DummySolution{
458+
Values: map[uint64]float64{
459+
v1.ID: 2.0,
460+
},
461+
Objective: 2.3,
462+
Status: solution_status.OPTIMAL,
463+
Problem: nil,
464+
}
465+
466+
// Algorithm
467+
_, err := solution.GetOptimalObjectiveValue(&tempSol)
468+
if err == nil {
469+
t.Errorf("Expected GetOptimalObjectiveValue to return an error for nil problem, but got nil")
470+
}
471+
}
472+
473+
/*
474+
TestSolution_GetOptimalObjectiveValue3
475+
Description:
476+
477+
This function tests whether we can compute the objective value
478+
for a constant objective function.
479+
*/
480+
func TestSolution_GetOptimalObjectiveValue3(t *testing.T) {
481+
// Constants
482+
p := problem.NewProblem("TestProblem")
483+
v1 := p.AddVariable()
484+
485+
// Set constant objective: 42.0
486+
objectiveExpr := symbolic.K(42.0)
487+
err := p.SetObjective(objectiveExpr, problem.SenseMaximize)
488+
if err != nil {
489+
t.Errorf("Failed to set objective: %v", err)
490+
}
491+
492+
// Create solution
493+
tempSol := solution.DummySolution{
494+
Values: map[uint64]float64{
495+
v1.ID: 1.0,
496+
},
497+
Objective: 42.0,
498+
Status: solution_status.OPTIMAL,
499+
Problem: p,
500+
}
501+
502+
// Algorithm
503+
objectiveValue, err := solution.GetOptimalObjectiveValue(&tempSol)
504+
if err != nil {
505+
t.Errorf("GetOptimalObjectiveValue returned an error: %v", err)
506+
}
507+
508+
expected := 42.0
509+
if objectiveValue != expected {
510+
t.Errorf(
511+
"Expected objective value to be %v; received %v",
512+
expected,
513+
objectiveValue,
514+
)
515+
}
516+
}
517+
518+
/*
519+
TestSolution_GetOptimalObjectiveValue4
520+
Description:
521+
522+
This function tests whether we can compute the objective value
523+
for a more complex objective with multiple variables and operations.
524+
*/
525+
func TestSolution_GetOptimalObjectiveValue4(t *testing.T) {
526+
// Constants
527+
p := problem.NewProblem("TestProblem")
528+
v1 := p.AddVariable()
529+
v2 := p.AddVariable()
530+
v3 := p.AddVariable()
531+
532+
// Set objective: (v1 + v2) * v3 + 5
533+
objectiveExpr := v1.Plus(v2).Multiply(v3).Plus(symbolic.K(5.0))
534+
err := p.SetObjective(objectiveExpr, problem.SenseMinimize)
535+
if err != nil {
536+
t.Errorf("Failed to set objective: %v", err)
537+
}
538+
539+
// Create solution with v1=1.0, v2=2.0, v3=3.0
540+
// Expected objective: (1.0 + 2.0) * 3.0 + 5 = 3.0 * 3.0 + 5 = 14.0
541+
tempSol := solution.DummySolution{
542+
Values: map[uint64]float64{
543+
v1.ID: 1.0,
544+
v2.ID: 2.0,
545+
v3.ID: 3.0,
546+
},
547+
Objective: 14.0,
548+
Status: solution_status.OPTIMAL,
549+
Problem: p,
550+
}
551+
552+
// Algorithm
553+
objectiveValue, err := solution.GetOptimalObjectiveValue(&tempSol)
554+
if err != nil {
555+
t.Errorf("GetOptimalObjectiveValue returned an error: %v", err)
556+
}
557+
558+
expected := 14.0
559+
if objectiveValue != expected {
560+
t.Errorf(
561+
"Expected objective value to be %v; received %v",
562+
expected,
563+
objectiveValue,
564+
)
565+
}
566+
}

0 commit comments

Comments
 (0)