Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions src/segger/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,16 @@ def segment(
validator=validators.Path(exists=True, dir_okay=True),
)] = registry.get_default("output_directory"),


save_anndata: Annotated[bool, registry.get_parameter(
"save_anndata",
group=group_io,
)] = registry.get_default("save_anndata"),

save_cell_masks: Annotated[bool, registry.get_parameter(
"save_cell_masks",
group=group_io,
)] = registry.get_default("save_cell_masks"),

# Cell Representation
node_representation_dim: Annotated[int, Parameter(
help="Number of dimensions used to represent each node type.",
Expand Down Expand Up @@ -121,7 +130,6 @@ def segment(
group=group_nodes,
)] = registry.get_default("genes_clusters_resolution"),


# Transcript-Transcript Graph
transcripts_max_k: Annotated[int, registry.get_parameter(
"transcripts_graph_max_k",
Expand All @@ -145,6 +153,12 @@ def segment(
)
] = registry.get_default("prediction_graph_mode"),

prediction_expansion_ratio: Annotated[float, registry.get_parameter(
"prediction_graph_buffer_ratio",
validator=validators.Number(gt=0),
group=group_prediction,
)] = registry.get_default("prediction_graph_buffer_ratio"),

prediction_max_k: Annotated[int | None, registry.get_parameter(
"prediction_graph_max_k",
validator=validators.Number(gt=0),
Expand Down Expand Up @@ -342,7 +356,11 @@ def segment(
from ..data import ISTSegmentationWriter
from lightning.pytorch import Trainer
logger = CSVLogger(output_directory)
writer = ISTSegmentationWriter(output_directory)
writer = ISTSegmentationWriter(
output_directory=output_directory,
save_anndata=save_anndata,
save_cell_masks=save_cell_masks,
)
trainer = Trainer(
logger=logger,
max_epochs=n_epochs,
Expand Down
101 changes: 97 additions & 4 deletions src/segger/data/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from ..io import TrainingTranscriptFields, TrainingBoundaryFields
from . import ISTDataModule
from .utils.anndata import anndata_from_transcripts
from ..geometry import generate_cell_boundaries


def threshold(x):
Expand All @@ -24,9 +26,17 @@ class ISTSegmentationWriter(BasePredictionWriter):
Path to write outputs.
"""

def __init__(self, output_directory: Path):
def __init__(
self,
output_directory: Path,
save_anndata: bool = True,
save_cell_masks: bool = False,
):
super().__init__(write_interval="epoch")
self.output_directory = Path(output_directory)
self.output_directory.mkdir(parents=True, exist_ok=True)
self.save_anndata = save_anndata
self.save_cell_masks = save_cell_masks

def write_on_epoch_end(
self,
Expand Down Expand Up @@ -125,10 +135,93 @@ def write_on_epoch_end(
.alias("similarity_threshold")
)
)
# Join and write output to file
# Join thresholds
segmentation = segmentation.join(thresholds, on=tx_fields.feature, how='left')

# Map gene encoding to gene names
gene_index = (
pl
.from_pandas(trainer.datamodule.ad.var.reset_index())
.rename({"index": tx_fields.feature})
.select([tx_fields.feature, tx_fields.gene_encoding])
)
segmentation = (
segmentation
.rename({tx_fields.feature: tx_fields.gene_encoding})
.join(gene_index, on=tx_fields.gene_encoding, how='left')
)

# Write segmentation output (keep prior columns)
(
segmentation
.join(thresholds, on=tx_fields.feature, how='left')
.drop(tx_fields.feature)
.drop([tx_fields.feature, tx_fields.gene_encoding])
.write_parquet(self.output_directory / 'segger_segmentation.parquet')
)

transcripts = None
if self.save_anndata or self.save_cell_masks:
transcripts = (
segmentation
.join(
trainer.datamodule.tx.select([
tx_fields.row_index,
tx_fields.x,
tx_fields.y,
tx_fields.feature,
]),
on=tx_fields.row_index,
how='left',
)
.rename({tx_fields.feature: "segger_gene"})
)

# Optional: save AnnData
if self.save_anndata and transcripts is not None:
adata = anndata_from_transcripts(
transcripts.select([
tx_fields.row_index,
"segger_gene",
"segger_cell_id",
"segger_similarity",
"similarity_threshold",
tx_fields.x,
tx_fields.y,
]),
feature_column="segger_gene",
cell_id_column="segger_cell_id",
score_column="segger_similarity",
coordinate_columns=[tx_fields.x, tx_fields.y],
)
adata.write_h5ad(self.output_directory / 'segger_anndata.h5ad')

if self.save_cell_masks and transcripts is not None:
cell_ids = (
transcripts
.get_column("segger_cell_id")
.drop_nulls()
.unique()
.to_list()
)

if len(cell_ids) > 0:
bd_fields = TrainingBoundaryFields()
boundaries = trainer.datamodule.bd
cell_boundaries = boundaries[
(boundaries[bd_fields.boundary_type] == bd_fields.cell_value)
& (boundaries[bd_fields.id].isin(cell_ids))
]

if cell_boundaries.empty:
cell_boundaries = generate_cell_boundaries(
transcripts=transcripts,
x_column=tx_fields.x,
y_column=tx_fields.y,
cell_id_column="segger_cell_id",
)

if not cell_boundaries.empty:
cell_boundaries.to_parquet(
self.output_directory / "segger_cell_boundaries.parquet",
write_covering_bbox=True,
geometry_encoding="geoarrow",
)
3 changes: 2 additions & 1 deletion src/segger/geometry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .conversion import points_to_geoseries, polygons_to_geoseries
from .query import points_in_polygons, polygons_in_polygons
from .quadtree import get_quadtree_index, quadtree_to_geoseries
from .morphology import get_polygon_props
from .morphology import get_polygon_props
from .boundaries import generate_cell_boundaries
100 changes: 100 additions & 0 deletions src/segger/geometry/boundaries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from __future__ import annotations

import geopandas as gpd
import polars as pl
import shapely
from shapely.geometry import LineString, MultiPoint, Point


def _points_to_polygon(
coords: list[list[float]] | "np.ndarray",
concave_ratio: float,
min_buffer: float,
):
"""
Convert a set of 2D points to a polygon hull.
"""
# Small-cardinality fallbacks
if len(coords) == 1:
return Point(coords[0]).buffer(min_buffer)
if len(coords) == 2:
return LineString(coords).buffer(min_buffer)

multipoint = MultiPoint(coords)
concave = getattr(shapely, "concave_hull", None)

if concave is not None:
try:
geom = concave(multipoint, ratio=concave_ratio)
except Exception:
geom = multipoint.convex_hull
else:
geom = multipoint.convex_hull

# Ensure validity
return geom.buffer(0)


def generate_cell_boundaries(
transcripts: pl.DataFrame,
x_column: str,
y_column: str,
cell_id_column: str,
concave_ratio: float = 0.6,
min_buffer: float = 0.5,
) -> gpd.GeoDataFrame:
"""
Build polygon boundaries per cell from assigned transcripts.

Parameters
----------
transcripts : pl.DataFrame
Transcript-level table containing coordinates and cell assignments.
x_column, y_column : str
Column names for x/y coordinates.
cell_id_column : str
Column containing the assigned cell identifier.
concave_ratio : float, default=0.6
Ratio passed to ``shapely.concave_hull`` when available. Values closer
to 0 produce tighter hulls; 1 approximates the convex hull.
min_buffer : float, default=0.5
Buffer used when only one or two points are available to avoid
degenerate polygons.

Returns
-------
geopandas.GeoDataFrame
GeoDataFrame with one row per cell and polygon geometry.
"""
# Early exit on empty inputs
if transcripts.is_empty():
return gpd.GeoDataFrame(
columns=[cell_id_column, "n_transcripts"], geometry=[]
)

subset = transcripts.select([cell_id_column, x_column, y_column]).drop_nulls(
cell_id_column
)
if subset.is_empty():
return gpd.GeoDataFrame(
columns=[cell_id_column, "n_transcripts"], geometry=[]
)

pdf = subset.to_pandas()
records: list[dict] = []
geoms = []

for cell_id, group in pdf.groupby(cell_id_column):
coords = group[[x_column, y_column]].to_numpy()
records.append(
{cell_id_column: cell_id, "n_transcripts": coords.shape[0]}
)
geoms.append(
_points_to_polygon(
coords=coords,
concave_ratio=concave_ratio,
min_buffer=min_buffer,
)
)

return gpd.GeoDataFrame(records, geometry=geoms)