Skip to content

Commit

Permalink
Release a Japanese joint model trained on NPCMJ/UD/Kyoto corpora with…
Browse files Browse the repository at this point in the history
… encoders including tok, pos, ner, dep, con, srl.
  • Loading branch information
hankcs committed May 18, 2021
1 parent 3764f7c commit faea3fa
Show file tree
Hide file tree
Showing 11 changed files with 155 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# HanLP: Han Language Processing

[中文](https://github.com/hankcs/HanLP/tree/doc-zh) | [docs](https://hanlp.hankcs.com/docs/) | [1.x](https://github.com/hankcs/HanLP/tree/1.x) | [forum](https://bbs.hankcs.com/) | [![Open In Colab](https://file.hankcs.com/img/colab-badge.svg)](https://colab.research.google.com/drive/1KPX6t1y36TOzRIeB4Kt3uJ1twuj6WuFv?usp=sharing)
[中文](https://github.com/hankcs/HanLP/tree/doc-zh) | [日本語](https://github.com/hankcs/HanLP/tree/doc-ja) | [docs](https://hanlp.hankcs.com/docs/) | [1.x](https://github.com/hankcs/HanLP/tree/1.x) | [forum](https://bbs.hankcs.com/) | [![Open In Colab](https://file.hankcs.com/img/colab-badge.svg)](https://colab.research.google.com/drive/1KPX6t1y36TOzRIeB4Kt3uJ1twuj6WuFv?usp=sharing)

The multilingual NLP library for researchers and companies, built on PyTorch and TensorFlow 2.x, for advancing state-of-the-art deep learning techniques in both academia and industry. HanLP was designed from day one to be efficient, user friendly and extendable.

Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
The multilingual NLP library for researchers and companies, built on PyTorch and TensorFlow 2.x, for advancing
state-of-the-art deep learning techniques in both academia and industry. HanLP was designed from day one to be
efficient, user friendly and extendable. It comes with pretrained models for various human languages
including English, Chinese and many others.
including English, Chinese, Japanese and many others.



Expand Down
33 changes: 28 additions & 5 deletions hanlp/layers/transformers/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
# Author: hankcs
# Date: 2020-06-22 21:06
import warnings
from typing import Union, Dict, Any, Sequence
from typing import Union, Dict, Any, Sequence, Tuple, Optional

import torch
from torch import nn

from hanlp.layers.dropout import WordDropout
from hanlp.layers.scalar_mix import ScalarMixWithDropout, ScalarMixWithDropoutBuilder
from hanlp.layers.transformers.pt_imports import PreTrainedModel, PreTrainedTokenizer, AutoTokenizer, AutoModel_
from hanlp.layers.transformers.pt_imports import PreTrainedModel, PreTrainedTokenizer, AutoTokenizer, AutoModel_, \
BertTokenizer
from hanlp.layers.transformers.utils import transformer_encode


Expand All @@ -24,7 +24,7 @@ def __init__(self,
max_sequence_length=None,
ret_raw_hidden_states=False,
transformer_args: Dict[str, Any] = None,
trainable=True,
trainable=Union[bool, Optional[Tuple[int, int]]],
training=True) -> None:
"""A pre-trained transformer encoder.
Expand Down Expand Up @@ -77,6 +77,14 @@ def __init__(self,
self.transformer = transformer
if not trainable:
transformer.requires_grad_(False)
elif isinstance(trainable, tuple):
layers = []
if hasattr(transformer, 'embeddings'):
layers.append(transformer.embeddings)
layers.extend(transformer.encoder.layer)
for i, layer in enumerate(layers):
if i < trainable[0] or i >= trainable[1]:
layer.requires_grad_(False)

if isinstance(scalar_mix, ScalarMixWithDropoutBuilder):
self.scalar_mix: ScalarMixWithDropout = scalar_mix.build()
Expand Down Expand Up @@ -121,4 +129,19 @@ def build_transformer_tokenizer(config_or_str, use_fast=True, do_basic_tokenize=
transformer = config_or_str.transformer
if use_fast and not do_basic_tokenize:
warnings.warn('`do_basic_tokenize=False` might not work when `use_fast=True`')
return AutoTokenizer.from_pretrained(transformer, use_fast=use_fast, do_basic_tokenize=do_basic_tokenize)
additional_config = dict()
if transformer.startswith('voidful/albert_chinese_'):
cls = BertTokenizer
elif transformer == 'cl-tohoku/bert-base-japanese-char':
# Since it's char level model, it's OK to use char level tok instead of fugashi
# from hanlp.utils.lang.ja.bert_tok import BertJapaneseTokenizerFast
# cls = BertJapaneseTokenizerFast
from transformers import BertJapaneseTokenizer
cls = BertJapaneseTokenizer
# from transformers import BertTokenizerFast
# cls = BertTokenizerFast
additional_config['word_tokenizer_type'] = 'basic'
else:
cls = AutoTokenizer
return cls.from_pretrained(transformer, use_fast=use_fast, do_basic_tokenize=do_basic_tokenize,
**additional_config)
4 changes: 3 additions & 1 deletion hanlp/pretrained/mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

UD_ONTONOTES_TOK_POS_LEM_FEA_NER_SRL_DEP_SDP_CON_MT5_SMALL = HANLP_URL + 'mtl/ud_ontonotes_tok_pos_lem_fea_ner_srl_dep_sdp_con_mt5_small_20210228_123458.zip'
'mt5 small version of joint tok, pos, lem, fea, ner, srl, dep, sdp and con model trained on UD and OntoNotes5 corpus.'

UD_ONTONOTES_TOK_POS_LEM_FEA_NER_SRL_DEP_SDP_CON_XLMR_BASE = HANLP_URL + 'mtl/ud_ontonotes_tok_pos_lem_fea_ner_srl_dep_sdp_con_xlm_base_20210114_005825.zip'
'XLM-R base version of joint tok, pos, lem, fea, ner, srl, dep, sdp and con model trained on UD and OntoNotes5 corpus.'

NPCMJ_UD_KYOTO_TOK_POS_CON_BERT_BASE_CHAR_JA = HANLP_URL + 'mtl/npcmj_ud_kyoto_tok_pos_ner_dep_con_srl_bert_base_char_ja_20210517_225654.zip'
'BERT base char encoder trained on NPCMJ/UD/Kyoto corpora with encoders including tok, pos, ner, dep, con, srl.'

# Will be filled up during runtime
ALL = {}
27 changes: 20 additions & 7 deletions hanlp/transform/transformer_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,11 @@ def __init__(self,
if self.ret_token_span or not self.truncate_long_sequences:
assert not self.cls_token_at_end
assert not self.pad_on_left
if self.ret_subtokens:
if not use_fast:
raise NotImplementedError(
'ret_subtokens is not available when using Python tokenizers. '
'To use this feature, set use_fast = True.')
# if self.ret_subtokens:
# if not use_fast:
# raise NotImplementedError(
# 'ret_subtokens is not available when using Python tokenizers. '
# 'To use this feature, set use_fast = True.')
self.dict: Optional[DictInterface] = dict_force # For tokenization of raw text
self.strip_cls_sep = strip_cls_sep

Expand Down Expand Up @@ -282,7 +282,11 @@ def tokenize_str(input_str, add_special_tokens=True):
input_ids = [self.cls_token_id] + input_ids
else:
input_tokens = tokenizer.tokenize(input_str)
subtoken_offsets = input_tokens
subtoken_offsets = []
_o = 0
for each in input_tokens:
subtoken_offsets.append((_o, _o + len(each)))
_o += len(each)
if add_special_tokens:
input_tokens = [self.cls_token] + input_tokens + [self.sep_token]
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
Expand Down Expand Up @@ -372,7 +376,16 @@ def tokenize_str(input_str, add_special_tokens=True):
if return_offsets_mapping:
offsets_mapping = [encoding.offsets for encoding in encodings.encodings]
else:
offsets_mapping = [None for encoding in encodings.encodings]
offsets_mapping = []
for token, subtoken_ids in zip(input_tokens, encodings.data['input_ids']):
if len(subtoken_ids) > len(token): # … --> ...
del subtoken_ids[len(token):]
char_per_subtoken = -(-len(token) // len(subtoken_ids))
bes = list(zip(range(0, len(token), char_per_subtoken),
range(char_per_subtoken, len(token) + char_per_subtoken, char_per_subtoken)))
if bes[-1][-1] != len(token):
bes[-1] = (bes[-1][0], len(token))
offsets_mapping.append(bes)
else:
encodings = SerializableDict()
encodings.data = {'input_ids': []}
Expand Down
3 changes: 3 additions & 0 deletions hanlp/utils/lang/ja/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2021-05-13 13:24
81 changes: 81 additions & 0 deletions hanlp/utils/lang/ja/bert_tok.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2021-05-13 13:24
from typing import Union, Optional

from transformers import BertTokenizerFast, TensorType, BatchEncoding, BertJapaneseTokenizer as _BertJapaneseTokenizer
from transformers.file_utils import PaddingStrategy
from transformers.tokenization_utils_base import TextInput, PreTokenizedInput, EncodedInput, TruncationStrategy


class BertJapaneseTokenizer(_BertJapaneseTokenizer):
# We may need to customize character level tokenization to handle English words and URLs
pass


class BertJapaneseTokenizerFast(BertTokenizerFast):
def encode_plus(
self,
text: Union[TextInput, PreTokenizedInput, EncodedInput],
text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
"""
Tokenize and prepare for the model a sequence or a pair of sequences.
.. warning::
This method is deprecated, ``__call__`` should be used instead.
Args:
text (:obj:`str`, :obj:`List[str]` or :obj:`List[int]` (the latter only for not-fast tokenizers)):
The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the
``tokenize`` method) or a list of integers (tokenized string ids using the ``convert_tokens_to_ids``
method).
text_pair (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`, `optional`):
Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using
the ``tokenize`` method) or a list of integers (tokenized string ids using the
``convert_tokens_to_ids`` method).
"""
text = list(text)
is_split_into_words = True
encoding = BertJapaneseTokenizer.encode_plus(self,
text,
text_pair,
add_special_tokens,
padding,
truncation,
max_length,
stride,
is_split_into_words,
pad_to_multiple_of,
return_tensors,
return_token_type_ids,
return_attention_mask,
return_overflowing_tokens,
return_special_tokens_mask,
return_offsets_mapping,
return_length,
verbose,
**kwargs
)
offsets = encoding.encodings[0].offsets
fixed_offsets = [(b + i, e + i) for i, (b, e) in enumerate(offsets)]
# TODO: This doesn't work with rust tokenizers
encoding.encodings[0].offsets.clear()
encoding.encodings[0].offsets.extend(fixed_offsets)
return encoding
2 changes: 1 addition & 1 deletion hanlp/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
# Author: hankcs
# Date: 2019-12-28 19:26

__version__ = '2.1.0-alpha.41'
__version__ = '2.1.0-alpha.42'
"""HanLP version"""
2 changes: 1 addition & 1 deletion plugins/hanlp_common/hanlp_common/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def condense(block_, extras_=None):
_srl[p_index] = '╟──►'
# _type[j] = 'V'
if len(block) != len(_srl) + 1:
warnings.warn(f'Unable to visualize overlapped spans: {pas}')
# warnings.warn(f'Unable to visualize overlapped spans: {pas}')
continue
block[0].extend(header)
for j, (_s, _t) in enumerate(zip(_srl, _type)):
Expand Down
3 changes: 3 additions & 0 deletions plugins/hanlp_demo/hanlp_demo/ja/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2021-05-17 22:30
13 changes: 13 additions & 0 deletions plugins/hanlp_demo/hanlp_demo/ja/demo_mtl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2021-05-17 22:30
import hanlp
from hanlp_common.document import Document

HanLP = hanlp.load(hanlp.pretrained.mtl.NPCMJ_UD_KYOTO_TOK_POS_CON_BERT_BASE_CHAR_JA, devices=-1)
doc: Document = HanLP([
'2021年、HanLPv2.1は次世代の最先端多言語NLP技術を本番環境に導入します。',
'奈須きのこは1973年11月28日に千葉県円空山で生まれ、ゲーム制作会社「ノーツ」の設立者だ。',
])
print(doc)
doc.pretty_print()

0 comments on commit faea3fa

Please sign in to comment.