diff --git a/install.sh b/install.sh old mode 100755 new mode 100644 index df8394b..8be157b --- a/install.sh +++ b/install.sh @@ -210,7 +210,7 @@ configure() { log_info "Configuring SeeSpot..." # Default values - DEFAULT_CACHE_DIR="$HOME/.seespot/cache" + DEFAULT_CACHE_DIR="$HOME/capsule/scratch/.seespot/cache" DEFAULT_PORT=5555 DEFAULT_HOST="0.0.0.0" diff --git a/src/see_spot/app.py b/src/see_spot/app.py index 0b1a377..f107362 100644 --- a/src/see_spot/app.py +++ b/src/see_spot/app.py @@ -10,6 +10,7 @@ import os from pathlib import Path import polars as pl +import pandas as pd import itertools from typing import List, Tuple, Dict, Any import yaml @@ -300,10 +301,13 @@ def calculate_sankey_data(df: Any) -> Dict[str, Any]: async def get_real_spots_data( sample_size: int = SAMPLE_SIZE, force_refresh: bool = False, - valid_spots_only: bool = False + valid_spots_only: bool = False, + sampling_type: str = "class_balanced", + display_chan: str = "mixed" ): logger.info(f"Real spots data requested with sample size: {sample_size}, " - f"force_refresh: {force_refresh}, valid_spots_only: {valid_spots_only}") + f"force_refresh: {force_refresh}, valid_spots_only: {valid_spots_only}, " + f"sampling_type: {sampling_type}, display_chan: {display_chan}") # Check if a dataset has been selected if DATA_PREFIX is None: @@ -452,7 +456,53 @@ async def get_real_spots_data( # 4. Subsample the data if len(df) > sample_size: logger.info(f"Subsampling DataFrame from {len(df)} to {sample_size} rows.") - plot_df = df.sample(n=sample_size, random_state=None).copy() + + if sampling_type == "class_balanced": + # Class-balanced sampling: sample equally from each channel + # Use the appropriate channel column based on display mode + channel_col = 'chan' if display_chan == 'mixed' else 'unmixed_chan' + logger.info(f"Using class-balanced sampling on column: {channel_col}") + + # Get unique channels and their counts + unique_channels = df[channel_col].unique() + num_channels = len(unique_channels) + samples_per_channel = sample_size // num_channels + + logger.info(f"Found {num_channels} unique channels, sampling {samples_per_channel} per channel") + + # Sample from each channel + sampled_dfs = [] + import secrets + for channel in unique_channels: + channel_df = df[df[channel_col] == channel] + n_samples = min(len(channel_df), samples_per_channel) + random_seed = secrets.randbelow(2**32) + sampled = channel_df.sample(n=n_samples, random_state=random_seed) + sampled_dfs.append(sampled) + logger.info(f"Channel {channel}: sampled {n_samples} from {len(channel_df)} spots") + + # Concatenate all samples + plot_df = pd.concat(sampled_dfs, ignore_index=True) + + # If we're short of target sample size, add random samples to fill + if len(plot_df) < sample_size: + remaining = sample_size - len(plot_df) + # Get spots not already sampled + remaining_df = df[~df.index.isin(plot_df.index)] + if len(remaining_df) > 0: + random_seed = secrets.randbelow(2**32) + additional = remaining_df.sample(n=min(remaining, len(remaining_df)), random_state=random_seed) + plot_df = pd.concat([plot_df, additional], ignore_index=True) + logger.info(f"Added {len(additional)} additional random samples to reach target") + + plot_df = plot_df.copy() + logger.info(f"Class-balanced sampling complete: {len(plot_df)} total samples") + else: + # Random sampling + import secrets + random_seed = secrets.randbelow(2**32) + plot_df = df.sample(n=sample_size, random_state=random_seed).copy() + logger.info(f"Random sampling with seed: {random_seed}") else: plot_df = df.copy() logger.info(f"Plotting DataFrame shape: {plot_df.shape}") @@ -516,7 +566,8 @@ async def get_real_spots_data( else: spot_details = { str(int(row['spot_id'])): { - col: row[col] for col in available_detail_cols if col != 'spot_id' + col: (row[col].item() if hasattr(row[col], 'item') else row[col]) + for col in available_detail_cols if col != 'spot_id' } for _, row in spot_details_df.iterrows() } @@ -549,6 +600,11 @@ async def get_real_spots_data( # 8. Convert DataFrame to list of records (dictionaries) try: data_for_frontend = plot_df_subset.to_dict(orient='records') + # Convert numpy types to native Python types for JSON serialization + for record in data_for_frontend: + for key, value in record.items(): + if hasattr(value, 'item'): # numpy scalar + record[key] = value.item() logger.info(f"Prepared {len(data_for_frontend)} records for frontend.") except Exception as e: logger.error(f"Error converting DataFrame to dict: {e}", exc_info=True) @@ -662,13 +718,41 @@ async def create_neuroglancer_link(request: Request): cross_section_scale, ) - # Determine JSON file name (env override allowed) and full S3 path - ng_json_filename = os.getenv( - "SEE_SPOT_NG_JSON_NAME", "phase_correlation_stitching_neuroglancer.json" - ) - ng_json_path = f"s3://{S3_BUCKET}/{DATA_PREFIX}/{ng_json_filename}" - s3_key_for_json = f"{DATA_PREFIX}/{ng_json_filename}" # key relative to bucket - logger.info("Constructed Neuroglancer JSON path: %s", ng_json_path) + # Determine dataset context (tile-specific JSONs live under image_spot_detection) + import re + + tile_pattern = re.compile(r"_X_\d+_Y_\d+_Z_\d+$") + tile_match = tile_pattern.search(DATA_PREFIX) if DATA_PREFIX else None + base_dataset_name = DATA_PREFIX + tile_folder = None + + if tile_match and DATA_PREFIX: + parts = DATA_PREFIX.rsplit('_', 6) + if len(parts) > 1: + base_dataset_name = parts[0] + tile_suffix = '_'.join(parts[1:]) + tile_folder = f"Tile_{tile_suffix}" + logger.info( + "Detected tile dataset for Neuroglancer request | base=%s tile=%s", + base_dataset_name, + tile_folder, + ) + + # Determine JSON file name (env override allowed for non-tile datasets) and full S3 path + if tile_folder: + ng_json_filename = f"{tile_folder}_spot_annotation_ng_link.json" + s3_key_for_json = ( + f"{base_dataset_name}/image_spot_detection/{ng_json_filename}" + ) + ng_json_path = f"s3://{S3_BUCKET}/{s3_key_for_json}" + logger.info("Using tile-specific Neuroglancer JSON: %s", ng_json_path) + else: + ng_json_filename = os.getenv( + "SEE_SPOT_NG_JSON_NAME", "phase_correlation_stitching_neuroglancer.json" + ) + ng_json_path = f"s3://{S3_BUCKET}/{DATA_PREFIX}/{ng_json_filename}" + s3_key_for_json = f"{DATA_PREFIX}/{ng_json_filename}" # key relative to bucket + logger.info("Using default Neuroglancer JSON: %s", ng_json_path) # Check existence of JSON on S3 (metadata only) for better diagnostics json_metadata = None @@ -692,7 +776,20 @@ async def create_neuroglancer_link(request: Request): ) # Decide strategy: prefer JSON-based method when file exists; fall back otherwise - use_json_method = json_metadata is not None or "merged" in unmixed_spots_filename.lower() + if tile_folder: + use_json_method = json_metadata is not None + if not use_json_method: + logger.warning( + "Tile-specific Neuroglancer JSON missing: %s | falling back to direct method", + ng_json_path, + ) + else: + merged_flag = ( + unmixed_spots_filename.lower().find("merged") != -1 + if isinstance(unmixed_spots_filename, str) + else False + ) + use_json_method = json_metadata is not None or merged_flag logger.info("Use JSON method decision: %s", use_json_method) try: @@ -711,6 +808,7 @@ async def create_neuroglancer_link(request: Request): annotation_color=annotation_color, spacing=3.0, cross_section_scale=cross_section_scale, + hide_existing_annotations=True, ) logger.info("Successfully built Neuroglancer link from JSON") except Exception as json_err: diff --git a/src/see_spot/ng_utils.py b/src/see_spot/ng_utils.py index f760612..f587e2a 100644 --- a/src/see_spot/ng_utils.py +++ b/src/see_spot/ng_utils.py @@ -259,6 +259,7 @@ def create_link_from_json( spacing=3.0, cross_section_scale=None, base_url="https://neuroglancer-demo.appspot.com", + hide_existing_annotations=True, ): """ Create a Neuroglancer link from an existing JSON file with updated position and annotation. @@ -274,6 +275,9 @@ def create_link_from_json( cross_section_scale (float, optional): Scale for cross-section view. If None, keeps existing value base_url (str, optional): Base Neuroglancer URL. Default: "https://neuroglancer-demo.appspot.com" + hide_existing_annotations (bool, optional): When True, sets existing annotation + layers to invisible before adding the new spot annotation. Default: True + Returns: -------- str: Direct Neuroglancer URL with updated state @@ -318,55 +322,52 @@ def create_link_from_json( state_dict["crossSectionScale"] = cross_section_scale print(f"Updated crossSectionScale to: {cross_section_scale}") - # Find or create annotation layer - annotation_layer_found = False - - if "layers" in state_dict: - # Look for existing annotation layer - for i, layer in enumerate(state_dict["layers"]): + # Hide existing annotation layers if requested + if hide_existing_annotations and "layers" in state_dict: + hidden_layers = 0 + for layer in state_dict["layers"]: if layer.get("type") == "annotation": - # Update existing annotation layer - annotation = { - "type": "point", - "id": str(spot_id), - "point": point_annotation, - } - - # Update the layer properties - state_dict["layers"][i]["name"] = f"Spot {spot_id}" - state_dict["layers"][i]["annotationColor"] = annotation_color - state_dict["layers"][i][ - "crossSectionAnnotationSpacing" - ] = spacing - state_dict["layers"][i]["annotations"] = [annotation] - - annotation_layer_found = True - print(f"Updated existing annotation layer with spot {spot_id}") - break - - # If no annotation layer exists, create one - if not annotation_layer_found: - annotation_layer = { - "type": "annotation", - "name": f"Spot {spot_id}", - "tab": "annotations", - "visible": True, - "annotationColor": annotation_color, - "crossSectionAnnotationSpacing": spacing, - "projectionAnnotationSpacing": 10, - "tool": "annotatePoint", - "annotations": [ - { - "type": "point", - "id": str(spot_id), - "point": point_annotation, - } - ], + layer["visible"] = False + hidden_layers += 1 + if hidden_layers: + print(f"Hid {hidden_layers} existing annotation layer(s) before adding spot {spot_id}") + + # Ensure layers list exists and append fresh annotation layer for the selected spot + if "layers" not in state_dict or not isinstance(state_dict["layers"], list): + state_dict["layers"] = [] + + spot_layer_name = f"Spot {spot_id}" + + # Remove any prior custom layer for this spot to avoid duplication + state_dict["layers"] = [ + layer + for layer in state_dict["layers"] + if not ( + layer.get("type") == "annotation" + and layer.get("name") == spot_layer_name + and layer.get("tab") == "annotations" + ) + ] + + annotation_layer = { + "type": "annotation", + "name": spot_layer_name, + "tab": "annotations", + "visible": True, + "annotationColor": annotation_color, + "crossSectionAnnotationSpacing": spacing, + "projectionAnnotationSpacing": 10, + "tool": "annotatePoint", + "annotations": [ + { + "type": "point", + "id": str(spot_id), + "point": point_annotation, } - state_dict["layers"].append(annotation_layer) - print(f"Created new annotation layer with spot {spot_id}") - else: - print("Warning: No 'layers' found in Neuroglancer state") + ], + } + state_dict["layers"].append(annotation_layer) + print(f"Appended new annotation layer with spot {spot_id}") # Generate direct URL direct_url = create_direct_neuroglancer_url(state_dict, base_url=base_url) diff --git a/src/see_spot/static/js/unmixed_spots.js b/src/see_spot/static/js/unmixed_spots.js index 3dcdba1..1cbe14a 100644 --- a/src/see_spot/static/js/unmixed_spots.js +++ b/src/see_spot/static/js/unmixed_spots.js @@ -12,6 +12,7 @@ document.addEventListener('DOMContentLoaded', function () { const nextChannelButton = document.getElementById('next_channel_pair'); const currentChannelDisplay = document.getElementById('current_channel_display'); const sampleSizeInput = document.getElementById('sample-size-input'); + const samplingTypeSelect = document.getElementById('sampling-type-select'); const resampleButton = document.getElementById('resample_button'); const sampleSizeNote = document.getElementById('sample_size_note'); const sampleSizeIcon = document.getElementById('sample_size_icon'); @@ -54,11 +55,13 @@ document.addEventListener('DOMContentLoaded', function () { let channelPairs = []; let currentPairIndex = 0; let currentSampleSize = parseInt(sampleSizeInput.value) || 10000; + let samplingType = 'class_balanced'; // 'class_balanced' or 'random' let highlightReassigned = false; let highlightRemoved = false; let displayChanMode = 'mixed'; // 'unmixed' or 'mixed' let isNeuroglancerMode = false; let showDyeLines = false; // Toggle state for dye lines + let channelVisibilityState = {}; // Track user's manual legend selections let spotDetails = {}; // Will store the spot details for neuroglancer lookup let fusedS3Paths = {}; // Will store the fused S3 paths from the API let summaryStats = null; // Will store the summary stats from the API @@ -671,25 +674,31 @@ document.addEventListener('DOMContentLoaded', function () { } currentSampleSize = newSampleSize; + const selectedValue = samplingTypeSelect.value; + samplingType = selectedValue; // Get current sampling type + console.log(`Resample clicked: dropdown value = ${selectedValue}, samplingType = ${samplingType}, displayChanMode = ${displayChanMode}`); + console.log(`Dropdown element:`, samplingTypeSelect); updateSampleSizeNote(currentSampleSize); // Show loading state myChart.showLoading({ - text: 'Loading new sample...', + text: `Loading new sample (${samplingType})...`, maskColor: 'rgba(255, 255, 255, 0.8)', fontSize: 14 }); - // Fetch data with new sample size + // Fetch data with new sample size and sampling type fetchData(currentSampleSize, false); }); // Handle refresh button click (force reload data from server) - refreshButton.addEventListener('click', function() { - if (confirm("This will reload data from the server. Continue?")) { - refreshData(true); - } - }); + if (refreshButton) { + refreshButton.addEventListener('click', function() { + if (confirm("This will reload data from the server. Continue?")) { + refreshData(true); + } + }); + } // Add function for refresh button function refreshData(forceRefresh = true) { @@ -735,8 +744,8 @@ document.addEventListener('DOMContentLoaded', function () { // Fetch data function function fetchData(sampleSize, forceRefresh = false) { const validSpotsOnly = false; // validSpotToggle.checked; // Toggle disabled - const url = `/api/real_spots_data?sample_size=${sampleSize}${forceRefresh ? '&force_refresh=true' : ''}${validSpotsOnly ? '&valid_spots_only=true' : '&valid_spots_only=false'}`; - console.log(`Fetching data with URL: ${url}`); + const url = `/api/real_spots_data?sample_size=${sampleSize}&sampling_type=${samplingType}&display_chan=${displayChanMode}${forceRefresh ? '&force_refresh=true' : ''}${validSpotsOnly ? '&valid_spots_only=true' : '&valid_spots_only=false'}`; + console.log(`Fetching data with URL: ${url} (sampling: ${samplingType}, display: ${displayChanMode})`); fetch(url) .then(response => { @@ -1226,10 +1235,17 @@ document.addEventListener('DOMContentLoaded', function () { return a.localeCompare(b); }); - const series = sortedChannels.map(channel => ({ - name: channel, // Remove the Mixed/Unmixed prefix from individual labels - type: 'scatter', - data: seriesData[channel].map(point => { + const series = sortedChannels.map(channel => { + // Start with series hidden if it's "Removed" in unmixed mode + const isRemovedSeries = channel === 'Removed'; + const shouldHideByDefault = isRemovedSeries && displayChanMode === 'unmixed'; + + return { + name: channel, // Remove the Mixed/Unmixed prefix from individual labels + type: 'scatter', + // Hide "Removed" series by default in unmixed mode + selected: !shouldHideByDefault, + data: seriesData[channel].map(point => { const spotId = point.value[4]; const isClicked = neuroglancerClickedSpots.has(spotId); const baseSize = (channel === 'Removed' ? 8 : 5) * markerSizeMultiplier; @@ -1346,7 +1362,8 @@ document.addEventListener('DOMContentLoaded', function () { }, // Use fixed color for the legend color: COLORS[channel] || COLORS.default - })); + }; + }); // Configuration for slider positioning and styling const sliderConfig = { @@ -1521,7 +1538,16 @@ document.addEventListener('DOMContentLoaded', function () { fontSize: 14 }, selected: sortedChannels.reduce((acc, chan) => { - acc[chan] = true; // Use channel name without prefix + // Check if user has manually set preference for this channel + if (channelVisibilityState.hasOwnProperty(chan)) { + // Use user's preference + acc[chan] = channelVisibilityState[chan]; + } else { + // Apply default behavior: hide "Removed" by default in unmixed mode + const isRemovedSeries = chan === 'Removed'; + const shouldHideByDefault = isRemovedSeries && displayChanMode === 'unmixed'; + acc[chan] = !shouldHideByDefault; + } return acc; }, {}) }, @@ -1667,6 +1693,13 @@ document.addEventListener('DOMContentLoaded', function () { } }); + // Capture legend selection changes to persist user preferences + myChart.on('legendselectchanged', function (params) { + console.log('Legend selection changed:', params.selected); + // Update our state tracking with user's selections + channelVisibilityState = Object.assign({}, params.selected); + }); + // Brush (lasso) selection event myChart.on('brushSelected', function (params) { lassoSelectedData = []; @@ -2080,6 +2113,13 @@ document.addEventListener('DOMContentLoaded', function () { // Update chart with new channel display mode updateChart(); }); + + // Event listener for sampling type select + samplingTypeSelect.addEventListener('change', function() { + samplingType = this.value; + console.log(`Sampling type changed to: ${samplingType}`); + // Note: Does not automatically resample - user must click Resample button + }); // Event listener for highlight removed toggle highlightRemovedToggle.addEventListener('change', function() { diff --git a/src/see_spot/templates/unmixed_spots.html b/src/see_spot/templates/unmixed_spots.html index 3c1e3b3..a3c64d1 100644 --- a/src/see_spot/templates/unmixed_spots.html +++ b/src/see_spot/templates/unmixed_spots.html @@ -658,13 +658,18 @@