Skip to content

Commit a85752d

Browse files
MooTong123Mu Tongw5688414w5688414
authored
enhancement: add simpleServing code for sentence transformers (PaddlePaddle#4795)
* enhancement: add simpleServing code for sentence transformers * Update export_model.py * enhancement: update simpleServing for sentence transformers * pre-commit * checkout code format --------- Co-authored-by: Mu Tong <[email protected]> Co-authored-by: w5688414 <[email protected]> Co-authored-by: w5688414 <[email protected]>
1 parent 3ae3ea5 commit a85752d

File tree

4 files changed

+319
-0
lines changed

4 files changed

+319
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# 基于PaddleNLP SimpleServing 的服务化部署
2+
3+
## 目录
4+
- [环境准备](#环境准备)
5+
- [Server启动服务](#Server服务启动)
6+
- [其他参数设置](#其他参数设置)
7+
8+
## 环境准备
9+
使用有SimpleServing功能的PaddleNLP版本
10+
```shell
11+
pip install paddlenlp >= 2.4.4
12+
```
13+
## Server服务启动
14+
### 分类任务启动
15+
#### 启动分类 Server 服务
16+
```bash
17+
paddlenlp server server:app --host 0.0.0.0 --port 8189
18+
```
19+
20+
#### 启动分类 Client 服务
21+
```bash
22+
python client.py
23+
```
24+
25+
## 其他参数设置
26+
可以在client端设置 `max_seq_len`, `batch_size`, `prob_limit` 参数
27+
```python
28+
data = {
29+
'data': {
30+
'text': texts,
31+
'text_pair': text_pairs,
32+
},
33+
'parameters': {
34+
'max_seq_len': args.max_seq_len,
35+
'batch_size': args.batch_size,
36+
'prob_limit': args.prob_limit
37+
}
38+
}
39+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import json
17+
18+
import requests
19+
20+
parser = argparse.ArgumentParser()
21+
parser.add_argument(
22+
"--max_seq_len", default=128, type=int, help="The maximum total input sequence length after tokenization."
23+
)
24+
parser.add_argument("--batch_size", default=1, type=int, help="Batch size per GPU/CPU for predicting.")
25+
parser.add_argument("--prob_limit", default=0.5, type=int, help="probability limit.")
26+
args = parser.parse_args()
27+
28+
url = "http://0.0.0.0:8189/models/text_matching"
29+
headers = {"Content-Type": "application/json"}
30+
31+
if __name__ == "__main__":
32+
texts = ["三亚是一个美丽的城市", "北京烤鸭怎么样"]
33+
text_pair = ["三亚是个漂亮的城市", "北京烤鸭多少钱"]
34+
35+
data = {
36+
"data": {
37+
"text": texts,
38+
"text_pair": text_pair,
39+
},
40+
"parameters": {"max_seq_len": args.max_seq_len, "batch_size": args.batch_size, "prob_limit": args.prob_limit},
41+
}
42+
r = requests.post(url=url, headers=headers, data=json.dumps(data))
43+
result_json = json.loads(r.text)
44+
print(result_json)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
from scipy.special import softmax
17+
18+
from paddlenlp import SimpleServer
19+
from paddlenlp.data import Pad, Tuple
20+
from paddlenlp.server import BaseModelHandler, BasePostHandler
21+
22+
23+
class TextMatchingModelHandler(BaseModelHandler):
24+
def __init__(self):
25+
super().__init__()
26+
27+
@classmethod
28+
def process(cls, predictor, tokenizer, data, parameters):
29+
30+
max_seq_len = 128
31+
batch_size = 1
32+
if "max_seq_len" not in parameters:
33+
max_seq_len = parameters["max_seq_len"]
34+
if "batch_size" not in parameters:
35+
batch_size = parameters["batch_size"]
36+
text = None
37+
if "text" in data:
38+
text = data["text"]
39+
if text is None:
40+
return {}
41+
if isinstance(text, str):
42+
text = [text]
43+
has_pair = False
44+
if "text_pair" in data and data["text_pair"] is not None:
45+
text_pair = data["text_pair"]
46+
if isinstance(text_pair, str):
47+
text_pair = [text_pair]
48+
if len(text) != len(text_pair):
49+
raise ValueError("The length of text and text_pair must be same.")
50+
has_pair = True
51+
52+
# Get the result of tokenizer
53+
examples = []
54+
for idx, _ in enumerate(text):
55+
if has_pair:
56+
text_a = tokenizer(text=text[idx], max_length=max_seq_len)
57+
text_b = tokenizer(text=text_pair[idx], max_length=max_seq_len)
58+
59+
examples.append((text_a["input_ids"], text_b["input_ids"]))
60+
61+
# Seperates data into some batches.
62+
batches = [examples[i : i + batch_size] for i in range(0, len(examples), batch_size)]
63+
64+
def batchify_fn(samples):
65+
return Tuple(
66+
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"),
67+
Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"),
68+
)(samples)
69+
70+
results = [[]] * predictor._output_num
71+
for batch in batches:
72+
query_input_ids, title_input_ids = batchify_fn(batch)
73+
if predictor._predictor_type == "paddle_inference":
74+
predictor._input_handles[0].copy_from_cpu(query_input_ids)
75+
predictor._input_handles[1].copy_from_cpu(title_input_ids)
76+
predictor._predictor.run()
77+
output = [output_handle.copy_to_cpu() for output_handle in predictor._output_handles]
78+
for i, out in enumerate(output):
79+
results[i].append(out)
80+
print(results)
81+
82+
# Resolve the logits result and get the predict label and confidence
83+
results_concat = []
84+
for i in range(0, len(results)):
85+
results_concat.append(np.concatenate(results[i], axis=0))
86+
87+
out_dict = {"logits": results_concat[0].tolist(), "data": data}
88+
89+
return out_dict
90+
91+
92+
class TextMatchingPostHandler(BasePostHandler):
93+
def __init__(self):
94+
super().__init__()
95+
96+
@classmethod
97+
def process(cls, data, parameters):
98+
if "logits" not in data:
99+
raise ValueError(
100+
"The output of model handler do not include the 'logits', "
101+
" please check the model handler output. The model handler output:\n{}".format(data)
102+
)
103+
104+
prob_limit = 0.5
105+
if "prob_limit" in parameters:
106+
prob_limit = parameters["prob_limit"]
107+
logits = data["logits"]
108+
# softmax for probs
109+
logits = softmax(logits, axis=-1)
110+
111+
print(logits)
112+
113+
labels = []
114+
probs = []
115+
for logit in logits:
116+
if logit[1] > prob_limit:
117+
labels.append(1)
118+
else:
119+
labels.append(0)
120+
probs.append(logit[1])
121+
122+
out_dict = {"label": labels, "similarity": probs}
123+
return out_dict
124+
125+
126+
app = SimpleServer()
127+
app.register(
128+
task_name="models/text_matching",
129+
model_path="../../export_model",
130+
tokenizer_name="ernie-3.0-medium-zh",
131+
model_handler=TextMatchingModelHandler,
132+
post_handler=TextMatchingPostHandler,
133+
precision="fp32",
134+
device_id=0,
135+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import os
17+
18+
import paddle
19+
import paddle.nn as nn
20+
21+
from paddlenlp.transformers import AutoModel, AutoTokenizer
22+
23+
parser = argparse.ArgumentParser()
24+
parser.add_argument("--params_path", type=str, default="ernie-1.0", help="The path to model parameters to be loaded.")
25+
parser.add_argument(
26+
"--output_path", type=str, default="./export", help="The path of model parameter in static graph to be saved."
27+
)
28+
args = parser.parse_args()
29+
30+
31+
class SentenceTransformer(nn.Layer):
32+
def __init__(self, pretrained_model, dropout=None):
33+
super().__init__()
34+
self.ptm = pretrained_model
35+
self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
36+
# num_labels = 2 (similar or dissimilar)
37+
self.classifier = nn.Linear(self.ptm.config["hidden_size"] * 3, 2)
38+
39+
def forward(
40+
self,
41+
query_input_ids,
42+
title_input_ids,
43+
query_token_type_ids=None,
44+
query_position_ids=None,
45+
query_attention_mask=None,
46+
title_token_type_ids=None,
47+
title_position_ids=None,
48+
title_attention_mask=None,
49+
):
50+
query_token_embedding, _ = self.ptm(
51+
query_input_ids, query_token_type_ids, query_position_ids, query_attention_mask
52+
)
53+
query_token_embedding = self.dropout(query_token_embedding)
54+
query_attention_mask = paddle.unsqueeze(
55+
(query_input_ids != self.ptm.pad_token_id).astype(self.ptm.pooler.dense.weight.dtype), axis=2
56+
)
57+
# Set token embeddings to 0 for padding tokens
58+
query_token_embedding = query_token_embedding * query_attention_mask
59+
query_sum_embedding = paddle.sum(query_token_embedding, axis=1)
60+
query_sum_mask = paddle.sum(query_attention_mask, axis=1)
61+
query_mean = query_sum_embedding / query_sum_mask
62+
63+
title_token_embedding, _ = self.ptm(
64+
title_input_ids, title_token_type_ids, title_position_ids, title_attention_mask
65+
)
66+
title_token_embedding = self.dropout(title_token_embedding)
67+
title_attention_mask = paddle.unsqueeze(
68+
(title_input_ids != self.ptm.pad_token_id).astype(self.ptm.pooler.dense.weight.dtype), axis=2
69+
)
70+
# Set token embeddings to 0 for padding tokens
71+
title_token_embedding = title_token_embedding * title_attention_mask
72+
title_sum_embedding = paddle.sum(title_token_embedding, axis=1)
73+
title_sum_mask = paddle.sum(title_attention_mask, axis=1)
74+
title_mean = title_sum_embedding / title_sum_mask
75+
76+
sub = paddle.abs(paddle.subtract(query_mean, title_mean))
77+
projection = paddle.concat([query_mean, title_mean, sub], axis=-1)
78+
79+
logits = self.classifier(projection)
80+
81+
return logits
82+
83+
84+
if __name__ == "__main__":
85+
86+
tokenizer = AutoTokenizer.from_pretrained(args.params_path)
87+
pretrained_model = AutoModel.from_pretrained(args.params_path)
88+
89+
model = SentenceTransformer(pretrained_model)
90+
model.eval()
91+
92+
input_spec = [
93+
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="query_input_ids"),
94+
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="title_input_ids"),
95+
]
96+
# Convert to static graph with specific input description
97+
model = paddle.jit.to_static(model, input_spec=input_spec)
98+
99+
# Save in static graph model.
100+
save_path = os.path.join(args.output_path, "float32")
101+
paddle.jit.save(model, save_path)

0 commit comments

Comments
 (0)