diff --git a/onmt/modules/embeddings.py b/onmt/modules/embeddings.py index 72d112c18a..39d1a1a8d8 100644 --- a/onmt/modules/embeddings.py +++ b/onmt/modules/embeddings.py @@ -156,8 +156,11 @@ def __init__(self, word_vec_size, # (these have no effect if feat_vocab_sizes is empty) if feat_merge == 'sum': feat_dims = [word_vec_size] * len(feat_vocab_sizes) - elif feat_vec_size > 0: - feat_dims = [feat_vec_size] * len(feat_vocab_sizes) + elif len(feat_vec_size) != 0: + if len(feat_vocab_sizes) == 0: + feat_dims = [] + else: + feat_dims = feat_vec_size else: feat_dims = [int(vocab ** feat_vec_exponent) for vocab in feat_vocab_sizes] @@ -209,10 +212,10 @@ def _validate_args(self, feat_merge, feat_vocab_sizes, feat_vec_exponent, if feat_vec_exponent != 0.7: warnings.warn("Merging with sum, but got non-default " "feat_vec_exponent. It will be unused.") - if feat_vec_size != -1: + if len(feat_vec_size) != 0: warnings.warn("Merging with sum, but got non-default " "feat_vec_size. It will be unused.") - elif feat_vec_size > 0: + elif len(feat_vec_size) != 0: # features will use feat_vec_size if feat_vec_exponent != -1: warnings.warn("Not merging with sum and positive " diff --git a/onmt/opts.py b/onmt/opts.py index 5606daa82a..8cf8a5fba1 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -47,7 +47,8 @@ def model_opts(parser): choices=['concat', 'sum', 'mlp'], help="Merge action for incorporating features embeddings. " "Options [concat|sum|mlp].") - group.add('--feat_vec_size', '-feat_vec_size', type=int, default=-1, + group.add('--feat_vec_size', '-feat_vec_size', type=int, + default=[], nargs='*', help="If specified, feature embedding sizes " "will be set to this. Otherwise, feat_vec_exponent " "will be used.")