diff --git a/gensim/models/word2vec.py b/gensim/models/word2vec.py index d35a9b4599..babbc5a731 100755 --- a/gensim/models/word2vec.py +++ b/gensim/models/word2vec.py @@ -170,11 +170,11 @@ CORPUSFILE_VERSION = -1 def train_epoch_sg(model, corpus_file, offset, _cython_vocab, _cur_epoch, _expected_examples, _expected_words, - _work, _neu1, compute_loss): + _work, _neu1): raise RuntimeError("Training with corpus_file argument is not supported") def train_epoch_cbow(model, corpus_file, offset, _cython_vocab, _cur_epoch, _expected_examples, _expected_words, - _work, _neu1, compute_loss): + _work, _neu1): raise RuntimeError("Training with corpus_file argument is not supported") @@ -182,7 +182,7 @@ class Word2Vec(utils.SaveLoad): def __init__(self, sentences=None, corpus_file=None, vector_size=100, alpha=0.025, window=5, min_count=5, max_vocab_size=None, sample=1e-3, seed=1, workers=3, min_alpha=0.0001, sg=0, hs=0, negative=5, ns_exponent=0.75, cbow_mean=1, hashfxn=hash, epochs=5, null_word=0, - trim_rule=None, sorted_vocab=1, batch_words=MAX_WORDS_IN_BATCH, compute_loss=False, callbacks=(), + trim_rule=None, sorted_vocab=1, batch_words=MAX_WORDS_IN_BATCH, callbacks=(), comment=None, max_final_vocab=None): """Train, use and evaluate neural networks described in https://code.google.com/p/word2vec/. @@ -282,9 +282,6 @@ def __init__(self, sentences=None, corpus_file=None, vector_size=100, alpha=0.02 Target size (in words) for batches of examples passed to worker threads (and thus cython routines).(Larger batches will be passed if individual texts are longer than 10000 words, but the standard cython code truncates to that maximum.) - compute_loss: bool, optional - If True, computes and stores loss value which can be retrieved using - :meth:`~gensim.models.word2vec.Word2Vec.get_latest_training_loss`. callbacks : iterable of :class:`~gensim.models.callbacks.CallbackAny2Vec`, optional Sequence of callbacks to be executed at specific stages during training. @@ -325,8 +322,8 @@ def __init__(self, sentences=None, corpus_file=None, vector_size=100, alpha=0.02 self.negative = int(negative) self.ns_exponent = ns_exponent self.cbow_mean = int(cbow_mean) - self.compute_loss = bool(compute_loss) - self.running_training_loss = 0 + self.epoch_loss = 0.0 + self.epoch_loss_history = [] self.min_alpha_yet_reached = float(alpha) self.corpus_count = 0 self.corpus_total_words = 0 @@ -380,7 +377,7 @@ def build_vocab_and_train(self, corpus_iterable=None, corpus_file=None, trim_rul self.train( corpus_iterable=corpus_iterable, corpus_file=corpus_file, total_examples=self.corpus_count, total_words=self.corpus_total_words, epochs=self.epochs, start_alpha=self.alpha, - end_alpha=self.min_alpha, compute_loss=self.compute_loss, callbacks=callbacks) + end_alpha=self.min_alpha, callbacks=callbacks) def build_vocab(self, corpus_iterable=None, corpus_file=None, update=False, progress_per=10000, keep_raw_vocab=False, trim_rule=None, **kwargs): @@ -838,10 +835,10 @@ def _do_train_epoch(self, corpus_file, thread_id, offset, cython_vocab, thread_p if self.sg: examples, tally, raw_tally = train_epoch_sg(self, corpus_file, offset, cython_vocab, cur_epoch, - total_examples, total_words, work, neu1, self.compute_loss) + total_examples, total_words, work, neu1) else: examples, tally, raw_tally = train_epoch_cbow(self, corpus_file, offset, cython_vocab, cur_epoch, - total_examples, total_words, work, neu1, self.compute_loss) + total_examples, total_words, work, neu1) return examples, tally, raw_tally @@ -866,9 +863,9 @@ def _do_train_job(self, sentences, alpha, inits): work, neu1 = inits tally = 0 if self.sg: - tally += train_batch_sg(self, sentences, alpha, work, self.compute_loss) + tally += train_batch_sg(self, sentences, alpha, work) else: - tally += train_batch_cbow(self, sentences, alpha, work, neu1, self.compute_loss) + tally += train_batch_cbow(self, sentences, alpha, work, neu1) return tally, self._raw_word_count(sentences) def _clear_post_train(self): @@ -877,7 +874,7 @@ def _clear_post_train(self): def train(self, corpus_iterable=None, corpus_file=None, total_examples=None, total_words=None, epochs=None, start_alpha=None, end_alpha=None, word_count=0, - queue_factor=2, report_delay=1.0, compute_loss=False, callbacks=(), + queue_factor=2, report_delay=1.0, callbacks=(), **kwargs): """Update the model's neural weights from a sequence of sentences. @@ -931,9 +928,6 @@ def train(self, corpus_iterable=None, corpus_file=None, total_examples=None, tot Multiplier for size of queue (number of workers * queue_factor). report_delay : float, optional Seconds to wait before reporting progress. - compute_loss: bool, optional - If True, computes and stores loss value which can be retrieved using - :meth:`~gensim.models.word2vec.Word2Vec.get_latest_training_loss`. callbacks : iterable of :class:`~gensim.models.callbacks.CallbackAny2Vec`, optional Sequence of callbacks to be executed at specific stages during training. @@ -959,8 +953,7 @@ def train(self, corpus_iterable=None, corpus_file=None, total_examples=None, tot total_examples=total_examples, total_words=total_words) - self.compute_loss = compute_loss - self.running_training_loss = 0.0 + self.epoch_loss_history = [] for callback in callbacks: callback.on_train_begin(self) @@ -971,6 +964,7 @@ def train(self, corpus_iterable=None, corpus_file=None, total_examples=None, tot job_tally = 0 for cur_epoch in range(self.epochs): + self.epoch_loss = 0.0 for callback in callbacks: callback.on_epoch_begin(self) @@ -988,6 +982,7 @@ def train(self, corpus_iterable=None, corpus_file=None, total_examples=None, tot raw_word_count += raw_word_count_epoch job_tally += job_tally_epoch + self.epoch_loss_history.append(self.epoch_loss) for callback in callbacks: callback.on_epoch_end(self) @@ -1820,17 +1815,6 @@ def save(self, *args, **kwargs): kwargs['ignore'] = kwargs.get('ignore', []) + ['cum_table', ] super(Word2Vec, self).save(*args, **kwargs) - def get_latest_training_loss(self): - """Get current value of the training loss. - - Returns - ------- - float - Current training loss. - - """ - return self.running_training_loss - @classmethod def load(cls, *args, rethrow=False, **kwargs): """Load a previously saved :class:`~gensim.models.word2vec.Word2Vec` model. diff --git a/gensim/models/word2vec_corpusfile.pyx b/gensim/models/word2vec_corpusfile.pyx index 467b6a2d45..3eb9d9bd81 100644 --- a/gensim/models/word2vec_corpusfile.pyx +++ b/gensim/models/word2vec_corpusfile.pyx @@ -251,7 +251,7 @@ cdef REAL_t get_next_alpha( def train_epoch_sg(model, corpus_file, offset, _cython_vocab, _cur_epoch, _expected_examples, _expected_words, _work, - _neu1, compute_loss): + _neu1): """Train Skipgram model for one epoch by training on an input stream. This function is used only in multistream mode. Called internally from :meth:`~gensim.models.word2vec.Word2Vec.train`. @@ -268,8 +268,6 @@ def train_epoch_sg(model, corpus_file, offset, _cython_vocab, _cur_epoch, _expec Private working memory for each worker. _neu1 : np.ndarray Private working memory for each worker. - compute_loss : bool - Whether or not the training loss should be computed in this batch. Returns ------- @@ -297,7 +295,7 @@ def train_epoch_sg(model, corpus_file, offset, _cython_vocab, _cur_epoch, _expec cdef long long total_effective_words = 0, total_words = 0 cdef int sent_idx, idx_start, idx_end - init_w2v_config(&c, model, _alpha, compute_loss, _work) + init_w2v_config(&c, model, _alpha, _work) cdef vector[vector[string]] sentences @@ -330,14 +328,14 @@ def train_epoch_sg(model, corpus_file, offset, _cython_vocab, _cur_epoch, _expec if c.hs: w2v_fast_sentence_sg_hs( c.points[i], c.codes[i], c.codelens[i], c.syn0, c.syn1, c.size, c.indexes[j], - c.alpha, c.work, c.words_lockf, c.words_lockf_len, c.compute_loss, - &c.running_training_loss) + c.alpha, c.work, c.words_lockf, c.words_lockf_len, + &c.minibatch_loss) if c.negative: c.next_random = w2v_fast_sentence_sg_neg( c.negative, c.cum_table, c.cum_table_len, c.syn0, c.syn1neg, c.size, c.indexes[i], c.indexes[j], c.alpha, c.work, c.next_random, c.words_lockf, c.words_lockf_len, - c.compute_loss, &c.running_training_loss) + &c.minibatch_loss) total_sentences += sentences.size() total_effective_words += effective_words @@ -346,12 +344,12 @@ def train_epoch_sg(model, corpus_file, offset, _cython_vocab, _cur_epoch, _expec start_alpha, end_alpha, total_sentences, total_words, expected_examples, expected_words, cur_epoch, num_epochs) - model.running_training_loss = c.running_training_loss + model.epoch_loss += c.minibatch_loss return total_sentences, total_effective_words, total_words def train_epoch_cbow(model, corpus_file, offset, _cython_vocab, _cur_epoch, _expected_examples, _expected_words, _work, - _neu1, compute_loss): + _neu1): """Train CBOW model for one epoch by training on an input stream. This function is used only in multistream mode. Called internally from :meth:`~gensim.models.word2vec.Word2Vec.train`. @@ -368,8 +366,6 @@ def train_epoch_cbow(model, corpus_file, offset, _cython_vocab, _cur_epoch, _exp Private working memory for each worker. _neu1 : np.ndarray Private working memory for each worker. - compute_loss : bool - Whether or not the training loss should be computed in this batch. Returns ------- @@ -397,7 +393,7 @@ def train_epoch_cbow(model, corpus_file, offset, _cython_vocab, _cur_epoch, _exp cdef long long total_effective_words = 0, total_words = 0 cdef int sent_idx, idx_start, idx_end - init_w2v_config(&c, model, _alpha, compute_loss, _work, _neu1) + init_w2v_config(&c, model, _alpha, _work, _neu1) cdef vector[vector[string]] sentences @@ -427,15 +423,15 @@ def train_epoch_cbow(model, corpus_file, offset, _cython_vocab, _cur_epoch, _exp if c.hs: w2v_fast_sentence_cbow_hs( c.points[i], c.codes[i], c.codelens, c.neu1, c.syn0, c.syn1, c.size, c.indexes, c.alpha, - c.work, i, j, k, c.cbow_mean, c.words_lockf, c.words_lockf_len, c.compute_loss, - &c.running_training_loss) + c.work, i, j, k, c.cbow_mean, c.words_lockf, c.words_lockf_len, + &c.minibatch_loss) if c.negative: c.next_random = w2v_fast_sentence_cbow_neg( c.negative, c.cum_table, c.cum_table_len, c.codelens, c.neu1, c.syn0, c.syn1neg, c.size, c.indexes, c.alpha, c.work, i, j, k, c.cbow_mean, - c.next_random, c.words_lockf, c.words_lockf_len, c.compute_loss, - &c.running_training_loss) + c.next_random, c.words_lockf, c.words_lockf_len, + &c.minibatch_loss) total_sentences += sentences.size() total_effective_words += effective_words @@ -444,7 +440,7 @@ def train_epoch_cbow(model, corpus_file, offset, _cython_vocab, _cur_epoch, _exp start_alpha, end_alpha, total_sentences, total_words, expected_examples, expected_words, cur_epoch, num_epochs) - model.running_training_loss = c.running_training_loss + model.epoch_loss += c.minibatch_loss return total_sentences, total_effective_words, total_words diff --git a/gensim/models/word2vec_inner.pxd b/gensim/models/word2vec_inner.pxd index 82abad2f05..53dfa84dde 100644 --- a/gensim/models/word2vec_inner.pxd +++ b/gensim/models/word2vec_inner.pxd @@ -49,8 +49,9 @@ cdef our_saxpy_ptr our_saxpy cdef struct Word2VecConfig: - int hs, negative, sample, compute_loss, size, window, cbow_mean, workers - REAL_t running_training_loss, alpha + int hs, negative, sample, size, window, cbow_mean, workers + REAL_t alpha + np.float64_t minibatch_loss REAL_t *syn0 REAL_t *words_lockf @@ -96,7 +97,7 @@ cdef void w2v_fast_sentence_sg_hs( const np.uint32_t *word_point, const np.uint8_t *word_code, const int codelen, REAL_t *syn0, REAL_t *syn1, const int size, const np.uint32_t word2_index, const REAL_t alpha, REAL_t *work, REAL_t *words_lockf, - const np.uint32_t lockf_len, const int _compute_loss, REAL_t *_running_training_loss_param) nogil + const np.uint32_t lockf_len, np.float64_t *minibatch_loss_ptr) nogil cdef unsigned long long w2v_fast_sentence_sg_neg( @@ -104,7 +105,7 @@ cdef unsigned long long w2v_fast_sentence_sg_neg( REAL_t *syn0, REAL_t *syn1neg, const int size, const np.uint32_t word_index, const np.uint32_t word2_index, const REAL_t alpha, REAL_t *work, unsigned long long next_random, REAL_t *words_lockf, - const np.uint32_t lockf_len, const int _compute_loss, REAL_t *_running_training_loss_param) nogil + const np.uint32_t lockf_len, np.float64_t *minibatch_loss_ptr) nogil cdef void w2v_fast_sentence_cbow_hs( @@ -112,7 +113,7 @@ cdef void w2v_fast_sentence_cbow_hs( REAL_t *neu1, REAL_t *syn0, REAL_t *syn1, const int size, const np.uint32_t indexes[MAX_SENTENCE_LEN], const REAL_t alpha, REAL_t *work, int i, int j, int k, int cbow_mean, REAL_t *words_lockf, - const np.uint32_t lockf_len, const int _compute_loss, REAL_t *_running_training_loss_param) nogil + const np.uint32_t lockf_len, np.float64_t *minibatch_loss_ptr) nogil cdef unsigned long long w2v_fast_sentence_cbow_neg( @@ -120,7 +121,7 @@ cdef unsigned long long w2v_fast_sentence_cbow_neg( REAL_t *neu1, REAL_t *syn0, REAL_t *syn1neg, const int size, const np.uint32_t indexes[MAX_SENTENCE_LEN], const REAL_t alpha, REAL_t *work, int i, int j, int k, int cbow_mean, unsigned long long next_random, REAL_t *words_lockf, - const np.uint32_t lockf_len, const int _compute_loss, REAL_t *_running_training_loss_param) nogil + const np.uint32_t lockf_len, np.float64_t *minibatch_loss_ptr) nogil -cdef init_w2v_config(Word2VecConfig *c, model, alpha, compute_loss, _work, _neu1=*) +cdef init_w2v_config(Word2VecConfig *c, model, alpha, _work, _neu1=*) diff --git a/gensim/models/word2vec_inner.pyx b/gensim/models/word2vec_inner.pyx index 50bfc803bd..ec2293abc1 100755 --- a/gensim/models/word2vec_inner.pyx +++ b/gensim/models/word2vec_inner.pyx @@ -75,7 +75,7 @@ cdef void w2v_fast_sentence_sg_hs( const np.uint32_t *word_point, const np.uint8_t *word_code, const int codelen, REAL_t *syn0, REAL_t *syn1, const int size, const np.uint32_t word2_index, const REAL_t alpha, REAL_t *work, REAL_t *words_lockf, - const np.uint32_t lockf_len, const int _compute_loss, REAL_t *_running_training_loss_param) nogil: + const np.uint32_t lockf_len, np.float64_t *minibatch_loss_ptr) nogil: """Train on a single effective word from the current batch, using the Skip-Gram model. In this model we are using a given word to predict a context word (a word that is @@ -104,9 +104,7 @@ cdef void w2v_fast_sentence_sg_hs( Private working memory for each worker. words_lockf Lock factors for each word. A value of 0 will block training. - _compute_loss - Whether or not the loss should be computed at this step. - _running_training_loss_param + minibatch_loss_ptr Running loss, used to debug or inspect how training progresses. """ @@ -124,13 +122,13 @@ cdef void w2v_fast_sentence_sg_hs( f = EXP_TABLE[((f_dot + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))] g = (1 - word_code[b] - f) * alpha - if _compute_loss == 1: - sgn = (-1)**word_code[b] # ch function: 0-> 1, 1 -> -1 - lprob = -1*sgn*f_dot - if lprob <= -MAX_EXP or lprob >= MAX_EXP: - continue - lprob = LOG_TABLE[((lprob + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))] - _running_training_loss_param[0] = _running_training_loss_param[0] - lprob + # tally loss + sgn = (-1)**word_code[b] # ch function: 0-> 1, 1 -> -1 + lprob = -1*sgn*f_dot + if lprob <= -MAX_EXP or lprob >= MAX_EXP: + continue + lprob = LOG_TABLE[((lprob + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))] + minibatch_loss_ptr[0] = minibatch_loss_ptr[0] - lprob our_saxpy(&size, &g, &syn1[row2], &ONE, work, &ONE) our_saxpy(&size, &g, &syn0[row1], &ONE, &syn1[row2], &ONE) @@ -161,7 +159,7 @@ cdef unsigned long long w2v_fast_sentence_sg_neg( REAL_t *syn0, REAL_t *syn1neg, const int size, const np.uint32_t word_index, const np.uint32_t word2_index, const REAL_t alpha, REAL_t *work, unsigned long long next_random, REAL_t *words_lockf, - const np.uint32_t lockf_len, const int _compute_loss, REAL_t *_running_training_loss_param) nogil: + const np.uint32_t lockf_len, np.float64_t *minibatch_loss_ptr) nogil: """Train on a single effective word from the current batch, using the Skip-Gram model. In this model we are using a given word to predict a context word (a word that is @@ -195,9 +193,7 @@ cdef unsigned long long w2v_fast_sentence_sg_neg( Seed to produce the index for the next word to be randomly sampled. words_lockf Lock factors for each word. A value of 0 will block training. - _compute_loss - Whether or not the loss should be computed at this step. - _running_training_loss_param + minibatch_loss_ptr Running loss, used to debug or inspect how training progresses. Returns @@ -232,12 +228,12 @@ cdef unsigned long long w2v_fast_sentence_sg_neg( f = EXP_TABLE[((f_dot + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))] g = (label - f) * alpha - if _compute_loss == 1: - f_dot = (f_dot if d == 0 else -f_dot) - if f_dot <= -MAX_EXP or f_dot >= MAX_EXP: - continue - log_e_f_dot = LOG_TABLE[((f_dot + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))] - _running_training_loss_param[0] = _running_training_loss_param[0] - log_e_f_dot + # tally loss + f_dot = (f_dot if d == 0 else -f_dot) + if f_dot <= -MAX_EXP or f_dot >= MAX_EXP: + continue + log_e_f_dot = LOG_TABLE[((f_dot + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))] + minibatch_loss_ptr[0] = minibatch_loss_ptr[0] - log_e_f_dot our_saxpy(&size, &g, &syn1neg[row2], &ONE, work, &ONE) our_saxpy(&size, &g, &syn0[row1], &ONE, &syn1neg[row2], &ONE) @@ -252,7 +248,7 @@ cdef void w2v_fast_sentence_cbow_hs( REAL_t *neu1, REAL_t *syn0, REAL_t *syn1, const int size, const np.uint32_t indexes[MAX_SENTENCE_LEN], const REAL_t alpha, REAL_t *work, int i, int j, int k, int cbow_mean, REAL_t *words_lockf, const np.uint32_t lockf_len, - const int _compute_loss, REAL_t *_running_training_loss_param) nogil: + np.float64_t *minibatch_loss_ptr) nogil: """Train on a single effective word from the current batch, using the CBOW method. Using this method we train the trainable neural network by attempting to predict a @@ -291,9 +287,7 @@ cdef void w2v_fast_sentence_cbow_hs( If 0, use the sum of the context word vectors as the prediction. If 1, use the mean. words_lockf Lock factors for each word. A value of 0 will block training. - _compute_loss - Whether or not the loss should be computed at this step. - _running_training_loss_param + minibatch_loss_ptr Running loss, used to debug or inspect how training progresses. """ @@ -324,13 +318,13 @@ cdef void w2v_fast_sentence_cbow_hs( f = EXP_TABLE[((f_dot + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))] g = (1 - word_code[b] - f) * alpha - if _compute_loss == 1: - sgn = (-1)**word_code[b] # ch function: 0-> 1, 1 -> -1 - lprob = -1*sgn*f_dot - if lprob <= -MAX_EXP or lprob >= MAX_EXP: - continue - lprob = LOG_TABLE[((lprob + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))] - _running_training_loss_param[0] = _running_training_loss_param[0] - lprob + # tally loss + sgn = (-1)**word_code[b] # ch function: 0-> 1, 1 -> -1 + lprob = -1*sgn*f_dot + if lprob <= -MAX_EXP or lprob >= MAX_EXP: + continue + lprob = LOG_TABLE[((lprob + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))] + minibatch_loss_ptr[0] = minibatch_loss_ptr[0] - lprob our_saxpy(&size, &g, &syn1[row2], &ONE, work, &ONE) our_saxpy(&size, &g, neu1, &ONE, &syn1[row2], &ONE) @@ -350,7 +344,7 @@ cdef unsigned long long w2v_fast_sentence_cbow_neg( REAL_t *neu1, REAL_t *syn0, REAL_t *syn1neg, const int size, const np.uint32_t indexes[MAX_SENTENCE_LEN], const REAL_t alpha, REAL_t *work, int i, int j, int k, int cbow_mean, unsigned long long next_random, REAL_t *words_lockf, - const np.uint32_t lockf_len, const int _compute_loss, REAL_t *_running_training_loss_param) nogil: + const np.uint32_t lockf_len, np.float64_t *minibatch_loss_ptr) nogil: """Train on a single effective word from the current batch, using the CBOW method. Using this method we train the trainable neural network by attempting to predict a @@ -394,9 +388,7 @@ cdef unsigned long long w2v_fast_sentence_cbow_neg( Seed for the drawing the predicted word for the next iteration of the same routine. words_lockf Lock factors for each word. A value of 0 will block training. - _compute_loss - Whether or not the loss should be computed at this step. - _running_training_loss_param + minibatch_loss_ptr Running loss, used to debug or inspect how training progresses. """ @@ -442,12 +434,12 @@ cdef unsigned long long w2v_fast_sentence_cbow_neg( f = EXP_TABLE[((f_dot + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))] g = (label - f) * alpha - if _compute_loss == 1: - f_dot = (f_dot if d == 0 else -f_dot) - if f_dot <= -MAX_EXP or f_dot >= MAX_EXP: - continue - log_e_f_dot = LOG_TABLE[((f_dot + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))] - _running_training_loss_param[0] = _running_training_loss_param[0] - log_e_f_dot + # tally loss + f_dot = (f_dot if d == 0 else -f_dot) + if f_dot <= -MAX_EXP or f_dot >= MAX_EXP: + continue + log_e_f_dot = LOG_TABLE[((f_dot + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))] + minibatch_loss_ptr[0] = minibatch_loss_ptr[0] - log_e_f_dot our_saxpy(&size, &g, &syn1neg[row2], &ONE, work, &ONE) our_saxpy(&size, &g, neu1, &ONE, &syn1neg[row2], &ONE) @@ -464,7 +456,7 @@ cdef unsigned long long w2v_fast_sentence_cbow_neg( return next_random -cdef init_w2v_config(Word2VecConfig *c, model, alpha, compute_loss, _work, _neu1=None): +cdef init_w2v_config(Word2VecConfig *c, model, alpha, _work, _neu1=None): c[0].hs = model.hs c[0].negative = model.negative c[0].sample = (model.sample != 0) @@ -472,8 +464,7 @@ cdef init_w2v_config(Word2VecConfig *c, model, alpha, compute_loss, _work, _neu1 c[0].window = model.window c[0].workers = model.workers - c[0].compute_loss = (1 if compute_loss else 0) - c[0].running_training_loss = model.running_training_loss + c[0].minibatch_loss = 0.0 c[0].syn0 = (np.PyArray_DATA(model.wv.vectors)) c[0].words_lockf = (np.PyArray_DATA(model.wv.vectors_lockf)) @@ -498,7 +489,7 @@ cdef init_w2v_config(Word2VecConfig *c, model, alpha, compute_loss, _work, _neu1 c[0].neu1 = np.PyArray_DATA(_neu1) -def train_batch_sg(model, sentences, alpha, _work, compute_loss): +def train_batch_sg(model, sentences, alpha, _work): """Update skip-gram model by training on a batch of sentences. Called internally from :meth:`~gensim.models.word2vec.Word2Vec.train`. @@ -513,8 +504,6 @@ def train_batch_sg(model, sentences, alpha, _work, compute_loss): The learning rate _work : np.ndarray Private working memory for each worker. - compute_loss : bool - Whether or not the training loss should be computed in this batch. Returns ------- @@ -529,7 +518,7 @@ def train_batch_sg(model, sentences, alpha, _work, compute_loss): cdef int sent_idx, idx_start, idx_end cdef np.uint32_t *vocab_sample_ints - init_w2v_config(&c, model, alpha, compute_loss, _work) + init_w2v_config(&c, model, alpha, _work) if c.sample: vocab_sample_ints = np.PyArray_DATA(model.wv.expandos['sample_int']) if c.hs: @@ -585,15 +574,14 @@ def train_batch_sg(model, sentences, alpha, _work, compute_loss): if j == i: continue if c.hs: - w2v_fast_sentence_sg_hs(c.points[i], c.codes[i], c.codelens[i], c.syn0, c.syn1, c.size, c.indexes[j], c.alpha, c.work, c.words_lockf, c.words_lockf_len, c.compute_loss, &c.running_training_loss) + w2v_fast_sentence_sg_hs(c.points[i], c.codes[i], c.codelens[i], c.syn0, c.syn1, c.size, c.indexes[j], c.alpha, c.work, c.words_lockf, c.words_lockf_len, &c.minibatch_loss) if c.negative: - c.next_random = w2v_fast_sentence_sg_neg(c.negative, c.cum_table, c.cum_table_len, c.syn0, c.syn1neg, c.size, c.indexes[i], c.indexes[j], c.alpha, c.work, c.next_random, c.words_lockf, c.words_lockf_len, c.compute_loss, &c.running_training_loss) - - model.running_training_loss = c.running_training_loss + c.next_random = w2v_fast_sentence_sg_neg(c.negative, c.cum_table, c.cum_table_len, c.syn0, c.syn1neg, c.size, c.indexes[i], c.indexes[j], c.alpha, c.work, c.next_random, c.words_lockf, c.words_lockf_len, &c.minibatch_loss) + model.epoch_loss += c.minibatch_loss return effective_words -def train_batch_cbow(model, sentences, alpha, _work, _neu1, compute_loss): +def train_batch_cbow(model, sentences, alpha, _work, _neu1): """Update CBOW model by training on a batch of sentences. Called internally from :meth:`~gensim.models.word2vec.Word2Vec.train`. @@ -610,8 +598,6 @@ def train_batch_cbow(model, sentences, alpha, _work, _neu1, compute_loss): Private working memory for each worker. _neu1 : np.ndarray Private working memory for each worker. - compute_loss : bool - Whether or not the training loss should be computed in this batch. Returns ------- @@ -625,7 +611,7 @@ def train_batch_cbow(model, sentences, alpha, _work, _neu1, compute_loss): cdef int sent_idx, idx_start, idx_end cdef np.uint32_t *vocab_sample_ints - init_w2v_config(&c, model, alpha, compute_loss, _work, _neu1) + init_w2v_config(&c, model, alpha, _work, _neu1) if c.sample: vocab_sample_ints = np.PyArray_DATA(model.wv.expandos['sample_int']) if c.hs: @@ -678,11 +664,11 @@ def train_batch_cbow(model, sentences, alpha, _work, _neu1, compute_loss): if k > idx_end: k = idx_end if c.hs: - w2v_fast_sentence_cbow_hs(c.points[i], c.codes[i], c.codelens, c.neu1, c.syn0, c.syn1, c.size, c.indexes, c.alpha, c.work, i, j, k, c.cbow_mean, c.words_lockf, c.words_lockf_len, c.compute_loss, &c.running_training_loss) + w2v_fast_sentence_cbow_hs(c.points[i], c.codes[i], c.codelens, c.neu1, c.syn0, c.syn1, c.size, c.indexes, c.alpha, c.work, i, j, k, c.cbow_mean, c.words_lockf, c.words_lockf_len, &c.minibatch_loss) if c.negative: - c.next_random = w2v_fast_sentence_cbow_neg(c.negative, c.cum_table, c.cum_table_len, c.codelens, c.neu1, c.syn0, c.syn1neg, c.size, c.indexes, c.alpha, c.work, i, j, k, c.cbow_mean, c.next_random, c.words_lockf, c.words_lockf_len, c.compute_loss, &c.running_training_loss) + c.next_random = w2v_fast_sentence_cbow_neg(c.negative, c.cum_table, c.cum_table_len, c.codelens, c.neu1, c.syn0, c.syn1neg, c.size, c.indexes, c.alpha, c.work, i, j, k, c.cbow_mean, c.next_random, c.words_lockf, c.words_lockf_len, &c.minibatch_loss) - model.running_training_loss = c.running_training_loss + model.epoch_loss += c.minibatch_loss return effective_words diff --git a/gensim/test/test_word2vec.py b/gensim/test/test_word2vec.py index f7a73ee375..88a22a6a20 100644 --- a/gensim/test/test_word2vec.py +++ b/gensim/test/test_word2vec.py @@ -11,6 +11,7 @@ import logging import unittest +import pytest import os import bz2 import sys @@ -589,7 +590,7 @@ def testEvaluateWordPairsFromFile(self): self.assertTrue(0.1 < spearman < 1.0, "spearman %f not between 0.1 and 1.0" % spearman) self.assertTrue(0.0 <= oov < 90.0, "oov %f not between 0.0 and 90.0" % oov) - def model_sanity(self, model, train=True, with_corpus_file=False): + def model_sanity(self, model, train=True, with_corpus_file=False, ranks=None): """Even tiny models trained on LeeCorpus should pass these sanity checks""" # run extra before/after training tests if train=True if train: @@ -603,14 +604,18 @@ def model_sanity(self, model, train=True, with_corpus_file=False): else: model.train(lee_corpus_list, total_examples=model.corpus_count, epochs=model.epochs) self.assertFalse((orig0 == model.wv.vectors[1]).all()) # vector should vary after training - sims = model.wv.most_similar('war', topn=len(model.wv.index2word)) - t_rank = [word for word, score in sims].index('terrorism') + query_word = 'attacks' + expected_word = 'bombings' + sims = model.wv.most_similar(query_word, topn=len(model.wv.index2word)) + t_rank = [word for word, score in sims].index(expected_word) # in >200 calibration runs w/ calling parameters, 'terrorism' in 50-most_sim for 'war' + if ranks is not None: + ranks.append(t_rank) # tabulate trial rank if requested self.assertLess(t_rank, 50) - war_vec = model.wv['war'] - sims2 = model.wv.most_similar([war_vec], topn=51) - self.assertTrue('war' in [word for word, score in sims2]) - self.assertTrue('terrorism' in [word for word, score in sims2]) + query_vec = model.wv[query_word] + sims2 = model.wv.most_similar([query_vec], topn=51) + self.assertTrue(query_word in [word for word, score in sims2]) + self.assertTrue(expected_word in [word for word, score in sims2]) def test_sg_hs(self): """Test skipgram w/ hierarchical softmax""" @@ -632,29 +637,51 @@ def test_sg_neg_fromfile(self): model = word2vec.Word2Vec(sg=1, window=4, hs=0, negative=15, min_count=5, epochs=10, workers=2) self.model_sanity(model, with_corpus_file=True) - def test_cbow_hs(self): + @pytest.mark.skipif('BULK_TEST_REPS' not in os.environ, reason="bulk test only occasionally run locally") + def test_method_in_bulk(self): + """Not run by default testing, but can be run locally to help tune stochastic aspects of tests + to very-very-rarely fail. EG: + % BULK_TEST_REPS=200 METHOD_NAME=test_cbow_hs pytest test_word2vec.py -k "test_method_in_bulk" + Method must accept `ranks` keyword-argument, empty list into which salient internal result can be reported. + """ + failures = 0 + ranks = [] + reps = int(os.environ['BULK_TEST_REPS']) + method_name = os.environ.get('METHOD_NAME', 'test_cbow_hs') # by default test that specially-troublesome one + method_fn = getattr(self, method_name) + for i in range(reps): + try: + method_fn(ranks=ranks) + except Exception as ex: + print('%s failed: %s' % (method_name, ex)) + failures = failures + 1 + print(ranks) + print(np.mean(ranks)) + self.assertEquals(failures, 0, "too many failures") + + def test_cbow_hs(self, ranks=None): """Test CBOW w/ hierarchical softmax""" model = word2vec.Word2Vec( - sg=0, cbow_mean=1, alpha=0.05, window=8, hs=1, negative=0, - min_count=5, epochs=20, workers=2, batch_words=1000 + sg=0, cbow_mean=1, alpha=0.1, window=2, hs=1, negative=0, + min_count=5, epochs=60, workers=2, batch_words=1000 ) - self.model_sanity(model) + self.model_sanity(model, ranks=ranks) @unittest.skipIf(os.name == 'nt' and six.PY2, "CythonLineSentence is not supported on Windows + Py27") def test_cbow_hs_fromfile(self): model = word2vec.Word2Vec( - sg=0, cbow_mean=1, alpha=0.05, window=8, hs=1, negative=0, - min_count=5, epochs=20, workers=2, batch_words=1000 + sg=0, cbow_mean=1, alpha=0.1, window=2, hs=1, negative=0, + min_count=5, epochs=60, workers=2, batch_words=1000 ) self.model_sanity(model, with_corpus_file=True) - def test_cbow_neg(self): + def test_cbow_neg(self, ranks=None): """Test CBOW w/ negative sampling""" model = word2vec.Word2Vec( sg=0, cbow_mean=1, alpha=0.05, window=5, hs=0, negative=15, min_count=5, epochs=10, workers=2, sample=0 ) - self.model_sanity(model) + self.model_sanity(model, ranks=ranks) @unittest.skipIf(os.name == 'nt' and six.PY2, "CythonLineSentence is not supported on Windows + Py27") def test_cbow_neg_fromfile(self): @@ -983,11 +1010,11 @@ def test_reset_from(self): model.reset_from(other_model) self.assertEqual(model.wv.key_to_index, other_model.wv.key_to_index) - def test_compute_training_loss(self): + def test_epoch_loss(self): model = word2vec.Word2Vec(min_count=1, sg=1, negative=5, hs=1) model.build_vocab(sentences) - model.train(sentences, compute_loss=True, total_examples=model.corpus_count, epochs=model.epochs) - training_loss_val = model.get_latest_training_loss() + model.train(sentences, total_examples=model.corpus_count, epochs=model.epochs) + training_loss_val = model.epoch_loss self.assertTrue(training_loss_val > 0.0)