Skip to content

Commit b0aa443

Browse files
authored
[BugFix] Fix ErnieSage model for configuration (PaddlePaddle#4425)
1 parent 31b5e19 commit b0aa443

File tree

2 files changed

+17
-20
lines changed

2 files changed

+17
-20
lines changed

examples/text_graph/erniesage/link_prediction.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
15+
import argparse
1616
import io
17+
import os
1718
import random
1819
import time
19-
import argparse
2020
from functools import partial
2121

2222
import numpy as np
23-
import yaml
2423
import paddle
2524
import pgl
25+
import yaml
26+
from data import GraphDataLoader, PredictData, TrainData, batch_fn
2627
from easydict import EasyDict as edict
27-
from paddlenlp.transformers import ErnieTokenizer, ErnieTinyTokenizer
28-
from paddlenlp.utils.log import logger
29-
3028
from models import ErnieSageForLinkPrediction
31-
from data import TrainData, PredictData, GraphDataLoader, batch_fn
29+
30+
from paddlenlp.transformers import ErnieTinyTokenizer, ErnieTokenizer
31+
from paddlenlp.utils.log import logger
3232

3333
MODEL_CLASSES = {
3434
"ernie-tiny": (ErnieSageForLinkPrediction, ErnieTinyTokenizer),
@@ -57,14 +57,14 @@ def do_train(config):
5757
base_graph, term_ids = load_data(config.graph_work_path)
5858
collate_fn = partial(batch_fn, samples=config.samples, base_graph=base_graph, term_ids=term_ids)
5959

60-
mode = "train"
60+
# mode = "train"
6161
train_ds = TrainData(config.graph_work_path)
6262

6363
model_class, tokenizer_class = MODEL_CLASSES[config.model_name_or_path]
6464
tokenizer = tokenizer_class.from_pretrained(config.model_name_or_path)
6565
config.cls_token_id = tokenizer.cls_token_id
6666

67-
model = model_class.from_pretrained(config.model_name_or_path, config=config)
67+
model = model_class.from_pretrained(config.model_name_or_path, config_file=config)
6868
model = paddle.DataParallel(model)
6969

7070
train_loader = GraphDataLoader(
@@ -113,7 +113,7 @@ def do_predict(config):
113113
paddle.distributed.init_parallel_env()
114114
set_seed(config)
115115

116-
mode = "predict"
116+
# mode = "predict"
117117
num_nodes = int(np.load(os.path.join(config.graph_work_path, "num_nodes.npy")))
118118

119119
base_graph, term_ids = load_data(config.graph_work_path)
@@ -123,7 +123,7 @@ def do_predict(config):
123123
tokenizer = tokenizer_class.from_pretrained(config.model_name_or_path)
124124
config.cls_token_id = tokenizer.cls_token_id
125125

126-
model = model_class.from_pretrained(config.infer_model, config=config)
126+
model = model_class.from_pretrained(config.infer_model, config_file=config)
127127

128128
model = paddle.DataParallel(model)
129129
predict_ds = PredictData(num_nodes)

examples/text_graph/erniesage/models/model.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import pgl
1615
import paddle
17-
import paddle.nn as nn
18-
import numpy as np
19-
from paddlenlp.transformers import ErniePretrainedModel
20-
2116
from models.encoder import Encoder
2217
from models.loss import LossFactory
2318

19+
from paddlenlp.transformers import ErnieModel, ErniePretrainedModel
20+
2421
__all__ = ["ErnieSageForLinkPrediction"]
2522

2623

2724
class ErnieSageForLinkPrediction(ErniePretrainedModel):
2825
"""ErnieSage for link prediction task."""
2926

30-
def __init__(self, ernie, config):
27+
def __init__(self, config, config_file):
3128
"""Model which Based on the PaddleNLP PretrainedModel
3229
3330
Note:
@@ -39,9 +36,9 @@ def __init__(self, ernie, config):
3936
ernie (nn.Layer): the submodule layer of ernie model.
4037
config (Dict): the config file
4138
"""
42-
super(ErnieSageForLinkPrediction, self).__init__()
43-
self.config_file = config
44-
self.ernie = ernie
39+
super(ErnieSageForLinkPrediction, self).__init__(config)
40+
self.config_file = config_file
41+
self.ernie = ErnieModel(config)
4542
self.encoder = Encoder.factory(self.config_file, self.ernie)
4643
self.loss_func = LossFactory(self.config_file)
4744

0 commit comments

Comments
 (0)