diff --git a/src/segger/cli/main.py b/src/segger/cli/main.py index 27e04d5..5a3bbe4 100644 --- a/src/segger/cli/main.py +++ b/src/segger/cli/main.py @@ -75,7 +75,16 @@ def segment( validator=validators.Path(exists=True, dir_okay=True), )] = registry.get_default("output_directory"), - + save_anndata: Annotated[bool, registry.get_parameter( + "save_anndata", + group=group_io, + )] = registry.get_default("save_anndata"), + + save_cell_masks: Annotated[bool, registry.get_parameter( + "save_cell_masks", + group=group_io, + )] = registry.get_default("save_cell_masks"), + # Cell Representation node_representation_dim: Annotated[int, Parameter( help="Number of dimensions used to represent each node type.", @@ -121,7 +130,6 @@ def segment( group=group_nodes, )] = registry.get_default("genes_clusters_resolution"), - # Transcript-Transcript Graph transcripts_max_k: Annotated[int, registry.get_parameter( "transcripts_graph_max_k", @@ -145,6 +153,12 @@ def segment( ) ] = registry.get_default("prediction_graph_mode"), + prediction_expansion_ratio: Annotated[float, registry.get_parameter( + "prediction_graph_buffer_ratio", + validator=validators.Number(gt=0), + group=group_prediction, + )] = registry.get_default("prediction_graph_buffer_ratio"), + prediction_max_k: Annotated[int | None, registry.get_parameter( "prediction_graph_max_k", validator=validators.Number(gt=0), @@ -342,7 +356,11 @@ def segment( from ..data import ISTSegmentationWriter from lightning.pytorch import Trainer logger = CSVLogger(output_directory) - writer = ISTSegmentationWriter(output_directory) + writer = ISTSegmentationWriter( + output_directory=output_directory, + save_anndata=save_anndata, + save_cell_masks=save_cell_masks, + ) trainer = Trainer( logger=logger, max_epochs=n_epochs, diff --git a/src/segger/data/writer.py b/src/segger/data/writer.py index 7889c9a..1ab99da 100644 --- a/src/segger/data/writer.py +++ b/src/segger/data/writer.py @@ -8,6 +8,8 @@ from ..io import TrainingTranscriptFields, TrainingBoundaryFields from . import ISTDataModule +from .utils.anndata import anndata_from_transcripts +from ..geometry import generate_cell_boundaries def threshold(x): @@ -24,9 +26,17 @@ class ISTSegmentationWriter(BasePredictionWriter): Path to write outputs. """ - def __init__(self, output_directory: Path): + def __init__( + self, + output_directory: Path, + save_anndata: bool = True, + save_cell_masks: bool = False, + ): super().__init__(write_interval="epoch") self.output_directory = Path(output_directory) + self.output_directory.mkdir(parents=True, exist_ok=True) + self.save_anndata = save_anndata + self.save_cell_masks = save_cell_masks def write_on_epoch_end( self, @@ -125,10 +135,93 @@ def write_on_epoch_end( .alias("similarity_threshold") ) ) - # Join and write output to file + # Join thresholds + segmentation = segmentation.join(thresholds, on=tx_fields.feature, how='left') + + # Map gene encoding to gene names + gene_index = ( + pl + .from_pandas(trainer.datamodule.ad.var.reset_index()) + .rename({"index": tx_fields.feature}) + .select([tx_fields.feature, tx_fields.gene_encoding]) + ) + segmentation = ( + segmentation + .rename({tx_fields.feature: tx_fields.gene_encoding}) + .join(gene_index, on=tx_fields.gene_encoding, how='left') + ) + + # Write segmentation output (keep prior columns) ( segmentation - .join(thresholds, on=tx_fields.feature, how='left') - .drop(tx_fields.feature) + .drop([tx_fields.feature, tx_fields.gene_encoding]) .write_parquet(self.output_directory / 'segger_segmentation.parquet') ) + + transcripts = None + if self.save_anndata or self.save_cell_masks: + transcripts = ( + segmentation + .join( + trainer.datamodule.tx.select([ + tx_fields.row_index, + tx_fields.x, + tx_fields.y, + tx_fields.feature, + ]), + on=tx_fields.row_index, + how='left', + ) + .rename({tx_fields.feature: "segger_gene"}) + ) + + # Optional: save AnnData + if self.save_anndata and transcripts is not None: + adata = anndata_from_transcripts( + transcripts.select([ + tx_fields.row_index, + "segger_gene", + "segger_cell_id", + "segger_similarity", + "similarity_threshold", + tx_fields.x, + tx_fields.y, + ]), + feature_column="segger_gene", + cell_id_column="segger_cell_id", + score_column="segger_similarity", + coordinate_columns=[tx_fields.x, tx_fields.y], + ) + adata.write_h5ad(self.output_directory / 'segger_anndata.h5ad') + + if self.save_cell_masks and transcripts is not None: + cell_ids = ( + transcripts + .get_column("segger_cell_id") + .drop_nulls() + .unique() + .to_list() + ) + + if len(cell_ids) > 0: + bd_fields = TrainingBoundaryFields() + boundaries = trainer.datamodule.bd + cell_boundaries = boundaries[ + (boundaries[bd_fields.boundary_type] == bd_fields.cell_value) + & (boundaries[bd_fields.id].isin(cell_ids)) + ] + + if cell_boundaries.empty: + cell_boundaries = generate_cell_boundaries( + transcripts=transcripts, + x_column=tx_fields.x, + y_column=tx_fields.y, + cell_id_column="segger_cell_id", + ) + + if not cell_boundaries.empty: + cell_boundaries.to_parquet( + self.output_directory / "segger_cell_boundaries.parquet", + write_covering_bbox=True, + geometry_encoding="geoarrow", + ) diff --git a/src/segger/geometry/__init__.py b/src/segger/geometry/__init__.py index 3abe3e5..dccf95f 100644 --- a/src/segger/geometry/__init__.py +++ b/src/segger/geometry/__init__.py @@ -1,4 +1,5 @@ from .conversion import points_to_geoseries, polygons_to_geoseries from .query import points_in_polygons, polygons_in_polygons from .quadtree import get_quadtree_index, quadtree_to_geoseries -from .morphology import get_polygon_props \ No newline at end of file +from .morphology import get_polygon_props +from .boundaries import generate_cell_boundaries \ No newline at end of file diff --git a/src/segger/geometry/boundaries.py b/src/segger/geometry/boundaries.py new file mode 100644 index 0000000..6509ddb --- /dev/null +++ b/src/segger/geometry/boundaries.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import geopandas as gpd +import polars as pl +import shapely +from shapely.geometry import LineString, MultiPoint, Point + + +def _points_to_polygon( + coords: list[list[float]] | "np.ndarray", + concave_ratio: float, + min_buffer: float, +): + """ + Convert a set of 2D points to a polygon hull. + """ + # Small-cardinality fallbacks + if len(coords) == 1: + return Point(coords[0]).buffer(min_buffer) + if len(coords) == 2: + return LineString(coords).buffer(min_buffer) + + multipoint = MultiPoint(coords) + concave = getattr(shapely, "concave_hull", None) + + if concave is not None: + try: + geom = concave(multipoint, ratio=concave_ratio) + except Exception: + geom = multipoint.convex_hull + else: + geom = multipoint.convex_hull + + # Ensure validity + return geom.buffer(0) + + +def generate_cell_boundaries( + transcripts: pl.DataFrame, + x_column: str, + y_column: str, + cell_id_column: str, + concave_ratio: float = 0.6, + min_buffer: float = 0.5, +) -> gpd.GeoDataFrame: + """ + Build polygon boundaries per cell from assigned transcripts. + + Parameters + ---------- + transcripts : pl.DataFrame + Transcript-level table containing coordinates and cell assignments. + x_column, y_column : str + Column names for x/y coordinates. + cell_id_column : str + Column containing the assigned cell identifier. + concave_ratio : float, default=0.6 + Ratio passed to ``shapely.concave_hull`` when available. Values closer + to 0 produce tighter hulls; 1 approximates the convex hull. + min_buffer : float, default=0.5 + Buffer used when only one or two points are available to avoid + degenerate polygons. + + Returns + ------- + geopandas.GeoDataFrame + GeoDataFrame with one row per cell and polygon geometry. + """ + # Early exit on empty inputs + if transcripts.is_empty(): + return gpd.GeoDataFrame( + columns=[cell_id_column, "n_transcripts"], geometry=[] + ) + + subset = transcripts.select([cell_id_column, x_column, y_column]).drop_nulls( + cell_id_column + ) + if subset.is_empty(): + return gpd.GeoDataFrame( + columns=[cell_id_column, "n_transcripts"], geometry=[] + ) + + pdf = subset.to_pandas() + records: list[dict] = [] + geoms = [] + + for cell_id, group in pdf.groupby(cell_id_column): + coords = group[[x_column, y_column]].to_numpy() + records.append( + {cell_id_column: cell_id, "n_transcripts": coords.shape[0]} + ) + geoms.append( + _points_to_polygon( + coords=coords, + concave_ratio=concave_ratio, + min_buffer=min_buffer, + ) + ) + + return gpd.GeoDataFrame(records, geometry=geoms)