Skip to content

Commit

Permalink
two packages
Browse files Browse the repository at this point in the history
  • Loading branch information
nguyenvo09 committed Jan 17, 2021
1 parent 91ac0bc commit 51bf727
Show file tree
Hide file tree
Showing 168 changed files with 27,238 additions and 0 deletions.
60 changes: 60 additions & 0 deletions matchzoo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os

# USER_DIR = Path.expanduser(Path('~')).joinpath('.matchzoo')
USER_DIR = os.path.expanduser("~")
USER_DIR = os.path.join(USER_DIR, ".matchzoo")
if not os.path.exists(USER_DIR):
os.mkdir(USER_DIR)
# USER_DIR.mkdir()
USER_DATA_DIR = os.path.join(USER_DIR, 'datasets')
if not os.path.exists(USER_DATA_DIR):
os.mkdir(USER_DATA_DIR)
# USER_DATA_DIR.mkdir()
USER_TUNED_MODELS_DIR = os.path.join(USER_DIR, 'tuned_models')

from .version import __version__

from .data_pack import DataPack
from .data_pack import pack
from .data_pack import load_data_pack

# from . import metrics
from . import tasks

from . import preprocessors
# from . import data_generator
# from .data_generator import DataGenerator
# from .data_generator import DataGeneratorBuilder

from .preprocessors.chain_transform import chain_transform
from .datasets import embeddings
# from . import metrics
# from . import losses
from . import engine
# from . import models
# from . import embedding
# from . import datasets
# from . import layers
# from . import auto
# from . import contrib

# from .engine import hyper_spaces
# from .engine.base_model import load_model
# from .engine.base_preprocessor import load_preprocessor
# from .engine import callbacks
# from .engine.param import Param
# from .engine.param_table import ParamTable

# from .embedding.embedding import Embedding

from .utils import one_hot
from .preprocessors.build_unit_from_data_pack import build_unit_from_data_pack
from .preprocessors.build_vocab_unit import build_vocab_unit

# deprecated, should be removed in v2.2
# from .contrib.legacy_data_generator import DPoolDataGenerator
# from .contrib.legacy_data_generator import DPoolPairDataGenerator
# from .contrib.legacy_data_generator import HistogramDataGenerator
# from .contrib.legacy_data_generator import HistogramPairDataGenerator
# from .contrib.legacy_data_generator import DynamicDataGenerator
# from .contrib.legacy_data_generator import PairDataGenerator
3 changes: 3 additions & 0 deletions matchzoo/data_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from . import callbacks
from .data_generator import DataGenerator
from .data_generator_builder import DataGeneratorBuilder
4 changes: 4 additions & 0 deletions matchzoo/data_generator/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .callback import Callback
from .lambda_callback import LambdaCallback
from .dynamic_pooling import DynamicPooling
from .histogram import Histogram
36 changes: 36 additions & 0 deletions matchzoo/data_generator/callbacks/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np

import matchzoo as mz


class Callback(object):
"""
DataGenerator callback base class.
To build your own callbacks, inherit `mz.data_generator.callbacks.Callback`
and overrides corresponding methods.
A batch is processed in the following way:
- slice data pack based on batch index
- handle `on_batch_data_pack` callbacks
- unpack data pack into x, y
- handle `on_batch_x_y` callbacks
- return x, y
"""

def on_batch_data_pack(self, data_pack: mz.DataPack):
"""
`on_batch_data_pack`.
:param data_pack: a sliced DataPack before unpacking.
"""

def on_batch_unpacked(self, x: dict, y: np.ndarray):
"""
`on_batch_unpacked`.
:param x: unpacked x.
:param y: unpacked y.
"""
92 changes: 92 additions & 0 deletions matchzoo/data_generator/callbacks/dynamic_pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import numpy as np

from matchzoo.data_generator.callbacks import Callback


class DynamicPooling(Callback):
""":class:`DPoolPairDataGenerator` constructor.
:param fixed_length_left: max length of left text.
:param fixed_length_right: max length of right text.
:param compress_ratio_left: the length change ratio,
especially after normal pooling layers.
:param compress_ratio_right: the length change ratio,
especially after normal pooling layers.
"""

def __init__(
self,
fixed_length_left: int,
fixed_length_right: int,
compress_ratio_left: float = 1,
compress_ratio_right: float = 1,
):
"""Init."""
self._fixed_length_left = fixed_length_left
self._fixed_length_right = fixed_length_right
self._compress_ratio_left = compress_ratio_left
self._compress_ratio_right = compress_ratio_right

def on_batch_unpacked(self, x, y):
"""
Insert `dpool_index` into `x`.
:param x: unpacked x.
:param y: unpacked y.
"""
x['dpool_index'] = _dynamic_pooling_index(
x['length_left'],
x['length_right'],
self._fixed_length_left,
self._fixed_length_right,
self._compress_ratio_left,
self._compress_ratio_right
)


def _dynamic_pooling_index(length_left: np.array,
length_right: np.array,
fixed_length_left: int,
fixed_length_right: int,
compress_ratio_left: float,
compress_ratio_right: float) -> np.array:
def _dpool_index(one_length_left: int,
one_length_right: int,
fixed_length_left: int,
fixed_length_right: int):
if one_length_left == 0:
stride_left = fixed_length_left
else:
stride_left = 1.0 * fixed_length_left / one_length_left

if one_length_right == 0:
stride_right = fixed_length_right
else:
stride_right = 1.0 * fixed_length_right / one_length_right

one_idx_left = [int(i / stride_left)
for i in range(fixed_length_left)]
one_idx_right = [int(i / stride_right)
for i in range(fixed_length_right)]
mesh1, mesh2 = np.meshgrid(one_idx_left, one_idx_right)
index_one = np.transpose(
np.stack([mesh1, mesh2]), (2, 1, 0))
return index_one

index = []
dpool_bias_left = dpool_bias_right = 0
if fixed_length_left % compress_ratio_left != 0:
dpool_bias_left = 1
if fixed_length_right % compress_ratio_right != 0:
dpool_bias_right = 1
cur_fixed_length_left = int(
fixed_length_left // compress_ratio_left) + dpool_bias_left
cur_fixed_length_right = int(
fixed_length_right // compress_ratio_right) + dpool_bias_right
for i in range(len(length_left)):
index.append(_dpool_index(
length_left[i] // compress_ratio_left,
length_right[i] // compress_ratio_right,
cur_fixed_length_left,
cur_fixed_length_right))
return np.array(index)
65 changes: 65 additions & 0 deletions matchzoo/data_generator/callbacks/histogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np

import matchzoo as mz
from matchzoo.data_generator.callbacks import Callback


class Histogram(Callback):
"""
Generate data with matching histogram.
:param embedding_matrix: The embedding matrix used to generator match
histogram.
:param bin_size: The number of bin size of the histogram.
:param hist_mode: The mode of the :class:`MatchingHistogramUnit`, one of
`CH`, `NH`, and `LCH`.
"""

def __init__(
self,
embedding_matrix: np.ndarray,
bin_size: int = 30,
hist_mode: str = 'CH',
):
"""Init."""
self._match_hist_unit = mz.preprocessors.units.MatchingHistogram(
bin_size=bin_size,
embedding_matrix=embedding_matrix,
normalize=True,
mode=hist_mode
)

def on_batch_unpacked(self, x, y):
"""Insert `match_histogram` to `x`."""
x['match_histogram'] = _build_match_histogram(x, self._match_hist_unit)


def _trunc_text(input_text: list, length: list) -> list:
"""
Truncating the input text according to the input length.
:param input_text: The input text need to be truncated.
:param length: The length used to truncated the text.
:return: The truncated text.
"""
return [row[:length[idx]] for idx, row in enumerate(input_text)]


def _build_match_histogram(
x: dict,
match_hist_unit: mz.preprocessors.units.MatchingHistogram
) -> np.ndarray:
"""
Generate the matching hisogram for input.
:param x: The input `dict`.
:param match_hist_unit: The histogram unit :class:`MatchingHistogramUnit`.
:return: The matching histogram.
"""
match_hist = []
text_left = x['text_left'].tolist()
text_right = _trunc_text(x['text_right'].tolist(),
x['length_right'].tolist())
for pair in zip(text_left, text_right):
match_hist.append(match_hist_unit.transform(list(pair)))
return np.asarray(match_hist)
31 changes: 31 additions & 0 deletions matchzoo/data_generator/callbacks/lambda_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from matchzoo.data_generator.callbacks.callback import Callback


class LambdaCallback(Callback):
"""
LambdaCallback. Just a shorthand for creating a callback class.
See :class:`matchzoo.data_generator.callbacks.Callback` for more details.
Example:
>>> from matchzoo.data_generator.callbacks import LambdaCallback
>>> callback = LambdaCallback(on_batch_unpacked=print)
>>> callback.on_batch_unpacked('x', 'y')
x y
"""

def __init__(self, on_batch_data_pack=None, on_batch_unpacked=None):
"""Init."""
self._on_batch_unpacked = on_batch_unpacked
self._on_batch_data_pack = on_batch_data_pack

def on_batch_data_pack(self, data_pack):
"""`on_batch_data_pack`."""
if self._on_batch_data_pack:
self._on_batch_data_pack(data_pack)

def on_batch_unpacked(self, x, y):
"""`on_batch_unpacked`."""
if self._on_batch_unpacked:
self._on_batch_unpacked(x, y)
Loading

0 comments on commit 51bf727

Please sign in to comment.