2323from random import random
2424from typing import Any , Hashable , Iterable , Optional , TypeAlias
2525
26- from mip import BINARY , Model , maximize , minimize , xsum
26+ import pulp
2727
28- from . utils import InfeasibleError , consecutive_ngrams , get_nested_key , sort_replacements
28+ from modelopt . torch . _compress . mip . utils import consecutive_ngrams , get_nested_key , sort_replacements
2929
3030ReplacementID : TypeAlias = Hashable
3131Replacement : TypeAlias = dict [str , Any ]
@@ -38,6 +38,7 @@ def run_mip(
3838 constraints : dict [str , float ],
3939 bigger_is_better : bool ,
4040 max_seconds_per_solution : Optional [float ] = None ,
41+ verbose : bool = True ,
4142) -> tuple [ChosenReplacements , float , dict [str , float ]]:
4243 orig_num_replacements = len (replacements )
4344 replacements = {
@@ -52,13 +53,15 @@ def run_mip(
5253 )
5354 print ("\n \n \n " )
5455
55- mip_model = Model ()
56+ # Create pulp problem with appropriate sense (minimize or maximize)
57+ sense = pulp .LpMaximize if bigger_is_better else pulp .LpMinimize
58+ problem = pulp .LpProblem (name = "multi_layer_replacement" , sense = sense )
5659
5760 objective_vars = []
5861 constraint_vars = {constraint_key : [] for constraint_key in constraints .keys ()}
5962 choice_indicators_by_layer = defaultdict (list )
60- for replacement_id , replacement in replacements .items ():
61- is_chosen = mip_model . add_var ( var_type = BINARY )
63+ for i , ( replacement_id , replacement ) in enumerate ( replacements .items () ):
64+ is_chosen = pulp . LpVariable ( f"choice_ { i } " , cat = pulp . LpBinary )
6265 replacement ["is_chosen" ] = is_chosen
6366
6467 for parent_layer_idx in replacement ["parent_layer_indices" ]:
@@ -73,30 +76,29 @@ def run_mip(
7376
7477 # MIP constraints: each parent layer must come from exactly one chosen replacement
7578 for parent_layer_idx , curr_choice_indicators in choice_indicators_by_layer .items ():
76- mip_model += xsum (curr_choice_indicators ) == 1
79+ problem += pulp . lpSum (curr_choice_indicators ) == 1
7780
7881 # MIP constraints: the sum of chosen replacement costs must be lower than the max cost
7982 for constraint_key , max_cost in constraints .items ():
8083 min_cost = None
8184 if isinstance (max_cost , Iterable ):
8285 min_cost , max_cost = max_cost
8386
84- if max_cost is not None :
85- mip_model += xsum (constraint_vars [constraint_key ]) <= max_cost
86- if min_cost is not None :
87- mip_model += xsum (constraint_vars [constraint_key ]) >= min_cost
87+ # PuLP is stricter than mip - it doesn't allow NaN/inf in constraints
88+ if max_cost is not None and math .isfinite (max_cost ):
89+ problem += pulp .lpSum (constraint_vars [constraint_key ]) <= max_cost
90+ if min_cost is not None and math .isfinite (min_cost ):
91+ problem += pulp .lpSum (constraint_vars [constraint_key ]) >= min_cost
8892
8993 # MIP objective
90- mip_model .objective = (
91- maximize (xsum (objective_vars )) if bigger_is_better else minimize (xsum (objective_vars ))
92- )
93-
94- if max_seconds_per_solution is not None :
95- mip_model .max_seconds = max_seconds_per_solution
94+ problem += (pulp .lpSum (objective_vars ), "objective" )
9695
97- mip_model .optimize ()
96+ # Configure and run solver
97+ solver = pulp .PULP_CBC_CMD (msg = verbose , timeLimit = max_seconds_per_solution )
98+ problem .solve (solver )
9899
99- if is_chosen .x is None :
100+ # Check if solution is feasible
101+ if problem .status != pulp .LpStatusOptimal :
100102 return []
101103 # raise InfeasibleError()
102104
@@ -106,7 +108,7 @@ def run_mip(
106108 chosen_replacements : ChosenReplacements = []
107109 chosen_layers = []
108110 for replacement_id , replacement in replacements .items ():
109- is_chosen = replacement ["is_chosen" ].x >= 0.99
111+ is_chosen = replacement ["is_chosen" ].varValue >= 0.99
110112 if is_chosen :
111113 assert replacement not in chosen_replacements
112114 chosen_replacements .append (replacement )
0 commit comments