Skip to content

Commit 661080f

Browse files
authoredJul 12, 2023
Add pscpu infer (PaddlePaddle#935)
* [add cpups infer] infer_from_dataset * [add cpups infer] infer_from_dataset * [bug fix] gpups benchmark * [bug fix] gpups benchmark adaptation
1 parent 3605123 commit 661080f

29 files changed

+372
-53
lines changed
 

‎models/demo/movie_recommand/utils/static_ps/infer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import print_function
1616
from reader_helper import get_reader, get_infer_reader, get_example_num, get_file_list, get_word_num
1717
from program_helper import get_model, get_strategy
18-
from common import YamlHelper, is_number
18+
from common_ps import YamlHelper, is_number
1919
import os
2020
import numpy as np
2121
import warnings

‎models/demo/movie_recommand/utils/static_ps/program_helper.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import paddle
2020
import paddle.distributed.fleet.base.role_maker as role_maker
2121
import paddle.distributed.fleet as fleet
22-
import common
22+
import common_ps
2323
import sys
2424

2525
logging.basicConfig(
@@ -28,7 +28,7 @@
2828

2929

3030
def get_strategy(config):
31-
if not common.is_distributed_env():
31+
if not common_ps.is_distributed_env():
3232
logger.warn(
3333
"Not Find Distributed env, Change To local train mode. If you want train with fleet, please use [fleetrun] command."
3434
)

‎models/demo/movie_recommand/utils/static_ps/reader_helper.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import paddle.distributed.fleet as fleet
2323
__dir__ = os.path.dirname(os.path.abspath(__file__))
2424
sys.path.append(__dir__)
25-
import common
25+
import common_ps
2626

2727
logging.basicConfig(
2828
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
@@ -95,7 +95,7 @@ def get_word_num(file_list):
9595

9696

9797
def get_reader_generator(path, reader_name="Reader"):
98-
reader_class = common.lazy_instance_by_fliename(path, reader_name)()
98+
reader_class = common_ps.lazy_instance_by_fliename(path, reader_name)()
9999
return reader_class
100100

101101

@@ -116,7 +116,8 @@ def get_reader(self):
116116
logger.info("Reader Path: {}".format(reader_path))
117117

118118
from paddle.io import DataLoader
119-
dataset = common.lazy_instance_by_fliename(reader_path, "RecDataset")
119+
dataset = common_ps.lazy_instance_by_fliename(reader_path,
120+
"RecDataset")
120121
print("dataset: {}".format(dataset))
121122

122123
use_cuda = int(self.config.get("runner.use_gpu"))
@@ -197,7 +198,7 @@ def __init__(self, input_var, file_list, config):
197198
self.pipe_command = self.config.get("runner.pipe_command")
198199
self.train_reader = self.config.get("runner.train_reader_path")
199200
assert self.pipe_command != None
200-
utils_path = common.get_utils_file_path()
201+
utils_path = common_ps.get_utils_file_path()
201202
print("utils_path: {}".format(utils_path))
202203
abs_train_reader = os.path.join(config["config_abs_dir"],
203204
self.train_reader)

‎models/match/dssm/bq_reader_train_insid.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def reader():
7272
yaml_path = sys.argv[1]
7373
utils_path = sys.argv[2]
7474
sys.path.append(utils_path)
75-
import common
76-
yaml_helper = common.YamlHelper()
75+
import common_ps
76+
yaml_helper = common_ps.YamlHelper()
7777
config = yaml_helper.load_yaml(yaml_path)
7878

7979
r = Reader()

‎models/rank/deepfm/benchmark_reader.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def reader():
8181
yaml_path = sys.argv[1]
8282
utils_path = sys.argv[2]
8383
sys.path.append(utils_path)
84-
import common
85-
yaml_helper = common.YamlHelper()
84+
import common_ps
85+
yaml_helper = common_ps.YamlHelper()
8686
config = yaml_helper.load_yaml(yaml_path)
8787

8888
r = Reader()

‎models/rank/dnn/benchmark_reader.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def reader():
8484
yaml_path = sys.argv[1]
8585
utils_path = sys.argv[2]
8686
sys.path.append(utils_path)
87-
import common
88-
yaml_helper = common.YamlHelper()
87+
import common_ps
88+
yaml_helper = common_ps.YamlHelper()
8989
config = yaml_helper.load_yaml(yaml_path)
9090

9191
r = Reader()

‎models/rank/dnn/queuedataset_reader.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def reader():
8686
yaml_path = sys.argv[1]
8787
utils_path = sys.argv[2]
8888
sys.path.append(utils_path)
89-
import common
90-
yaml_helper = common.YamlHelper()
89+
import common_ps
90+
yaml_helper = common_ps.YamlHelper()
9191
config = yaml_helper.load_yaml(yaml_path)
9292

9393
r = Reader()

‎models/rank/slot_dnn/inmemorydataset_reader.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ def reader():
101101
yaml_path = sys.argv[1]
102102
utils_path = sys.argv[2]
103103
sys.path.append(utils_path)
104-
import common
105-
yaml_helper = common.YamlHelper()
104+
import common_ps
105+
yaml_helper = common_ps.YamlHelper()
106106
config = yaml_helper.load_yaml(yaml_path)
107107

108108
r = Reader()

‎models/rank/slot_dnn/queuedataset_reader.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def reader():
9898
yaml_path = sys.argv[1]
9999
utils_path = sys.argv[2]
100100
sys.path.append(utils_path)
101-
import common
102-
yaml_helper = common.YamlHelper()
101+
import common_ps
102+
yaml_helper = common_ps.YamlHelper()
103103
config = yaml_helper.load_yaml(yaml_path)
104104

105105
r = Reader()

‎models/rank/wide_deep/benchmark_reader.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def reader():
8484
yaml_path = sys.argv[1]
8585
utils_path = sys.argv[2]
8686
sys.path.append(utils_path)
87-
import common
88-
yaml_helper = common.YamlHelper()
87+
import common_ps
88+
yaml_helper = common_ps.YamlHelper()
8989
config = yaml_helper.load_yaml(yaml_path)
9090

9191
r = Reader()

‎models/rank/wide_deep/queuedataset_reader.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def reader():
8686
yaml_path = sys.argv[1]
8787
utils_path = sys.argv[2]
8888
sys.path.append(utils_path)
89-
import common
90-
yaml_helper = common.YamlHelper()
89+
import common_ps
90+
yaml_helper = common_ps.YamlHelper()
9191
config = yaml_helper.load_yaml(yaml_path)
9292

9393
r = Reader()

‎models/recall/ncf/queuedataset_reader.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def reader():
5555
yaml_path = sys.argv[1]
5656
utils_path = sys.argv[2]
5757
sys.path.append(utils_path)
58-
import common
59-
yaml_helper = common.YamlHelper()
58+
import common_ps
59+
yaml_helper = common_ps.YamlHelper()
6060
config = yaml_helper.load_yaml(yaml_path)
6161

6262
r = Reader()

‎models/recall/word2vec/benchmark/benchmark_reader.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ def _replace_oov(self, original_vocab, line):
235235
yaml_path = sys.argv[1]
236236
utils_path = sys.argv[2]
237237
sys.path.append(utils_path)
238-
import common
239-
yaml_helper = common.YamlHelper()
238+
import common_ps
239+
yaml_helper = common_ps.YamlHelper()
240240
config = yaml_helper.load_yaml(yaml_path)
241241
abs_dir = os.path.dirname(os.path.abspath(yaml_path))
242242
config["config_abs_dir"] = abs_dir

‎models/recall/word2vec/utils/static_ps/infer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
__dir__ = os.path.dirname(os.path.abspath(__file__))
2828
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
29-
from common import YamlHelper, is_number
29+
from common_ps import YamlHelper, is_number
3030
from program_helper import get_model, get_strategy
3131
from reader_helper import get_reader, get_infer_reader, get_example_num, get_file_list, get_word_num
3232

‎models/recall/word2vec/utils/static_ps/program_helper.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import paddle
2020
import paddle.distributed.fleet.base.role_maker as role_maker
2121
import paddle.distributed.fleet as fleet
22-
import common
22+
import common_ps
2323
import sys
2424

2525
logging.basicConfig(
@@ -28,7 +28,7 @@
2828

2929

3030
def get_strategy(config):
31-
if not common.is_distributed_env():
31+
if not common_ps.is_distributed_env():
3232
logger.warn(
3333
"Not Find Distributed env, Change To local train mode. If you want train with fleet, please use [fleetrun] command."
3434
)

‎models/recall/word2vec/utils/static_ps/reader_helper.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import paddle
2020
import paddle.distributed.fleet.base.role_maker as role_maker
2121
import paddle.distributed.fleet as fleet
22-
import common
22+
import common_ps
2323

2424
logging.basicConfig(
2525
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
@@ -92,7 +92,7 @@ def get_word_num(file_list):
9292

9393

9494
def get_reader_generator(path, reader_name="Reader"):
95-
reader_class = common.lazy_instance_by_fliename(path, reader_name)()
95+
reader_class = common_ps.lazy_instance_by_fliename(path, reader_name)()
9696
return reader_class
9797

9898

@@ -113,7 +113,8 @@ def get_reader(self):
113113
logger.info("Reader Path: {}".format(reader_path))
114114

115115
from paddle.io import DataLoader
116-
dataset = common.lazy_instance_by_fliename(reader_path, "RecDataset")
116+
dataset = common_ps.lazy_instance_by_fliename(reader_path,
117+
"RecDataset")
117118
print("dataset: {}".format(dataset))
118119

119120
use_cuda = int(self.config.get("runner.use_gpu"))
@@ -194,7 +195,7 @@ def __init__(self, input_var, file_list, config):
194195
self.pipe_command = self.config.get("runner.pipe_command")
195196
self.train_reader = self.config.get("runner.train_reader_path")
196197
assert self.pipe_command != None
197-
utils_path = common.get_utils_file_path()
198+
utils_path = common_ps.get_utils_file_path()
198199
print("utils_path: {}".format(utils_path))
199200
abs_train_reader = os.path.join(config["config_abs_dir"],
200201
self.train_reader)

‎models/treebased/tdm/get_leaf_embedding.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ def get_emb_numpy(tree_node_num, node_emb_size, init_model_path=""):
5454
os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd()))))
5555
sys.path.append(utils_path)
5656
print(utils_path)
57-
import common
57+
import common_ps
5858

59-
yaml_helper = common.YamlHelper()
59+
yaml_helper = common_ps.YamlHelper()
6060
config = yaml_helper.load_yaml(sys.argv[1])
6161

6262
tree_name = config.get("hyper_parameters.tree_name")

‎models/treebased/tdm/infer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,8 @@ def infer(filelist, process_idx, init_model_path, id_code_map, code_id_map,
269269
os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd()))))
270270
sys.path.append(utils_path)
271271
print(utils_path)
272-
import common
273-
yaml_helper = common.YamlHelper()
272+
import common_ps
273+
yaml_helper = common_ps.YamlHelper()
274274
config = yaml_helper.load_yaml(sys.argv[1])
275275

276276
test_files_path = "../data/demo_test_data/"

‎models/treebased/tdm/reader.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ def reader():
9797
yaml_path = sys.argv[1]
9898
utils_path = sys.argv[2]
9999
sys.path.append(utils_path)
100-
import common
101-
yaml_helper = common.YamlHelper()
100+
import common_ps
101+
yaml_helper = common_ps.YamlHelper()
102102
config = yaml_helper.load_yaml(yaml_path)
103103

104104
r = MyDataset()

‎tools/static_gpubox_trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import print_function
1616
from utils.static_ps.reader_helper import get_reader, get_example_num, get_file_list, get_word_num
1717
from utils.static_ps.program_helper import get_model, get_strategy
18-
from utils.static_ps.common import YamlHelper, is_distributed_env
18+
from utils.static_ps.common_ps import YamlHelper, is_distributed_env
1919
from utils.utils_single import auc
2020
import argparse
2121
import time

‎tools/static_ps_infer.py

+316
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
from utils.static_ps.reader_helper import get_reader, get_example_num, get_file_list, get_word_num
17+
from utils.static_ps.program_helper import get_model, get_strategy, set_dump_config
18+
from utils.static_ps.metric_helper import set_zero, get_global_auc
19+
from utils.static_ps.common import YamlHelper, is_distributed_env
20+
import argparse
21+
import time
22+
import sys
23+
import paddle.distributed.fleet as fleet
24+
import paddle.distributed.fleet.base.role_maker as role_maker
25+
import paddle
26+
import os
27+
import warnings
28+
import logging
29+
import ast
30+
import numpy as np
31+
import struct
32+
from utils.utils_single import auc
33+
34+
__dir__ = os.path.dirname(os.path.abspath(__file__))
35+
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
36+
37+
logging.basicConfig(
38+
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
39+
logger = logging.getLogger(__name__)
40+
41+
42+
def parse_args():
43+
parser = argparse.ArgumentParser("PaddleRec train script")
44+
parser.add_argument("-o", "--opt", nargs='*', type=str)
45+
parser.add_argument(
46+
'-m',
47+
'--config_yaml',
48+
type=str,
49+
required=True,
50+
help='config file path')
51+
parser.add_argument(
52+
'-bf16',
53+
'--pure_bf16',
54+
type=ast.literal_eval,
55+
default=False,
56+
help="whether use bf16")
57+
args = parser.parse_args()
58+
args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
59+
yaml_helper = YamlHelper()
60+
config = yaml_helper.load_yaml(args.config_yaml)
61+
# modify config from command
62+
if args.opt:
63+
for parameter in args.opt:
64+
parameter = parameter.strip()
65+
key, value = parameter.split("=")
66+
if type(config.get(key)) is int:
67+
value = int(value)
68+
if type(config.get(key)) is float:
69+
value = float(value)
70+
if type(config.get(key)) is bool:
71+
value = (True if value.lower() == "true" else False)
72+
config[key] = value
73+
config["yaml_path"] = args.config_yaml
74+
config["config_abs_dir"] = args.abs_dir
75+
config["pure_bf16"] = args.pure_bf16
76+
yaml_helper.print_yaml(config)
77+
return config
78+
79+
80+
def bf16_to_fp32(val):
81+
return np.float32(struct.unpack('<f', struct.pack('<I', val << 16))[0])
82+
83+
84+
class Main(object):
85+
def __init__(self, config):
86+
self.metrics = {}
87+
self.config = config
88+
self.input_data = None
89+
self.reader = None
90+
self.exe = None
91+
self.train_result_dict = {}
92+
self.train_result_dict["speed"] = []
93+
self.train_result_dict["auc"] = []
94+
self.model = None
95+
self.pure_bf16 = self.config['pure_bf16']
96+
97+
def run(self):
98+
self.init_fleet_with_gloo()
99+
self.network()
100+
if fleet.is_server():
101+
self.run_server()
102+
elif fleet.is_worker():
103+
self.run_worker()
104+
fleet.stop_worker()
105+
self.record_result()
106+
logger.info("Run Success, Exit.")
107+
108+
def init_fleet_with_gloo(use_gloo=True):
109+
if use_gloo:
110+
os.environ["PADDLE_WITH_GLOO"] = "1"
111+
role = role_maker.PaddleCloudRoleMaker()
112+
fleet.init(role)
113+
else:
114+
fleet.init()
115+
116+
def network(self):
117+
self.model = get_model(self.config)
118+
self.input_data = self.model.create_feeds()
119+
self.inference_feed_var = self.model.create_feeds()
120+
self.init_reader()
121+
self.metrics = self.model.net(self.input_data)
122+
self.inference_target_var = self.model.inference_target_var
123+
logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
124+
self.model.create_optimizer(get_strategy(self.config))
125+
126+
def run_server(self):
127+
logger.info("Run Server Begin")
128+
fleet.init_server(config.get("runner.warmup_model_path"))
129+
fleet.run_server()
130+
131+
def run_worker(self):
132+
logger.info("Run Worker Begin")
133+
use_cuda = int(config.get("runner.use_gpu"))
134+
use_auc = config.get("runner.use_auc", False)
135+
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
136+
self.exe = paddle.static.Executor(place)
137+
138+
with open("./{}_worker_main_program.prototxt".format(
139+
fleet.worker_index()), 'w+') as f:
140+
f.write(str(paddle.static.default_main_program()))
141+
with open("./{}_worker_startup_program.prototxt".format(
142+
fleet.worker_index()), 'w+') as f:
143+
f.write(str(paddle.static.default_startup_program()))
144+
145+
self.exe.run(paddle.static.default_startup_program())
146+
if self.pure_bf16:
147+
self.model.optimizer.amp_init(self.exe.place)
148+
fleet.init_worker()
149+
150+
init_model_path = config.get("runner.infer_load_path")
151+
model_mode = config.get("runner.model_mode", 0)
152+
#if fleet.is_first_worker():
153+
#fleet.load_inference_model(init_model_path, mode=int(model_mode))
154+
#fleet.barrier_worker()
155+
156+
save_model_path = self.config.get("runner.model_save_path")
157+
if save_model_path and (not os.path.exists(save_model_path)):
158+
os.makedirs(save_model_path)
159+
160+
reader_type = self.config.get("runner.reader_type", "QueueDataset")
161+
epochs = int(self.config.get("runner.epochs"))
162+
sync_mode = self.config.get("runner.sync_mode")
163+
opt_info = paddle.static.default_main_program()._fleet_opt
164+
if use_auc is True:
165+
opt_info['stat_var_names'] = [
166+
self.model.stat_pos.name, self.model.stat_neg.name
167+
]
168+
else:
169+
opt_info['stat_var_names'] = []
170+
171+
if reader_type == "InmemoryDataset":
172+
self.reader.load_into_memory()
173+
174+
for epoch in range(epochs):
175+
fleet.load_inference_model(
176+
os.path.join(init_model_path, str(epoch)),
177+
mode=int(model_mode))
178+
epoch_start_time = time.time()
179+
180+
if sync_mode == "heter":
181+
self.heter_train_loop(epoch)
182+
elif reader_type == "QueueDataset":
183+
self.dataset_train_loop(epoch)
184+
elif reader_type == "InmemoryDataset":
185+
self.dataset_train_loop(epoch)
186+
187+
epoch_time = time.time() - epoch_start_time
188+
epoch_speed = self.example_nums / epoch_time
189+
if use_auc is True:
190+
global_auc = get_global_auc(paddle.static.global_scope(),
191+
self.model.stat_pos.name,
192+
self.model.stat_neg.name)
193+
self.train_result_dict["auc"].append(global_auc)
194+
set_zero(self.model.stat_pos.name,
195+
paddle.static.global_scope())
196+
set_zero(self.model.stat_neg.name,
197+
paddle.static.global_scope())
198+
set_zero(self.model.batch_stat_pos.name,
199+
paddle.static.global_scope())
200+
set_zero(self.model.batch_stat_neg.name,
201+
paddle.static.global_scope())
202+
logger.info(
203+
"Epoch: {}, using time: {} second, ips: {} {}/sec. auc: {}".
204+
format(epoch, epoch_time, epoch_speed, self.count_method,
205+
global_auc))
206+
else:
207+
logger.info(
208+
"Epoch: {}, using time {} second, ips {} {}/sec.".format(
209+
epoch, epoch_time, epoch_speed, self.count_method))
210+
211+
self.train_result_dict["speed"].append(epoch_speed)
212+
213+
model_dir = "{}/{}".format(save_model_path, epoch)
214+
215+
if reader_type == "InmemoryDataset":
216+
self.reader.release_memory()
217+
218+
def init_reader(self):
219+
if fleet.is_server():
220+
return
221+
self.config["runner.reader_type"] = self.config.get(
222+
"runner.reader_type", "QueueDataset")
223+
self.reader, self.file_list = get_reader(self.input_data, config)
224+
self.example_nums = 0
225+
self.count_method = self.config.get("runner.example_count_method",
226+
"example")
227+
if self.count_method == "example":
228+
self.example_nums = get_example_num(self.file_list)
229+
elif self.count_method == "word":
230+
self.example_nums = get_word_num(self.file_list)
231+
else:
232+
raise ValueError(
233+
"Set static_benchmark.example_count_method for example / word for example count."
234+
)
235+
236+
def dataset_train_loop(self, epoch):
237+
logger.info("Epoch: {}, Running Dataset Begin.".format(epoch))
238+
fetch_info = [
239+
"Epoch {} Var {}".format(epoch, var_name)
240+
for var_name in self.metrics
241+
]
242+
fetch_vars = [var for _, var in self.metrics.items()]
243+
print_step = int(config.get("runner.print_interval"))
244+
245+
debug = config.get("runner.dataset_debug", False)
246+
if config.get("runner.need_dump"):
247+
debug = True
248+
dump_fields_path = "{}/{}".format(
249+
config.get("runner.dump_fields_path"), epoch)
250+
set_dump_config(paddle.static.default_main_program(), {
251+
"dump_fields_path": dump_fields_path,
252+
"dump_fields": config.get("runner.dump_fields")
253+
})
254+
print(paddle.static.default_main_program()._fleet_opt)
255+
self.exe.infer_from_dataset(
256+
program=paddle.static.default_main_program(),
257+
dataset=self.reader,
258+
fetch_list=fetch_vars,
259+
fetch_info=fetch_info,
260+
print_period=print_step,
261+
debug=debug)
262+
263+
def heter_train_loop(self, epoch):
264+
logger.info(
265+
"Epoch: {}, Running Begin. Check running metrics at heter_log".
266+
format(epoch))
267+
reader_type = self.config.get("runner.reader_type")
268+
if reader_type == "QueueDataset":
269+
self.exe.infer_from_dataset(
270+
program=paddle.static.default_main_program(),
271+
dataset=self.reader,
272+
debug=config.get("runner.dataset_debug"))
273+
elif reader_type == "DataLoader":
274+
batch_id = 0
275+
train_run_cost = 0.0
276+
total_examples = 0
277+
self.reader.start()
278+
while True:
279+
try:
280+
train_start = time.time()
281+
# --------------------------------------------------- #
282+
self.exe.run(program=paddle.static.default_main_program())
283+
# --------------------------------------------------- #
284+
train_run_cost += time.time() - train_start
285+
total_examples += self.config.get("runner.batch_size")
286+
batch_id += 1
287+
print_step = int(config.get("runner.print_period"))
288+
if batch_id % print_step == 0:
289+
profiler_string = ""
290+
profiler_string += "avg_batch_cost: {} sec, ".format(
291+
format((train_run_cost) / print_step, '.5f'))
292+
profiler_string += "avg_samples: {}, ".format(
293+
format(total_examples / print_step, '.5f'))
294+
profiler_string += "ips: {} {}/sec ".format(
295+
format(total_examples / (train_run_cost), '.5f'),
296+
self.count_method)
297+
logger.info("Epoch: {}, Batch: {}, {}".format(
298+
epoch, batch_id, profiler_string))
299+
train_run_cost = 0.0
300+
total_examples = 0
301+
except paddle.core.EOFException:
302+
self.reader.reset()
303+
break
304+
305+
def record_result(self):
306+
logger.info("train_result_dict: {}".format(self.train_result_dict))
307+
with open("./train_result_dict.txt", 'w+') as f:
308+
f.write(str(self.train_result_dict))
309+
310+
311+
if __name__ == "__main__":
312+
paddle.enable_static()
313+
config = parse_args()
314+
os.environ["CPU_NUM"] = str(config.get("runner.thread_num"))
315+
benchmark_main = Main(config)
316+
benchmark_main.run()

‎tools/static_ps_trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from utils.static_ps.reader_helper import get_reader, get_example_num, get_file_list, get_word_num
1717
from utils.static_ps.program_helper import get_model, get_strategy, set_dump_config
1818
from utils.static_ps.metric_helper import set_zero, get_global_auc
19-
from utils.static_ps.common import YamlHelper, is_distributed_env
19+
from utils.static_ps.common_ps import YamlHelper, is_distributed_env
2020
import argparse
2121
import time
2222
import sys
File renamed without changes.

‎tools/utils/static_ps/flow_helper.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import paddle.distributed.fleet as fleet
3131
__dir__ = os.path.dirname(os.path.abspath(__file__))
3232
sys.path.append(__dir__)
33-
import common
33+
import common_ps
3434

3535
logging.basicConfig(
3636
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)

‎tools/utils/static_ps/infer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import print_function
1616
from reader_helper import get_reader, get_infer_reader, get_example_num, get_file_list, get_word_num
1717
from program_helper import get_model, get_strategy
18-
from common import YamlHelper, is_number
18+
from common_ps import YamlHelper, is_number
1919
import os
2020
import numpy as np
2121
import warnings

‎tools/utils/static_ps/metric_helper.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import paddle.distributed.fleet as fleet
2525
__dir__ = os.path.dirname(os.path.abspath(__file__))
2626
sys.path.append(__dir__)
27-
import common
27+
import common_ps
2828

2929
logging.basicConfig(
3030
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)

‎tools/utils/static_ps/program_helper.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import paddle
2020
import paddle.distributed.fleet.base.role_maker as role_maker
2121
import paddle.distributed.fleet as fleet
22-
import common
22+
from . import common_ps
2323
import sys
2424

2525
logging.basicConfig(
@@ -28,7 +28,7 @@
2828

2929

3030
def get_strategy(config):
31-
if not common.is_distributed_env():
31+
if not common_ps.is_distributed_env():
3232
logger.warn(
3333
"Not Find Distributed env, Change To local train mode. If you want train with fleet, please use [fleetrun] command."
3434
)

‎tools/utils/static_ps/reader_helper.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import paddle.distributed.fleet as fleet
2323
__dir__ = os.path.dirname(os.path.abspath(__file__))
2424
sys.path.append(__dir__)
25-
import common
25+
from . import common_ps
2626

2727
logging.basicConfig(
2828
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
@@ -106,7 +106,7 @@ def get_word_num(file_list):
106106

107107

108108
def get_reader_generator(path, reader_name="Reader"):
109-
reader_class = common.lazy_instance_by_fliename(path, reader_name)()
109+
reader_class = common_ps.lazy_instance_by_fliename(path, reader_name)()
110110
return reader_class
111111

112112

@@ -127,7 +127,8 @@ def get_reader(self):
127127
logger.info("Reader Path: {}".format(reader_path))
128128

129129
from paddle.io import DataLoader
130-
dataset = common.lazy_instance_by_fliename(reader_path, "RecDataset")
130+
dataset = common_ps.lazy_instance_by_fliename(reader_path,
131+
"RecDataset")
131132
print("dataset: {}".format(dataset))
132133

133134
use_cuda = int(self.config.get("runner.use_gpu"))
@@ -219,7 +220,7 @@ def __init__(self, input_var, file_list, config):
219220
self.pipe_command = self.config.get("runner.pipe_command")
220221
self.train_reader = self.config.get("runner.train_reader_path")
221222
assert self.pipe_command != None
222-
utils_path = common.get_utils_file_path()
223+
utils_path = common_ps.get_utils_file_path()
223224
print("utils_path: {}".format(utils_path))
224225
abs_train_reader = os.path.join(config["config_abs_dir"],
225226
self.train_reader)
@@ -273,7 +274,7 @@ def __init__(self, input_var, file_list, config):
273274
self.pipe_command = self.config.get("runner.pipe_command")
274275
self.train_reader = self.config.get("runner.train_reader_path")
275276
assert self.pipe_command != None
276-
utils_path = common.get_utils_file_path()
277+
utils_path = common_ps.get_utils_file_path()
277278
print("utils_path: {}".format(utils_path))
278279
abs_train_reader = os.path.join(config["config_abs_dir"],
279280
self.train_reader)

‎tools/utils/static_ps/time_helper.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import paddle.distributed.fleet as fleet
2525
__dir__ = os.path.dirname(os.path.abspath(__file__))
2626
sys.path.append(__dir__)
27-
import common
27+
import common_ps
2828

2929
logging.basicConfig(
3030
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)

0 commit comments

Comments
 (0)
Please sign in to comment.