112
112
SCHEDULE_NAME = ["Flow_DPM_Solver" ]
113
113
DEFAULT_SCHEDULE_NAME = "Flow_DPM_Solver"
114
114
NUM_IMAGES_PER_PROMPT = 1
115
- TEST_TIMES = 0
116
115
INFER_SPEED = 0
117
116
118
117
@@ -139,10 +138,11 @@ def write_inference_count(count):
139
138
140
139
def run_inference (num_imgs = 1 ):
141
140
write_inference_count (num_imgs )
141
+ count = read_inference_count ()
142
142
143
143
return (
144
144
f"<span style='font-size: 16px; font-weight: bold;'>Total inference runs: </span><span style='font-size: "
145
- f"16px; color:red; font-weight: bold;'>{ TEST_TIMES } </span>"
145
+ f"16px; color:red; font-weight: bold;'>{ count } </span>"
146
146
)
147
147
148
148
@@ -242,13 +242,12 @@ def generate(
242
242
flow_dpms_inference_steps : int = 20 ,
243
243
randomize_seed : bool = False ,
244
244
):
245
- global TEST_TIMES
246
245
global INFER_SPEED
247
246
# seed = 823753551
248
247
run_inference (num_imgs )
249
248
seed = int (randomize_seed_fn (seed , randomize_seed ))
250
249
generator = torch .Generator (device = device ).manual_seed (seed )
251
- print (f"PORT: { DEMO_PORT } , model_path: { model_path } , time_times: { TEST_TIMES } " )
250
+ print (f"PORT: { DEMO_PORT } , model_path: { model_path } " )
252
251
if safety_check .is_dangerous (safety_checker_tokenizer , safety_checker_model , prompt , threshold = 0.2 ):
253
252
prompt = "A red heart."
254
253
@@ -302,7 +301,6 @@ def generate(
302
301
)
303
302
304
303
305
- TEST_TIMES = read_inference_count ()
306
304
model_size = "1.6" if "1600M" in args .model_path else "0.6"
307
305
title = f"""
308
306
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
0 commit comments