Skip to content
4 changes: 2 additions & 2 deletions octis/models/CTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def train_model(self, dataset, hyperparameters=None, top_words=10):
topic_prior_variance=self.hyperparameters["prior_variance"],
top_words=top_words)

self.model.fit(x_train, x_valid, verbose=False)
self.model.fit(x_train, x_valid, verbose=self.hyperparameters["verbose"], save_dir=self.hyperparameters["save_dir"])
result = self.inference(x_test)
return result

Expand All @@ -161,7 +161,7 @@ def train_model(self, dataset, hyperparameters=None, top_words=10):
topic_prior_variance=self.hyperparameters["prior_variance"],
top_words=top_words)

self.model.fit(x_train, None, verbose=False)
self.model.fit(x_train, None, verbose=self.hyperparameters["verbose"], save_dir=self.hyperparameters["save_dir"])
result = self.model.get_info()
return result

Expand Down
4 changes: 2 additions & 2 deletions octis/models/contextualized_topic_models/models/ctm.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def fit(self, train_dataset, validation_dataset=None,

train_loader = DataLoader(
self.train_data, batch_size=self.batch_size, shuffle=True,
num_workers=self.num_data_loader_workers)
num_workers=self.num_data_loader_workers, drop_last=True)

# init training variables
train_loss = 0
Expand Down Expand Up @@ -301,7 +301,7 @@ def fit(self, train_dataset, validation_dataset=None,
if self.validation_data is not None:
validation_loader = DataLoader(
self.validation_data, batch_size=self.batch_size,
shuffle=True, num_workers=self.num_data_loader_workers)
shuffle=True, num_workers=self.num_data_loader_workers, drop_last=True)
# train epoch
s = datetime.datetime.now()
val_samples_processed, val_loss = self._validation(
Expand Down
6 changes: 3 additions & 3 deletions octis/models/pytorchavitm/AVITM.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ def train_model(self, dataset, hyperparameters=None, top_words=10):
solver=self.hyperparameters['solver'], num_epochs=self.hyperparameters['num_epochs'],
reduce_on_plateau=self.hyperparameters['reduce_on_plateau'], num_samples=self.hyperparameters[
'num_samples'], topic_prior_mean=self.hyperparameters["prior_mean"],
topic_prior_variance=self.hyperparameters["prior_variance"]
topic_prior_variance=self.hyperparameters["prior_variance"], verbose=self.hyperparameters["verbose"], top_words=top_words,
)

if self.use_partitions:
self.model.fit(x_train, x_valid)
self.model.fit(x_train, x_valid, save_dir=self.hyperparameters["save_dir"])
result = self.inference(x_test)
else:
self.model.fit(x_train, None)
self.model.fit(x_train, None, save_dir=self.hyperparameters["save_dir"])
result = self.model.get_info()
return result

Expand Down
9 changes: 5 additions & 4 deletions octis/models/pytorchavitm/avitm/avitm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class AVITM_model(object):
def __init__(self, input_size, num_topics=10, model_type='prodLDA', hidden_sizes=(100, 100),
activation='softplus', dropout=0.2, learn_priors=True, batch_size=64, lr=2e-3, momentum=0.99,
solver='adam', num_epochs=100, reduce_on_plateau=False, topic_prior_mean=0.0,
topic_prior_variance=None, num_samples=10, num_data_loader_workers=0, verbose=False):
topic_prior_variance=None, num_samples=10, num_data_loader_workers=0, verbose=False, top_words=10):
"""
Initialize AVITM model.

Expand Down Expand Up @@ -68,6 +68,7 @@ def __init__(self, input_size, num_topics=10, model_type='prodLDA', hidden_sizes
# assert isinstance(topic_prior_variance, float), \
# "topic prior_variance must be type float"

self.top_words = top_words
self.input_size = input_size
self.num_topics = num_topics
self.verbose = verbose
Expand Down Expand Up @@ -240,7 +241,7 @@ def fit(self, train_dataset, validation_dataset, save_dir=None):
self.validation_data = validation_dataset
train_loader = DataLoader(
self.train_data, batch_size=self.batch_size, shuffle=True,
num_workers=self.num_data_loader_workers)
num_workers=self.num_data_loader_workers, drop_last=True)

# init training variables
train_loss = 0
Expand All @@ -267,7 +268,7 @@ def fit(self, train_dataset, validation_dataset, save_dir=None):
if self.validation_data is not None:
validation_loader = DataLoader(
self.validation_data, batch_size=self.batch_size, shuffle=True,
num_workers=self.num_data_loader_workers)
num_workers=self.num_data_loader_workers, drop_last=True)
# train epoch
s = datetime.datetime.now()
val_samples_processed, val_loss = self._validation(validation_loader)
Expand Down Expand Up @@ -347,7 +348,7 @@ def get_topics(self, k=10):

def get_info(self):
info = {}
topic_word = self.get_topics()
topic_word = self.get_topics(k=self.top_words) # or self.input_size
topic_word_dist = self.get_topic_word_mat()
# topic_document_dist = self.get_topic_document_mat()
info['topics'] = topic_word
Expand Down