diff --git a/dpgen2/op/run_dp_train.py b/dpgen2/op/run_dp_train.py index 5a9782f4..9816d0d7 100644 --- a/dpgen2/op/run_dp_train.py +++ b/dpgen2/op/run_dp_train.py @@ -44,6 +44,7 @@ ) from dpgen2.utils.run_command import ( run_command, + run_command_streaming, ) @@ -292,7 +293,8 @@ def clean_before_quit(): train_args, ) - ret, out, err = run_command(command) + # Use streaming output for real-time monitoring + ret, out, err = run_command_streaming(command, log_file="train.log") if ret != 0: clean_before_quit() logging.error( @@ -309,10 +311,12 @@ def clean_before_quit(): ) ) raise FatalError("dp train failed") - fplog.write("#=================== train std out ===================\n") - fplog.write(out) - fplog.write("#=================== train std err ===================\n") - fplog.write(err) + # Note: output is already written to log file by run_command_streaming + # No need to write again here to avoid duplication + # fplog.write("#=================== train std out ===================\n") + # fplog.write(out) + # fplog.write("#=================== train std err ===================\n") + # fplog.write(err) if finetune_mode == "finetune" and os.path.exists("input_v2_compat.json"): shutil.copy2("input_v2_compat.json", train_script_name) @@ -339,10 +343,10 @@ def clean_before_quit(): ) raise FatalError("dp freeze failed") model_file = "frozen_model.pb" - fplog.write("#=================== freeze std out ===================\n") - fplog.write(out) - fplog.write("#=================== freeze std err ===================\n") - fplog.write(err) + fplog.write("#=================== freeze std out ===================\n") + fplog.write(out) + fplog.write("#=================== freeze std err ===================\n") + fplog.write(err) clean_before_quit() diff --git a/dpgen2/utils/run_command.py b/dpgen2/utils/run_command.py index 2d5c5764..9596920b 100644 --- a/dpgen2/utils/run_command.py +++ b/dpgen2/utils/run_command.py @@ -1,4 +1,7 @@ import os +import subprocess +import sys +import threading from typing import ( List, Tuple, @@ -11,6 +14,74 @@ from dflow.utils import run_command as dflow_run_command +def run_command_streaming( + cmd: Union[str, List[str]], + shell: bool = False, + log_file=None, +) -> Tuple[int, str, str]: + """Run command with streaming output to both terminal and log file.""" + if isinstance(cmd, str): + cmd = cmd if shell else cmd.split() + + # Open log file if specified + log_fp = open(log_file, "w") if log_file else None + + try: + # Start subprocess + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + stdin=subprocess.PIPE, + shell=shell, + text=True, + bufsize=1, # Line buffered + universal_newlines=True, + ) + + # Store output + stdout_buffer = [] + stderr_buffer = [] + + def stream_output(pipe, buffer, is_stderr=False): + for line in iter(pipe.readline, ""): + buffer.append(line) + # Print to terminal + if is_stderr: + print(line, end="", file=sys.stderr) + else: + print(line, end="") + # Write to log file + if log_fp: + log_fp.write(line) + log_fp.flush() + pipe.close() + + # Start threads for streaming + stdout_thread = threading.Thread( + target=stream_output, args=(process.stdout, stdout_buffer, False) + ) + stderr_thread = threading.Thread( + target=stream_output, args=(process.stderr, stderr_buffer, True) + ) + + stdout_thread.start() + stderr_thread.start() + + # Wait for process to complete + return_code = process.wait() + + # Wait for threads to finish + stdout_thread.join() + stderr_thread.join() + + return return_code, "".join(stdout_buffer), "".join(stderr_buffer) + + finally: + if log_fp: + log_fp.close() + + def run_command( cmd: Union[str, List[str]], shell: bool = False,