You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
parser.add_argument("--mlflow_server_url", type=str, default="http://localhost:8080", help="mlflow server url")
39
+
parser.add_argument("--model_type", type=str, default=model_names[0],help="The model type to train", choices=model_names)
40
+
parser.add_argument("--generator_type", type=str, default=list(generators.keys())[1],help="The generator type to train", choices=list(generators.keys()))
41
+
parser.add_argument("--input_size", type=int, default=16, help="The input size of the model")
42
+
parser.add_argument("--hidden_size", type=int, default=128, help="The hidden size of the model")
43
+
parser.add_argument("--num_layers", type=int, default=2, help="The number of layers of the model")
44
+
parser.add_argument("--number_heads", type=int, default=16, help="The number of heads of the model (transformer model only)")
45
+
parser.add_argument("--input_sequence_length", type=int, default=10,help="The length of the input sequence")
0 commit comments