Skip to content

Commit e5f101e

Browse files
committed
initial pass at warmstarted FWPH
1 parent 3c6d534 commit e5f101e

File tree

1 file changed

+172
-45
lines changed

1 file changed

+172
-45
lines changed

mpisppy/opt/fwph.py

Lines changed: 172 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import time
2424
import re # For manipulating scenario names
2525
import random
26+
import math
2627

2728
from mpisppy import MPI
2829
from mpisppy import global_toc
@@ -110,7 +111,56 @@ def fwph_main(self, finalize=True):
110111
self._reenable_W()
111112

112113
if (self.ph_converger):
113-
self.convobject = self.ph_converger(self, self.cylinder_rank, self.n_proc)
114+
self.convobject = self.ph_converger(self)
115+
116+
if self.options.get("FW_LP_start_iterations", 1000) > 0:
117+
global_toc("Starting LP PH...")
118+
lp_iterations = self.options.get("FW_LP_start_iterations", 1000)
119+
total_iterations = self.options["PHIterLimit"]
120+
self.options["PHIterLimit"] = lp_iterations
121+
integer_relaxer = pyo.TransformationFactory('core.relax_integer_vars')
122+
for s in self.local_subproblems.values():
123+
integer_relaxer.apply_to(s)
124+
if sputils.is_persistent(s._solver_plugin):
125+
for v,_ in s._relaxed_integer_vars[None].values():
126+
s._solver_plugin.update_var(v)
127+
self.attach_PH_to_objective(add_duals=False, add_prox=True)
128+
self._reenable_prox()
129+
super().iterk_loop()
130+
self._disable_prox()
131+
for s in self.local_subproblems.values():
132+
for v, d in s._relaxed_integer_vars[None].values():
133+
v.domain = d
134+
if sputils.is_persistent(s._solver_plugin):
135+
s._solver_plugin.update_var(v)
136+
# s._solver_plugin.update_var(v)
137+
s.del_component("_relaxed_integer_vars")
138+
self.options["PHIterLimit"] = total_iterations
139+
self._PHIter -= 1
140+
141+
global_toc("Finished LP PH; Starting FW PH crossover")
142+
teeme = (
143+
self.options.get("tee-rank0-solves", False)
144+
and self.cylinder_rank == 0
145+
)
146+
# teeme = True
147+
self.fwph_solve_loop(
148+
mip_solver_options=self.current_solver_options,
149+
dtiming=self.options["display_timing"],
150+
tee=teeme,
151+
verbose=self.options["verbose"],
152+
# sdm_iter_limit=20,
153+
# FW_conv_thresh=-1,
154+
)
155+
global_toc("Starting FW PH")
156+
157+
else:
158+
# FWPH can take some time to initialize
159+
# If run as a spoke, check for convergence here
160+
if self.spcomm and self.spcomm.is_converged():
161+
if finalize:
162+
return 0, None, None
163+
return 0
114164

115165
self.iterk_loop()
116166

@@ -131,20 +181,17 @@ def iterk_loop(self):
131181
and self.options["tee-rank0-solves"]
132182
and self.cylinder_rank == 0
133183
)
184+
# teeme = True
134185

135186
self.conv = None
136187

137188
max_iterations = int(self.options["PHIterLimit"])
138-
# FWPH can take some time to initialize
139-
# If run as a spoke, check for convergence here
140-
if self.spcomm and self.spcomm.is_converged():
141-
return
142189

143190
# The body of the algorithm
144-
for self._PHIter in range(1, max_iterations+1):
191+
while (self._PHIter < max_iterations):
145192
iteration_start_time = time.perf_counter()
146193
if dprogress:
147-
global_toc(f"Initiating FWPH Iteration {self._PHIter}\n", self.cylinder_rank == 0)
194+
global_toc(f"Initiating FWPH Major Iteration {self._PHIter+1}\n", self.cylinder_rank == 0)
148195

149196
# tbphloop = time.perf_counter()
150197
# TODO: should implement our own Xbar / W computation
@@ -157,18 +204,20 @@ def iterk_loop(self):
157204
if hasattr(self.spcomm, "sync_Ws"):
158205
self.spcomm.sync_Ws()
159206

160-
self.conv = self.convergence_diff()
207+
self.conv = self.fwph_convergence_diff()
161208

162209
if (self.extensions):
163210
self.extobject.miditer()
164211

165212
if (self.ph_converger):
166-
diff = self.convobject.convergence_value()
213+
self._swap_nonant_vars()
167214
if (self.convobject.is_converged()):
168215
secs = time.perf_counter() - self.start_time
169-
self._output(self._local_bound, self._fwph_best_bound, diff, secs)
216+
self._output(self._local_bound, self._fwph_best_bound, self.conv, secs)
170217
global_toc('FWPH converged to user-specified criteria', self.cylinder_rank == 0)
218+
self._swap_nonant_vars_back()
171219
break
220+
self._swap_nonant_vars_back()
172221
if self.conv is not None: # Convergence check from Boland
173222
if (self.conv < self.options['convthresh']):
174223
secs = time.perf_counter() - self.start_time
@@ -197,12 +246,7 @@ def iterk_loop(self):
197246
self._output(self._local_bound, self._fwph_best_bound, self.conv, secs)
198247

199248
## Hubs/spokes take precedence over convergers
200-
if hasattr(self.spcomm, "sync_bounds"):
201-
self.spcomm.sync_bounds()
202-
self.spcomm.sync_extensions()
203-
elif hasattr(self.spcomm, "sync"):
204-
self.spcomm.sync()
205-
if self.spcomm and self.spcomm.is_converged():
249+
if self.spcomm and self.spcomm.is_converged(screen_trace=False):
206250
secs = time.perf_counter() - self.start_time
207251
self._output(self._local_bound, self._fwph_best_bound, np.nan, secs)
208252
global_toc("Cylinder convergence", self.cylinder_rank == 0)
@@ -238,22 +282,82 @@ def fwph_solve_loop(
238282
dtiming=False,
239283
tee=False,
240284
verbose=False,
285+
sdm_iter_limit=None,
286+
FW_conv_thresh=None,
241287
):
288+
if sdm_iter_limit is None:
289+
sdm_iter_limit = self.FW_options["FW_iter_limit"]
290+
if FW_conv_thresh is None:
291+
FW_conv_thresh = self.FW_options["FW_conv_thresh"]
292+
max_iterations = int(self.options["PHIterLimit"])
293+
# print(f"{sdm_iter_limit=}")
242294
self._swap_nonant_vars()
243295
self._local_bound = 0
244296
# tbsdm = time.perf_counter()
297+
_sdm_generators = {}
298+
stop = False
245299
for name in self.local_subproblems:
246-
dual_bound = self.SDM(name, mip_solver_options, dtiming, tee, verbose)
247-
if dual_bound is None:
248-
dual_bound = np.nan
300+
_sdm_generators[name] = self.SDM(name, mip_solver_options, dtiming, tee, verbose, sdm_iter_limit, FW_conv_thresh)
301+
try:
302+
dual_bound = next(_sdm_generators[name])
303+
except StopIteration as e:
304+
dual_bound = e.value
305+
stop = True
249306
self._local_bound += self.local_subproblems[name]._mpisppy_probability * \
250307
dual_bound
308+
self._update_dual_bounds()
309+
self._PHIter += 1
310+
if self._PHIter == max_iterations:
311+
stop = True
312+
if self._sync_after_mip_solve():
313+
stop = True
314+
stop = self.allreduce_or(stop)
315+
while not stop:
316+
stop = False
317+
for col_generator in _sdm_generators.values():
318+
try:
319+
next(col_generator)
320+
except StopIteration:
321+
stop = True
322+
self._PHIter += 1
323+
if self._PHIter == max_iterations:
324+
stop = True
325+
if self._sync_after_mip_solve():
326+
stop = True
327+
stop = self.allreduce_or(stop)
251328
# tsdm = time.perf_counter() - tbsdm
252329
# print(f"PH iter {self._PHIter}, total SDM time: {tsdm}")
253-
self._update_dual_bounds()
330+
331+
# Re-set the mip._mpisppy_model.W so that the QP objective
332+
# is correct in the next major iteration
333+
for model_name, mip in self.local_subproblems.items():
334+
qp = self.local_QP_subproblems[model_name]
335+
mip_source = mip.scen_list if self.bundling else [model_name]
336+
for scenario_name in mip_source:
337+
scen_mip = self.local_scenarios[scenario_name]
338+
for (node_name, ix) in scen_mip._mpisppy_data.nonant_indices:
339+
scen_mip._mpisppy_model.W[node_name, ix]._value = \
340+
qp._mpisppy_model.W[node_name, ix]._value
341+
254342
self._swap_nonant_vars_back()
255343

256-
def SDM(self, model_name, mip_solver_options, dtiming, tee, verbose):
344+
def _sync_after_mip_solve(self):
345+
# add columns from cylinder(s)
346+
self._swap_nonant_vars_back()
347+
if hasattr(self.spcomm, "add_cylinder_columns"):
348+
self.spcomm.sync_nonants()
349+
self.spcomm.add_cylinder_columns()
350+
if hasattr(self.spcomm, "sync_bounds"):
351+
self.spcomm.sync_bounds()
352+
self.spcomm.sync_extensions()
353+
elif hasattr(self.spcomm, "sync"):
354+
self.spcomm.sync()
355+
self._swap_nonant_vars()
356+
if self.spcomm and self.spcomm.is_converged():
357+
return True
358+
return False
359+
360+
def SDM(self, model_name, mip_solver_options, dtiming, tee, verbose, sdm_iter_limit, FW_conv_thresh):
257361
''' Algorithm 2 in Boland et al. (with small tweaks)
258362
'''
259363
mip = self.local_subproblems[model_name]
@@ -280,7 +384,7 @@ def SDM(self, model_name, mip_solver_options, dtiming, tee, verbose):
280384
for ndn_i, xvar in arb_scen_mip._mpisppy_data.nonant_indices.items()
281385
}
282386

283-
for itr in range(self.FW_options['FW_iter_limit']):
387+
for itr in range(sdm_iter_limit):
284388
# loop_start = time.perf_counter()
285389
# Algorithm 2 line 4
286390
for scenario_name in mip_source:
@@ -292,7 +396,9 @@ def SDM(self, model_name, mip_solver_options, dtiming, tee, verbose):
292396
* (xt[ndn_i]
293397
- scen_mip._mpisppy_model.xbars[ndn_i]._value))
294398

399+
self._fix_fixings(model_name, mip, qp)
295400
cutoff = self._add_objective_cutoff(mip, qp)
401+
# print(f"{model_name=}, {cutoff=}")
296402
# Algorithm 2 line 5
297403
self.solve_one(
298404
mip_solver_options,
@@ -304,15 +410,14 @@ def SDM(self, model_name, mip_solver_options, dtiming, tee, verbose):
304410
)
305411
self._remove_objective_cutoff(mip)
306412
# tmipsolve = time.perf_counter() - tbmipsolve
307-
if mip._mpisppy_data.scenario_feasible:
308413

309-
# Algorithm 2 lines 6--8
310-
if (itr == 0):
311-
dual_bound = mip._mpisppy_data.outer_bound
414+
if mip._mpisppy_data.scenario_feasible:
312415

313416
# Algorithm 2 line 9 (compute \Gamma^t)
314417
inner_bound = mip._mpisppy_data.inner_bound
418+
# print(f"{model_name=}, {inner_bound=}")
315419
gamma_t = self._compute_gamma_t(cutoff, inner_bound)
420+
# print(f"{itr=}, {model_name=}, {gamma_t=}")
316421

317422
# tbcol = time.perf_counter()
318423
self._add_QP_column(model_name)
@@ -329,12 +434,6 @@ def SDM(self, model_name, mip_solver_options, dtiming, tee, verbose):
329434
self._swap_nonant_vars_back()
330435
self._add_shared_columns(shared_columns)
331436
self._swap_nonant_vars()
332-
# add columns from cylinder(s)
333-
if hasattr(self.spcomm, "add_cylinder_columns"):
334-
self._swap_nonant_vars_back()
335-
self.spcomm.sync_nonants()
336-
self.spcomm.add_cylinder_columns()
337-
self._swap_nonant_vars()
338437

339438
# tbqpsol = time.perf_counter()
340439
# QPs are weird if bundled
@@ -353,23 +452,25 @@ def SDM(self, model_name, mip_solver_options, dtiming, tee, verbose):
353452
# print(f"{model_name}, solve + add_col time: {tmipsolve + tcol + tqpsol}")
354453
# fwloop = time.perf_counter() - loop_start
355454
# print(f"{model_name}, total loop time: {fwloop}")
455+
# Algorithm 2 lines 6--8
456+
# Stopping after the MIP solve will give a point
457+
# to synchronize with spokes
458+
dual_bound = None
459+
if (itr == 0):
460+
if mip._mpisppy_data.scenario_feasible:
461+
dual_bound = mip._mpisppy_data.outer_bound
462+
else:
463+
dual_bound = np.nan
356464

357-
if not mip._mpisppy_data.scenario_feasible or (gamma_t < self.FW_options['FW_conv_thresh']):
358-
break
465+
if itr + 1 == sdm_iter_limit or not mip._mpisppy_data.scenario_feasible or gamma_t < FW_conv_thresh:
466+
return dual_bound
467+
else:
468+
yield dual_bound
359469

360470
# reset for next loop
361471
for ndn_i, xvar in arb_scen_mip._mpisppy_data.nonant_indices.items():
362472
xt[ndn_i] = xvar._value
363473

364-
# Re-set the mip._mpisppy_model.W so that the QP objective
365-
# is correct in the next major iteration
366-
for scenario_name in mip_source:
367-
scen_mip = self.local_scenarios[scenario_name]
368-
for (node_name, ix) in scen_mip._mpisppy_data.nonant_indices:
369-
scen_mip._mpisppy_model.W[node_name, ix]._value = \
370-
qp._mpisppy_model.W[node_name, ix]._value
371-
372-
return dual_bound
373474

374475
def _add_shared_columns(self, shared_columns):
375476
self.mpicomm.Barrier()
@@ -382,6 +483,28 @@ def _add_shared_columns(self, shared_columns):
382483
self._generate_shared_column(shared_columns)
383484
self._reenable_W()
384485

486+
def _fix_fixings(self, model_name, mip, qp):
487+
""" If some variable is fixed in the mip, but its value in the QP does
488+
not agree with that fixed value, we will have a bad time. This method
489+
removes such fixings.
490+
"""
491+
for var in qp.x.values():
492+
if var.fixed:
493+
raise RuntimeError(f"var {var.name} is fixed in QP!!")
494+
solver = mip._solver_plugin
495+
unfixed = 0
496+
target = mip.ref_vars if self.bundling else mip._mpisppy_data.nonant_vars
497+
mip_to_qp = mip._mpisppy_data.mip_to_qp
498+
for ndn_i, var in target.items():
499+
if var.fixed:
500+
if not math.isclose(mip_to_qp[id(var)].value, var.value, abs_tol=1e-5):
501+
var.unfix()
502+
if sputils.is_persistent(solver):
503+
solver.update_var(var)
504+
unfixed += 1
505+
if unfixed > 0:
506+
global_toc(f"{self.__class__.__name__}: unfixed {unfixed} nonant variables in {model_name}", True)
507+
385508
def _add_QP_column(self, model_name, disable_W=False):
386509
''' Add a column to the QP, with values taken from the most recent MIP
387510
solve. Assumes the inner_bound is up-to-date in the MIP model.
@@ -458,8 +581,10 @@ def _add_objective_cutoff(self, mip, qp):
458581
an improving direction in the QP subproblem is generated
459582
"""
460583
assert not hasattr(mip._mpisppy_model, "obj_cutoff_constraint")
584+
# print(f"\tnonants part: {pyo.value(qp._mpisppy_model.mip_obj_in_qp)}")
585+
# print(f"\trecoursepart: {pyo.value(qp.recourse_cost)}")
461586
cutoff = pyo.value(qp._mpisppy_model.mip_obj_in_qp) + pyo.value(qp.recourse_cost)
462-
epsilon = 0 #1e-6
587+
epsilon = 0 #1e-4
463588
rel_epsilon = abs(cutoff)*epsilon
464589
epsilon = max(epsilon, rel_epsilon)
465590
# tbmipsolve = time.perf_counter()
@@ -559,8 +684,10 @@ def _update_dual_bounds(self):
559684
self._fwph_best_bound = np.fmin(self._fwph_best_bound, self._local_bound)
560685
if self._can_update_best_bound():
561686
self.best_bound_obj_val = self._fwph_best_bound
687+
# if self.cylinder_rank == 0:
688+
# print(f"{self._local_bound=}")
562689

563-
def convergence_diff(self):
690+
def fwph_convergence_diff(self):
564691
''' Perform the convergence check of Algorithm 3 in Boland et al. '''
565692
diff = 0.
566693
for name in self.local_subproblems.keys():

0 commit comments

Comments
 (0)