Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions engibench/problems/airfoil/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def __design_to_simulator_input(self, design: DesignType, config: dict[str, Any]
image=self.container_id,
name="machaero",
mounts=[(self.__local_base_directory, self.__docker_base_dir)],
sync_uid=True,
)

except Exception as e:
Expand Down Expand Up @@ -436,6 +437,7 @@ def simulate(self, design: DesignType, config: dict[str, Any] | None = None, mpi
image=self.container_id,
name="machaero",
mounts=[(self.__local_base_directory, self.__docker_base_dir)],
sync_uid=True,
)
except Exception as e:
raise RuntimeError(
Expand Down Expand Up @@ -513,6 +515,7 @@ def optimize(
image=self.container_id,
name="machaero",
mounts=[(self.__local_base_directory, self.__docker_base_dir)],
sync_uid=True,
)
except Exception as e:
raise RuntimeError(f"Optimization failed: {e!s}. Check logs in {self.__local_study_dir}") from e
Expand Down
56 changes: 40 additions & 16 deletions engibench/utils/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ def pull(image: str) -> None:
RUNTIME.pull(image)


def run(
def run( # noqa: PLR0913
command: list[str],
image: str,
mounts: Sequence[tuple[str, str]] = (),
env: dict[str, str] | None = None,
name: str | None = None,
*,
sync_uid: bool = False,
) -> None:
"""Run a command in a container using the selected runtime.

Expand All @@ -34,13 +36,14 @@ def run(
mounts: Pairs of host folder and destination folder inside the container.
env: Mapping of environment variable names and values to set inside the container.
name: Optional name for the container (not supported by all runtimes).
sync_uid: Use the uid of the current process as uid inside the container.
"""
if RUNTIME is None:
msg = "No container runtime found. Please ensure Docker, Podman, or Singularity is installed and running."
raise FileNotFoundError(msg)

try:
result = RUNTIME.run(command, image, mounts, env, name)
result = RUNTIME.run(command, image, mounts, env, name, sync_uid=sync_uid)
result.check_returncode()
except subprocess.CalledProcessError as e:
msg = f"Container command failed with exit code {e.returncode}:\nCommand: {' '.join(command)}\nOutput: {e.output.decode() if e.output else 'No output'}"
Expand Down Expand Up @@ -81,13 +84,15 @@ def pull(cls, image: str) -> None:
raise NotImplementedError("Must be implemented by a subclass")

@classmethod
def run(
def run( # noqa: PLR0913
cls,
command: list[str],
image: str,
mounts: Sequence[tuple[str, str]] = (),
env: dict[str, str] | None = None,
name: str | None = None,
*,
sync_uid: bool = False,
) -> subprocess.CompletedProcess:
"""Run a command in a container.

Expand All @@ -97,6 +102,7 @@ def run(
mounts: Pairs of host folder and destination folder inside the container.
env: Mapping of environment variable names and values to set inside the container.
name: Optional name for the container (not supported by all runtimes).
sync_uid: Use the uid of the current process as uid inside the container.
"""
raise NotImplementedError("Must be implemented by a subclass")

Expand Down Expand Up @@ -155,13 +161,15 @@ def pull(cls, image: str) -> None:
subprocess.run([cls.executable, "pull", image], check=True)

@classmethod
def run(
def run( # noqa: PLR0913
cls,
command: list[str],
image: str,
mounts: Sequence[tuple[str, str]] = (),
env: dict[str, str] | None = None,
name: str | None = None,
*,
sync_uid: bool = False,
) -> subprocess.CompletedProcess:
"""Run a command in a container.

Expand All @@ -171,25 +179,30 @@ def run(
mounts: Pairs of host folder and destination folder inside the container.
env: Mapping of environment variable names and values to set inside the container.
name: Optional name for the container (not supported by all runtimes).
sync_uid: Use the uid of the current process as uid inside the container.
"""
name_args = [] if name is None else ["--name", name]
mount_args = (["--mount", f"type=bind,src={src},target={target}"] for src, target in mounts)
env_args = (["--env", f"{var}={value}"] for var, value in (env or {}).items())
user_args = cls._user_args() if sync_uid else ()

return subprocess.run(
[
cls.executable,
"run",
"--rm",
*name_args,
*(arg for args in mount_args for arg in args),
*(arg for args in env_args for arg in args),
*_mount_args(mounts),
*_env_args(env or {}),
*user_args,
image,
*command,
],
check=False,
)

@classmethod
def _user_args(cls) -> tuple[str, ...]:
return ("--user", str(os.getuid()))


class Podman(Docker):
"""Podman 🦭 runtime."""
Expand Down Expand Up @@ -217,6 +230,10 @@ def is_available(cls) -> bool:
except FileNotFoundError:
return False

@classmethod
def _user_args(cls) -> tuple[str, ...]:
return ("--userns=keep-id", "--user", str(os.getuid()))


DOCKER_PREFIX = "docker://"

Expand Down Expand Up @@ -275,13 +292,15 @@ def pull(cls, image: str) -> None:
subprocess.run([cls.executable, "pull", docker_uri], check=True)

@classmethod
def run(
def run( # noqa: PLR0913
cls,
command: list[str],
image: str,
mounts: Sequence[tuple[str, str]] = (),
env: dict[str, str] | None = None,
_name: str | None = None,
name: str | None = None, # noqa: ARG003
*,
sync_uid: bool = False, # noqa: ARG003
) -> subprocess.CompletedProcess:
"""Run a command in a container.

Expand All @@ -291,31 +310,36 @@ def run(
mounts: Pairs of host folder and destination folder inside the container.
env: Mapping of environment variable names and values to set inside the container.
name: Optional name for the container (not supported by all runtimes).
sync_uid: Use the uid of the current process as uid inside the container.
"""
# Set Apptainer environment variables
cls._set_apptainer_env()

# Get sif filename
sif_image = cls.sif_filename(image)

# Reconstruct mount and env args
mount_args = (["--mount", f"type=bind,src={src},target={target}"] for src, target in mounts)
env_args = (["--env", f"{var}={value}"] for var, value in (env or {}).items())

return subprocess.run(
[
cls.executable,
"run",
"--compat",
*(arg for args in mount_args for arg in args),
*(arg for args in env_args for arg in args),
*_mount_args(mounts),
*_env_args(env or {}),
sif_image,
*command,
],
check=False,
)


def _mount_args(mounts: Sequence[tuple[str, str]]) -> list[str]:
return [arg for args in (["--mount", f"type=bind,src={src},target={target}"] for src, target in mounts) for arg in args]


def _env_args(env: dict[str, str]) -> list[str]:
return [arg for args in (["--env", f"{var}={value}"] for var, value in (env or {}).items()) for arg in args]


RUNTIMES = [
rt
for rt in globals().values()
Expand Down
Loading