Skip to content

Commit ebf7334

Browse files
committed
Add network validation script executed in the sagemaker_ui_post_startup script
**Description** This change introduces the network validation script which tests if certain AWS services are reachable by making read only API calls with a set timeout. If the call exceeds the timeout, the script infers that it was caused by a bad network setup such as not having access to the internet/ VPC endpoint to make the call. API calls that resolve (succeed or fail) within the timeout are inferred as having the proper network setup. AWS services for Compute Connections and Git are checked in this script. More specifically, the script lists the datazone connections to see which services need to be checked. The unreachable services are aggregated and are displayed by writing to the post-startup-status.json, which displays the error notification in the IDE. **Testing** Tested in a SMUS portal containing internet, no internet, and no internet with VPC Endpoints to Datazone and s3.
1 parent 0fc2d44 commit ebf7334

File tree

4 files changed

+376
-0
lines changed

4 files changed

+376
-0
lines changed
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
#!/bin/bash
2+
set -eux
3+
4+
# Input parameters with defaults:
5+
# Default to 1 (Git storage) if no parameter is passed.
6+
is_s3_storage=${1:-"1"}
7+
# Output file path for unreachable services JSON
8+
network_validation_file=${2:-"/tmp/.network_validation.json"}
9+
10+
# Function to write unreachable services to a JSON file
11+
write_unreachable_services_to_file() {
12+
local value="$1"
13+
local file="$network_validation_file"
14+
15+
# Create the file if it doesn't exist
16+
if [ ! -f "$file" ]; then
17+
touch "$file" || {
18+
echo "Failed to create $file" >&2
19+
return 0
20+
}
21+
fi
22+
23+
# Check file is writable
24+
if [ ! -w "$file" ]; then
25+
echo "Error: $file is not writable" >&2
26+
return 0
27+
fi
28+
29+
# Write JSON object with UnreachableServices key and the comma-separated list value
30+
jq -n --arg value "$value" '{"UnreachableServices": $value}' > "$file"
31+
}
32+
33+
# Configure AWS CLI region using environment variable REGION_NAME
34+
aws configure set region "${REGION_NAME}"
35+
echo "Successfully configured region to ${REGION_NAME}"
36+
37+
# Metadata file location containing DataZone info
38+
sourceMetaData=/opt/ml/metadata/resource-metadata.json
39+
40+
# Extract necessary DataZone metadata fields via jq
41+
dataZoneDomainId=$(jq -r '.AdditionalMetadata.DataZoneDomainId' < "$sourceMetaData")
42+
dataZoneProjectId=$(jq -r '.AdditionalMetadata.DataZoneProjectId' < "$sourceMetaData")
43+
dataZoneEndPoint=$(jq -r '.AdditionalMetadata.DataZoneEndpoint' < "$sourceMetaData")
44+
dataZoneDomainRegion=$(jq -r '.AdditionalMetadata.DataZoneDomainRegion' < "$sourceMetaData")
45+
46+
# Call AWS CLI list-connections, including endpoint if specified
47+
if [ -n "$dataZoneEndPoint" ]; then
48+
response=$(aws datazone list-connections \
49+
--endpoint-url "$dataZoneEndPoint" \
50+
--domain-identifier "$dataZoneDomainId" \
51+
--project-identifier "$dataZoneProjectId" \
52+
--region "$dataZoneDomainRegion")
53+
else
54+
response=$(aws datazone list-connections \
55+
--domain-identifier "$dataZoneDomainId" \
56+
--project-identifier "$dataZoneProjectId" \
57+
--region "$dataZoneDomainRegion")
58+
fi
59+
60+
# Extract each connection item as a compact JSON string
61+
connection_items=$(echo "$response" | jq -c '.items[]')
62+
63+
# Required AWS Services for Compute connections and Git
64+
# Initialize SERVICE_COMMANDS with always-needed STS and S3 checks
65+
declare -A SERVICE_COMMANDS=(
66+
["STS"]="aws sts get-caller-identity"
67+
["S3"]="aws s3api list-buckets --max-items 1"
68+
)
69+
70+
# Track connection types found for conditional checks
71+
declare -A seen_types=()
72+
73+
# Iterate over each connection to populate service commands conditionally
74+
while IFS= read -r item; do
75+
# Extract connection type
76+
type=$(echo "$item" | jq -r '.type')
77+
seen_types["$type"]=1
78+
79+
# For SPARK connections, check for Glue and EMR properties
80+
if [[ "$type" == "SPARK" ]]; then
81+
# If sparkGlueProperties present, add Glue check
82+
if echo "$item" | jq -e '.props.sparkGlueProperties' > /dev/null; then
83+
SERVICE_COMMANDS["Glue"]="aws glue get-crawlers --max-items 1"
84+
fi
85+
86+
# Check for emr-serverless in sparkEmrProperties.computeArn for EMR Serverless check
87+
emr_arn=$(echo "$item" | jq -r '.props.sparkEmrProperties.computeArn // empty')
88+
if [[ "$emr_arn" == *"emr-serverless"* ]]; then
89+
SERVICE_COMMANDS["EMR Serverless"]="aws emr-serverless list-applications --max-results 1"
90+
fi
91+
fi
92+
done <<< "$connection_items"
93+
94+
# Add Athena if ATHENA connection found
95+
[[ -n "${seen_types["ATHENA"]}" ]] && SERVICE_COMMANDS["Athena"]="aws athena list-data-catalogs --max-results 1"
96+
97+
# Add Redshift checks if REDSHIFT connection found
98+
if [[ -n "${seen_types["REDSHIFT"]}" ]]; then
99+
SERVICE_COMMANDS["Redshift Cluster"]="aws redshift describe-clusters --max-records 1"
100+
SERVICE_COMMANDS["Redshift Serverless"]="aws redshift-serverless list-namespaces --max-results 1"
101+
fi
102+
103+
# Optionally add CodeConnections if S3 storage flag is true (Git storage)
104+
if [[ "$is_s3_storage" == "1" ]]; then
105+
SERVICE_COMMANDS["CodeConnections"]="aws codeconnections list-hosts --max-results 1"
106+
fi
107+
108+
# Timeout (seconds) for each API call
109+
api_time_out_limit=5
110+
# Array to accumulate unreachable services
111+
unreachable_services=()
112+
# Create a temporary directory to store individual service results
113+
temp_dir=$(mktemp -d)
114+
115+
# Launch all service API checks in parallel background jobs
116+
for service in "${!SERVICE_COMMANDS[@]}"; do
117+
{
118+
# Run command with timeout, discard stdout/stderr
119+
if timeout "${api_time_out_limit}s" bash -c "${SERVICE_COMMANDS[$service]}" > /dev/null 2>&1; then
120+
# Success: write OK to temp file
121+
echo "OK" > "$temp_dir/$service"
122+
else
123+
# Get exit code to differentiate timeout or other errors
124+
exit_code=$?
125+
if [ "$exit_code" -eq 124 ]; then
126+
# Timeout exit code
127+
echo "TIMEOUT" > "$temp_dir/$service"
128+
else
129+
# Other errors (e.g., permission denied)
130+
echo "ERROR" > "$temp_dir/$service"
131+
fi
132+
fi
133+
} &
134+
done
135+
136+
# Wait for all background jobs to complete before continuing
137+
wait
138+
139+
# Process each service's result file to identify unreachable ones
140+
for service in "${!SERVICE_COMMANDS[@]}"; do
141+
result_file="$temp_dir/$service"
142+
if [ -f "$result_file" ]; then
143+
result=$(<"$result_file")
144+
if [[ "$result" == "TIMEOUT" ]]; then
145+
echo "$service API did NOT resolve within ${api_time_out_limit}s. Marking as unreachable."
146+
unreachable_services+=("$service")
147+
elif [[ "$result" == "OK" ]]; then
148+
echo "$service API is reachable."
149+
else
150+
echo "$service API returned an error (but not a timeout). Ignored for network check."
151+
fi
152+
else
153+
echo "$service check did not produce a result file. Skipping."
154+
fi
155+
done
156+
157+
# Cleanup temporary directory
158+
rm -rf "$temp_dir"
159+
160+
# Write unreachable services to file if any, else write empty string
161+
if (( ${#unreachable_services[@]} > 0 )); then
162+
joined_services=$(IFS=','; echo "${unreachable_services[*]}")
163+
# Add spaces after commas for readability
164+
joined_services_with_spaces=${joined_services//,/,\ }
165+
write_unreachable_services_to_file "$joined_services_with_spaces"
166+
echo "Unreachable AWS Services: ${joined_services_with_spaces}"
167+
else
168+
write_unreachable_services_to_file ""
169+
echo "All required AWS services reachable within ${api_time_out_limit}s"
170+
fi

template/v2/dirs/etc/sagemaker-ui/sagemaker_ui_post_startup.sh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,4 +204,22 @@ if [ "${SAGEMAKER_APP_TYPE_LOWERCASE}" = "jupyterlab" ]; then
204204
bash /etc/sagemaker-ui/workflows/sm-spark-cli-install.sh
205205
fi
206206

207+
# Execute network validation script, to check if any required AWS Services are unreachable
208+
echo "Starting network validation script..."
209+
210+
network_validation_file="/tmp/.network_validation.json"
211+
212+
# Run the validation script; only if it succeeds, check unreachable services
213+
if bash ./network_validation.sh "$is_s3_storage_flag" "$network_validation_file"; then
214+
# Read unreachable services from JSON file
215+
failed_services=$(jq -r '.UnreachableServices // empty' "$network_validation_file" || echo "")
216+
if [[ -n "$failed_services" ]]; then
217+
error_message="Network issue detected. The following AWS services are not reachable: $failed_services. Please contact your admin."
218+
write_status_to_file "error" "$error_message"
219+
echo "$error_message"
220+
fi
221+
else
222+
echo "Warning: network_validation.sh failed, skipping unreachable services check."
223+
fi
224+
207225
write_status_to_file_on_script_complete
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
#!/bin/bash
2+
set -eux
3+
4+
# Input parameters with defaults:
5+
# Default to 1 (Git storage) if no parameter is passed.
6+
is_s3_storage=${1:-"1"}
7+
# Output file path for unreachable services JSON
8+
network_validation_file=${2:-"/tmp/.network_validation.json"}
9+
10+
# Function to write unreachable services to a JSON file
11+
write_unreachable_services_to_file() {
12+
local value="$1"
13+
local file="$network_validation_file"
14+
15+
# Create the file if it doesn't exist
16+
if [ ! -f "$file" ]; then
17+
touch "$file" || {
18+
echo "Failed to create $file" >&2
19+
return 0
20+
}
21+
fi
22+
23+
# Check file is writable
24+
if [ ! -w "$file" ]; then
25+
echo "Error: $file is not writable" >&2
26+
return 0
27+
fi
28+
29+
# Write JSON object with UnreachableServices key and the comma-separated list value
30+
jq -n --arg value "$value" '{"UnreachableServices": $value}' > "$file"
31+
}
32+
33+
# Configure AWS CLI region using environment variable REGION_NAME
34+
aws configure set region "${REGION_NAME}"
35+
echo "Successfully configured region to ${REGION_NAME}"
36+
37+
# Metadata file location containing DataZone info
38+
sourceMetaData=/opt/ml/metadata/resource-metadata.json
39+
40+
# Extract necessary DataZone metadata fields via jq
41+
dataZoneDomainId=$(jq -r '.AdditionalMetadata.DataZoneDomainId' < "$sourceMetaData")
42+
dataZoneProjectId=$(jq -r '.AdditionalMetadata.DataZoneProjectId' < "$sourceMetaData")
43+
dataZoneEndPoint=$(jq -r '.AdditionalMetadata.DataZoneEndpoint' < "$sourceMetaData")
44+
dataZoneDomainRegion=$(jq -r '.AdditionalMetadata.DataZoneDomainRegion' < "$sourceMetaData")
45+
46+
# Call AWS CLI list-connections, including endpoint if specified
47+
if [ -n "$dataZoneEndPoint" ]; then
48+
response=$(aws datazone list-connections \
49+
--endpoint-url "$dataZoneEndPoint" \
50+
--domain-identifier "$dataZoneDomainId" \
51+
--project-identifier "$dataZoneProjectId" \
52+
--region "$dataZoneDomainRegion")
53+
else
54+
response=$(aws datazone list-connections \
55+
--domain-identifier "$dataZoneDomainId" \
56+
--project-identifier "$dataZoneProjectId" \
57+
--region "$dataZoneDomainRegion")
58+
fi
59+
60+
# Extract each connection item as a compact JSON string
61+
connection_items=$(echo "$response" | jq -c '.items[]')
62+
63+
# Required AWS Services for Compute connections and Git
64+
# Initialize SERVICE_COMMANDS with always-needed STS and S3 checks
65+
declare -A SERVICE_COMMANDS=(
66+
["STS"]="aws sts get-caller-identity"
67+
["S3"]="aws s3api list-buckets --max-items 1"
68+
)
69+
70+
# Track connection types found for conditional checks
71+
declare -A seen_types=()
72+
73+
# Iterate over each connection to populate service commands conditionally
74+
while IFS= read -r item; do
75+
# Extract connection type
76+
type=$(echo "$item" | jq -r '.type')
77+
seen_types["$type"]=1
78+
79+
# For SPARK connections, check for Glue and EMR properties
80+
if [[ "$type" == "SPARK" ]]; then
81+
# If sparkGlueProperties present, add Glue check
82+
if echo "$item" | jq -e '.props.sparkGlueProperties' > /dev/null; then
83+
SERVICE_COMMANDS["Glue"]="aws glue get-crawlers --max-items 1"
84+
fi
85+
86+
# Check for emr-serverless in sparkEmrProperties.computeArn for EMR Serverless check
87+
emr_arn=$(echo "$item" | jq -r '.props.sparkEmrProperties.computeArn // empty')
88+
if [[ "$emr_arn" == *"emr-serverless"* ]]; then
89+
SERVICE_COMMANDS["EMR Serverless"]="aws emr-serverless list-applications --max-results 1"
90+
fi
91+
fi
92+
done <<< "$connection_items"
93+
94+
# Add Athena if ATHENA connection found
95+
[[ -n "${seen_types["ATHENA"]}" ]] && SERVICE_COMMANDS["Athena"]="aws athena list-data-catalogs --max-results 1"
96+
97+
# Add Redshift checks if REDSHIFT connection found
98+
if [[ -n "${seen_types["REDSHIFT"]}" ]]; then
99+
SERVICE_COMMANDS["Redshift Cluster"]="aws redshift describe-clusters --max-records 1"
100+
SERVICE_COMMANDS["Redshift Serverless"]="aws redshift-serverless list-namespaces --max-results 1"
101+
fi
102+
103+
# Optionally add CodeConnections if S3 storage flag is true (Git storage)
104+
if [[ "$is_s3_storage" == "1" ]]; then
105+
SERVICE_COMMANDS["CodeConnections"]="aws codeconnections list-hosts --max-results 1"
106+
fi
107+
108+
# Timeout (seconds) for each API call
109+
api_time_out_limit=5
110+
# Array to accumulate unreachable services
111+
unreachable_services=()
112+
# Create a temporary directory to store individual service results
113+
temp_dir=$(mktemp -d)
114+
115+
# Launch all service API checks in parallel background jobs
116+
for service in "${!SERVICE_COMMANDS[@]}"; do
117+
{
118+
# Run command with timeout, discard stdout/stderr
119+
if timeout "${api_time_out_limit}s" bash -c "${SERVICE_COMMANDS[$service]}" > /dev/null 2>&1; then
120+
# Success: write OK to temp file
121+
echo "OK" > "$temp_dir/$service"
122+
else
123+
# Get exit code to differentiate timeout or other errors
124+
exit_code=$?
125+
if [ "$exit_code" -eq 124 ]; then
126+
# Timeout exit code
127+
echo "TIMEOUT" > "$temp_dir/$service"
128+
else
129+
# Other errors (e.g., permission denied)
130+
echo "ERROR" > "$temp_dir/$service"
131+
fi
132+
fi
133+
} &
134+
done
135+
136+
# Wait for all background jobs to complete before continuing
137+
wait
138+
139+
# Process each service's result file to identify unreachable ones
140+
for service in "${!SERVICE_COMMANDS[@]}"; do
141+
result_file="$temp_dir/$service"
142+
if [ -f "$result_file" ]; then
143+
result=$(<"$result_file")
144+
if [[ "$result" == "TIMEOUT" ]]; then
145+
echo "$service API did NOT resolve within ${api_time_out_limit}s. Marking as unreachable."
146+
unreachable_services+=("$service")
147+
elif [[ "$result" == "OK" ]]; then
148+
echo "$service API is reachable."
149+
else
150+
echo "$service API returned an error (but not a timeout). Ignored for network check."
151+
fi
152+
else
153+
echo "$service check did not produce a result file. Skipping."
154+
fi
155+
done
156+
157+
# Cleanup temporary directory
158+
rm -rf "$temp_dir"
159+
160+
# Write unreachable services to file if any, else write empty string
161+
if (( ${#unreachable_services[@]} > 0 )); then
162+
joined_services=$(IFS=','; echo "${unreachable_services[*]}")
163+
# Add spaces after commas for readability
164+
joined_services_with_spaces=${joined_services//,/,\ }
165+
write_unreachable_services_to_file "$joined_services_with_spaces"
166+
echo "Unreachable AWS Services: ${joined_services_with_spaces}"
167+
else
168+
write_unreachable_services_to_file ""
169+
echo "All required AWS services reachable within ${api_time_out_limit}s"
170+
fi

template/v3/dirs/etc/sagemaker-ui/sagemaker_ui_post_startup.sh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,4 +204,22 @@ if [ "${SAGEMAKER_APP_TYPE_LOWERCASE}" = "jupyterlab" ]; then
204204
bash /etc/sagemaker-ui/workflows/sm-spark-cli-install.sh
205205
fi
206206

207+
# Execute network validation script, to check if any required AWS Services are unreachable
208+
echo "Starting network validation script..."
209+
210+
network_validation_file="/tmp/.network_validation.json"
211+
212+
# Run the validation script; only if it succeeds, check unreachable services
213+
if bash ./network_validation.sh "$is_s3_storage_flag" "$network_validation_file"; then
214+
# Read unreachable services from JSON file
215+
failed_services=$(jq -r '.UnreachableServices // empty' "$network_validation_file" || echo "")
216+
if [[ -n "$failed_services" ]]; then
217+
error_message="Network issue detected. The following AWS services are not reachable: $failed_services. Please contact your admin."
218+
write_status_to_file "error" "$error_message"
219+
echo "$error_message"
220+
fi
221+
else
222+
echo "Warning: network_validation.sh failed, skipping unreachable services check."
223+
fi
224+
207225
write_status_to_file_on_script_complete

0 commit comments

Comments
 (0)