Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 87 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ cartopy = "^0.25.0"
matplotlib = "^3.10.7"
numpy = "^2.3.5"
pandas = "^2.3.3"
boto3 = "^1.42.65"

[tool.poetry.group.dev.dependencies]
mypy = "^1.18.1"
Expand Down
108 changes: 107 additions & 1 deletion pytrajplot/main.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,103 @@
"""Command line interface of pytrajplot."""
from typing import Tuple, Dict
from typing import Tuple, Dict, Optional
import logging
import os
from pathlib import Path

# Third-party
import click
import boto3

# First-party
from pytrajplot import __version__
from pytrajplot.generate_pdf import generate_pdf
from pytrajplot.parse_data import check_input_dir
from pytrajplot.utils import count_to_log_level

# Setup logging
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(level=log_level)
logger = logging.getLogger(__name__)

def print_version(ctx: click.Context, _param: click.Parameter, value: bool) -> None:
"""Print the version number and exit."""
if value:
click.echo(__version__)
ctx.exit(0)

def replace_variables(template_content: str) -> str:
"""
Replace $VAR with actual environment variable values.
Args:
template_content: Template string with $VARIABLE placeholders
Returns:
String with variables replaced by environment values
"""
result = template_content
# Get all environment variables as dict
env_vars = dict(os.environ)

# Replace variables found in the template
for env_key, env_value in env_vars.items():
placeholder = f'${env_key}'
if placeholder in result:
result = result.replace(placeholder, env_value)
logger.info(f"Replaced {placeholder} with {env_value}")
return result


def check_plot_info_file(input_dir: str, info_name: str, ssm_parameter_path: str | None = None) -> bool:
"""
Check if plot_info file exists in input directory.
If not found, fetch from SSM parameter and create it replacing variables.
Args:
input_dir: Input directory path
info_name: Name of the plot info file
ssm_parameter_path: SSM parameter path (optional, uses env var if not provided)
Returns:
bool: True if file exists or was created successfully, False otherwise
"""
input_path = Path(input_dir)
plot_info_file = input_path / info_name

# Check if plot_info file already exists
if plot_info_file.exists():
logger.info(f"Plot info file already exists: {plot_info_file}")
return True

# File doesn't exist, try to create it from SSM parameter
logger.info(f"Plot info file not found: {plot_info_file}")

try:
# Get SSM parameter path from argument or environment
ssm_param_path = ssm_parameter_path or os.environ.get('SSM_PARAMETER_PATH', '/pytrajplot/icon/plot_info')
logger.info(f"Fetching SSM parameter: {ssm_param_path}")

# Fetch template from SSM Parameter
ssm_client = boto3.client('ssm')
response = ssm_client.get_parameter(
Name=ssm_param_path,
WithDecryption=True
)

# Get the template content
template_content = response['Parameter']['Value']
logger.info(f"Template content length: {len(template_content)} chars")

# Replace variables with environment variable values
substituted_content = replace_variables(template_content)

# Create the plot_info file
with open(plot_info_file, 'w') as f:
f.write(substituted_content)

logger.info(f"Successfully created plot info file: {plot_info_file}")
return True

except Exception as e:
logger.error(f"Failed to create plot info file from SSM parameter: {str(e)}")
logger.error(f"SSM parameter path: {ssm_parameter_path or os.environ.get('SSM_PARAMETER_PATH', 'not_set')}")
return False

def interpret_options(start_prefix: str, traj_prefix: str, info_name: str, language: str) -> Tuple[Dict[str, str], str]:
"""Reformat command line inputs.
Expand Down Expand Up @@ -124,6 +205,17 @@ def interpret_options(start_prefix: str, traj_prefix: str, info_name: str, langu
default=["pdf"],
help="Choose data type(s) of final result. Default: pdf",
)
@click.option(
"--ssm-parameter-path",
type=str,
help="SSM parameter path for plot_info template. Uses SSM_PARAMETER_PATH env var if not specified.",
)
@click.option(
"--skip-ssm-fallback",
is_flag=True,
default=False,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would default to True here so that the implementation at CSCS would not need to change its args.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As --skip-ssm-fallback is a flag, defaulting it to true might not be practical because then setting it without a value would be the same as not setting it? Would it be better to invert the logic to a --use-ssm-fallback flag so that setting it without a value has an effect? Alternatively, could the --ssm-parameter-path be used for this functionality and the fallback is skipped if --ssm-parameter-path is not present? Unless there is an advantage to keep the --ssm-parameter-path even when the fallback should be skipped.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, sorry i misinterpreted the flag. Yes I would be fine to leave as is, True = skip. But I think I prefer the suggestion that if --ssm-parameter-path is set to True or if SSM_PARAMETER_PATH env var is defined, i would run the fallback, and otherwise skip.

help="Skip SSM parameter fallback if plot_info file is missing.",
)
@click.option(
"--version",
"-V",
Expand All @@ -143,7 +235,21 @@ def cli(
language: str,
domain: str,
datatype: str,
ssm_parameter_path: str | None = None,
skip_ssm_fallback: bool = False,
) -> None:
# Check if plot_info file exists (create from SSM if needed)
if not skip_ssm_fallback:
plot_info_created = check_plot_info_file(
input_dir=input_dir,
info_name=info_name,
ssm_parameter_path=ssm_parameter_path
)

if not plot_info_created:
logger.error("Failed to check if plot_info file exists. Use --skip-ssm-fallback to continue anyway.")
raise click.ClickException("Missing plot_info file and failed to create from SSM parameter.")

prefix_dict, language = interpret_options(
start_prefix=start_prefix,
traj_prefix=traj_prefix,
Expand Down
2 changes: 1 addition & 1 deletion test/integration/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def create_args(input_dir: str, output_dir: str, opts: dict) -> list:
# Positional arguments
args.append(input_dir)
args.append(output_dir)
args.append("--skip-ssm-fallback")

# Keyword arguments
for key, value in opts.items():
Expand Down Expand Up @@ -208,4 +209,3 @@ def test_pytrajplot(input_args, input_dir, output_dir):
for rel in expected:
expected_file = Path(output_path) / Path(rel).name
assert expected_file.exists(), f"Expected output not found: {expected_file}"

Loading