Skip to content
This repository was archived by the owner on Apr 28, 2021. It is now read-only.

Commit 9700b82

Browse files
refactor: add support for the augmentation factor in the train route
1 parent 5e5e327 commit 9700b82

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

rasa/server.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,11 @@ async def train(request: Request) -> HTTPResponse:
780780
for key in rjs["nlu"].keys():
781781
nlu_path = os.path.join(nlu_dir, "{}.md".format(key))
782782
rasa.utils.io.write_text_file(rjs["nlu"][key]["data"], nlu_path)
783+
784+
if "augmentation_factor" in rjs:
785+
augmentation_factor = rjs["augmentation_factor"]
786+
else:
787+
augmentation_factor = os.environ.get("AUGMENTATION_FACTOR", 50)
783788

784789
# << bf
785790

@@ -800,6 +805,7 @@ async def train(request: Request) -> HTTPResponse:
800805
model_output_directory = DEFAULT_MODELS_PATH
801806
else:
802807
model_output_directory = tempfile.gettempdir()
808+
803809

804810
try:
805811
with app.active_training_processes.get_lock():
@@ -815,7 +821,7 @@ async def train(request: Request) -> HTTPResponse:
815821
persist_nlu_training_data=True, # bf
816822
additional_arguments={
817823
"augmentation_factor": int(
818-
os.environ.get("AUGMENTATION_FACTOR", 50)
824+
augmentation_factor
819825
),
820826
}, # bf
821827
)

0 commit comments

Comments
 (0)