Skip to content

Commit 707496a

Browse files
marcorudolphflexyaugenst-flex
authored andcommitted
fix(tidy3d): FXC-3655 Fix the color computation for plot_eps
1 parent 8c21416 commit 707496a

File tree

2 files changed

+72
-28
lines changed

2 files changed

+72
-28
lines changed

tests/test_components/test_scene.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
from __future__ import annotations
44

5+
import matplotlib as mpl
56
import matplotlib.pyplot as plt
67
import numpy as np
78
import pydantic.v1 as pd
89
import pytest
910

1011
import tidy3d as td
1112
from tidy3d.components.scene import MAX_GEOMETRY_COUNT, MAX_NUM_MEDIUMS
13+
from tidy3d.components.viz import STRUCTURE_EPS_CMAP, STRUCTURE_EPS_CMAP_R
1214
from tidy3d.exceptions import SetupError
1315

1416
from ..utils import SIM_FULL, cartesian_to_unstructured
@@ -142,11 +144,62 @@ def test_get_structure_plot_params():
142144
pp = SCENE_FULL._get_structure_eps_plot_params(
143145
medium=SCENE_FULL.medium, freq=1, eps_min=1, eps_max=2
144146
)
145-
assert float(pp.facecolor) == 1.0
147+
expected_color = mpl.cm.get_cmap(STRUCTURE_EPS_CMAP)(0.0)
148+
assert np.allclose(pp.facecolor, expected_color)
146149
pp = SCENE_FULL._get_structure_eps_plot_params(medium=td.PEC, freq=1, eps_min=1, eps_max=2)
147150
assert pp.facecolor == "gold"
148151

149152

153+
def test_structure_eps_color_mapping():
154+
medium_min = td.Medium(permittivity=1.0)
155+
medium_max = td.Medium(permittivity=5.0)
156+
norm = mpl.colors.Normalize(vmin=1.0, vmax=5.0)
157+
158+
pp_min = SCENE_FULL._get_structure_eps_plot_params(
159+
medium=medium_min,
160+
freq=1,
161+
eps_min=1.0,
162+
eps_max=5.0,
163+
norm=norm,
164+
reverse=False,
165+
)
166+
expected_min = mpl.cm.get_cmap(STRUCTURE_EPS_CMAP)(norm(1.0))
167+
assert np.allclose(pp_min.facecolor, expected_min)
168+
169+
pp_max = SCENE_FULL._get_structure_eps_plot_params(
170+
medium=medium_max,
171+
freq=1,
172+
eps_min=1.0,
173+
eps_max=5.0,
174+
norm=norm,
175+
reverse=False,
176+
)
177+
expected_max = mpl.cm.get_cmap(STRUCTURE_EPS_CMAP)(norm(5.0))
178+
assert np.allclose(pp_max.facecolor, expected_max)
179+
180+
pp_min_reverse = SCENE_FULL._get_structure_eps_plot_params(
181+
medium=medium_min,
182+
freq=1,
183+
eps_min=1.0,
184+
eps_max=5.0,
185+
norm=norm,
186+
reverse=True,
187+
)
188+
expected_min_reverse = mpl.cm.get_cmap(STRUCTURE_EPS_CMAP_R)(norm(1.0))
189+
assert np.allclose(pp_min_reverse.facecolor, expected_min_reverse)
190+
191+
pp_max_reverse = SCENE_FULL._get_structure_eps_plot_params(
192+
medium=medium_max,
193+
freq=1,
194+
eps_min=1.0,
195+
eps_max=5.0,
196+
norm=norm,
197+
reverse=True,
198+
)
199+
expected_max_reverse = mpl.cm.get_cmap(STRUCTURE_EPS_CMAP_R)(norm(5.0))
200+
assert np.allclose(pp_max_reverse.facecolor, expected_max_reverse)
201+
202+
150203
def test_num_mediums():
151204
"""Make sure we error if too many mediums supplied."""
152205

tidy3d/components/scene.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@
9090
MAX_STRUCTURES_PER_MEDIUM = 1_000
9191

9292

93+
def _get_colormap(reverse: bool = False):
94+
return STRUCTURE_EPS_CMAP_R if reverse else STRUCTURE_EPS_CMAP
95+
96+
9397
class Scene(Tidy3dBaseModel):
9498
"""Contains generic information about the geometry and medium properties common to all types of
9599
simulations.
@@ -1200,7 +1204,7 @@ def _add_cbar_eps(
12001204
vmin=eps_min,
12011205
vmax=eps_max,
12021206
label=r"$\epsilon_r$",
1203-
cmap=STRUCTURE_EPS_CMAP if not reverse else STRUCTURE_EPS_CMAP_R,
1207+
cmap=_get_colormap(reverse=reverse),
12041208
ax=ax,
12051209
norm=norm,
12061210
)
@@ -1314,16 +1318,14 @@ def _pcolormesh_shape_custom_medium_structure_eps(
13141318
# extract slice if volumetric unstructured data
13151319
eps = eps.plane_slice(axis=normal_axis_ind, pos=normal_position)
13161320

1317-
if reverse:
1318-
eps = eps_min + eps_max - eps
1319-
13201321
# at this point eps_mean is TriangularGridDataset and we just plot it directly
13211322
# with applying shape mask
1323+
cmap_name = _get_colormap(reverse=reverse)
13221324
eps.plot(
13231325
grid=False,
13241326
ax=ax,
13251327
cbar=False,
1326-
cmap=STRUCTURE_EPS_CMAP,
1328+
cmap=cmap_name,
13271329
vmin=eps_min,
13281330
vmax=eps_max,
13291331
pcolor_kwargs={
@@ -1395,18 +1397,15 @@ def _pcolormesh_shape_custom_medium_structure_eps(
13951397

13961398
# remove the normal_axis and take real part
13971399
eps_shape = eps_shape.real.mean(axis=normal_axis_ind)
1398-
# reverse
1399-
if reverse:
1400-
eps_shape = eps_min + eps_max - eps_shape
1401-
14021400
# pcolormesh
14031401
plane_xp, plane_yp = np.meshgrid(plane_coord[0], plane_coord[1], indexing="ij")
1402+
cmap_name = _get_colormap(reverse=reverse)
14041403
ax.pcolormesh(
14051404
plane_xp,
14061405
plane_yp,
14071406
eps_shape,
14081407
clip_path=(polygon_path(shape), ax.transData),
1409-
cmap=STRUCTURE_EPS_CMAP,
1408+
cmap=cmap_name,
14101409
alpha=alpha,
14111410
clip_box=ax.bbox,
14121411
norm=norm,
@@ -1447,23 +1446,15 @@ def _get_structure_eps_plot_params(
14471446
plot_params = plot_params.copy(update={"edgecolor": "k", "linewidth": 1})
14481447
else:
14491448
eps_medium = medium._eps_plot(frequency=freq, eps_component=eps_component)
1450-
if norm is not None:
1451-
# Use the same normalization as the colorbar for consistency
1452-
color = norm(eps_medium)
1453-
# TODO: This is a hack to ensure color consistency with the colorbar.
1454-
# It should be removed once we establish a proper color mapping where
1455-
# eps_min maps to 0 and eps_max maps to 1 for 'reverse=False'.
1456-
if not reverse:
1457-
color = 1 - color
1458-
color = min(1, max(color, 0)) # clip in case of custom eps limits
1459-
else:
1460-
# Fallback to linear mapping for backward compatibility
1461-
delta_eps = eps_medium - eps_min
1462-
delta_eps_max = eps_max - eps_min + 1e-5
1463-
eps_fraction = delta_eps / delta_eps_max
1464-
color = eps_fraction if reverse else 1 - eps_fraction
1465-
color = min(1, max(color, 0)) # clip in case of custom eps limits
1466-
plot_params = plot_params.copy(update={"facecolor": str(color)})
1449+
active_norm = (
1450+
norm if norm is not None else mpl.colors.Normalize(vmin=eps_min, vmax=eps_max)
1451+
)
1452+
color_value = float(active_norm(eps_medium))
1453+
color_value = min(1.0, max(0.0, color_value))
1454+
cmap_name = _get_colormap(reverse=reverse)
1455+
cmap = mpl.cm.get_cmap(cmap_name)
1456+
rgba = tuple(float(component) for component in cmap(color_value))
1457+
plot_params = plot_params.copy(update={"facecolor": rgba})
14671458

14681459
return plot_params
14691460

0 commit comments

Comments
 (0)