Skip to content

Commit 35be940

Browse files
authored
Add semantic search taskflow API for pipelines (PaddlePaddle#4792)
* Add semantic search taskflow API for pipelines * code refactoring and Add title indexing * Add text_feature_extraction.py * Add rocketqav2 and ernie search support * Add rocketqa 3.0 models * Try to fix the windows dtype bug * Add rocketv2 and ernie-search support * Add ernie-search comments * Add unitest for text encoder and text similarity * Add text feature extraction unittest * Fix protobuf to 3.20 * Update gpu configs * Add feature extraction usage examples * Update test_multimodal_feature_extraction.py * Update README docs
1 parent be80a3e commit 35be940

17 files changed

+873
-129
lines changed

docs/model_zoo/taskflow.md

+59-3
Original file line numberDiff line numberDiff line change
@@ -1249,13 +1249,14 @@ from paddlenlp import Taskflow
12491249

12501250
| 模型 | 结构 | 语言 |
12511251
| :---: | :--------: | :--------: |
1252-
| `rocketqa-zh-dureader-cross-encoder` (默认) | 12-layers, 768-hidden, 12-heads | 中文 |
1253-
| `simbert-base-chinese` | 12-layers, 768-hidden, 12-heads | 中文 |
1252+
| `rocketqa-zh-dureader-cross-encoder` | 12-layers, 768-hidden, 12-heads | 中文 |
1253+
| `simbert-base-chinese` (默认) | 12-layers, 768-hidden, 12-heads | 中文 |
12541254
| `rocketqa-base-cross-encoder` | 12-layers, 768-hidden, 12-heads | 中文 |
12551255
| `rocketqa-medium-cross-encoder` | 6-layers, 768-hidden, 12-heads | 中文 |
12561256
| `rocketqa-mini-cross-encoder` | 6-layers, 384-hidden, 12-heads | 中文 |
12571257
| `rocketqa-micro-cross-encoder` | 4-layers, 384-hidden, 12-heads | 中文 |
12581258
| `rocketqa-nano-cross-encoder` | 4-layers, 312-hidden, 12-heads | 中文 |
1259+
| `rocketqav2-en-marco-cross-encoder` | 12-layers, 768-hidden, 12-heads | 英文 |
12591260

12601261

12611262
#### 可配置参数说明
@@ -1785,7 +1786,7 @@ from paddlenlp import Taskflow
17851786

17861787
<details><summary>&emsp; 基于百度自研中文图文跨模态预训练模型ERNIE-ViL 2.0</summary><div>
17871788

1788-
#### 支持单条、批量预测
1789+
#### 多模态特征提取
17891790

17901791
```python
17911792
>>> from paddlenlp import Taskflow
@@ -1846,6 +1847,61 @@ Tensor(shape=[1, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
18461847
* `_static_mode`:静态图模式,默认开启。
18471848
* `model`:选择任务使用的模型,默认为`PaddlePaddle/ernie_vil-2.0-base-zh`
18481849

1850+
#### 文本特征提取
1851+
1852+
```python
1853+
>>> from paddlenlp import Taskflow
1854+
>>> import paddle.nn.functional as F
1855+
>>> text_encoder = Taskflow("feature_extraction", model='rocketqa-zh-base-query-encoder')
1856+
>>> text_embeds = text_encoder(['春天适合种什么花?','谁有狂三这张高清的?'])
1857+
>>> text_features1 = text_embeds["features"]
1858+
>>> text_features1
1859+
Tensor(shape=[2, 768], dtype=float32, place=Place(gpu:0), stop_gradient=True,
1860+
[[ 0.27640465, -0.13405125, 0.00612330, ..., -0.15600294,
1861+
-0.18932408, -0.03029604],
1862+
[-0.12041329, -0.07424965, 0.07895312, ..., -0.17068857,
1863+
0.04485796, -0.18887770]])
1864+
>>> text_embeds = text_encoder('春天适合种什么菜?')
1865+
>>> text_features2 = text_embeds["features"]
1866+
>>> text_features2
1867+
Tensor(shape=[1, 768], dtype=float32, place=Place(gpu:0), stop_gradient=True,
1868+
[[ 0.32578075, -0.02398480, -0.18929179, -0.18639392, -0.04062131,
1869+
......
1870+
>>> probs = F.cosine_similarity(text_features1, text_features2)
1871+
>>> probs
1872+
Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
1873+
[0.86455142, 0.41222256])
1874+
```
1875+
1876+
#### 模型选择
1877+
1878+
- 多模型选择,满足精度、速度要求
1879+
1880+
| 模型 | 层数| 维度 | 语言|
1881+
| :---: | :--------: | :--------: | :--------: |
1882+
| `rocketqa-zh-dureader-query-encoder` | 12 | 768 | 中文|
1883+
| `rocketqa-zh-dureader-para-encoder` | 12 | 768 | 中文|
1884+
| `rocketqa-zh-base-query-encoder` | 12 | 768 | 中文|
1885+
| `rocketqa-zh-base-para-encoder` | 12 | 768 | 中文|
1886+
| `rocketqa-zh-medium-query-encoder` | 6 | 768 | 中文|
1887+
| `rocketqa-zh-medium-para-encoder` | 6 | 768 | 中文|
1888+
| `rocketqa-zh-mini-query-encoder` | 6 | 384 | 中文|
1889+
| `rocketqa-zh-mini-para-encoder` | 6 | 384 | 中文|
1890+
| `rocketqa-zh-micro-query-encoder` | 4 | 384 | 中文|
1891+
| `rocketqa-zh-micro-para-encoder` | 4 | 384 | 中文|
1892+
| `rocketqa-zh-nano-query-encoder` | 4 | 312 | 中文|
1893+
| `rocketqa-zh-nano-para-encoder` | 4 | 312 | 中文|
1894+
| `rocketqav2-en-marco-query-encoder` | 12 | 768 | 英文|
1895+
| `rocketqav2-en-marco-para-encoder` | 12 | 768 | 英文|
1896+
| `ernie-search-base-dual-encoder-marco-en"` | 12 | 768 | 英文|
1897+
1898+
#### 可配置参数说明
1899+
* `batch_size`:批处理大小,请结合机器情况进行调整,默认为1
1900+
* `max_seq_len`:文本序列的最大长度,默认为128
1901+
* `return_tensors`: 返回的类型,有pd和np,默认为pd。
1902+
* `model`:选择任务使用的模型,默认为`PaddlePaddle/ernie_vil-2.0-base-zh`
1903+
1904+
18491905
</div></details>
18501906

18511907
## PART Ⅱ &emsp; 定制化训练

paddlenlp/taskflow/feature_extraction.py paddlenlp/taskflow/multimodal_feature_extraction.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
import os
1516

1617
import numpy as np
@@ -25,9 +26,8 @@
2526
usage = r"""
2627
from paddlenlp import Taskflow
2728
from PIL import Image
28-
29-
# multi modal feature_extraction with ernie_vil-2.0-base-zh
30-
vision_language = Taskflow("feature_extraction")
29+
# Multi modal feature_extraction with ernie_vil-2.0-base-zh
30+
vision_language = Taskflow("feature_extraction", model='PaddlePaddle/ernie_vil-2.0-base-zh')
3131
image_embeds = vision_language([Image.open("demo/000000039769.jpg")])
3232
print(image_embeds)
3333
'''
@@ -211,7 +211,6 @@ class MultimodalFeatureExtractionTask(Task):
211211
def __init__(self, task, model, batch_size=1, is_static_model=True, max_length=128, return_tensors="pd", **kwargs):
212212
super().__init__(task=task, model=model, **kwargs)
213213
self._seed = None
214-
# we do not use batch
215214
self.export_type = "text"
216215
self._batch_size = batch_size
217216
self.return_tensors = return_tensors

paddlenlp/taskflow/taskflow.py

+85-1
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323
from .dependency_parsing import DDParserTask
2424
from .dialogue import DialogueTask
2525
from .document_intelligence import DocPromptTask
26-
from .feature_extraction import MultimodalFeatureExtractionTask
2726
from .fill_mask import FillMaskTask
2827
from .information_extraction import GPTask, UIETask
2928
from .knowledge_mining import NPTagTask, WordTagTask
3029
from .lexical_analysis import LacTask
30+
from .multimodal_feature_extraction import MultimodalFeatureExtractionTask
3131
from .named_entity_recognition import NERLACTask, NERWordTagTask
3232
from .poetry_generation import PoetryGenerationTask
3333
from .pos_tagging import POSTaggingTask
@@ -36,6 +36,7 @@
3636
from .sentiment_analysis import SentaTask, SkepTask, UIESentaTask
3737
from .text_classification import TextClassificationTask
3838
from .text_correction import CSCTask
39+
from .text_feature_extraction import TextFeatureExtractionTask
3940
from .text_similarity import TextSimilarityTask
4041
from .text_summarization import TextSummarizationTask
4142
from .text_to_image import (
@@ -243,6 +244,14 @@
243244
"task_class": TextSimilarityTask,
244245
"task_flag": "text_similarity-rocketqa-nano-cross-encoder",
245246
},
247+
"rocketqav2-en-marco-cross-encoder": {
248+
"task_class": TextSimilarityTask,
249+
"task_flag": "text_similarity-rocketqav2-en-marco-cross-encoder",
250+
},
251+
"ernie-search-large-cross-encoder-marco-en": {
252+
"task_class": TextSimilarityTask,
253+
"task_flag": "text_similarity-ernie-search-large-cross-encoder-marco-en",
254+
},
246255
"__internal_testing__/tiny-random-bert": {
247256
"task_class": TextSimilarityTask,
248257
"task_flag": "text_similarity-tiny-random-bert",
@@ -563,6 +572,81 @@
563572
},
564573
"feature_extraction": {
565574
"models": {
575+
"rocketqa-zh-dureader-query-encoder": {
576+
"task_class": TextFeatureExtractionTask,
577+
"task_flag": "feature_extraction-rocketqa-zh-dureader-query-encoder",
578+
"task_priority_path": "rocketqa-zh-dureader-query-encoder",
579+
},
580+
"rocketqa-zh-dureader-para-encoder": {
581+
"task_class": TextFeatureExtractionTask,
582+
"task_flag": "feature_extraction-rocketqa-zh-dureader-para-encoder",
583+
"task_priority_path": "rocketqa-rocketqa-zh-dureader-para-encoder",
584+
},
585+
"rocketqa-zh-base-query-encoder": {
586+
"task_class": TextFeatureExtractionTask,
587+
"task_flag": "feature_extraction-rocketqa-zh-base-query-encoder",
588+
"task_priority_path": "rocketqa-zh-base-query-encoder",
589+
},
590+
"rocketqa-zh-base-para-encoder": {
591+
"task_class": TextFeatureExtractionTask,
592+
"task_flag": "feature_extraction-rocketqa-zh-base-para-encoder",
593+
"task_priority_path": "rocketqa-zh-base-para-encoder",
594+
},
595+
"rocketqa-zh-medium-query-encoder": {
596+
"task_class": TextFeatureExtractionTask,
597+
"task_flag": "feature_extraction-rocketqa-zh-medium-query-encoder",
598+
"task_priority_path": "rocketqa-zh-medium-query-encoder",
599+
},
600+
"rocketqa-zh-medium-para-encoder": {
601+
"task_class": TextFeatureExtractionTask,
602+
"task_flag": "feature_extraction-rocketqa-zh-medium-para-encoder",
603+
"task_priority_path": "rocketqa-zh-medium-para-encoder",
604+
},
605+
"rocketqa-zh-mini-query-encoder": {
606+
"task_class": TextFeatureExtractionTask,
607+
"task_flag": "feature_extraction-rocketqa-zh-mini-query-encoder",
608+
"task_priority_path": "rocketqa-zh-mini-query-encoder",
609+
},
610+
"rocketqa-zh-mini-para-encoder": {
611+
"task_class": TextFeatureExtractionTask,
612+
"task_flag": "feature_extraction-rocketqa-rocketqa-zh-mini-para-encoder",
613+
"task_priority_path": "rocketqa-zh-mini-para-encoder",
614+
},
615+
"rocketqa-zh-micro-query-encoder": {
616+
"task_class": TextFeatureExtractionTask,
617+
"task_flag": "feature_extraction-rocketqa-zh-micro-query-encoder",
618+
"task_priority_path": "rocketqa-zh-micro-query-encoder",
619+
},
620+
"rocketqa-zh-micro-para-encoder": {
621+
"task_class": TextFeatureExtractionTask,
622+
"task_flag": "feature_extraction-rocketqa-zh-micro-para-encoder",
623+
"task_priority_path": "rocketqa-zh-micro-para-encoder",
624+
},
625+
"rocketqa-zh-nano-query-encoder": {
626+
"task_class": TextFeatureExtractionTask,
627+
"task_flag": "feature_extraction-rocketqa-zh-nano-query-encoder",
628+
"task_priority_path": "rocketqa-zh-nano-query-encoder",
629+
},
630+
"rocketqa-zh-nano-para-encoder": {
631+
"task_class": TextFeatureExtractionTask,
632+
"task_flag": "feature_extraction-rocketqa-zh-nano-para-encoder",
633+
"task_priority_path": "rocketqa-zh-nano-para-encoder",
634+
},
635+
"rocketqav2-en-marco-query-encoder": {
636+
"task_class": TextFeatureExtractionTask,
637+
"task_flag": "feature_extraction-rocketqav2-en-marco-query-encoder",
638+
"task_priority_path": "rocketqav2-en-marco-query-encoder",
639+
},
640+
"rocketqav2-en-marco-para-encoder": {
641+
"task_class": TextFeatureExtractionTask,
642+
"task_flag": "feature_extraction-rocketqav2-en-marco-para-encoder",
643+
"task_priority_path": "rocketqav2-en-marco-para-encoder",
644+
},
645+
"ernie-search-base-dual-encoder-marco-en": {
646+
"task_class": TextFeatureExtractionTask,
647+
"task_flag": "feature_extraction-ernie-search-base-dual-encoder-marco-en",
648+
"task_priority_path": "ernie-search-base-dual-encoder-marco-en",
649+
},
566650
"PaddlePaddle/ernie_vil-2.0-base-zh": {
567651
"task_class": MultimodalFeatureExtractionTask,
568652
"task_flag": "feature_extraction-PaddlePaddle/ernie_vil-2.0-base-zh",

0 commit comments

Comments
 (0)