12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- import os
15
+ import argparse
16
16
import io
17
+ import os
17
18
import random
18
19
import time
19
- import argparse
20
20
from functools import partial
21
21
22
22
import numpy as np
23
- import yaml
24
23
import paddle
25
24
import pgl
25
+ import yaml
26
+ from data import GraphDataLoader , PredictData , TrainData , batch_fn
26
27
from easydict import EasyDict as edict
27
- from paddlenlp .transformers import ErnieTokenizer , ErnieTinyTokenizer
28
- from paddlenlp .utils .log import logger
29
-
30
28
from models import ErnieSageForLinkPrediction
31
- from data import TrainData , PredictData , GraphDataLoader , batch_fn
29
+
30
+ from paddlenlp .transformers import ErnieTinyTokenizer , ErnieTokenizer
31
+ from paddlenlp .utils .log import logger
32
32
33
33
MODEL_CLASSES = {
34
34
"ernie-tiny" : (ErnieSageForLinkPrediction , ErnieTinyTokenizer ),
@@ -57,14 +57,14 @@ def do_train(config):
57
57
base_graph , term_ids = load_data (config .graph_work_path )
58
58
collate_fn = partial (batch_fn , samples = config .samples , base_graph = base_graph , term_ids = term_ids )
59
59
60
- mode = "train"
60
+ # mode = "train"
61
61
train_ds = TrainData (config .graph_work_path )
62
62
63
63
model_class , tokenizer_class = MODEL_CLASSES [config .model_name_or_path ]
64
64
tokenizer = tokenizer_class .from_pretrained (config .model_name_or_path )
65
65
config .cls_token_id = tokenizer .cls_token_id
66
66
67
- model = model_class .from_pretrained (config .model_name_or_path , config = config )
67
+ model = model_class .from_pretrained (config .model_name_or_path , config_file = config )
68
68
model = paddle .DataParallel (model )
69
69
70
70
train_loader = GraphDataLoader (
@@ -113,7 +113,7 @@ def do_predict(config):
113
113
paddle .distributed .init_parallel_env ()
114
114
set_seed (config )
115
115
116
- mode = "predict"
116
+ # mode = "predict"
117
117
num_nodes = int (np .load (os .path .join (config .graph_work_path , "num_nodes.npy" )))
118
118
119
119
base_graph , term_ids = load_data (config .graph_work_path )
@@ -123,7 +123,7 @@ def do_predict(config):
123
123
tokenizer = tokenizer_class .from_pretrained (config .model_name_or_path )
124
124
config .cls_token_id = tokenizer .cls_token_id
125
125
126
- model = model_class .from_pretrained (config .infer_model , config = config )
126
+ model = model_class .from_pretrained (config .infer_model , config_file = config )
127
127
128
128
model = paddle .DataParallel (model )
129
129
predict_ds = PredictData (num_nodes )
0 commit comments