Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions dpgen2/op/run_dp_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
)
from dpgen2.utils.run_command import (
run_command,
run_command_streaming,
)


Expand Down Expand Up @@ -292,7 +293,8 @@ def clean_before_quit():
train_args,
)

ret, out, err = run_command(command)
# Use streaming output for real-time monitoring
ret, out, err = run_command_streaming(command, log_file="train.log")
if ret != 0:
clean_before_quit()
logging.error(
Expand All @@ -309,10 +311,12 @@ def clean_before_quit():
)
)
raise FatalError("dp train failed")
fplog.write("#=================== train std out ===================\n")
fplog.write(out)
fplog.write("#=================== train std err ===================\n")
fplog.write(err)
# Note: output is already written to log file by run_command_streaming
# No need to write again here to avoid duplication
# fplog.write("#=================== train std out ===================\n")
# fplog.write(out)
# fplog.write("#=================== train std err ===================\n")
# fplog.write(err)

if finetune_mode == "finetune" and os.path.exists("input_v2_compat.json"):
shutil.copy2("input_v2_compat.json", train_script_name)
Expand All @@ -339,10 +343,10 @@ def clean_before_quit():
)
raise FatalError("dp freeze failed")
model_file = "frozen_model.pb"
fplog.write("#=================== freeze std out ===================\n")
fplog.write(out)
fplog.write("#=================== freeze std err ===================\n")
fplog.write(err)
fplog.write("#=================== freeze std out ===================\n")
fplog.write(out)
fplog.write("#=================== freeze std err ===================\n")
fplog.write(err)

clean_before_quit()

Expand Down
71 changes: 71 additions & 0 deletions dpgen2/utils/run_command.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os
import subprocess
import sys
import threading
from typing import (
List,
Tuple,
Expand All @@ -11,6 +14,74 @@
from dflow.utils import run_command as dflow_run_command


def run_command_streaming(
cmd: Union[str, List[str]],
shell: bool = False,
log_file=None,
) -> Tuple[int, str, str]:
"""Run command with streaming output to both terminal and log file."""
if isinstance(cmd, str):
cmd = cmd if shell else cmd.split()

Comment on lines +23 to +25
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use shlex for robust splitting and handle shell=True with list commands

cmd.split() breaks on quoted args and spaces; also, passing a list while shell=True is unsupported/undefined. Prefer shlex.split and shlex.join to cover both cases safely.

+import shlex
@@
-    if isinstance(cmd, str):
-        cmd = cmd if shell else cmd.split()
+    if isinstance(cmd, str):
+        cmd = cmd if shell else shlex.split(cmd)
+    elif shell and isinstance(cmd, list):
+        # When shell=True, pass a single string; preserve quoting.
+        cmd = shlex.join(cmd)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if isinstance(cmd, str):
cmd = cmd if shell else cmd.split()
# at the top of dpgen2/utils/run_command.py
import shlex
# … later, around lines 23–25 …
if isinstance(cmd, str):
cmd = cmd if shell else shlex.split(cmd)
elif shell and isinstance(cmd, list):
# When shell=True, pass a single string; preserve quoting.
cmd = shlex.join(cmd)
🤖 Prompt for AI Agents
In dpgen2/utils/run_command.py around lines 23 to 25, the current logic uses
cmd.split() which mishandles quoted arguments and does not address the undefined
behavior of passing a list when shell=True; replace this with shlex utilities:
if shell is False and cmd is a string, call shlex.split(cmd) to produce a safe
list; if shell is True and cmd is a list, convert it to a string with
shlex.join(cmd) (or if cmd is already a string leave it as-is); ensure types are
normalized accordingly and raise a clear TypeError if an unexpected type is
provided.

# Open log file if specified
log_fp = open(log_file, "w") if log_file else None

try:
# Start subprocess
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=subprocess.PIPE,
shell=shell,
text=True,
bufsize=1, # Line buffered
universal_newlines=True,
)

# Store output
stdout_buffer = []
stderr_buffer = []

def stream_output(pipe, buffer, is_stderr=False):
for line in iter(pipe.readline, ""):
buffer.append(line)
# Print to terminal
if is_stderr:
print(line, end="", file=sys.stderr)
else:
print(line, end="")
# Write to log file
if log_fp:
log_fp.write(line)
log_fp.flush()
pipe.close()
Comment on lines +42 to +58
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Guard concurrent writes to the same log file

stdout_thread and stderr_thread write to log_fp concurrently. File writes aren’t atomic, so lines can interleave. Use a lock to serialize writes.

         # Store output
         stdout_buffer = []
         stderr_buffer = []
+        log_lock = threading.Lock() if log_fp else None
 
         def stream_output(pipe, buffer, is_stderr=False):
             for line in iter(pipe.readline, ""):
                 buffer.append(line)
                 # Print to terminal
                 if is_stderr:
                     print(line, end="", file=sys.stderr)
                 else:
                     print(line, end="")
                 # Write to log file
                 if log_fp:
-                    log_fp.write(line)
-                    log_fp.flush()
+                    with log_lock:
+                        log_fp.write(line)
+                        log_fp.flush()
             pipe.close()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Store output
stdout_buffer = []
stderr_buffer = []
def stream_output(pipe, buffer, is_stderr=False):
for line in iter(pipe.readline, ""):
buffer.append(line)
# Print to terminal
if is_stderr:
print(line, end="", file=sys.stderr)
else:
print(line, end="")
# Write to log file
if log_fp:
log_fp.write(line)
log_fp.flush()
pipe.close()
# Store output
stdout_buffer = []
stderr_buffer = []
log_lock = threading.Lock() if log_fp else None
def stream_output(pipe, buffer, is_stderr=False):
for line in iter(pipe.readline, ""):
buffer.append(line)
# Print to terminal
if is_stderr:
print(line, end="", file=sys.stderr)
else:
print(line, end="")
# Write to log file
if log_fp:
with log_lock:
log_fp.write(line)
log_fp.flush()
pipe.close()
🤖 Prompt for AI Agents
In dpgen2/utils/run_command.py around lines 42 to 58, stdout_thread and
stderr_thread both write to the same log_fp concurrently which can interleave
output; create a single threading.Lock (e.g., log_lock) visible to both threads
and wrap any log_fp.write() and log_fp.flush() calls in
log_lock.acquire()/release() or preferably with log_lock as a context manager
(with log_lock: ...), ensuring both threads use that same lock before writing to
the file so writes are serialized and then close the pipe as before.


# Start threads for streaming
stdout_thread = threading.Thread(
target=stream_output, args=(process.stdout, stdout_buffer, False)
)
stderr_thread = threading.Thread(
target=stream_output, args=(process.stderr, stderr_buffer, True)
)

stdout_thread.start()
stderr_thread.start()

# Wait for process to complete
return_code = process.wait()

# Wait for threads to finish
stdout_thread.join()
stderr_thread.join()

return return_code, "".join(stdout_buffer), "".join(stderr_buffer)

finally:
if log_fp:
log_fp.close()


def run_command(
cmd: Union[str, List[str]],
shell: bool = False,
Expand Down
Loading