|
| 1 | +import polars as pl |
| 2 | +from typing import List, Optional |
| 3 | +import numpy as np |
| 4 | +import pandas as pd |
| 5 | +import canopy |
| 6 | +import logging |
| 7 | +import asyncio |
| 8 | + |
| 9 | +logger = logging.getLogger(__name__) |
| 10 | + |
| 11 | + |
| 12 | +async def load_channels( |
| 13 | + session: canopy.Session, |
| 14 | + job_access_information: canopy.openapi.BlobAccessInformation, |
| 15 | + sim_type: str, |
| 16 | + channel_names: List[str], |
| 17 | + vector_metadata: Optional[pd.DataFrame] = None, |
| 18 | + semaphore: Optional[asyncio.Semaphore] = None) -> List[Optional[canopy.LoadedChannel]]: |
| 19 | + sim_type = canopy.ensure_sim_type_string(sim_type) |
| 20 | + |
| 21 | + if semaphore is None: |
| 22 | + semaphore = asyncio.Semaphore(session.default_blob_storage_concurrency) |
| 23 | + |
| 24 | + if vector_metadata is None: |
| 25 | + vector_metadata = await canopy.load_vector_metadata(session, job_access_information, sim_type) |
| 26 | + |
| 27 | + if vector_metadata is None: |
| 28 | + return [None] * len(channel_names) |
| 29 | + |
| 30 | + # First attempt to load from parquet if available |
| 31 | + # We group channels by their x-domain to load from the correct parquet files. |
| 32 | + parquet_results = {} |
| 33 | + |
| 34 | + channels_by_x_domain = {} |
| 35 | + for name in channel_names: |
| 36 | + if name in vector_metadata.index: |
| 37 | + x_domain = vector_metadata.at[name, 'xDomainName'] |
| 38 | + if pd.isna(x_domain) or not x_domain: |
| 39 | + continue |
| 40 | + |
| 41 | + if x_domain not in channels_by_x_domain: |
| 42 | + channels_by_x_domain[x_domain] = [] |
| 43 | + channels_by_x_domain[x_domain].append(name) |
| 44 | + |
| 45 | + # Try loading from parquet for each x-domain |
| 46 | + for x_domain, domain_channels in channels_by_x_domain.items(): |
| 47 | + loaded_from_parquet = await _try_load_channels_from_parquet( |
| 48 | + job_access_information, |
| 49 | + sim_type, |
| 50 | + x_domain, |
| 51 | + domain_channels, |
| 52 | + vector_metadata) |
| 53 | + |
| 54 | + if loaded_from_parquet: |
| 55 | + for channel in loaded_from_parquet: |
| 56 | + if channel is not None: |
| 57 | + parquet_results[channel.name] = channel |
| 58 | + |
| 59 | + async def _load_channel(channel_name: str) -> Optional[canopy.LoadedChannel]: |
| 60 | + if channel_name in parquet_results: |
| 61 | + return parquet_results[channel_name] |
| 62 | + |
| 63 | + async with semaphore: |
| 64 | + if channel_name not in vector_metadata.index: |
| 65 | + logger.debug('Channel not found: %s', channel_name) |
| 66 | + return None |
| 67 | + |
| 68 | + channel_metadata = vector_metadata.xs(channel_name) |
| 69 | + |
| 70 | + points_count: int = channel_metadata['NPtsInChannel'] |
| 71 | + units: str = channel_metadata['units'] |
| 72 | + |
| 73 | + file_name = ''.join([sim_type, '_', channel_name, '.bin']) |
| 74 | + channel_url = ''.join([job_access_information.url, file_name, job_access_information.access_signature]) |
| 75 | + |
| 76 | + channel_bytes: Optional[bytes] = await session.try_load_bytes( |
| 77 | + channel_url, |
| 78 | + f'"{file_name}" from "{job_access_information.url}"') |
| 79 | + |
| 80 | + if channel_bytes is None: |
| 81 | + return None |
| 82 | + |
| 83 | + if points_count * 4 == len(channel_bytes): |
| 84 | + data_type = np.float32 |
| 85 | + else: |
| 86 | + data_type = np.float64 |
| 87 | + channel_data: np.array = np.frombuffer(channel_bytes, data_type) |
| 88 | + |
| 89 | + loaded_channel = canopy.LoadedChannel(channel_name, units, channel_data) |
| 90 | + return loaded_channel |
| 91 | + |
| 92 | + return await asyncio.gather(*[_load_channel(name) for name in channel_names]) |
| 93 | + |
| 94 | +async def _try_load_channels_from_parquet( |
| 95 | + job_access_information: canopy.openapi.BlobAccessInformation, |
| 96 | + sim_type: str, |
| 97 | + x_domain: str, |
| 98 | + channel_names: List[str], |
| 99 | + vector_metadata: pd.DataFrame) -> Optional[List[Optional[canopy.LoadedChannel]]]: |
| 100 | + file_name = f'{sim_type}_{x_domain}_VectorResults.parquet' |
| 101 | + url = f'{job_access_information.url}{file_name}{job_access_information.access_signature}' |
| 102 | + |
| 103 | + try: |
| 104 | + # We'll use a single scan for all requested channels that exist in metadata. |
| 105 | + valid_channels = [name for name in channel_names if name in vector_metadata.index] |
| 106 | + if not valid_channels: |
| 107 | + return None |
| 108 | + |
| 109 | + # Fetch all required columns in one go |
| 110 | + df: pl.DataFrame = await (pl.scan_parquet(url, parallel="columns", storage_options={ "max_retries": 1, "retry_timeout_ms": 100 }) |
| 111 | + .select(valid_channels) |
| 112 | + .collect_async()) |
| 113 | + |
| 114 | + return [ |
| 115 | + canopy.LoadedChannel(name, str(vector_metadata.at[name, "units"]), df.get_column(name).to_numpy()) |
| 116 | + if name in valid_channels |
| 117 | + else None |
| 118 | + for name in channel_names |
| 119 | + ] |
| 120 | + except Exception as e: |
| 121 | + logger.debug(f"Failed to load channels from parquet {file_name}: {e}") |
| 122 | + return None |
0 commit comments