-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
91ac0bc
commit 51bf727
Showing
168 changed files
with
27,238 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import os | ||
|
||
# USER_DIR = Path.expanduser(Path('~')).joinpath('.matchzoo') | ||
USER_DIR = os.path.expanduser("~") | ||
USER_DIR = os.path.join(USER_DIR, ".matchzoo") | ||
if not os.path.exists(USER_DIR): | ||
os.mkdir(USER_DIR) | ||
# USER_DIR.mkdir() | ||
USER_DATA_DIR = os.path.join(USER_DIR, 'datasets') | ||
if not os.path.exists(USER_DATA_DIR): | ||
os.mkdir(USER_DATA_DIR) | ||
# USER_DATA_DIR.mkdir() | ||
USER_TUNED_MODELS_DIR = os.path.join(USER_DIR, 'tuned_models') | ||
|
||
from .version import __version__ | ||
|
||
from .data_pack import DataPack | ||
from .data_pack import pack | ||
from .data_pack import load_data_pack | ||
|
||
# from . import metrics | ||
from . import tasks | ||
|
||
from . import preprocessors | ||
# from . import data_generator | ||
# from .data_generator import DataGenerator | ||
# from .data_generator import DataGeneratorBuilder | ||
|
||
from .preprocessors.chain_transform import chain_transform | ||
from .datasets import embeddings | ||
# from . import metrics | ||
# from . import losses | ||
from . import engine | ||
# from . import models | ||
# from . import embedding | ||
# from . import datasets | ||
# from . import layers | ||
# from . import auto | ||
# from . import contrib | ||
|
||
# from .engine import hyper_spaces | ||
# from .engine.base_model import load_model | ||
# from .engine.base_preprocessor import load_preprocessor | ||
# from .engine import callbacks | ||
# from .engine.param import Param | ||
# from .engine.param_table import ParamTable | ||
|
||
# from .embedding.embedding import Embedding | ||
|
||
from .utils import one_hot | ||
from .preprocessors.build_unit_from_data_pack import build_unit_from_data_pack | ||
from .preprocessors.build_vocab_unit import build_vocab_unit | ||
|
||
# deprecated, should be removed in v2.2 | ||
# from .contrib.legacy_data_generator import DPoolDataGenerator | ||
# from .contrib.legacy_data_generator import DPoolPairDataGenerator | ||
# from .contrib.legacy_data_generator import HistogramDataGenerator | ||
# from .contrib.legacy_data_generator import HistogramPairDataGenerator | ||
# from .contrib.legacy_data_generator import DynamicDataGenerator | ||
# from .contrib.legacy_data_generator import PairDataGenerator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from . import callbacks | ||
from .data_generator import DataGenerator | ||
from .data_generator_builder import DataGeneratorBuilder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .callback import Callback | ||
from .lambda_callback import LambdaCallback | ||
from .dynamic_pooling import DynamicPooling | ||
from .histogram import Histogram |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import numpy as np | ||
|
||
import matchzoo as mz | ||
|
||
|
||
class Callback(object): | ||
""" | ||
DataGenerator callback base class. | ||
To build your own callbacks, inherit `mz.data_generator.callbacks.Callback` | ||
and overrides corresponding methods. | ||
A batch is processed in the following way: | ||
- slice data pack based on batch index | ||
- handle `on_batch_data_pack` callbacks | ||
- unpack data pack into x, y | ||
- handle `on_batch_x_y` callbacks | ||
- return x, y | ||
""" | ||
|
||
def on_batch_data_pack(self, data_pack: mz.DataPack): | ||
""" | ||
`on_batch_data_pack`. | ||
:param data_pack: a sliced DataPack before unpacking. | ||
""" | ||
|
||
def on_batch_unpacked(self, x: dict, y: np.ndarray): | ||
""" | ||
`on_batch_unpacked`. | ||
:param x: unpacked x. | ||
:param y: unpacked y. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import numpy as np | ||
|
||
from matchzoo.data_generator.callbacks import Callback | ||
|
||
|
||
class DynamicPooling(Callback): | ||
""":class:`DPoolPairDataGenerator` constructor. | ||
:param fixed_length_left: max length of left text. | ||
:param fixed_length_right: max length of right text. | ||
:param compress_ratio_left: the length change ratio, | ||
especially after normal pooling layers. | ||
:param compress_ratio_right: the length change ratio, | ||
especially after normal pooling layers. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
fixed_length_left: int, | ||
fixed_length_right: int, | ||
compress_ratio_left: float = 1, | ||
compress_ratio_right: float = 1, | ||
): | ||
"""Init.""" | ||
self._fixed_length_left = fixed_length_left | ||
self._fixed_length_right = fixed_length_right | ||
self._compress_ratio_left = compress_ratio_left | ||
self._compress_ratio_right = compress_ratio_right | ||
|
||
def on_batch_unpacked(self, x, y): | ||
""" | ||
Insert `dpool_index` into `x`. | ||
:param x: unpacked x. | ||
:param y: unpacked y. | ||
""" | ||
x['dpool_index'] = _dynamic_pooling_index( | ||
x['length_left'], | ||
x['length_right'], | ||
self._fixed_length_left, | ||
self._fixed_length_right, | ||
self._compress_ratio_left, | ||
self._compress_ratio_right | ||
) | ||
|
||
|
||
def _dynamic_pooling_index(length_left: np.array, | ||
length_right: np.array, | ||
fixed_length_left: int, | ||
fixed_length_right: int, | ||
compress_ratio_left: float, | ||
compress_ratio_right: float) -> np.array: | ||
def _dpool_index(one_length_left: int, | ||
one_length_right: int, | ||
fixed_length_left: int, | ||
fixed_length_right: int): | ||
if one_length_left == 0: | ||
stride_left = fixed_length_left | ||
else: | ||
stride_left = 1.0 * fixed_length_left / one_length_left | ||
|
||
if one_length_right == 0: | ||
stride_right = fixed_length_right | ||
else: | ||
stride_right = 1.0 * fixed_length_right / one_length_right | ||
|
||
one_idx_left = [int(i / stride_left) | ||
for i in range(fixed_length_left)] | ||
one_idx_right = [int(i / stride_right) | ||
for i in range(fixed_length_right)] | ||
mesh1, mesh2 = np.meshgrid(one_idx_left, one_idx_right) | ||
index_one = np.transpose( | ||
np.stack([mesh1, mesh2]), (2, 1, 0)) | ||
return index_one | ||
|
||
index = [] | ||
dpool_bias_left = dpool_bias_right = 0 | ||
if fixed_length_left % compress_ratio_left != 0: | ||
dpool_bias_left = 1 | ||
if fixed_length_right % compress_ratio_right != 0: | ||
dpool_bias_right = 1 | ||
cur_fixed_length_left = int( | ||
fixed_length_left // compress_ratio_left) + dpool_bias_left | ||
cur_fixed_length_right = int( | ||
fixed_length_right // compress_ratio_right) + dpool_bias_right | ||
for i in range(len(length_left)): | ||
index.append(_dpool_index( | ||
length_left[i] // compress_ratio_left, | ||
length_right[i] // compress_ratio_right, | ||
cur_fixed_length_left, | ||
cur_fixed_length_right)) | ||
return np.array(index) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import numpy as np | ||
|
||
import matchzoo as mz | ||
from matchzoo.data_generator.callbacks import Callback | ||
|
||
|
||
class Histogram(Callback): | ||
""" | ||
Generate data with matching histogram. | ||
:param embedding_matrix: The embedding matrix used to generator match | ||
histogram. | ||
:param bin_size: The number of bin size of the histogram. | ||
:param hist_mode: The mode of the :class:`MatchingHistogramUnit`, one of | ||
`CH`, `NH`, and `LCH`. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
embedding_matrix: np.ndarray, | ||
bin_size: int = 30, | ||
hist_mode: str = 'CH', | ||
): | ||
"""Init.""" | ||
self._match_hist_unit = mz.preprocessors.units.MatchingHistogram( | ||
bin_size=bin_size, | ||
embedding_matrix=embedding_matrix, | ||
normalize=True, | ||
mode=hist_mode | ||
) | ||
|
||
def on_batch_unpacked(self, x, y): | ||
"""Insert `match_histogram` to `x`.""" | ||
x['match_histogram'] = _build_match_histogram(x, self._match_hist_unit) | ||
|
||
|
||
def _trunc_text(input_text: list, length: list) -> list: | ||
""" | ||
Truncating the input text according to the input length. | ||
:param input_text: The input text need to be truncated. | ||
:param length: The length used to truncated the text. | ||
:return: The truncated text. | ||
""" | ||
return [row[:length[idx]] for idx, row in enumerate(input_text)] | ||
|
||
|
||
def _build_match_histogram( | ||
x: dict, | ||
match_hist_unit: mz.preprocessors.units.MatchingHistogram | ||
) -> np.ndarray: | ||
""" | ||
Generate the matching hisogram for input. | ||
:param x: The input `dict`. | ||
:param match_hist_unit: The histogram unit :class:`MatchingHistogramUnit`. | ||
:return: The matching histogram. | ||
""" | ||
match_hist = [] | ||
text_left = x['text_left'].tolist() | ||
text_right = _trunc_text(x['text_right'].tolist(), | ||
x['length_right'].tolist()) | ||
for pair in zip(text_left, text_right): | ||
match_hist.append(match_hist_unit.transform(list(pair))) | ||
return np.asarray(match_hist) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from matchzoo.data_generator.callbacks.callback import Callback | ||
|
||
|
||
class LambdaCallback(Callback): | ||
""" | ||
LambdaCallback. Just a shorthand for creating a callback class. | ||
See :class:`matchzoo.data_generator.callbacks.Callback` for more details. | ||
Example: | ||
>>> from matchzoo.data_generator.callbacks import LambdaCallback | ||
>>> callback = LambdaCallback(on_batch_unpacked=print) | ||
>>> callback.on_batch_unpacked('x', 'y') | ||
x y | ||
""" | ||
|
||
def __init__(self, on_batch_data_pack=None, on_batch_unpacked=None): | ||
"""Init.""" | ||
self._on_batch_unpacked = on_batch_unpacked | ||
self._on_batch_data_pack = on_batch_data_pack | ||
|
||
def on_batch_data_pack(self, data_pack): | ||
"""`on_batch_data_pack`.""" | ||
if self._on_batch_data_pack: | ||
self._on_batch_data_pack(data_pack) | ||
|
||
def on_batch_unpacked(self, x, y): | ||
"""`on_batch_unpacked`.""" | ||
if self._on_batch_unpacked: | ||
self._on_batch_unpacked(x, y) |
Oops, something went wrong.