Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions octis/evaluation_metrics/coherence_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def score(self, model_output):


class WECoherencePairwise(AbstractMetric):
def __init__(self, word2vec_path=None, binary=False, topk=10):
def __init__(self, word2vec_path=None, binary=False, topk=10, saved_kv=False):
"""
Initialize metric

Expand All @@ -83,6 +83,7 @@ def __init__(self, word2vec_path=None, binary=False, topk=10):
word2vec_path : if word2vec_file is specified retrieves word embeddings file (in word2vec format)
to compute similarities, otherwise 'word2vec-google-news-300' is downloaded
binary : True if the word2vec file is binary, False otherwise (default False)
saved_kv : True if the word2vec file is saved in gensim's format (using KeyedVectors.save()) (default False)
"""
super().__init__()

Expand All @@ -91,6 +92,8 @@ def __init__(self, word2vec_path=None, binary=False, topk=10):
self.word2vec_path = word2vec_path
if word2vec_path is None:
self._wv = api.load('word2vec-google-news-300')
elif saved_kv:
self._wv = KeyedVectors.load(word2vec_path)
else:
self._wv = KeyedVectors.load_word2vec_format(
word2vec_path, binary=self.binary)
Expand Down Expand Up @@ -144,14 +147,15 @@ def score(self, model_output):


class WECoherenceCentroid(AbstractMetric):
def __init__(self, topk=10, word2vec_path=None, binary=True):
def __init__(self, topk=10, word2vec_path=None, binary=True, saved_kv=False):
"""
Initialize metric

Parameters
----------
topk : how many most likely words to consider
w2v_model_path : a word2vector model path, if not provided, google news 300 will be used instead
saved_kv : True if the word2vec file is saved in gensim's format (using KeyedVectors.save()) (default False)
"""
super().__init__()

Expand All @@ -160,6 +164,8 @@ def __init__(self, topk=10, word2vec_path=None, binary=True):
self.word2vec_path = word2vec_path
if self.word2vec_path is None:
self._wv = api.load('word2vec-google-news-300')
elif saved_kv:
self._wv = KeyedVectors.load(self.word2vec_path)
else:
self._wv = KeyedVectors.load_word2vec_format(
self.word2vec_path, binary=self.binary)
Expand Down
9 changes: 7 additions & 2 deletions octis/evaluation_metrics/diversity_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def score(self, model_output):

class WordEmbeddingsInvertedRBO(AbstractMetric):

def __init__(self, topk=10, weight=0.9, normalize=True, word2vec_path=None, binary=True):
def __init__(self, topk=10, weight=0.9, normalize=True, word2vec_path=None, binary=True, saved_kv=False):
"""
Initialize metric WE-IRBO-Match

Expand All @@ -102,6 +102,7 @@ def __init__(self, topk=10, weight=0.9, normalize=True, word2vec_path=None, bina
:param weight: Weight of each agreement at depth d. When set to 1.0, there is no weight, the rbo returns to
average overlap. (Default 0.9)
:param normalize: if true, normalize the cosine similarity
:param saved_kv: True if the word2vec file is saved in gensim's format (using KeyedVectors.save())
"""
super().__init__()
self.topk = topk
Expand All @@ -111,6 +112,8 @@ def __init__(self, topk=10, weight=0.9, normalize=True, word2vec_path=None, bina
self.word2vec_path = word2vec_path
if word2vec_path is None:
self._wv = api.load('word2vec-google-news-300')
elif saved_kv:
self._wv = KeyedVectors.load(word2vec_path)
else:
self._wv = KeyedVectors.load_word2vec_format(word2vec_path, binary=self.binary)

Expand Down Expand Up @@ -145,7 +148,7 @@ def get_word2index(list1, list2):


class WordEmbeddingsInvertedRBOCentroid(AbstractMetric):
def __init__(self, topk=10, weight=0.9, normalize=True, word2vec_path=None, binary=True):
def __init__(self, topk=10, weight=0.9, normalize=True, word2vec_path=None, binary=True, saved_kv=False):
super().__init__()
self.topk = topk
self.weight = weight
Expand All @@ -154,6 +157,8 @@ def __init__(self, topk=10, weight=0.9, normalize=True, word2vec_path=None, bina
self.word2vec_path = word2vec_path
if word2vec_path is None:
self.wv = api.load('word2vec-google-news-300')
elif saved_kv:
self.wv = KeyedVectors.load(word2vec_path)
else:
self.wv = KeyedVectors.load_word2vec_format( word2vec_path, binary=self.binary)

Expand Down
17 changes: 12 additions & 5 deletions octis/evaluation_metrics/similarity_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def score(self, model_output):


class WordEmbeddingsPairwiseSimilarity(AbstractMetric):
def __init__(self, word2vec_path=None, topk=10, binary=False):
def __init__(self, word2vec_path=None, topk=10, binary=False, saved_kv=False):
"""
Initialize metric WE pairwise similarity

Expand All @@ -68,10 +68,13 @@ def __init__(self, word2vec_path=None, topk=10, binary=False):
:param topk: top k words on which the topic diversity will be computed
:param word2vec_path: word embedding space in gensim word2vec format
:param binary: If True, indicates whether the data is in binary word2vec format.
:param saved_kv: True if the word2vec file is saved in gensim's format (using KeyedVectors.save())
"""
super().__init__()
if word2vec_path is None:
self.wv = api.load('word2vec-google-news-300')
elif saved_kv:
self.wv = KeyedVectors.load(word2vec_path)
else:
self.wv = KeyedVectors.load_word2vec_format( word2vec_path, binary=binary)

Expand Down Expand Up @@ -104,7 +107,7 @@ def score(self, model_output):


class WordEmbeddingsCentroidSimilarity(AbstractMetric):
def __init__(self, word2vec_path=None, topk=10, binary=False):
def __init__(self, word2vec_path=None, topk=10, binary=False, saved_kv=False):
"""
Initialize metric WE centroid similarity

Expand All @@ -113,11 +116,13 @@ def __init__(self, word2vec_path=None, topk=10, binary=False):
:param topk: top k words on which the topic diversity will be computed
:param word2vec_path: word embedding space in gensim word2vec format
:param binary: If True, indicates whether the data is in binary word2vec format.

:param saved_kv: True if the word2vec file is saved in gensim's format (using KeyedVectors.save())
"""
super().__init__()
if word2vec_path is None:
self.wv = api.load('word2vec-google-news-300')
elif saved_kv:
self.wv = KeyedVectors.load(word2vec_path)
else:
self.wv = KeyedVectors.load_word2vec_format(word2vec_path, binary=binary)
self.topk = topk
Expand Down Expand Up @@ -161,19 +166,21 @@ def get_word2index(list1, list2):


class WordEmbeddingsWeightedSumSimilarity(AbstractMetric):
def __init__(self, id2word, word2vec_path=None, topk=10, binary=False):
def __init__(self, id2word, word2vec_path=None, topk=10, binary=False, saved_kv=False):
"""
Initialize metric WE Weighted Sum similarity

:param id2word: dictionary mapping each id to the word of the vocabulary
:param topk: top k words on which the topic diversity will be computed
:param word2vec_path: word embedding space in gensim word2vec format
:param binary: If True, indicates whether the data is in binary word2vec format.

:param saved_kv: True if the word2vec file is saved in gensim's format (using KeyedVectors.save())
"""
super().__init__()
if word2vec_path is None:
self.wv = api.load('word2vec-google-news-300')
elif saved_kv:
self.wv = KeyedVectors.load(word2vec_path)
else:
self.wv = KeyedVectors.load_word2vec_format(word2vec_path, binary=binary)
self.topk = topk
Expand Down