Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 43 additions & 24 deletions tools/torchci/test_insights/file_report_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import gzip
import io
import json
import logging
import re
import time
import urllib.request
Expand All @@ -39,6 +40,15 @@
from torchci.clickhouse import query_clickhouse


logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
logger.setLevel(logging.DEBUG)
handler.setLevel(logging.DEBUG)
handler.setFormatter(formatter)
logger.addHandler(handler)


def get_temp_dir() -> Path:
"""Create a temporary directory for processing files"""
temp_dir = Path("/tmp/file_report_generator")
Expand All @@ -61,7 +71,7 @@ def __init__(self, dry_run: bool = True):
@lru_cache
def load_runner_costs(self) -> Dict[str, float]:
"""Load runner costs from the S3 endpoint"""
print("Fetching EC2 pricing data from S3...")
logger.debug("Fetching EC2 pricing data from S3...")
with urllib.request.urlopen(self.EC2_PRICING_URL) as response:
compressed_data = response.read()

Expand All @@ -77,7 +87,7 @@ def load_runner_costs(self) -> Dict[str, float]:
def load_test_owners(self) -> List[Dict[str, Any]]:
"""Load the test owner labels JSON file from S3"""
S3_URL = "https://ossci-metrics.s3.us-east-1.amazonaws.com/test_owner_labels/test_owner_labels.json.gz"
print(f"Fetching test owner labels from S3: {S3_URL}")
logger.debug(f"Fetching test owner labels from S3: {S3_URL}")
with urllib.request.urlopen(S3_URL) as response:
compressed_data = response.read()
decompressed_data = gzip.decompress(compressed_data)
Expand Down Expand Up @@ -107,7 +117,7 @@ def _get_first_suitable_sha(self, shas: list[dict[str, Any]]) -> Optional[str]:
has_no_job_name = True
break
if has_no_job_name:
print(f"Has entries with no job name for {head_sha}")
logger.debug(f"Has entries with no job name for {head_sha}")
continue

lens.append((head_sha, len(test_data)))
Expand All @@ -119,7 +129,7 @@ def _get_first_suitable_sha(self, shas: list[dict[str, Any]]) -> Optional[str]:
_, len2 = lens[1]

if abs(len1 - len2) * 2 / (len1 + len2) < 0.1:
print(f"Using SHA {sha1} with {len1} entries")
logger.debug(f"Using SHA {sha1} with {len1} entries")
return sha1
return None

Expand All @@ -135,7 +145,7 @@ def find_suitable_sha(self, date: str) -> Optional[str]:
- All test entries have job names
"""

print("Searching for suitable SHAs from PyTorch main branch...")
logger.debug("Searching for suitable SHAs from PyTorch main branch...")

params = {
"start_date": date + " 00:00:00",
Expand All @@ -160,10 +170,10 @@ def find_suitable_sha(self, date: str) -> Optional[str]:
ORDER BY
min(w.head_commit.'timestamp') DESC
"""
print(f"Querying ClickHouse for successful shas on {date}")
logger.debug(f"Querying ClickHouse for successful shas on {date}")
candidates = query_clickhouse(query, params)

print(f"Found {len(candidates)} candidate SHAs")
logger.debug(f"Found {len(candidates)} candidate SHAs")

return self._get_first_suitable_sha(candidates)

Expand All @@ -185,7 +195,7 @@ def _get_workflow_jobs_for_sha(self, sha: str) -> List[Dict[str, Any]]:

params = {"sha": sha}

print(f"Querying ClickHouse for workflow runs with SHA: {sha}")
logger.debug(f"Querying ClickHouse for workflow runs with SHA: {sha}")
result = query_clickhouse(query, params)

for row in result:
Expand Down Expand Up @@ -265,7 +275,7 @@ def _fetch_from_s3(self, bucket: str, key: str) -> str:
try:
file_loc = get_temp_dir() / f"cache_{bucket}_{key.replace('/', '_')}"
if file_loc.exists():
print(f"Using cached download for {file_loc}")
logger.debug(f"Using cached download for {file_loc}")
compressed_data = file_loc.read_bytes()
else:
url = f"https://{bucket}.s3.amazonaws.com/{key}"
Expand All @@ -279,7 +289,7 @@ def _fetch_from_s3(self, bucket: str, key: str) -> str:
text_data = decompressed_data.decode("utf-8")
return text_data
except Exception as e:
print(f"Failed to fetch from s3://{bucket}/{key}: {e}")
logger.debug(f"Failed to fetch from s3://{bucket}/{key}: {e}")
raise e

def _fetch_invoking_file_summary_from_s3(
Expand All @@ -305,7 +315,7 @@ def _fetch_invoking_file_summary_from_s3(
entry["short_job_name"] = f"{build} / test ({config})"
data_as_list.append(entry)

print(
logger.debug(
f"Fetched {len(data_as_list)} test entries from {key}, took {time.time() - start_time:.2f} seconds"
)
return data_as_list
Expand Down Expand Up @@ -403,7 +413,7 @@ def _fetch_status_changes_from_s3(
data["run_id"] = workflow_run_id
test_data.append(data)

print(
logger.debug(
f"Fetched {len(test_data)} test entries from {key}, took {time.time() - start_time:.2f} seconds"
)
return test_data
Expand Down Expand Up @@ -456,7 +466,9 @@ def _check_status_change_already_exists(self, sha1: str, sha2: str) -> bool:
try:
with urllib.request.urlopen(url) as response:
if response.status == 200:
print(f"Status changes for {sha1} to {sha2} already exist in S3.")
logger.debug(
f"Status changes for {sha1} to {sha2} already exist in S3."
)
return True
except Exception:
pass
Expand Down Expand Up @@ -515,6 +527,9 @@ def get_status_changes(
to_write = []
for key, entries in counts.items():
to_write.extend(entries[:10])
logger.debug(
f"Found {len(status_changes)} status changes between {sha1} and {sha2}, truncated to {len(to_write)} for upload"
)

self.upload_to_s3(
to_write,
Expand Down Expand Up @@ -588,9 +603,14 @@ def upload_to_s3(
html_url = f"https://{bucket_name}.s3.amazonaws.com/{key}"

if self.dry_run:
print(f"Dry run: would upload data to s3: {html_url}")
local_file = get_temp_dir() / f"dry_run_{key.replace('/', '_')}.json"
logger.info(
f"Dry run: would upload data to s3: {html_url}, writing to local file {local_file} instead"
)
with open(local_file, "w") as f:
f.write(body.getvalue())
return
print(f"Uploading data to s3: {html_url}")
logger.info(f"Uploading data to s3: {html_url}")
self.get_s3_resource().Object(bucket_name, key).put(
Body=gzip.compress(body.getvalue().encode()),
ContentEncoding="gzip",
Expand All @@ -599,12 +619,11 @@ def upload_to_s3(

def remove_key_from_s3(self, bucket: str, key: str) -> None:
"""Remove a specific key from S3"""
s3_path = f"s3://{bucket}/{key}"
html_url = f"https://{bucket}.s3.amazonaws.com/{key}"
if self.dry_run:
print(f"Dry run: would remove from s3: {html_url}")
logger.info(f"Dry run: would remove from s3: {html_url}")
return
print(f"Removing from s3: {html_url}")
logger.info(f"Removing from s3: {html_url}")
self.get_s3_resource().Object(bucket, key).delete()


Expand Down Expand Up @@ -645,7 +664,7 @@ def main() -> None:
if args.remove_sha:
for i, entry in enumerate(existing_metadata):
if entry["sha"] == args.remove_sha:
print(f"Removing SHA {args.remove_sha} from existing metadata")
logger.info(f"Removing SHA {args.remove_sha} from existing metadata")
generator.remove_key_from_s3(
"ossci-raw-job-status",
f"additional_info/weekly_file_report/data_{args.remove_sha}.json.gz",
Expand All @@ -668,19 +687,19 @@ def main() -> None:
shas: list[str] = []
for date in args.add_dates or []:
if date in _existing_dates:
print(f"Date {date} already exists in metadata, skipping")
logger.info(f"Date {date} already exists in metadata, skipping")
continue
sha = generator.find_suitable_sha(date)
if sha is None:
print(f"No suitable SHA found for date {date}, skipping")
logger.info(f"No suitable SHA found for date {date}, skipping")
continue
print(f"Found suitable SHA {sha} for date {date}")
logger.info(f"Found suitable SHA {sha} for date {date}")
shas.append(sha)

for sha in args.add_shas or []:
shas.append(cast(str, sha))

print(f"Adding SHAs: {shas}")
logger.info(f"Adding SHAs: {shas}")

# Load data to get dates/ordering
for sha in shas:
Expand All @@ -690,7 +709,7 @@ def main() -> None:

existing_metadata = sorted(existing_metadata, key=lambda x: x["push_date"])

print("Calculating diffs for all files and grouping by labels...")
logger.debug("Calculating diffs for all files and grouping by labels...")
for i in range(1, len(existing_metadata)):
if not generator._check_status_change_already_exists(
existing_metadata[i - 1]["sha"],
Expand Down