diff --git a/octis/evaluation_metrics/coherence_metrics.py b/octis/evaluation_metrics/coherence_metrics.py index 888d0720..99806e65 100644 --- a/octis/evaluation_metrics/coherence_metrics.py +++ b/octis/evaluation_metrics/coherence_metrics.py @@ -72,7 +72,7 @@ def score(self, model_output): class WECoherencePairwise(AbstractMetric): - def __init__(self, word2vec_path=None, binary=False, topk=10): + def __init__(self, word2vec_path=None, binary=False, topk=10, saved_kv=False): """ Initialize metric @@ -83,6 +83,7 @@ def __init__(self, word2vec_path=None, binary=False, topk=10): word2vec_path : if word2vec_file is specified retrieves word embeddings file (in word2vec format) to compute similarities, otherwise 'word2vec-google-news-300' is downloaded binary : True if the word2vec file is binary, False otherwise (default False) + saved_kv : True if the word2vec file is saved in gensim's format (using KeyedVectors.save()) (default False) """ super().__init__() @@ -91,6 +92,8 @@ def __init__(self, word2vec_path=None, binary=False, topk=10): self.word2vec_path = word2vec_path if word2vec_path is None: self._wv = api.load('word2vec-google-news-300') + elif saved_kv: + self._wv = KeyedVectors.load(word2vec_path) else: self._wv = KeyedVectors.load_word2vec_format( word2vec_path, binary=self.binary) @@ -144,7 +147,7 @@ def score(self, model_output): class WECoherenceCentroid(AbstractMetric): - def __init__(self, topk=10, word2vec_path=None, binary=True): + def __init__(self, topk=10, word2vec_path=None, binary=True, saved_kv=False): """ Initialize metric @@ -152,6 +155,7 @@ def __init__(self, topk=10, word2vec_path=None, binary=True): ---------- topk : how many most likely words to consider w2v_model_path : a word2vector model path, if not provided, google news 300 will be used instead + saved_kv : True if the word2vec file is saved in gensim's format (using KeyedVectors.save()) (default False) """ super().__init__() @@ -160,6 +164,8 @@ def __init__(self, topk=10, word2vec_path=None, binary=True): self.word2vec_path = word2vec_path if self.word2vec_path is None: self._wv = api.load('word2vec-google-news-300') + elif saved_kv: + self._wv = KeyedVectors.load(self.word2vec_path) else: self._wv = KeyedVectors.load_word2vec_format( self.word2vec_path, binary=self.binary) diff --git a/octis/evaluation_metrics/diversity_metrics.py b/octis/evaluation_metrics/diversity_metrics.py index cbc95247..d9ee2bac 100644 --- a/octis/evaluation_metrics/diversity_metrics.py +++ b/octis/evaluation_metrics/diversity_metrics.py @@ -91,7 +91,7 @@ def score(self, model_output): class WordEmbeddingsInvertedRBO(AbstractMetric): - def __init__(self, topk=10, weight=0.9, normalize=True, word2vec_path=None, binary=True): + def __init__(self, topk=10, weight=0.9, normalize=True, word2vec_path=None, binary=True, saved_kv=False): """ Initialize metric WE-IRBO-Match @@ -102,6 +102,7 @@ def __init__(self, topk=10, weight=0.9, normalize=True, word2vec_path=None, bina :param weight: Weight of each agreement at depth d. When set to 1.0, there is no weight, the rbo returns to average overlap. (Default 0.9) :param normalize: if true, normalize the cosine similarity + :param saved_kv: True if the word2vec file is saved in gensim's format (using KeyedVectors.save()) """ super().__init__() self.topk = topk @@ -111,6 +112,8 @@ def __init__(self, topk=10, weight=0.9, normalize=True, word2vec_path=None, bina self.word2vec_path = word2vec_path if word2vec_path is None: self._wv = api.load('word2vec-google-news-300') + elif saved_kv: + self._wv = KeyedVectors.load(word2vec_path) else: self._wv = KeyedVectors.load_word2vec_format(word2vec_path, binary=self.binary) @@ -145,7 +148,7 @@ def get_word2index(list1, list2): class WordEmbeddingsInvertedRBOCentroid(AbstractMetric): - def __init__(self, topk=10, weight=0.9, normalize=True, word2vec_path=None, binary=True): + def __init__(self, topk=10, weight=0.9, normalize=True, word2vec_path=None, binary=True, saved_kv=False): super().__init__() self.topk = topk self.weight = weight @@ -154,6 +157,8 @@ def __init__(self, topk=10, weight=0.9, normalize=True, word2vec_path=None, bina self.word2vec_path = word2vec_path if word2vec_path is None: self.wv = api.load('word2vec-google-news-300') + elif saved_kv: + self.wv = KeyedVectors.load(word2vec_path) else: self.wv = KeyedVectors.load_word2vec_format( word2vec_path, binary=self.binary) diff --git a/octis/evaluation_metrics/similarity_metrics.py b/octis/evaluation_metrics/similarity_metrics.py index b7a529a4..0173a471 100644 --- a/octis/evaluation_metrics/similarity_metrics.py +++ b/octis/evaluation_metrics/similarity_metrics.py @@ -59,7 +59,7 @@ def score(self, model_output): class WordEmbeddingsPairwiseSimilarity(AbstractMetric): - def __init__(self, word2vec_path=None, topk=10, binary=False): + def __init__(self, word2vec_path=None, topk=10, binary=False, saved_kv=False): """ Initialize metric WE pairwise similarity @@ -68,10 +68,13 @@ def __init__(self, word2vec_path=None, topk=10, binary=False): :param topk: top k words on which the topic diversity will be computed :param word2vec_path: word embedding space in gensim word2vec format :param binary: If True, indicates whether the data is in binary word2vec format. + :param saved_kv: True if the word2vec file is saved in gensim's format (using KeyedVectors.save()) """ super().__init__() if word2vec_path is None: self.wv = api.load('word2vec-google-news-300') + elif saved_kv: + self.wv = KeyedVectors.load(word2vec_path) else: self.wv = KeyedVectors.load_word2vec_format( word2vec_path, binary=binary) @@ -104,7 +107,7 @@ def score(self, model_output): class WordEmbeddingsCentroidSimilarity(AbstractMetric): - def __init__(self, word2vec_path=None, topk=10, binary=False): + def __init__(self, word2vec_path=None, topk=10, binary=False, saved_kv=False): """ Initialize metric WE centroid similarity @@ -113,11 +116,13 @@ def __init__(self, word2vec_path=None, topk=10, binary=False): :param topk: top k words on which the topic diversity will be computed :param word2vec_path: word embedding space in gensim word2vec format :param binary: If True, indicates whether the data is in binary word2vec format. - + :param saved_kv: True if the word2vec file is saved in gensim's format (using KeyedVectors.save()) """ super().__init__() if word2vec_path is None: self.wv = api.load('word2vec-google-news-300') + elif saved_kv: + self.wv = KeyedVectors.load(word2vec_path) else: self.wv = KeyedVectors.load_word2vec_format(word2vec_path, binary=binary) self.topk = topk @@ -161,7 +166,7 @@ def get_word2index(list1, list2): class WordEmbeddingsWeightedSumSimilarity(AbstractMetric): - def __init__(self, id2word, word2vec_path=None, topk=10, binary=False): + def __init__(self, id2word, word2vec_path=None, topk=10, binary=False, saved_kv=False): """ Initialize metric WE Weighted Sum similarity @@ -169,11 +174,13 @@ def __init__(self, id2word, word2vec_path=None, topk=10, binary=False): :param topk: top k words on which the topic diversity will be computed :param word2vec_path: word embedding space in gensim word2vec format :param binary: If True, indicates whether the data is in binary word2vec format. - + :param saved_kv: True if the word2vec file is saved in gensim's format (using KeyedVectors.save()) """ super().__init__() if word2vec_path is None: self.wv = api.load('word2vec-google-news-300') + elif saved_kv: + self.wv = KeyedVectors.load(word2vec_path) else: self.wv = KeyedVectors.load_word2vec_format(word2vec_path, binary=binary) self.topk = topk