Skip to content

Commit 4f42ad5

Browse files
alembckepre-commit-ci[bot]behacklMrDiver
authored
Enabled filling color by value for :class:.OpenGLSurface, replaced colors keyword argument of :meth:.Surface.set_fill_by_value with colorscale (#2186)
* Enable filling color by value for OpenGLSurface. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added doc string and removed unused deprecated_params * Added docstring and made method private. * Fixed spelling error. * Update manim/mobject/types/opengl_surface.py Agreed. Co-authored-by: Benjamin Hackl <[email protected]> * Fixed typings * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added None type as part of color checking * Another attempt at fixing the color typings * Made some of the requested modifications. * Fixed failing tests. * Added tests for Cairo method. * One minor change to test plot_surface. * Fixed typo in docstring. * Removed passing in of axes. * Removed axes in tests. * Fixed typo, mixed up u and v. * Fixed docstring typo, updated format to recommended. * Fixed one more typo in docstring. * One last time editing the docstring. * Add deprecation import to opengl_surface. * Added import for Surface and OpenGLSurface to coordinate_systems. * Hoping this fixes the example for the docs. * Addes colorscale_axis to plot_surface * Forgot to update the docstring. * added proper deprecation, made change temporarily backwards compatible * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Benjamin Hackl <[email protected]> Co-authored-by: Tristan Schulz <[email protected]>
1 parent f1c5b6b commit 4f42ad5

File tree

7 files changed

+279
-10
lines changed

7 files changed

+279
-10
lines changed

manim/mobject/graphing/coordinate_systems.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
from manim.mobject.graphing.number_line import NumberLine
2929
from manim.mobject.graphing.scale import LinearBase
3030
from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL
31+
from manim.mobject.opengl.opengl_surface import OpenGLSurface
3132
from manim.mobject.text.tex_mobject import MathTex
33+
from manim.mobject.three_d.three_dimensions import Surface
3234
from manim.mobject.types.vectorized_mobject import (
3335
VDict,
3436
VectorizedPoint,
@@ -858,6 +860,90 @@ def construct(self):
858860
graph.underlying_function = r_func
859861
return graph
860862

863+
def plot_surface(
864+
self,
865+
function: Callable[[float], float],
866+
u_range: Sequence[float] | None = None,
867+
v_range: Sequence[float] | None = None,
868+
colorscale: Sequence[[color], float] | None = None,
869+
colorscale_axis: int = 2,
870+
**kwargs,
871+
):
872+
"""Generates a surface based on a function.
873+
874+
Parameters
875+
----------
876+
function
877+
The function used to construct the :class:`~.Surface`.
878+
u_range
879+
The range of the ``u`` variable: ``(u_min, u_max)``.
880+
v_range
881+
The range of the ``v`` variable: ``(v_min, v_max)``.
882+
colorscale
883+
Colors of the surface. Passing a list of colors will color the surface by z-value.
884+
Passing a list of tuples in the form ``(color, pivot)`` allows user-defined pivots
885+
where the color transitions.
886+
colorscale_axis
887+
Defines the axis on which the colorscale is applied (0 = x, 1 = y, 2 = z), default
888+
is z-axis (2).
889+
kwargs
890+
Additional parameters to be passed to :class:`~.Surface`.
891+
892+
Returns
893+
-------
894+
:class:`~.Surface`
895+
The plotted surface.
896+
897+
Examples
898+
--------
899+
.. manim:: PlotSurfaceExample
900+
:save_last_frame:
901+
902+
class PlotSurfaceExample(ThreeDScene):
903+
def construct(self):
904+
resolution_fa = 42
905+
self.set_camera_orientation(phi=75 * DEGREES, theta=-60 * DEGREES)
906+
axes = ThreeDAxes(x_range=(-3, 3, 1), y_range=(-3, 3, 1), z_range=(-5, 5, 1))
907+
def param_trig(u, v):
908+
x = u
909+
y = v
910+
z = 2 * np.sin(x) + 2 * np.cos(y)
911+
return z
912+
trig_plane = axes.plot_surface(
913+
param_trig,
914+
resolution=(resolution_fa, resolution_fa),
915+
u_range = (-3, 3),
916+
v_range = (-3, 3),
917+
colorscale = [BLUE, GREEN, YELLOW, ORANGE, RED],
918+
)
919+
self.add(axes, trig_plane)
920+
"""
921+
if config.renderer != "opengl":
922+
surface = Surface(
923+
lambda u, v: self.c2p(u, v, function(u, v)),
924+
u_range=u_range,
925+
v_range=v_range,
926+
**kwargs,
927+
)
928+
if colorscale:
929+
surface.set_fill_by_value(
930+
axes=self.copy(),
931+
colorscale=colorscale,
932+
axis=colorscale_axis,
933+
)
934+
else:
935+
surface = OpenGLSurface(
936+
lambda u, v: self.c2p(u, v, function(u, v)),
937+
u_range=u_range,
938+
v_range=v_range,
939+
axes=self.copy(),
940+
colorscale=colorscale,
941+
colorscale_axis=colorscale_axis,
942+
**kwargs,
943+
)
944+
945+
return surface
946+
861947
def input_to_graph_point(
862948
self,
863949
x: float,

manim/mobject/opengl/opengl_surface.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,44 @@
1212
from manim.utils.bezier import integer_interpolate, interpolate
1313
from manim.utils.color import *
1414
from manim.utils.config_ops import _Data, _Uniforms
15+
from manim.utils.deprecation import deprecated
1516
from manim.utils.images import change_to_rgba_array, get_full_raster_image_path
1617
from manim.utils.iterables import listify
1718
from manim.utils.space_ops import normalize_along_axis
1819

1920

2021
class OpenGLSurface(OpenGLMobject):
22+
r"""Creates a Surface.
23+
24+
Parameters
25+
----------
26+
uv_func
27+
The function that defines the surface.
28+
u_range
29+
The range of the ``u`` variable: ``(u_min, u_max)``.
30+
v_range
31+
The range of the ``v`` variable: ``(v_min, v_max)``.
32+
resolution
33+
The number of samples taken of the surface.
34+
axes
35+
Axes on which the surface is to be drawn. Optional
36+
parameter used when coloring a surface by z-value.
37+
color
38+
Color of the surface. Defaults to grey.
39+
colorscale
40+
Colors of the surface. Optional parameter used when
41+
coloring a surface by values. Passing a list of
42+
colors and an axes will color the surface by z-value.
43+
Passing a list of tuples in the form ``(color, pivot)``
44+
allows user-defined pivots where the color transitions.
45+
colorscale_axis
46+
Defines the axis on which the colorscale is applied
47+
(0 = x, 1 = y, 2 = z), default is z-axis (2).
48+
opacity
49+
Opacity of the surface from 0 being fully transparent
50+
to 1 being fully opaque. Defaults to 1.
51+
"""
52+
2153
shader_dtype = [
2254
("point", np.float32, (3,)),
2355
("du_point", np.float32, (3,)),
@@ -35,7 +67,10 @@ def __init__(
3567
# each coordinate is one more than the the number of
3668
# rows/columns of approximating squares
3769
resolution=None,
70+
axes=None,
3871
color=GREY,
72+
colorscale=None,
73+
colorscale_axis=2,
3974
opacity=1.0,
4075
gloss=0.3,
4176
shadow=0.4,
@@ -55,6 +90,9 @@ def __init__(
5590
# each coordinate is one more than the the number of
5691
# rows/columns of approximating squares
5792
self.resolution = resolution if resolution is not None else (101, 101)
93+
self.axes = axes
94+
self.colorscale = colorscale
95+
self.colorscale_axis = colorscale_axis
5896
self.prefered_creation_axis = prefered_creation_axis
5997
# For du and dv steps. Much smaller and numerical error
6098
# can crop up in the shaders.
@@ -206,22 +244,106 @@ def index_dot(index):
206244

207245
# For shaders
208246
def get_shader_data(self):
247+
"""Called by parent Mobject to calculate and return
248+
the shader data.
249+
250+
Returns
251+
-------
252+
shader_dtype
253+
An array containing the shader data (vertices and
254+
color of each vertex)
255+
"""
209256
s_points, du_points, dv_points = self.get_surface_points_and_nudged_points()
210257
shader_data = np.zeros(len(s_points), dtype=self.shader_dtype)
211258
if "points" not in self.locked_data_keys:
212259
shader_data["point"] = s_points
213260
shader_data["du_point"] = du_points
214261
shader_data["dv_point"] = dv_points
215-
self.fill_in_shader_color_info(shader_data)
262+
if self.colorscale:
263+
shader_data["color"] = self._get_color_by_value(s_points)
264+
else:
265+
self.fill_in_shader_color_info(shader_data)
216266
return shader_data
217267

218268
def fill_in_shader_color_info(self, shader_data):
269+
"""Fills in the shader color data when the surface
270+
is all one color.
271+
272+
Parameters
273+
----------
274+
shader_data
275+
The vertices of the surface.
276+
277+
Returns
278+
-------
279+
shader_dtype
280+
An array containing the shader data (vertices and
281+
color of each vertex)
282+
"""
219283
self.read_data_to_shader(shader_data, "color", "rgbas")
220284
return shader_data
221285

286+
def _get_color_by_value(self, s_points):
287+
"""Matches each vertex to a color associated to it's z-value.
288+
289+
Parameters
290+
----------
291+
s_points
292+
The vertices of the surface.
293+
294+
Returns
295+
-------
296+
List
297+
A list of colors matching the vertex inputs.
298+
"""
299+
if type(self.colorscale[0]) in (list, tuple):
300+
new_colors, pivots = [
301+
[i for i, j in self.colorscale],
302+
[j for i, j in self.colorscale],
303+
]
304+
else:
305+
new_colors = self.colorscale
306+
307+
pivot_min = self.axes.z_range[0]
308+
pivot_max = self.axes.z_range[1]
309+
pivot_frequency = (pivot_max - pivot_min) / (len(new_colors) - 1)
310+
pivots = np.arange(
311+
start=pivot_min,
312+
stop=pivot_max + pivot_frequency,
313+
step=pivot_frequency,
314+
)
315+
316+
return_colors = []
317+
for point in s_points:
318+
axis_value = self.axes.point_to_coords(point)[self.colorscale_axis]
319+
if axis_value <= pivots[0]:
320+
return_colors.append(color_to_rgba(new_colors[0], self.opacity))
321+
elif axis_value >= pivots[-1]:
322+
return_colors.append(color_to_rgba(new_colors[-1], self.opacity))
323+
else:
324+
for i, pivot in enumerate(pivots):
325+
if pivot > axis_value:
326+
color_index = (axis_value - pivots[i - 1]) / (
327+
pivots[i] - pivots[i - 1]
328+
)
329+
color_index = max(min(color_index, 1), 0)
330+
temp_color = interpolate_color(
331+
new_colors[i - 1],
332+
new_colors[i],
333+
color_index,
334+
)
335+
break
336+
return_colors.append(color_to_rgba(temp_color, self.opacity))
337+
338+
return return_colors
339+
222340
def get_shader_vert_indices(self):
223341
return self.get_triangle_indices()
224342

343+
@deprecated(
344+
since="v0.16.0",
345+
message="Use colorscale attribute instead.",
346+
)
225347
def set_fill_by_value(self, axes, colors):
226348
# directly copied from three_dimensions.py with some compatibility changes.
227349
"""Sets the color of each mobject of a parametric surface to a color relative to its z-value

manim/mobject/three_d/three_dimensions.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import numpy as np
2323
from colour import Color
2424

25-
from manim import config
25+
from manim import config, logger
2626
from manim.constants import *
2727
from manim.mobject.geometry.arc import Circle
2828
from manim.mobject.geometry.polygram import Square
@@ -31,6 +31,7 @@
3131
from manim.mobject.opengl.opengl_mobject import OpenGLMobject
3232
from manim.mobject.types.vectorized_mobject import VGroup, VMobject
3333
from manim.utils.color import *
34+
from manim.utils.deprecation import deprecated_params
3435
from manim.utils.iterables import tuplify
3536
from manim.utils.space_ops import normalize, perpendicular_bisector, z_to_vector
3637

@@ -165,19 +166,21 @@ def set_fill_by_checkerboard(self, *colors, opacity=None):
165166
face.set_fill(colors[c_index], opacity=opacity)
166167
return self
167168

169+
@deprecated_params("colors", since="v0.16.0")
168170
def set_fill_by_value(
169171
self,
170172
axes: Mobject,
171-
colors: Union[Iterable[Color], Color],
173+
colorscale: Union[Iterable[Color], Color] | None = None,
172174
axis: int = 2,
175+
**kwargs,
173176
):
174177
"""Sets the color of each mobject of a parametric surface to a color relative to its axis-value
175178
176179
Parameters
177180
----------
178181
axes :
179182
The axes for the parametric surface, which will be used to map axis-values to colors.
180-
colors :
183+
colorscale :
181184
A list of colors, ordered from lower axis-values to higher axis-values. If a list of tuples is passed
182185
containing colors paired with numbers, then those numbers will be used as the pivots.
183186
axis :
@@ -210,16 +213,32 @@ def param_surface(u, v):
210213
u_range=[0, 5],
211214
)
212215
surface_plane.set_style(fill_opacity=1)
213-
surface_plane.set_fill_by_value(axes=axes, colors=[(RED, -0.5), (YELLOW, 0), (GREEN, 0.5)], axis=2)
216+
surface_plane.set_fill_by_value(axes=axes, colorscale=[(RED, -0.5), (YELLOW, 0), (GREEN, 0.5)], axis=2)
214217
self.add(axes, surface_plane)
215218
"""
219+
if "colors" in kwargs and colorscale is None:
220+
colorscale = kwargs.pop("colors")
221+
if kwargs:
222+
raise ValueError(
223+
"Unsupported keyword argument(s): "
224+
f"{', '.join(str(key) for key in kwargs)}"
225+
)
226+
if colorscale is None:
227+
logger.warning(
228+
"The value passed to the colorscale keyword argument was None, "
229+
"the surface fill color has not been changed"
230+
)
231+
return self
216232

217233
ranges = [axes.x_range, axes.y_range, axes.z_range]
218234

219-
if type(colors[0]) is tuple:
220-
new_colors, pivots = [[i for i, j in colors], [j for i, j in colors]]
235+
if type(colorscale[0]) is tuple:
236+
new_colors, pivots = [
237+
[i for i, j in colorscale],
238+
[j for i, j in colorscale],
239+
]
221240
else:
222-
new_colors = colors
241+
new_colors = colorscale
223242

224243
pivot_min = ranges[axis][0]
225244
pivot_max = ranges[axis][1]
Binary file not shown.
Binary file not shown.

tests/test_graphical_units/test_coordinate_systems.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,46 @@ def test_line_graph(scene):
3737
scene.add(plane, first_line, second_line)
3838

3939

40+
@frames_comparison(base_scene=ThreeDScene)
41+
def test_plot_surface(scene):
42+
axes = ThreeDAxes(x_range=(-5, 5, 1), y_range=(-5, 5, 1), z_range=(-5, 5, 1))
43+
44+
def param_trig(u, v):
45+
x = u
46+
y = v
47+
z = 2 * np.sin(x) + 2 * np.cos(y)
48+
return z
49+
50+
trig_plane = axes.plot_surface(
51+
param_trig,
52+
u_range=(-5, 5),
53+
v_range=(-5, 5),
54+
color=BLUE,
55+
)
56+
57+
scene.add(axes, trig_plane)
58+
59+
60+
@frames_comparison(base_scene=ThreeDScene)
61+
def test_plot_surface_colorscale(scene):
62+
axes = ThreeDAxes(x_range=(-3, 3, 1), y_range=(-3, 3, 1), z_range=(-5, 5, 1))
63+
64+
def param_trig(u, v):
65+
x = u
66+
y = v
67+
z = 2 * np.sin(x) + 2 * np.cos(y)
68+
return z
69+
70+
trig_plane = axes.plot_surface(
71+
param_trig,
72+
u_range=(-3, 3),
73+
v_range=(-3, 3),
74+
colorscale=[BLUE, GREEN, YELLOW, ORANGE, RED],
75+
)
76+
77+
scene.add(axes, trig_plane)
78+
79+
4080
@frames_comparison
4181
def test_implicit_graph(scene):
4282
ax = Axes()

0 commit comments

Comments
 (0)