-
Notifications
You must be signed in to change notification settings - Fork 38
/
Copy pathbase_operator.py
42 lines (34 loc) · 1.65 KB
/
base_operator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import spacy
import subprocess
import sys
def download_model(model_name):
try:
spacy.load(model_name)
except OSError:
print(f"Model '{model_name}' not found. Downloading...")
subprocess.run(["python", "-m", "spacy", "download", model_name])
class BaseOperator:
"""Base text cleaner class that provides common functionality for specific text cleaning tasks."""
def __init__(self, text_key='content', language='mix'):
"""
Args:
text_key (str, optional): The key in the sample dictionary containing the text data to be processed. Default is 'content'.
language (str, optional): The language of the text to be processed, default is 'mix'. Some operators require this parameter. Currently, only Chinese ('zh'), English ('en'), and mixed text ('mix') are supported. For more accurate results, use 'zh' or 'en' if the language is known.
"""
self.text_key = text_key
self.language = language
def process(self, sample):
raise NotImplementedError("Each pruner must implement its own `process` method.")
def load_language_model(self, language):
if language == 'en':
download_model("en_core_web_sm")
return spacy.load("en_core_web_sm")
elif language == 'zh':
download_model("zh_core_web_sm")
return spacy.load("zh_core_web_sm")
elif language == 'mix':
download_model("en_core_web_sm")
download_model("zh_core_web_sm")
return (spacy.load("en_core_web_sm"), spacy.load("zh_core_web_sm"))
else:
raise ValueError(f"Unsupported language: {language}")