Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
rotation_matrix_to_euler
from cmlibs.utils.zinc.finiteelement import evaluate_field_nodeset_range
from cmlibs.utils.zinc.general import ChangeManager, HierarchicalChangeManager
from cmlibs.utils.zinc.field import get_group_list
from cmlibs.utils.zinc.group import group_add_group_local_contents
from cmlibs.utils.zinc.scene import scene_clear_selection_group, scene_get_or_create_selection_group
from cmlibs.zinc.field import Field
from cmlibs.zinc.field import Field, FieldGroup
from cmlibs.zinc.glyph import Glyph
from cmlibs.zinc.graphics import Graphicslineattributes
from cmlibs.zinc.material import Material
Expand All @@ -36,11 +37,11 @@ def __init__(self, segmentation_file_locations, location, step_identifier,
initially assigned to network group 2, allowing them to be stitched together.
"""
self._stitcher = Stitcher(segmentation_file_locations, network_group_1_keywords, network_group2_keywords)
self._location = os.path.join(location, step_identifier)
self._location_stem = os.path.join(location, step_identifier)
self._step_identifier = step_identifier
self._category_graphics_info = [
(AnnotationCategory.GENERAL, "display_line_general", "green"),
(AnnotationCategory.INDEPENDENT_NETWORK, "display_independent_networks", "yellow"),
(AnnotationCategory.GENERAL, "display_line_general", "mid green"),
(AnnotationCategory.INDEPENDENT_NETWORK, "display_independent_networks", "purple"),
(AnnotationCategory.NETWORK_GROUP_1, "display_network_group_1", "mid blue"),
(AnnotationCategory.NETWORK_GROUP_2, "display_network_group_2", "orange")
]
Expand All @@ -50,6 +51,9 @@ def __init__(self, segmentation_file_locations, location, step_identifier,
"display_axes": True,
"display_marker_points": True,
"display_marker_names": False,
"display_node_numbers": False,
"display_node_points": False,
"display_node_group_name": None,
"display_line_general": True,
"display_line_general_radius": False,
"display_line_general_trans": False,
Expand All @@ -66,7 +70,8 @@ def __init__(self, segmentation_file_locations, location, step_identifier,
"display_end_point_best_fit_lines": True,
"display_end_point_radius": False,
"display_end_point_trans": False,
"display_radius_scale": 1.0
"display_radius_scale": 1.0,
"display_theme": "Dark"
}
self._load_settings()
self.create_graphics()
Expand All @@ -79,6 +84,8 @@ def __init__(self, segmentation_file_locations, location, step_identifier,
self._current_annotation = None
self._segment_data_change_callback = None

_EMPTY_GROUP_NAME = "<empty>"

def _init_graphics_modules(self):
context = self._stitcher.get_context()
self._materialmodule = context.getMaterialmodule()
Expand All @@ -87,11 +94,27 @@ def _init_graphics_modules(self):
mid_blue = self._materialmodule.createMaterial()
mid_blue.setName("mid blue")
mid_blue.setManaged(True)
mid_blue.setAttributeReal3(Material.ATTRIBUTE_AMBIENT, [0.0, 0.2, 0.6])
mid_blue.setAttributeReal3(Material.ATTRIBUTE_DIFFUSE, [0.0, 0.7, 1.0])
mid_blue.setAttributeReal3(Material.ATTRIBUTE_AMBIENT, [0.0, 0.4, 0.8])
mid_blue.setAttributeReal3(Material.ATTRIBUTE_DIFFUSE, [0.0, 0.6, 1.0])
mid_blue.setAttributeReal3(Material.ATTRIBUTE_EMISSION, [0.0, 0.0, 0.0])
mid_blue.setAttributeReal3(Material.ATTRIBUTE_SPECULAR, [0.1, 0.1, 0.1])
mid_blue.setAttributeReal(Material.ATTRIBUTE_SHININESS, 0.2)
mid_green = self._materialmodule.createMaterial()
mid_green.setName("mid green")
mid_green.setManaged(True)
mid_green.setAttributeReal3(Material.ATTRIBUTE_AMBIENT, [0.2, 0.7, 0.2])
mid_green.setAttributeReal3(Material.ATTRIBUTE_DIFFUSE, [0.2, 0.7, 0.2])
mid_green.setAttributeReal3(Material.ATTRIBUTE_EMISSION, [0.0, 0.0, 0.0])
mid_green.setAttributeReal3(Material.ATTRIBUTE_SPECULAR, [0.1, 0.1, 0.1])
mid_green.setAttributeReal(Material.ATTRIBUTE_SHININESS, 0.2)
purple = self._materialmodule.createMaterial()
purple.setName("purple")
purple.setManaged(True)
purple.setAttributeReal3(Material.ATTRIBUTE_AMBIENT, [0.6, 0.0, 1.0])
purple.setAttributeReal3(Material.ATTRIBUTE_DIFFUSE, [0.6, 0.0, 1.0])
purple.setAttributeReal3(Material.ATTRIBUTE_EMISSION, [0.0, 0.0, 0.0])
purple.setAttributeReal3(Material.ATTRIBUTE_SPECULAR, [0.1, 0.1, 0.1])
purple.setAttributeReal(Material.ATTRIBUTE_SHININESS, 0.2)
line_material_names = [graphics_info[-1] for graphics_info in self._category_graphics_info]
for material_name in line_material_names + [self._end_point_material_name]:
material = self._materialmodule.findMaterialByName(material_name)
Expand All @@ -111,40 +134,80 @@ def _init_graphics_modules(self):
default_tessellation = tessellationmodule.getDefaultTessellation()
default_tessellation.setRefinementFactors([12])

def _get_settings_file_name(self):
return self._location + "-settings.json"
@classmethod
def get_json_settings_filename(cls, location_stem):
"""
:param location_stem: Path and stem name of workflow step to make unique filenames from.
:return: Standard settings file location.
"""
return location_stem + "-settings.json"

@classmethod
def get_json_display_settings_filename(cls, location_stem):
"""
:param location_stem: Path and stem name of workflow step to make unique filenames from.
:return: Standard display settings file location.
"""
return location_stem + "-display-settings.json"

def _get_display_settings_file_name(self):
return self._location + "-display-settings.json"
@classmethod
def get_config_files(cls, location_stem):
"""
:param location_stem: Path and stem name of workflow step to make unique filenames from.
:return: list of config file locations needed to reproduce step behaviour.
"""
return [cls.get_json_settings_filename(location_stem), cls.get_json_display_settings_filename(location_stem)]

@classmethod
def get_output_segmentation_filename(cls, location_stem):
"""
:param location_stem: Path and stem name of workflow step to make unique filenames from.
:return: Standard name of output segmentation file name.
"""
return location_stem + ".exf"

SEGMENTATION_STITCHER_DISPLAY_SETTINGS_ID = 'segmentation stitcher display settings'

def _get_output_display_settings(self):
"""
:return: Display settings augmented with id and version information.
"""
display_settings = {
'id': self.SEGMENTATION_STITCHER_DISPLAY_SETTINGS_ID,
'version': '1.0.0'
}
display_settings.update(self._display_settings)
return display_settings

def _load_settings(self):
settings_file_name = self._get_settings_file_name()
settings_file_name = self.get_json_settings_filename(self._location_stem)
if os.path.isfile(settings_file_name):
with open(settings_file_name, "r") as f:
settings = json.loads(f.read())
self._stitcher.decode_settings(settings)
display_settings_file_name = self._get_display_settings_file_name()
display_settings_file_name = self.get_json_display_settings_filename(self._location_stem)
if os.path.isfile(display_settings_file_name):
with open(display_settings_file_name, "r") as f:
display_settings = json.loads(f.read())
settings_id = display_settings.get('id')
if settings_id is not None:
assert settings_id == self.SEGMENTATION_STITCHER_DISPLAY_SETTINGS_ID
assert display_settings['version'] == '1.0.0' # future: migrate if version changes
# these are not stored:
del display_settings['id']
del display_settings['version']
self._display_settings.update(display_settings)

def _save_settings(self):
with open(self._get_settings_file_name(), "w") as f:
with open(self.get_json_settings_filename(self._location_stem), "w") as f:
settings = self._stitcher.encode_settings()
f.write(json.dumps(settings, sort_keys=False, indent=4))
with open(self._get_display_settings_file_name(), "w") as f:
f.write(json.dumps(self._display_settings, sort_keys=False, indent=4))

def get_output_file_name_stem(self):
return self._location

def get_output_segmentation_file_name(self):
return self._location + ".exf"
with open(self.get_json_display_settings_filename(self._location_stem), "w") as f:
f.write(json.dumps(self._get_output_display_settings(), sort_keys=False, indent=4))

def done(self):
self._save_settings()
self._stitcher.write_output_segmentation_file(self.get_output_segmentation_file_name())
self._stitcher.write_output_segmentation_file(self.get_output_segmentation_filename(self._location_stem))

def get_step_identifier(self):
return self._step_identifier
Expand Down Expand Up @@ -415,6 +478,88 @@ def is_display_marker_names(self):
def set_display_marker_names(self, show):
self._set_raw_visibility("display_marker_names", show)

def get_raw_group_names(self):
"""
Get all unique-named groups in raw regions for populating node group combo box.
:return: List of unique group names in alphabetical order.
"""
group_names_set = set()
for segment in self._stitcher.get_segments():
for group in get_group_list(segment.get_raw_region().getFieldmodule()):
group_names_set.add(group.getName())
return sorted(group_names_set)

def get_display_node_group_name(self):
"""
:return: Name of node group being displayed or None.
"""
return self._display_settings['display_node_group_name']

def set_display_node_group_name(self, node_group_name):
"""
Set group to restrict display of node points to.
:param node_group_name: Node group name or None.
"""
self._display_settings['display_node_group_name'] = node_group_name
segments = self._stitcher.get_segments()
for segment in segments:
region = segment.get_raw_region()
fieldmodule = region.getFieldmodule()
if node_group_name:
node_group = fieldmodule.findFieldByName(node_group_name)
if not node_group.isValid():
node_group = fieldmodule.findFieldByName(self._EMPTY_GROUP_NAME)
else:
node_group = FieldGroup()
scene = region.getScene()
with ChangeManager(scene):
for graphics_name in ['display_node_points', 'display_node_numbers']:
graphics = scene.findGraphicsByName(graphics_name)
if graphics.isValid():
graphics.setSubgroupField(node_group)

def is_display_node_numbers(self):
return self._get_visibility('display_node_numbers')

def set_display_node_numbers(self, show):
self._set_raw_visibility('display_node_numbers', show)

def is_display_node_points(self):
return self._get_visibility('display_node_points')

def set_display_node_points(self, show):
self._set_raw_visibility('display_node_points', show)

def get_display_theme(self):
return self._display_settings['display_theme']

def _apply_display_theme(self):
"""
Update graphics materials for the current theme.
"""
root_region = self.get_root_region()
root_scene = root_region.getScene()
if not root_scene:
return
display_theme_name = self._display_settings['display_theme']
is_dark = display_theme_name == 'Dark'
segments = self._stitcher.get_segments()
for segment in segments:
region = segment.get_raw_region()
scene = region.getScene()
with ChangeManager(scene):
for graphics_name in ['display_marker_points', 'display_marker_names']:
graphics = scene.findGraphicsByName(graphics_name)
graphics.setMaterial(self._materialmodule.findMaterialByName('white' if is_dark else 'black'))
for graphics_name in ['display_node_points', 'display_node_numbers']:
graphics = scene.findGraphicsByName(graphics_name)
graphics.setMaterial(self._materialmodule.findMaterialByName('yellow' if is_dark else 'magenta'))

def set_display_theme(self, display_theme_name):
assert display_theme_name in ('Dark', 'Light')
self._display_settings['display_theme'] = display_theme_name
self._apply_display_theme()

def is_display_line_general(self):
return self._get_visibility("display_line_general")

Expand Down Expand Up @@ -599,6 +744,14 @@ def create_graphics(self):
minimums[c] = segment_minimums[c]
elif segment_maximums[c] > maximums[c]:
maximums[c] = segment_maximums[c]
# create '.empty' group in all raw regions to show nothing when a required group doesn't exist there
with HierarchicalChangeManager(self._stitcher.get_root_region()):
for segment in segments:
region = segment.get_raw_region()
fieldmodule = region.getFieldmodule()
empty_group = fieldmodule.createFieldGroup()
empty_group.setName(self._EMPTY_GROUP_NAME)
empty_group.setManaged(True)
if minimums:
max_range = 0.0
for c in range(3):
Expand Down Expand Up @@ -638,6 +791,14 @@ def create_graphics(self):
marker_group = None
marker_coordinates = None
marker_name = None
cmiss_number = fieldmodule.findFieldByName('cmiss_number')
node_group_name = self.get_display_node_group_name()
if node_group_name:
node_group = fieldmodule.findFieldByName(node_group_name)
if not node_group.isValid():
node_group = fieldmodule.findFieldByName(self._EMPTY_GROUP_NAME)
else:
node_group = FieldGroup()
radius = fieldmodule.findFieldByName("radius")
if not radius.isValid():
radius = None
Expand Down Expand Up @@ -673,6 +834,31 @@ def create_graphics(self):
marker_names.setName("display_marker_names")
marker_names.setVisibilityFlag(self.is_display_marker_names())

node_points = scene.createGraphicsPoints()
node_points.setFieldDomainType(Field.DOMAIN_TYPE_NODES)
node_points.setSubgroupField(node_group)
node_points.setCoordinateField(coordinates)
pointattr = node_points.getGraphicspointattributes()
pointattr.setGlyphShapeType(Glyph.SHAPE_TYPE_SPHERE)
pointattr.setBaseSize([0.5 * glyph_width_small])
# pointattr.setGlyphShapeType(Glyph.SHAPE_TYPE_POINT)
# node_points.setRenderPointSize(3.0)
node_points.setMaterial(self._materialmodule.findMaterialByName('yellow'))
node_points.setName('display_node_points')
node_points.setVisibilityFlag(self.is_display_node_points())

node_numbers = scene.createGraphicsPoints()
node_numbers.setFieldDomainType(Field.DOMAIN_TYPE_NODES)
node_numbers.setSubgroupField(node_group)
node_numbers.setCoordinateField(coordinates)
pointattr = node_numbers.getGraphicspointattributes()
pointattr.setLabelField(cmiss_number)
pointattr.setLabelText(1, " ")
pointattr.setGlyphShapeType(Glyph.SHAPE_TYPE_NONE)
node_numbers.setMaterial(self._materialmodule.findMaterialByName('yellow'))
node_numbers.setName('display_node_numbers')
node_numbers.setVisibilityFlag(self.is_display_node_numbers())

self._create_category_graphics(segment, scene, coordinates, radius, radius_scale)

working_region = segment.get_working_region()
Expand Down Expand Up @@ -716,6 +902,7 @@ def create_graphics(self):
end_point_best_fit_lines.setVisibilityFlag(self.is_display_end_point_best_fit_lines())

self._create_connection_graphics()
self._apply_display_theme()

def _create_connection_graphics(self, only_connection=None):
for connection in [only_connection] if only_connection else self._stitcher.get_connections():
Expand Down
Loading