diff --git a/.gitignore b/.gitignore index 9f1eeba..6443ff2 100644 --- a/.gitignore +++ b/.gitignore @@ -134,3 +134,6 @@ dmypy.json .DS_store examples/firedrake/getting_started/output/ + +# VSCode +.vscode/ \ No newline at end of file diff --git a/examples/jaxfluids/environment_config.yaml b/examples/jaxfluids/environment_config.yaml new file mode 100644 index 0000000..8d3843e --- /dev/null +++ b/examples/jaxfluids/environment_config.yaml @@ -0,0 +1,9 @@ +jaxfluids: + resolution: "coarse" + secondary_pressure_ratio: 0.9 + is_pressure_probes: true + target_fn: "step" + steps_per_action: 100 + is_scale_observations: true + ngpus: 1 + render_mode: "SAVE" \ No newline at end of file diff --git a/examples/jaxfluids/test_jaxfluids_env.py b/examples/jaxfluids/test_jaxfluids_env.py new file mode 100644 index 0000000..46dd32c --- /dev/null +++ b/examples/jaxfluids/test_jaxfluids_env.py @@ -0,0 +1,37 @@ +import os + +from hydrogym.jaxfluids import Nozzle2D + + +def main(): + env_config = { + "environment_name": "Nozzle2D_coarse", + "configuration_file": os.path.abspath("environment_config.yaml") + } + + env = Nozzle2D(env_config=env_config) + + observation, info = env.reset(seed=0) + env.render() + + for i in range(1000): + + # Random action + # action = env.action_space.sample() + + # Fixed action + action = [0.0, 0.5] + + observation, reward, terminated, truncated, info = env.step(action) + + if env.env_step % 10 == 0: + env.render() + + if terminated or truncated: + observation, info = env.reset() + + env.close() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/hydrogym/jaxfluids/__init__.py b/hydrogym/jaxfluids/__init__.py new file mode 100644 index 0000000..e34f630 --- /dev/null +++ b/hydrogym/jaxfluids/__init__.py @@ -0,0 +1 @@ +from .envs.nozzle import Nozzle2D, Nozzle3D \ No newline at end of file diff --git a/hydrogym/jaxfluids/env_core.py b/hydrogym/jaxfluids/env_core.py new file mode 100644 index 0000000..5012efc --- /dev/null +++ b/hydrogym/jaxfluids/env_core.py @@ -0,0 +1,185 @@ +import glob +import os +from pathlib import Path +from typing import Dict, Optional, Tuple, Union + +from omegaconf import OmegaConf + +from hydrogym.data_manager import HFDataManager + +from jaxfluids_rl.jxf_env import JAXFluidsEnv, RenderMode + + +class ConfigError(Exception): + """Exception raised for configuration-related errors.""" + + pass + + +class JAXFluidsFlowEnv(JAXFluidsEnv): + """ + Base JAXFluidsFlowEnv with Hugging Face Hub integration for configuration management. + + :param JAXFluidsEnv: _description_ + :type JAXFluidsEnv: _type_ + """ + + def _init_from_hf(self, env_config: dict) -> None: + + # Initialize HF data manager + self.hf_repo_id = env_config.get("hf_repo_id", "dynamicslab/HydroGym-environments") + self.local_fallback_dir = env_config.get("local_fallback_dir", None) + self.use_clean_cache = env_config.get("use_clean_cache", True) + + self.data_manager = HFDataManager( + repo_id=self.hf_repo_id, + local_fallback_dir=self.local_fallback_dir, + use_clean_cache=self.use_clean_cache, + fallback_profile="JAXFLUIDS" + ) + + # Environment identification + self.environment_name = env_config.get("environment_name") + + if not self.environment_name: + raise ConfigError("'environment_name' must be specified in env_config") + + # Download/get environment configuration + self.env_data_path = self._setup_environment_data() + + # Resolve and load configuration file + self.configuration_file = self._resolve_configuration_file(env_config.get("configuration_file")) + + if not self.configuration_file: + raise ConfigError( + f"No configuration file found for environment '{self.environment_name}'. " + f"Expected config.yaml in: {self.env_data_path}" + ) + + # Load configuration from HF + self.conf = OmegaConf.load(self.configuration_file) + + + def _setup_environment_data(self) -> str: + """ + Download and setup environment data from HF Hub. + + First checks ~/.cache/jaxfluidsgym/ for local data, otherwise falls back to data_manager. + + Returns: + Path to the local environment data directory. + + Raises: + ConfigError: If environment data cannot be retrieved. + """ + # Check cache directory first + cache_dir = Path.home() / ".cache" / "jaxfluidsgym" / self.environment_name + if cache_dir.exists() and cache_dir.is_dir(): + print(f"Using cached environment data from: {cache_dir}") + return str(cache_dir) + + # Fall back to data_manager if cache doesn't exist + try: + env_path = self.data_manager.get_environment_path(self.environment_name) + print(f"Using environment data from: {env_path}") + return env_path + except Exception as e: + raise ConfigError(f"Failed to setup environment data for {self.environment_name}: {e}") + + + def _resolve_configuration_file(self, config_file_input: Optional[str]) -> Optional[str]: + """ + Resolve configuration file path from various input formats. + + Args: + config_file_input: Can be: + - None: Auto-detect in HF environment + - Absolute path: Use directly + - Relative path starting with . or /: Use as-is + - Just filename: Look in HF environment directory + + Returns: + Absolute path to configuration file, or None if not found. + + Raises: + ConfigError: If specified configuration file is not found. + """ + # Case 1: No config file specified - try to find one + if config_file_input is None: + print("No config file specified, searching in environment directory...") + return self._find_configuration_file() + + # Case 2: Absolute path provided + if os.path.isabs(config_file_input): + if os.path.exists(config_file_input): + print(f"Using absolute path config file: {config_file_input}") + return config_file_input + else: + raise ConfigError(f"Configuration file not found: {config_file_input}") + + # Case 3: Relative path from current directory (starts with ./ or ../) + if config_file_input.startswith("./") or config_file_input.startswith("../"): + abs_path = os.path.abspath(config_file_input) + if os.path.exists(abs_path): + print(f"Using config file from current directory: {abs_path}") + return abs_path + else: + raise ConfigError(f"Configuration file not found: {abs_path}") + + # Case 4: Just a filename - look in multiple places + # First check current directory + if os.path.exists(config_file_input): + abs_path = os.path.abspath(config_file_input) + print(f"Using config file from current directory: {abs_path}") + return abs_path + + # Then check HF environment directory + env_config_path = os.path.join(self.env_data_path, config_file_input) + if os.path.exists(env_config_path): + print(f"Using config file from environment: {env_config_path}") + return env_config_path + + raise ConfigError( + f"Configuration file '{config_file_input}' not found in:\n" + f" - Current directory: {os.getcwd()}\n" + f" - Environment directory: {self.env_data_path}" + ) + + def _find_configuration_file(self) -> Optional[str]: + """ + Auto-detect configuration file in the environment data directory. + + Returns: + Path to configuration file, or None if not found. + """ + # Look for specific configuration file names (most specific first) + config_names = [ + "config.yaml", + "environment_config.yaml", + "env_config.yaml", + "environment.yaml", + f"{self.environment_name}.yaml", + ] + + # Check exact names first + for name in config_names: + file_path = os.path.join(self.env_data_path, name) + if os.path.exists(file_path): + print(f"Auto-detected configuration file: {name}") + return file_path + + # Then try patterns (but be specific - avoid catching property files) + config_patterns = ["config_*.yaml", "config_*.yml"] + + for pattern in config_patterns: + matches = glob.glob(os.path.join(self.env_data_path, pattern)) + if matches: + print(f"Auto-detected configuration file: {os.path.basename(matches[0])}") + return matches[0] + + # Not found + print(f"WARNING: No configuration file auto-detected in {self.env_data_path}") + if os.path.exists(self.env_data_path): + print(f"Available files: {os.listdir(self.env_data_path)}") + + return None \ No newline at end of file diff --git a/hydrogym/jaxfluids/envs/__init__.py b/hydrogym/jaxfluids/envs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hydrogym/jaxfluids/envs/nozzle.py b/hydrogym/jaxfluids/envs/nozzle.py new file mode 100644 index 0000000..9633a62 --- /dev/null +++ b/hydrogym/jaxfluids/envs/nozzle.py @@ -0,0 +1,656 @@ +from abc import abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Any, ClassVar, Callable, NamedTuple + +import jax +from jax import Array +import jax.numpy as jnp +import matplotlib.pyplot as plt +from mpl_toolkits.axes_grid1 import make_axes_locatable +import numpy as np +import gymnasium as gym + +from jaxfluids.data_types import JaxFluidsBuffers +from jaxfluids.data_types.ml_buffers import ( + CallablesSetup, + ParametersSetup, + LevelSetSetup, + InterfaceFluxCallablesSetup, + InterfaceFluxParametersSetup, +) +from jaxfluids.domain.helper_functions import ( + reassemble_buffer_np, + reassemble_cell_centers, + reassemble_cell_sizes +) +from jaxfluids_rl.jxf_env import RenderMode + + +from hydrogym.jaxfluids.env_core import JAXFluidsFlowEnv +from hydrogym.jaxfluids.utils.nozzle import ( + InjectorGeometry, + ObsData, + PressureRatios, + TVCSpec, + TargetThrustAngleFn, + build_tvc_env_options, + build_tvc_runtime_setup, + compute_thrust, + initialize_injector_flux_fn, + plot_flowfield_3d, +) + + +Array = jax.Array + + +class NozzleBase(JAXFluidsFlowEnv): + + SPEC: ClassVar[TVCSpec] + TARGET_FNS: ClassVar[dict[str, TargetThrustAngleFn]] = {} + + def __init__(self, env_config: dict) -> None: + + self._init_from_hf(env_config) + + env_options = build_tvc_env_options( + env_config=self.conf.jaxfluids, + spec=self.SPEC, + target_fns=self.TARGET_FNS, + cls_name=self.__class__.__name__, + ) + + self.num_actuators = env_options.num_actuators + self.secondary_pressure_ratio = env_options.secondary_pressure_ratio + self.resolution = env_options.resolution + self.target_fn = env_options.target_fn + self.is_pressure_probes = env_options.is_pressure_probes + self.is_scale_observations = env_options.is_scale_observations + + runtime_setup = build_tvc_runtime_setup( + base_path=Path(self.env_data_path), + dim=self.SPEC.dim, + resolution=self.resolution, + ngpus=env_options.ngpus, + ) + + self.env_name = runtime_setup.env_name + self.env_dir = runtime_setup.env_dir + self.restart_file_path = runtime_setup.restart_file_path + self.injector_geometry = InjectorGeometry( + X=self.SPEC.injector_x, + IW=self.SPEC.injector_width, + N=self.num_actuators, + ) + self.pressure_ratios = PressureRatios( + NPR=self.SPEC.nozzle_pressure_ratio, + SPR=self.secondary_pressure_ratio, + ) + + super().__init__( + self.conf.jaxfluids, + runtime_setup.case_setup_dict, + runtime_setup.numerical_setup_dict, + ) + + self.default_action_reset = np.zeros(self.num_actuators) + + if self.is_pressure_probes: + self.probe_locations = self._compute_probe_locations() + self.num_probes = self.probe_locations.shape[0] + else: + self.probe_locations = None + self.num_probes = 0 + + self.action_callable_setup = self._build_action_callable_setup() + + self._set_spaces( + action_space=self._build_action_space(), + observation_space=self._build_observation_space(), + ) + + def _is_terminated(self, action: np.ndarray, jxf_buffers: JaxFluidsBuffers, info: dict) -> bool: + physical_simulation_time = jxf_buffers.time_control_variables.physical_simulation_time + return physical_simulation_time >= self.SPEC.t_end + + def _is_truncated(self, jxf_buffers: JaxFluidsBuffers, info: dict) -> bool: + return False + + @abstractmethod + def _get_reward(self, action: np.ndarray) -> float: + pass + + def _build_action_callable_setup(self) -> CallablesSetup: + interface_flux_fn = initialize_injector_flux_fn( + injector_geometry=self.injector_geometry, + pressure_ratios=self.pressure_ratios, + p_infty=self.SPEC.ambient_pressure, + T_infty=self.SPEC.ambient_temperature, + specific_heat_ratio=self.SPEC.specific_heat_ratio, + specific_gas_constant=self.SPEC.specific_gas_constant, + sim_manager=self.sim_manager, + ) + levelset_setup = LevelSetSetup( + fluid_solid=InterfaceFluxCallablesSetup(interface_flux_fn) + ) + return CallablesSetup(levelset=levelset_setup) + + def _build_action_space(self) -> gym.spaces.Box: + return gym.spaces.Box( + low=0.0, + high=1.0, + shape=(self.num_actuators,), + ) + + def _get_observation_shape(self) -> tuple[int, ...]: + num_angles = self.SPEC.dim - 1 + num_obs = 2 * num_angles + self.num_probes + return (num_obs,) + + def _build_observation_space(self) -> gym.spaces.Box: + num_angles = self.SPEC.dim - 1 + + if self.is_scale_observations: + low = np.array([-1.0] * (2 * num_angles) + [0.0] * self.num_probes, dtype=np.float32) + high = np.array([1.0] * (2 * num_angles) + [1.0] * self.num_probes, dtype=np.float32) + else: + low = np.array([-np.pi] * (2 * num_angles) + [0.0] * self.num_probes, dtype=np.float32) + high = np.array([np.pi] * (2 * num_angles) + [np.inf] * self.num_probes, dtype=np.float32) + + return gym.spaces.Box( + low=low, + high=high, + shape=self._get_observation_shape(), + ) + + def _convert_action_for_jxf(self, action: np.ndarray) -> ParametersSetup: + levelset_setup = LevelSetSetup(fluid_solid=InterfaceFluxParametersSetup(jnp.array(action))) + return ParametersSetup(levelset=levelset_setup) + + def _get_obs(self) -> np.ndarray: + + jxf_buffers, _ = self._require_state() + obs_data = self.compute_obs(jxf_buffers) + + self.thrust_angle = obs_data.thrust_angle + self.target_angle = obs_data.target_angle + self.pressure_probes = obs_data.pressure_probes + + thrust_angle = jnp.atleast_1d(obs_data.thrust_angle) + target_angle = jnp.atleast_1d(obs_data.target_angle) + if self.is_scale_observations: + thrust_angle /= jnp.pi + target_angle /= jnp.pi + + obs = [thrust_angle, target_angle] + + if self.is_pressure_probes: + pressure_probes = jnp.atleast_1d(obs_data.pressure_probes) + if self.is_scale_observations: + pressure_probes /= self.SPEC.p0 + obs.append(pressure_probes) + + obs = jnp.concatenate(obs) + + if obs.shape != self.observation_space.shape: + raise ValueError(f"Observation shape mismatch: got {obs.shape}, expected {self.observation_space.shape}") + + return np.asarray(obs) + + def _get_info(self) -> dict[str, Any]: + return { + "thrust_angle": np.array(self.thrust_angle), + "target_angle": np.array(self.target_angle), + } + + def _after_step( + self, + action: np.ndarray, + observation: np.ndarray, + reward: float, + terminated: bool, + truncated: bool, + info: dict, + jxf_buffers: JaxFluidsBuffers + ) -> None: + + t = float(jxf_buffers.time_control_variables.physical_simulation_time) + self._append_history( + time=t, + thrust_angle=self.thrust_angle, + pressure_probes=self.pressure_probes, + action=action, + ) + + def compute_obs(self, jxf_buffers: JaxFluidsBuffers) -> ObsData: + if self.sim_manager.domain_information.is_parallel: + return jax.pmap( + self._compute_obs, + axis_name="i", + in_axes=(JaxFluidsBuffers(0, None, None, None),), + out_axes=None + )(jxf_buffers) + + else: + return jax.jit(self._compute_obs)(jxf_buffers) + + def _compute_obs(self, jxf_buffers: JaxFluidsBuffers) -> ObsData: + + current_angle = self.compute_thrust_angle(jxf_buffers) + + sim_time = jxf_buffers.time_control_variables.physical_simulation_time + target_angle = jnp.asarray(self.target_fn(sim_time)) + if self.SPEC.dim == 2 and target_angle.ndim != 0: + raise ValueError(f"2D target_angle must be scalar, got {target_angle.shape}") + if self.SPEC.dim == 3 and target_angle.shape != (2,): + raise ValueError(f"3D target_angle must have shape (2,), got {target_angle.shape}") + + if self.is_pressure_probes: + pressure_probes = self.compute_pressure_probes(jxf_buffers) + else: + pressure_probes = None + + return ObsData(current_angle, target_angle, pressure_probes) + + def compute_thrust_angle(self, jxf_buffers: JaxFluidsBuffers) -> Array: + domain_information = self.sim_manager.domain_information + nhx, nhy, nhz = domain_information.domain_slices_conservatives + nhx_, nhy_, nhz_ = domain_information.domain_slices_geometry + cell_centers = domain_information.get_device_cell_centers() + cell_sizes = domain_information.get_device_cell_sizes() + is_parallel = domain_information.is_parallel + + simulation_buffers = jxf_buffers.simulation_buffers + + primitives = simulation_buffers.material_fields.primitives[..., nhx, nhy, nhz] + apertures_x = simulation_buffers.levelset_fields.apertures[0][..., nhx_, nhy_, nhz_] + thrust, _, _ = compute_thrust( + primitives, + self.SPEC.ambient_pressure, + apertures_x, + cell_centers, + cell_sizes, + ) + + if is_parallel: + thrust = jax.lax.psum(thrust, axis_name="i") + + if self.SPEC.dim == 2: + current_angle = jnp.atan2(thrust[1], thrust[0]) + else: + current_angle = jnp.stack([ + jnp.atan2(thrust[1], thrust[0]), # Pitch + jnp.atan2(thrust[2], thrust[0]), # Yaw + ]) + + return current_angle + + def compute_pressure_probes(self, jxf_buffers: JaxFluidsBuffers) -> Array: + x_p = self.probe_locations[:,0] + y_p = self.probe_locations[:,1] + z_p = self.probe_locations[:,2] + + domain_information = self.sim_manager.domain_information + + cell_centers = domain_information.get_device_cell_centers() + x, y, z = [xi.flatten() for xi in cell_centers] + + x_id = jnp.searchsorted(x, x_p, side="left", method="scan_unrolled") + y_id = jnp.searchsorted(y, y_p, side="left", method="scan_unrolled") + + if self.SPEC.dim == 3: + z_id = jnp.searchsorted(z, z_p, side="left", method="scan_unrolled") + else: + z_id = 0 + + nhx, nhy, nhz = domain_information.domain_slices_conservatives + pressure = jxf_buffers.simulation_buffers.material_fields.primitives[4, nhx, nhy, nhz] + pressure_probes = pressure[x_id, y_id, z_id] + + if domain_information.is_parallel: + device_domain_size = domain_information.get_device_domain_size() + + mask = 1 + for i in range(domain_information.dim): + xi = self.probe_locations[:,i] + mask *= (device_domain_size[i][0] <= xi) & (xi < device_domain_size[i][1]) + + pressure_probes = jax.lax.psum(mask * pressure_probes, axis_name="i") + + return pressure_probes + + def _get_fields_for_plotting(self, jxf_buffers: JaxFluidsBuffers) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + domain_information = self.sim_manager.domain_information + nhx, nhy, nhz = domain_information.domain_slices_conservatives + nhx_, nhy_, nhz_ = domain_information.domain_slices_geometry + + levelset_fields = jxf_buffers.simulation_buffers.levelset_fields + primitives = jxf_buffers.simulation_buffers.material_fields.primitives + + fields = [ + primitives[..., nhx, nhy, nhz], + levelset_fields.levelset[..., nhx, nhy, nhz], + levelset_fields.volume_fraction[..., nhx_, nhy_, nhz_], + ] + + if domain_information.is_parallel: + fields = [ + reassemble_buffer_np(field, domain_information.split_factors) + for field in fields + ] + + fields = [field.squeeze() for field in fields] + return tuple(fields) + + def _get_meshgrid_for_plotting(self) -> tuple[np.ndarray, np.ndarray]: + domain_information = self.sim_manager.domain_information + cell_centers = domain_information.get_global_cell_centers() + if domain_information.is_parallel: + cell_centers = reassemble_cell_centers(cell_centers, domain_information.split_factors) + + x, y, _ = [xi.flatten() for xi in cell_centers] + return np.meshgrid(x, y, indexing="ij") + + @abstractmethod + def _compute_probe_locations(self) -> np.ndarray: + pass + + def render(self) -> None: + if self.render_mode is None: + return + + self._plot_flow_field() + self._plot_observations() + + @abstractmethod + def _plot_flow_field(self) -> None: + pass + + @abstractmethod + def _plot_observations(self) -> None: + pass + + +class Nozzle2D(NozzleBase): + + SPEC = TVCSpec( + dim=2, + fixed_num_actuators=2, + ) + + TARGET_FNS = { + "sine": lambda t: (t > 5e-4) * (10.0 / 180.0 * jnp.pi) * jnp.sin(2 * jnp.pi * (t - 5e-4) / 4e-3), + "step": lambda t: (t > 5e-4) * (5.0 / 180.0 * jnp.pi) + } + + def _compute_probe_locations(self) -> np.ndarray: + G = self.SPEC.nozzle_geometry.G + H = self.SPEC.nozzle_geometry.H + + probe_locations = [] + for i in (1, 2): + x = G[0] + i / 3 * (H[0] - G[0]) + y = G[1] + i / 3 * (H[1] - G[1]) + + x_probes = np.array([x, x]) + y_probes = np.array([y, -y]) + z_probes = np.zeros_like(x_probes) + probe_locations.append( + np.stack([x_probes, y_probes, z_probes], axis=1) + ) + + return np.concatenate(probe_locations, axis=0) + + def _get_reward(self, action: np.ndarray) -> float: + error = np.abs(self.target_angle - self.thrust_angle) + return float(-error) + + def _plot_flow_field(self) -> None: + jxf_buffers, _ = self._require_state() + primitives, _, volume_fraction = self._get_fields_for_plotting(jxf_buffers) + X, Y = self._get_meshgrid_for_plotting() + + D_throat = self.SPEC.nozzle_geometry.D_throat + X = X / D_throat + Y = Y / D_throat + + physical_simulation_time = jxf_buffers.time_control_variables.physical_simulation_time + + rho = primitives[0] + p = primitives[-1] + v = primitives[1:3] + c = np.sqrt(self.SPEC.specific_heat_ratio * p / rho) + M = np.linalg.norm(v, axis=0, ord=2) / c + + mask = volume_fraction == 0.0 + p = np.ma.masked_where(mask, p) + M = np.ma.masked_where(mask, M) + + fig, ax = plt.subplots(ncols=2, sharex=True, sharey=True, figsize=(10, 4)) + fig.suptitle(f"Env Step: {self.env_step}, Time: {physical_simulation_time * 1e3:.3f} ms") + + quants = (M, p / self.SPEC.p0) + vmins = (0.0, 0.0); vmaxs = (3.0, 1.0) + for axi, quant, vmin, vmax in zip(ax, quants, vmins, vmaxs): + pci = axi.pcolormesh(X, Y, quant, cmap="Spectral_r", vmin=vmin, vmax=vmax, shading="auto") + axi.contour(X, Y, M, levels=[1.0], linewidths=0.5, colors="k", linestyles="-") + + divider = make_axes_locatable(axi) + cax = divider.append_axes('right', size='5%', pad=0.05) + fig.colorbar(pci, cax=cax, orientation='vertical') + + titles = ("Mach number", "Pressure p / p_0") + for axi, title in zip(ax, titles): + axi.set_aspect("equal") + axi.set_xlabel(r"x / D") + axi.set_ylabel(r"y / D") + axi.set_title(title) + axi.set_xlim(0, 8) + axi.set_ylim(-2, 2) + + if self.is_pressure_probes: + for axi in ax: + axi.scatter(self.probe_locations[:,0] / D_throat, self.probe_locations[:,1] / D_throat, s=2, c="black") + + if self.render_mode is RenderMode.SHOW: + plt.show() + elif self.render_mode is RenderMode.SAVE: + self._save_render_figure(fig, "flowfield") + else: + raise ValueError(f"RenderMode {self.render_mode} is not valid.") + plt.close(fig) + + + def _plot_observations(self) -> None: + jxf_buffers, _ = self._require_state() + physical_simulation_time = jxf_buffers.time_control_variables.physical_simulation_time + + fig, ax = plt.subplots(ncols=2, nrows=2, figsize=(10, 10)) + ax = ax.flatten() + + fig.suptitle(f"Env Step: {self.env_step}, Time: {physical_simulation_time * 1e3:.3f} ms") + + times_full = np.linspace(0.0, self.SPEC.t_end, 100) + target_full = self.target_fn(times_full) + ax[0].plot(times_full * 1e3, np.rad2deg(target_full), "b--", label="target") + + t = np.array(self.history["time"]) + thrust_angle = np.array(self.history["thrust_angle"]) + action = np.array(self.history["action"]) + + if len(t) > 0: + ax[0].plot(t * 1e3, np.rad2deg(thrust_angle), "k", label="current") + for actuator_i in range(self.num_actuators): + ax[1].plot(t * 1e3, action[:,actuator_i], label=f"Inj. {actuator_i:02d}") + + ax[1].set_ylim(-0.05, 1.05) + + if self.is_pressure_probes and len(t) > 0: + pressure_probes = np.array(self.history["pressure_probes"]) + pressure_probes /= self.SPEC.p0 + for probe_i in range(self.num_probes // 2): + ax[2].plot(t * 1e3, pressure_probes[:, probe_i], label=f"Probe 0{probe_i:d}") + ax[3].plot(t * 1e3, pressure_probes[:, probe_i + self.num_probes // 2], label=f"Probe 1{probe_i:d}") + ax[2].set_ylim(0.0, 0.5) + ax[3].set_ylim(0.0, 0.5) + + titles = ( + "Thrust angle [deg]", + "Actuators", + "Pressure probes downstream loc 0", + "Pressure probes downstream loc 1", + ) + for axi, title in zip(ax, titles): + axi.set_box_aspect(1.0) + axi.set_xlabel("t [ms]") + axi.set_title(title) + axi.legend() + + if self.render_mode is RenderMode.SHOW: + plt.show() + elif self.render_mode is RenderMode.SAVE: + self._save_render_figure(fig, "observations") + else: + raise ValueError(f"RenderMode {self.render_mode} is not valid.") + plt.close(fig) + + +class Nozzle3D(NozzleBase): + + SPEC: ClassVar[TVCSpec] = TVCSpec( + dim=3, + min_num_actuators=4, + max_num_actuators=12, + ) + + TARGET_FNS = { + "sine": lambda t: jnp.array([ + (t > 1e-3) * (10.0 / 180.0 * jnp.pi) * jnp.sin(2 * jnp.pi * (t - 1e-3) / 4e-3), + jnp.zeros_like(t), + ]), + "step": lambda t: jnp.array([ + (t > 1e-3) * (5.0 / 180.0 * jnp.pi), + jnp.zeros_like(t), + ]), + } + + def _compute_probe_locations(self) -> np.ndarray: + G = self.SPEC.nozzle_geometry.G + H = self.SPEC.nozzle_geometry.H + + num_probes_per_diameter = 6 + theta = np.linspace(0, 2 * np.pi, num_probes_per_diameter, endpoint=False) + + probe_locations = [] + for i in (1, 2): + x = G[0] + i / 3 * (H[0] - G[0]) + R = G[1] + i / 3 * (H[1] - G[1]) + x_probes = np.full_like(theta, x) + y_probes = R * np.cos(theta) + z_probes = -R * np.sin(theta) + probe_locations.append( + np.stack([x_probes, y_probes, z_probes], axis=1) + ) + return np.concatenate(probe_locations, axis=0) + + def _get_reward(self, action: np.ndarray) -> float: + error = jnp.sqrt(jnp.sum((self.target_angle - self.thrust_angle)**2)) + return float(-error) + + def _plot_flow_field(self) -> None: + jxf_buffers, _ = self._require_state() + primitives, levelset, _ = self._get_fields_for_plotting(jxf_buffers) + + physical_simulation_time = jxf_buffers.time_control_variables.physical_simulation_time + + domain_information = self.sim_manager.domain_information + cell_centers = domain_information.get_global_cell_centers() + cell_sizes = domain_information.get_global_cell_sizes() + if domain_information.is_parallel: + cell_centers = reassemble_cell_centers(cell_centers, domain_information.split_factors) + cell_sizes = reassemble_cell_sizes(cell_sizes, domain_information.split_factors) + + cell_centers = tuple(x.squeeze() for x in cell_centers) + cell_sizes = tuple(x.squeeze() for x in cell_sizes) + + self.render_dir.mkdir(parents=True, exist_ok=True) + + plotter = plot_flowfield_3d( + primitives, + levelset, + cell_centers, + cell_sizes, + self.SPEC.p0, + self.injector_geometry, + self.SPEC.specific_heat_ratio, + ) + + filename = self.render_dir / f"flowfield_{self.env_step:04d}.png" + if self.render_mode is RenderMode.SHOW: + plotter.show(auto_close=False) + elif self.render_mode is RenderMode.SAVE: + plotter.show(screenshot=str(filename), auto_close=False) + else: + raise ValueError(f"RenderMode {self.render_mode} is not valid.") + plotter.clear() + + def _plot_observations(self) -> None: + jxf_buffers, _ = self._require_state() + physical_simulation_time = jxf_buffers.time_control_variables.physical_simulation_time + + fig, ax = plt.subplots(ncols=2, nrows=2, figsize=(10, 10)) + ax = ax.flatten() + + fig.suptitle(f"Env Step: {self.env_step}, Time: {physical_simulation_time * 1e3:.3f} ms") + + times_full = np.linspace(0.0, self.SPEC.t_end, 100) + target_full = self.target_fn(times_full) + + ax[0].plot(times_full * 1e3, np.rad2deg(target_full[0]), "b--", label="target") + ax[1].plot(times_full * 1e3, np.rad2deg(target_full[1]), "b--", label="target") + + t = np.array(self.history["time"]) + thrust_angle = np.array(self.history["thrust_angle"]) + action = np.array(self.history["action"]) + + if len(t) > 0: + ax[0].plot(t * 1e3, np.rad2deg(thrust_angle[:,0]), "k", label="current") + ax[1].plot(t * 1e3, np.rad2deg(thrust_angle[:,1]), "k", label="current") + + ax[0].set_ylim(-10.0, 10.0) + ax[1].set_ylim(-10.0, 10.0) + + # for actuator_i in range(self.num_actuators): + # ax[1].plot(t * 1e3, action[:,actuator_i], label=f"Inj. {actuator_i:02d}") + + # ax[1].set_ylim(-0.05, 1.05) + + if self.is_pressure_probes and len(t) > 0: + pressure_probes = np.array(self.history["pressure_probes"]) + pressure_probes /= self.SPEC.p0 + for probe_i in range(self.num_probes // 2): + ax[2].plot(t * 1e3, pressure_probes[:, probe_i], label=f"Probe 0{probe_i:d}") + ax[3].plot(t * 1e3, pressure_probes[:, probe_i + self.num_probes // 2], label=f"Probe 1{probe_i:d}") + ax[2].set_ylim(0.0, 0.5) + ax[3].set_ylim(0.0, 0.5) + + titles = ( + "Thrust angle" + r"$\delta_0$" + "[deg]", + "Thrust angle" + r"$\delta_1$" + "[deg]", + "Pressure probes downstream loc 0", + "Pressure probes downstream loc 1" + ) + for axi, title in zip(ax, titles): + axi.set_box_aspect(1.0) + axi.set_xlabel("t [ms]") + axi.set_title(title) + axi.legend() + + if self.render_mode is RenderMode.SHOW: + plt.show() + elif self.render_mode is RenderMode.SAVE: + self._save_render_figure(fig, "observations") + else: + raise ValueError(f"RenderMode {self.render_mode} is not valid.") + plt.close(fig) diff --git a/hydrogym/jaxfluids/utils/__init__.py b/hydrogym/jaxfluids/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hydrogym/jaxfluids/utils/nozzle.py b/hydrogym/jaxfluids/utils/nozzle.py new file mode 100644 index 0000000..e855d06 --- /dev/null +++ b/hydrogym/jaxfluids/utils/nozzle.py @@ -0,0 +1,857 @@ +from dataclasses import dataclass +import os +from pathlib import Path +import platform +if platform.machine().lower() in {"aarch64", "arm64"}: + os.environ["VTK_DEFAULT_OPENGL_WINDOW"] = "vtkOSOpenGLRenderWindow" +from typing import Callable, NamedTuple + +from jax import Array +import jax.numpy as jnp +import json +from numpy import ndarray +import numpy as np +import pyvista as pv +pv.global_theme.allow_empty_mesh = True + + +from jaxfluids import SimulationManager +from jaxfluids_thirdparty.gas_dynamics.isentropic import ( + pressure_ratio_isentropic, + density_ratio_isentropic, + mach_number_from_pressure_ratio_isentropic, +) +from jaxfluids_thirdparty.gas_dynamics.core import ( + speed_of_sound, + total_energy, + density_from_pressure_temperature, +) + + +TargetThrustAngle = Array | float +TargetThrustAngleFn = Callable[[Array | float], TargetThrustAngle] + + +@dataclass(frozen=True, slots=True) +class NozzleGeometry: + """Fixed nozzle geometry based on the publication + Das et al. 2025 AIAA + """ + A: tuple[float, float] = (0.0,-0.01559) + B: tuple[float, float] = (0.0, 0.0352) + C: tuple[float, float] = (0.02329, 0.02954) + D: tuple[float, float] = (0.05049, 0.01552) + E: tuple[float, float] = (0.0608076, 0.0140462) + F: tuple[float, float] = (0.05779, 0.02962) + G: tuple[float, float] = (0.10401, 0.02247) + H: tuple[float, float] = (0.11557, 0.0246888) + R: float = 0.0137543 + UPPER_EDGE: float = 0.045 + + def area_ratio_inlet(self, dim: int) -> float: + """Computes the ratio of the inlet area + to the throat area. + """ + if dim not in (2, 3): + raise ValueError(f"Invalid dim. Got {dim}.") + + ratio = self.B[1] / self.R + + if dim == 2: + return ratio + else: + return ratio**2 + + @property + def D_exit(self) -> float: + return 2 * self.H[1] + + @property + def D_throat(self) -> float: + return 2 * self.R + + +@dataclass(frozen=True, slots=True) +class InjectorGeometry: + X: float # position + IW: float # width + N: int # count + + +@dataclass(frozen=True, slots=True) +class InjectorPlaneParameters: + positions: Array + tangents: Array + normals: Array + + +@dataclass(frozen=True, slots=True) +class PressureRatios: + NPR: float # nozzle pressure ratio + SPR: float # secondary pressure ratio + + +class ObsData(NamedTuple): + thrust_angle: Array + target_angle: Array + pressure_probes: Array | None + + +@dataclass(frozen=True, slots=True) +class TVCSpec: + dim: int + grid_resolutions: tuple[str, ...] = ("coarse", "fine") + fixed_num_actuators: int | None = None + min_num_actuators: int | None = None + max_num_actuators: int | None = None + ambient_pressure: float = 1e+5 + ambient_temperature: float = 300.0 + specific_gas_constant: float = 287.14 + specific_heat_ratio: float = 1.4 + nozzle_pressure_ratio: float = 4.6 + nozzle_geometry: NozzleGeometry = NozzleGeometry() + injector_x: float = 0.789 + injector_width: float = 0.002032 + t_end: float = 1e-2 + + @property + def p0(self) -> float: + return self.nozzle_pressure_ratio * self.ambient_pressure + + +@dataclass(frozen=True, slots=True) +class TVCEnvOptions: + num_actuators: int + secondary_pressure_ratio: float + resolution: str + ngpus: int + is_pressure_probes: bool + is_scale_observations: bool + target_fn: TargetThrustAngleFn + + +@dataclass(frozen=True, slots=True) +class TVCRuntimeSetup: + env_name: str + env_dir: Path + case_setup_dict: dict + numerical_setup_dict: dict + restart_file_path: Path + + +def build_tvc_env_options( + *, + env_config: dict, + spec: TVCSpec, + target_fns: dict[str, TargetThrustAngleFn], + cls_name: str, + ) -> TVCEnvOptions: + + num_actuators = env_config.get("num_actuators") + if spec.fixed_num_actuators is not None: + if num_actuators is not None: + raise ValueError( + f"{cls_name} requires {spec.fixed_num_actuators} actuators." + ) + num_actuators = spec.fixed_num_actuators + else: + if num_actuators is None: + raise ValueError("num_actuators must be provided.") + + min_num_actuators = spec.min_num_actuators + max_num_actuators = spec.max_num_actuators + if min_num_actuators is None or max_num_actuators is None: + raise ValueError( + f"{cls_name} must define either a fixed actuator count " + "or both min_num_actuators and max_num_actuators." + ) + + if not (min_num_actuators <= num_actuators <= max_num_actuators): + raise ValueError( + f"num_actuators must be in " + f"[{min_num_actuators}, {max_num_actuators}]. " + f"Got {num_actuators}." + ) + + secondary_pressure_ratio = env_config.get("secondary_pressure_ratio", 0.7) + if secondary_pressure_ratio < 0.7 or secondary_pressure_ratio > 0.9: + raise ValueError( + f"secondary_pressure_ratio must be >= 0.7 and <= 0.9." + f"Got {secondary_pressure_ratio}." + ) + + resolution = env_config.get("resolution", "fine") + if resolution not in spec.grid_resolutions: + raise ValueError( + f"Resolution {resolution} is not supported. " + f"Please choose from {spec.grid_resolutions}." + ) + + ngpus = env_config.get("ngpus", 1) + if not isinstance(ngpus, int): + raise ValueError(f"ngpus must be of type int. Got {type(ngpus)}.") + if ngpus < 1: + raise ValueError(f"ngpus must be >= 1. Got {ngpus}.") + + is_pressure_probes = env_config.get("is_pressure_probes", False) + if not isinstance(is_pressure_probes, bool): + raise ValueError( + "is_pressure_probes must be of type bool. " + f"Got {type(is_pressure_probes)}" + ) + + is_scale_observations = env_config.get("is_scale_observations", True) + if not isinstance(is_scale_observations, bool): + raise ValueError( + f"is_scale_observations needs to be of type bool. " + f"Got {type(is_scale_observations)}." + ) + + target_key = env_config.get("target_fn", "sine") + if target_key not in target_fns: + raise ValueError( + f"Unknown target_fn {target_key!r}. " + f"Please choose from {tuple(target_fns)}." + ) + + return TVCEnvOptions( + num_actuators=num_actuators, + secondary_pressure_ratio=secondary_pressure_ratio, + resolution=resolution, + ngpus=ngpus, + is_pressure_probes=is_pressure_probes, + is_scale_observations=is_scale_observations, + target_fn=target_fns[target_key], + ) + + +def build_tvc_runtime_setup( + *, + base_path: Path, + dim: int, + resolution: str, + ngpus: int, + ) -> TVCRuntimeSetup: + + env_name = f"Nozzle{dim}D_{resolution}" + # env_dir = base_path / env_name + env_dir = base_path + + case_setup_path = env_dir / "jxf_case_setup.json" + numerical_setup_path = env_dir / "jxf_numerical_setup.json" + restart_file_path = env_dir / "restart.h5" + + if not case_setup_path.exists(): + raise FileNotFoundError(f"Could not find case setup file {case_setup_path}.") + + if not numerical_setup_path.exists(): + raise FileNotFoundError(f"Could not find numerical setup file {numerical_setup_path}.") + + if not restart_file_path.exists(): + raise FileNotFoundError(f"Could not find restart file {restart_file_path}.") + + case_setup_dict = json.loads(case_setup_path.read_text()) + case_setup_dict["domain"]["decomposition"]["split_x"] = ngpus + + numerical_setup_dict = json.loads(numerical_setup_path.read_text()) + + return TVCRuntimeSetup( + env_name=env_name, + env_dir=env_dir, + case_setup_dict=case_setup_dict, + numerical_setup_dict=numerical_setup_dict, + restart_file_path=restart_file_path, + ) + + +def compute_thrust( + primitives: Array, # shape (Np,Nx,Ny,Nz) + p_infty: float, + apertures_x: Array, + cell_centers: tuple[Array, ...], + cell_sizes: tuple[Array, ...], + ) -> tuple[Array, Array, Array]: + """Computes the thrust of the nozzle. + + F_x = mdot_e * u_e + (p_e - p_infty) * A_e + F_y = mdot_e * v_e + F_z = mdot_e * w_e + """ + + nozzle_geometry = NozzleGeometry() + + x, y, z = [xi.flatten() for xi in cell_centers] + dx, _, _ = [dxi.flatten() for dxi in cell_sizes] + + DIM = 3 if len(z) > 1 else 2 + dx_min = jnp.min(dx) + cell_face_area = dx_min if DIM == 2 else dx_min**2 + + # interpolate primitives to cell face + primitives_cf = jnp.concatenate([ + primitives[:,0:1], primitives, primitives[:,-1:] + ], axis=1) + primitives_cf = (primitives_cf[:,1:] + primitives_cf[:,:-1]) / 2 + + x_cf = jnp.concatenate([x - dx/2, x[-1:] + dx[-1:]/2], axis=0) + + # x ids throat and exit + xid_e = jnp.searchsorted(x_cf, nozzle_geometry.H[0]) - 1 + xid_t = jnp.searchsorted(x_cf, nozzle_geometry.F[0]) - 1 + + _, Y, Z = jnp.meshgrid(x_cf, y, z, indexing="ij") + if DIM == 2: + mask = jnp.abs(Y[xid_e]) <= nozzle_geometry.H[1] + else: + mask = jnp.sqrt(Y[xid_e]**2 + Z[xid_e]**2) <= nozzle_geometry.H[1] + + # states at nozzle exit + A_e = apertures_x[xid_e] * cell_face_area + rho_e = primitives_cf[0, xid_e] + vel_e = primitives_cf[1:4, xid_e] + p_e = primitives_cf[4, xid_e] + + # states at throat + A_t = apertures_x[xid_t] * cell_face_area + rho_t = primitives_cf[0, xid_t] + u_t = primitives_cf[1, xid_t] + + # mass flow + mdot_e = rho_e * vel_e[0] * A_e * mask + mdot_t = rho_t * u_t * A_t * mask + + # thrust + thrust = mdot_e * vel_e + thrust = thrust.at[0].add( (p_e - p_infty) * A_e ) + thrust = jnp.sum(thrust * mask, axis=(-1,-2)) + + # metrics + mdot_e = jnp.sum(mdot_e, axis=(-1,-2)) + mdot_t = jnp.sum(mdot_t, axis=(-1,-2)) + + return thrust, mdot_t, mdot_e + + +def _compute_injector_plane_params( + injector_geometry: InjectorGeometry, + ) -> InjectorPlaneParameters: + """Computes position, normal and tangent of the injector. + + :param injector_geometry: _description_ + :type injector_geometry: InjectorGeometry + :return: _description_ + :rtype: InjectorPlaneParameters + """ + + nozzle_geometry = NozzleGeometry() + + # compute vectors for base injector + H = jnp.array(nozzle_geometry.H) # end point of convergent linear nozzle section + E = jnp.array(nozzle_geometry.E) # start point of convergent linear nozzle section + EH = H - E + tangent = EH / jnp.linalg.norm(EH) + x_position = E[0] + injector_geometry.X * (H[0] - E[0]) + y_position = (H[1] - E[1])/(H[0] - E[0]) * (x_position - E[0]) + E[1] + + t0 = jnp.array([tangent[0], tangent[1], 0.0]) + n0 = jnp.array([tangent[1], -tangent[0], 0.0]) + pos0 = jnp.array([x_position, y_position, 0.0]) + + positions = [] + tangents = [] + normals = [] + + theta = jnp.linspace(0, 2*jnp.pi, injector_geometry.N, endpoint=False) + + # rotate the vectors around the circumference of the nozzle + # to get the corresponding vectors of the remaining injectors + for th in theta: + R = jnp.array([ + [1, 0, 0], + [0, jnp.cos(th), -jnp.sin(th)], + [0, jnp.sin(th), jnp.cos(th)], + ]) + + positions.append(R @ pos0) + tangents.append(R @ t0) + normals.append(R @ n0) + + positions = jnp.stack(positions) + tangents = jnp.stack(tangents) + normals = jnp.stack(normals) + + return InjectorPlaneParameters( + positions, tangents, normals + ) + + +def _compute_injector_mask( + mesh_grid: tuple[Array], + IW: float, + injector_params: InjectorPlaneParameters, + dim: int, + dx: float + ) -> Array: + + position = injector_params.positions + tangent = injector_params.tangents + normal = injector_params.normals + + # project the distance vector R (mesh_grid - injector position) + # on injector tangent plane to compute mask + + position = position[:, None] + tangent = tangent[:, None] + normal = normal[:, None] + + if dim == 2: + X, Y = mesh_grid + + R = jnp.stack([ + X - position[0], + Y - position[1] + ], axis=0) + + s = jnp.sum(R*tangent[:2], axis=0) + d = jnp.sum(R*normal[:2], axis=0) + + mask_injector = ( + (jnp.abs(s) <= IW / 2) & + (jnp.abs(d) <= dx) + ) + + elif dim == 3: + X, Y, Z = mesh_grid + + R = jnp.stack([ + X - position[0], + Y - position[1], + Z - position[2] + ], axis=0) + + # compute circumferential tangent + t_theta = jnp.cross(normal, tangent, axis=0) + t_theta = t_theta / jnp.linalg.norm(t_theta, keepdims=True) + + # projections + s_x = jnp.sum(R*tangent, axis=0) + s_theta = jnp.sum(R*t_theta, axis=0) + d = jnp.sum(R*normal, axis=0) + + R_inj = IW / 2 + + mask_injector = ( + (s_x**2 + s_theta**2 <= R_inj**2) & + (jnp.abs(d) <= 5*dx) # NOTE large safety offset in normal direction given that nozzle exit surface is curved around circumference + ) + + else: + raise ValueError + + return mask_injector + + +def _compute_choked_state( + p_total_injector: float, + rho_total_injector: float, + specific_heat_ratio: float, + ) -> tuple[float, float, float, float]: + p_choked = p_total_injector * pressure_ratio_isentropic(1.0, specific_heat_ratio) + rho_choked = rho_total_injector * density_ratio_isentropic(1.0, specific_heat_ratio) + u_choked = speed_of_sound(p_choked, rho_choked, specific_heat_ratio) + E_choked = total_energy(p_choked, rho_choked, u_choked, specific_heat_ratio) + return p_choked, rho_choked, u_choked, E_choked + + +def _compute_unchoked_state( + p_local: float, + rho_total_injector: float, + p_ratio: float, + specific_heat_ratio: float + ) -> tuple[float, float, float]: + # NOTE if p_ratio_unchoked <= 1.0 -> no injection. We clip to prevent negative sqrt. + p_ratio_unchoked = jnp.clip(p_ratio, 0.0, 1.0) + M_unchoked = mach_number_from_pressure_ratio_isentropic(p_ratio_unchoked, specific_heat_ratio) + rho_unchoked = rho_total_injector * density_ratio_isentropic(M_unchoked, specific_heat_ratio) + u_unchoked = M_unchoked * speed_of_sound(p_local, rho_unchoked, specific_heat_ratio) + E_unchoked = total_energy(p_local, rho_unchoked, u_unchoked, specific_heat_ratio) + return rho_unchoked, u_unchoked, E_unchoked + + +def initialize_injector_flux_fn( + injector_geometry: InjectorGeometry, + pressure_ratios: PressureRatios, + p_infty: float, + T_infty: float, + specific_heat_ratio: float, + specific_gas_constant: float, + sim_manager: SimulationManager, + ) -> Callable[ + [Array, Array, Array, Array, Array], + tuple[Array, Array, Array] + ]: + + domain_information = sim_manager.domain_information + dim = domain_information.dim + dx = domain_information.smallest_cell_size + + IW = injector_geometry.IW + NPR = pressure_ratios.NPR + SPR = pressure_ratios.SPR + + # total pressure in reservoir for injector + P0 = p_infty * NPR * SPR + + num_injectors = injector_geometry.N + injector_params = _compute_injector_plane_params(injector_geometry) + + def compute_interface_flux( + primitives: Array, + interface_length: Array, + normal: Array, + mesh_grid: Array, + actuator: Array + ) -> tuple[Array, Array, Array]: + """Computes the interface flux for each injector. + """ + + if len(actuator) != num_injectors: + raise ValueError("Number of actions unequal to injector count") + + # local nozzle pressure + p_local = primitives[-1] + + # clipping actions + actuator = jnp.clip(actuator, 0.0, 1.0) + + mass_flux = 0.0 + momentum_flux = p_local * normal * interface_length + energy_flux = 0.0 + + for i in range(num_injectors): + + actuator_i = actuator[i] + + # actuator [0.0, 1.0] adjusts total pressure + p_total_injector = p_local + actuator_i * (P0 - p_local) + rho_total_injector = density_from_pressure_temperature( + p=p_total_injector, + T=T_infty, + R=specific_gas_constant, + ) + + # we assume isentropic expansion to either M=1 (choked) + # or to local pressure to compute injector state + + # choked injector state + p_choked, rho_choked, u_choked, E_choked = _compute_choked_state( + p_total_injector=p_total_injector, + rho_total_injector=rho_total_injector, + specific_heat_ratio=specific_heat_ratio, + ) + + # unchoked injector state + p_ratio = p_local / p_total_injector + rho_unchoked, u_unchoked, E_unchoked = _compute_unchoked_state( + p_local=p_local, + rho_total_injector=rho_total_injector, + p_ratio=p_ratio, + specific_heat_ratio=specific_heat_ratio + ) + + # checking if injector is choked + p_ratio_choked = pressure_ratio_isentropic(1.0, specific_heat_ratio) + mask_choked = p_ratio < p_ratio_choked + + # masking choked vs. unchoked + rho_injector = jnp.where(mask_choked, rho_choked, rho_unchoked) + p_injector = jnp.where(mask_choked, p_choked, p_local) + u_injector = jnp.where(mask_choked, u_choked, u_unchoked) + E_injector = jnp.where(mask_choked, E_choked, E_unchoked) + + mask_injector = _compute_injector_mask( + mesh_grid, IW, + InjectorPlaneParameters( + injector_params.positions[i], + injector_params.tangents[i], + injector_params.normals[i] + ), + dim, dx + ) + + u_injector_vec = u_injector * injector_params.normals[i][:,None] + u_injector_n = jnp.sum(u_injector_vec * normal, axis=0) + + mass_flux += rho_injector * u_injector_n * interface_length * mask_injector + momentum_flux_i = ( + rho_injector * u_injector_n * u_injector_vec + p_injector * normal + ) * interface_length + momentum_flux = momentum_flux * (1 - mask_injector) + momentum_flux_i * mask_injector + energy_flux += (E_injector + p_injector) * u_injector_n * interface_length * mask_injector + + return mass_flux, momentum_flux, energy_flux + + return compute_interface_flux + + +def plot_flowfield_3d( + primitives: ndarray, + levelset: ndarray, + cell_centers: tuple[ndarray, ...], + cell_sizes: tuple[ndarray, ...], + nozzle_pressure: float, + injector_geometry: InjectorGeometry, + specific_heat_ratio: float, + ) -> pv.Plotter: + + + injector_params = _compute_injector_plane_params(injector_geometry) + + mesh_grid = np.meshgrid(*cell_centers, indexing="ij") + mesh_grid = np.stack(mesh_grid, axis=0) + IW = injector_geometry.IW + min_dx = np.min(cell_sizes[0]) + + mask_injector_list = [] + + for i in range(injector_geometry.N): + mask_injector = _compute_injector_mask( + mesh_grid.reshape(3,-1), IW, + InjectorPlaneParameters( + injector_params.positions[i], + injector_params.tangents[i], + injector_params.normals[i] + ), + 3, min_dx + ) + mask_injector_list.append(mask_injector.reshape(mesh_grid[0].shape)) + + # compute fields + density = primitives[0] + velocity = np.linalg.norm(primitives[1:4],axis=0,ord=2) + pressure = primitives[-1] + speed_of_sound = np.sqrt(specific_heat_ratio * pressure / density) + mach_number = velocity/speed_of_sound + schlieren = np.linalg.norm(np.gradient(density),axis=0,ord=2) + + min_dx = np.min(cell_sizes[0]) + + # Build stretched grid + x_centers, y_centers, z_centers = cell_centers + dx, dy, dz = cell_sizes + + # Compute cell edges from centers + sizes + x_edges = np.concatenate(([x_centers[0] - dx[0] / 2], x_centers + dx / 2)) + y_edges = np.concatenate(([y_centers[0] - dy[0] / 2], y_centers + dy / 2)) + z_edges = np.concatenate(([z_centers[0] - dz[0] / 2], z_centers + dz / 2)) + + # Create rectilinear grid (supports stretched mesh) + grid_levelset = pv.RectilinearGrid(x_edges, y_edges, z_edges) + grid = pv.RectilinearGrid(x_edges, y_edges, z_edges) + + # Attach data + total_mask = np.maximum.reduce(mask_injector_list).astype(float) + grid_levelset.cell_data["levelset"] = levelset.ravel(order="F") / min_dx + grid_levelset.cell_data["pressure"] = pressure.ravel(order="F") / nozzle_pressure + grid_levelset.cell_data["total_mask"] = total_mask.ravel(order="F") + + grid.cell_data["levelset"] = levelset.ravel(order="F")/min_dx + grid.cell_data["schlieren"] = schlieren.ravel(order="F") + grid.cell_data["mach_number"] = mach_number.ravel(order="F") + + # Ensure we are working with point data + grid_levelset = grid_levelset.cell_data_to_point_data() + grid = grid.cell_data_to_point_data() + grid, grid_levelset = clip_grids(grid, grid_levelset) + + # Plot + plotter = plot_slice(grid_levelset, grid) + + return plotter + + +def clip_grids(grid, grid_levelset): + + # --- compute shared geometry once --- + points = grid.points # use one grid (same topology assumed) + x = points[:, 0] + y = points[:, 1] + z = points[:, 2] + r = np.sqrt(y**2 + z**2) + + # --- build masks --- + nozzle_x_limit = NozzleGeometry().H[0] - 5e-4 + + mask_levelset = ( + (r < 0.044) & + (x < nozzle_x_limit) + ) + + mask_grid = ( + (r < 0.044) & + (x < 0.25) & + (grid.point_data["levelset"] > 2) + ) + + # --- extract in one pass --- + grid_levelset_clipped = grid_levelset.extract_points( + mask_levelset, + adjacent_cells=True + ) + + grid_clipped = grid.extract_points( + mask_grid, + adjacent_cells=True + ) + + return grid_clipped, grid_levelset_clipped + +def plot_slice(grid_levelset, grid) -> pv.Plotter: + + plotter = pv.Plotter(off_screen=True, window_size=(3000, 2200)) + + # CMAP = "Spectral_r" + CMAP = "coolwarm" + + # nozzle + contour = grid_levelset.contour(isosurfaces=[0.0], scalars="levelset") + contour = contour.triangulate().subdivide(2, subfilter="linear") + contour = contour.sample(grid_levelset) + + contour_visible = contour.threshold( + value=0.5, + scalars="total_mask", + invert=True, + ).extract_surface(algorithm="dataset_surface") + + contour_masked = contour.threshold( + value=0.5, + scalars="total_mask", + invert=False, + ).extract_surface(algorithm="dataset_surface") + + contour_masked = contour.threshold( + value=0.5, + scalars="total_mask", + invert=False, + ).extract_surface(algorithm="dataset_surface") + + contour_masked = contour_masked.compute_normals( + cell_normals=False, + point_normals=True, + ) + contour_masked = contour_masked.warp_by_vector("Normals", factor=2e-4) + + masked_outline = contour_masked.extract_feature_edges( + boundary_edges=True, + feature_edges=False, + manifold_edges=False, + non_manifold_edges=False, + ) + + plotter.add_mesh( + contour_visible, + scalars="pressure", + cmap=CMAP, + opacity=1.0, + smooth_shading=True, + show_scalar_bar=False, + clim=[0.0, 1.0], + ) + + plotter.add_mesh( + contour_masked, + color="white", + opacity=1.0, + smooth_shading=False, + show_scalar_bar=False, + lighting=False, + ) + + plotter.add_mesh( + masked_outline, + color="black", + line_width=5.0, + lighting=False, + ) + + + # mach number + offset = 0.05 + slice_xy = grid.slice( + normal=(0, 0, 1), + origin=(0.0, 0.0, 0.0) + ) + slice_xy = slice_xy.threshold( + value=0.0, + scalars="levelset", + ) + slice_xy = slice_xy.translate((0.0, 0.0, -offset), inplace=False) + plotter.add_mesh( + slice_xy, + scalars="mach_number", + cmap=CMAP, + clim=[0.0, 3.0], + show_scalar_bar=False, + lighting=True, + ambient=0.1, + # diffuse=1.0, + # specular=0.1, + ) + + slice_xz = grid.slice( + normal=(0, 1, 0), + origin=(0.0, 0.0, 0.0) + ) + slice_xz = slice_xz.threshold( + value=0.0, + scalars="levelset", + ) + slice_xz = slice_xz.translate((0.0, -offset, 0.0), inplace=False) + plotter.add_mesh( + slice_xz, + scalars="mach_number", + cmap=CMAP, + clim=[0.0, 3.0], + show_scalar_bar=False, + # lighting=True, + # ambient=0.05, + # diffuse=0.5, + # specular=0.1, + ) + + # schlieren + low = 0.25 + high = 1.5 + n = 20 + + values = np.linspace(low, high, n) + contours = grid.contour( + isosurfaces=values, + scalars="schlieren" + ) + plotter.add_mesh( + contours, + color="darkgray", + opacity=0.3, + smooth_shading=True, + lighting=True, + # ambient=0.2, + diffuse=1.0, + specular=1.0, + specular_power=100, + ) + + plotter.add_light(pv.Light(light_type='headlight', intensity=0.3)) + # plotter.add_axes() + # plotter.show_bounds() + # plotter.show_grid() + plotter.camera_position = ( + (0.25,0.1,0.15), + (0.12,0.0,0.0), + (0,1,0), + ) + # plotter.enable_eye_dome_lighting() + # plotter.set_background("black") + + return plotter \ No newline at end of file