@@ -30,11 +30,13 @@ public enum GradientMethod {
3030 private final GradientMethod gradientMethod ;
3131
3232 private final int iterations ;
33- private boolean runnning = false ;
3433 private double [] learningRate ;
3534 private final double eps ;
3635 private final double [] betas ;
3736
37+ private boolean runnning = false ;
38+ private int iteration ;
39+
3840 private final RandomVariableDifferentiable [] parameters ;
3941 private RandomVariableDifferentiable [] bestFitParameters ;
4042 private double bestValue = Double .MAX_VALUE ;
@@ -92,7 +94,7 @@ public void run() {
9294 final double [] m = new double [parameters .length ];
9395 final double [] v = new double [parameters .length ];
9496
95- for (int k =0 ; k <iterations && runnning ; k ++) {
97+ for (iteration =0 ; iteration <iterations && runnning ; iteration ++) {
9698 final RandomVariable value = setValue (parameters );
9799 if (value .getAverage () < bestValue || bestFitParameters == null ) {
98100 bestValue = value .getAverage ();
@@ -113,18 +115,18 @@ public void run() {
113115 m [i ] = (betas [0 ]*m [i ] + (1 -betas [0 ])*gradient );
114116 v [i ] = (betas [1 ]*v [i ] + (1 -betas [1 ])*gradient *gradient );
115117
116- final double update_m = m [i ] / (1 -Math .pow (betas [0 ],k +1 ));
117- final double update_v = v [i ] / (1 -Math .pow (betas [1 ],k +1 ));
118+ final double update_m = m [i ] / (1 -Math .pow (betas [0 ], iteration +1 ));
119+ final double update_v = v [i ] / (1 -Math .pow (betas [1 ], iteration +1 ));
118120 final double stepDirection = update_m / (Math .sqrt (update_v )+eps );
119121
120122 parameters [i ] = ((RandomVariableDifferentiable ) parameters [i ].sub (learningRate [i ]*stepDirection )).getCloneIndependent ();
121123 }
122124
123- if ( k % 10 == 0 ) {
125+ if ( iteration % 10 == 0 ) {
124126 final double valueForPrinting = (gradientMethod == GradientMethod .AVERAGE ) ? value .getAverage () :
125127 -RandomOperators .expectedShortFall (value .mult (-1.0 ),0.05 ).doubleValue ();
126- if (k % 100 == 0 ) {
127- System .out .printf ("iteration %8d \t \t value %8.4f %n" , k , -valueForPrinting );
128+ if (iteration % 100 == 0 ) {
129+ System .out .printf ("iteration %8d \t \t value %8.4f %n" , iteration , -valueForPrinting );
128130 } else {
129131 // System.out.printf("iteration %8d \t\t value %8.4f \r", k, -valueForPrinting);
130132 }
@@ -140,7 +142,7 @@ public void run() {
140142 v [i ] = randomVariableFactory .createRandomVariable (0 );
141143 }
142144
143- for (int k =0 ; k <iterations && runnning ; k ++) {
145+ for (iteration =0 ; iteration <iterations && runnning ; iteration ++) {
144146 final RandomVariable value = setValue (parameters );
145147 if (value .getAverage () < bestValue || bestFitParameters == null ) {
146148 bestValue = value .getAverage ();
@@ -160,16 +162,16 @@ public void run() {
160162 m [i ] = m [i ].mult (betas [0 ]).add (gradient .mult (1 -betas [0 ]));
161163 v [i ] = v [i ].mult (betas [1 ]).add (gradient .squared ().mult (1 -betas [1 ]));
162164
163- final RandomVariable update_m = m [i ].div (1 -Math .pow (betas [0 ],k +1 ));
164- final RandomVariable update_v = v [i ].div (1 -Math .pow (betas [1 ],k +1 ));
165+ final RandomVariable update_m = m [i ].div (1 -Math .pow (betas [0 ], iteration +1 ));
166+ final RandomVariable update_v = v [i ].div (1 -Math .pow (betas [1 ], iteration +1 ));
165167 final RandomVariable stepDirection = update_m .div (update_v .sqrt ().add (eps ));
166168
167169 parameters [i ] =
168170 ((RandomVariableDifferentiable ) parameters [i ].sub (stepDirection .mult (learningRate [i ]))).getCloneIndependent ();
169171 }
170172
171- if (k % 100 == 0 ) {
172- System .out .printf ("iteration %8.4f \t \t value %8.4f %n" , (double ) k ,value .getAverage ());
173+ if (iteration % 100 == 0 ) {
174+ System .out .printf ("iteration %8.4f \t \t value %8.4f %n" , (double )iteration ,value .getAverage ());
173175 }
174176 }
175177 }
@@ -209,4 +211,8 @@ private RandomVariable[] getGradient(RandomVariable[] parameters, RandomVariable
209211
210212 return gradient ;
211213 }
214+
215+ public double getIteration () {
216+ return iteration ;
217+ }
212218}
0 commit comments