diff --git a/src/main.py b/src/main.py index 82ac735..876cd57 100644 --- a/src/main.py +++ b/src/main.py @@ -39,7 +39,7 @@ def parse_args(): parser.add_argument('--window-size', type=int, default=10, help='Context size for optimization. Default is 10.') - parser.add_argument('--iter', default=1, type=int, + parser.add_argument('--epochs', default=1, type=int, help='Number of epochs in SGD') parser.add_argument('--workers', type=int, default=8, @@ -84,7 +84,7 @@ def learn_embeddings(walks): Learn embeddings by optimizing the Skipgram objective using SGD. ''' walks = [map(str, walk) for walk in walks] - model = Word2Vec(walks, size=args.dimensions, window=args.window_size, min_count=0, sg=1, workers=args.workers, iter=args.iter) + model = Word2Vec(walks, vector_size=args.dimensions, window=args.window_size, min_count=0, sg=1, workers=args.workers, epochs=args.epochs) model.save_word2vec_format(args.output) return