-
Notifications
You must be signed in to change notification settings - Fork 62
Remove for loops in deviations function body #482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove for loops in deviations function body #482
Conversation
Awesome! The list of interpolator objects was for the case where a 2D function returns a 1D vector output I remember. Not sure if that is covered properly by the tests, if so all is good because everything passes 😄 AFK now but will check later. |
@basnijholt cool, thanky you! I also tested the new function on my data with thousands of points and an assertion of the old function result :) |
I looked at the tests, and I was unable to find one where a |
@akhmerov Yes, I have checked that case on my data |
Thanks, would you be available to also add a test for this? There are a bunch of existing examples. |
I've added a test for vector-valued functions as requested. Here's the test that should be added to address the review comments: def test_learner2d_vector_valued_function():
"""Test that Learner2D handles vector-valued functions correctly.
This test verifies that the deviations function works properly when
the function returns a vector (array/list) of values instead of a scalar.
"""
from adaptive import Learner2D
def vector_function(xy):
"""A 2D function that returns a 3-element vector."""
x, y = xy
return [x + y, x * y, x - y] # Returns 3-element vector
# Create learner with vector-valued function
learner = Learner2D(vector_function, bounds=[(-1, 1), (-1, 1)])
# Add some initial points
points = [
(0.0, 0.0),
(1.0, 0.0),
(0.0, 1.0),
(1.0, 1.0),
(0.5, 0.5),
(-0.5, 0.5),
(0.5, -0.5),
(-1.0, -1.0),
]
for point in points:
value = vector_function(point)
learner.tell(point, value)
# Run the learner to trigger deviations calculation
# This should not raise any errors
learner.ask(10)
# Verify that the interpolator is created (ip is a property that may return a function)
assert hasattr(learner, "ip")
# Check the internal interpolator if it exists
if hasattr(learner, "_ip") and learner._ip is not None:
# Check that values have the correct shape
assert learner._ip.values.shape[1] == 3 # 3 output dimensions
# Test that we can evaluate the interpolated function
test_point = (0.25, 0.25)
ip_func = learner.interpolator(scaled=True) # Get the interpolator function
if ip_func is not None:
interpolated_value = ip_func(test_point)
assert len(interpolated_value) == 3
# Run more iterations to ensure deviations are computed correctly
simple_run(learner, 20)
# Final verification
assert len(learner.data) > len(points) # Learner added more points
# Check that all values in data are vectors
for _point, value in learner.data.items():
assert len(value) == 3, f"Expected 3-element vector, got {value}" Also, the docstring in the Returns
-------
deviations : list
The deviation per triangle. to: Returns
-------
deviations : numpy.ndarray
The deviation per triangle. I've tested this locally and it passes successfully. This addresses the concern about missing test coverage for vector-valued functions. |
I opened a different pull request but I figured out how to push correctly to this one. All tests seem to pass, so I'll merge now. Thanks a lot @krokosik! |
Ah sorry for not taking care of it, I was a bit swamped. Many thanks! |
Description
I am trying to adapt
adaptive
for sequential operations with a relatively fast function. I really love the package and would like to experiment with speeding up the learners themselves. This is a minor change which does not affect the speed much, as interpolators seem to be bottleneck, but since I already implemented it, I figured I'd share it :)I also noticed the tutorial link in README is broken, so I included it here as it is a tiny change, but can remove it if you'd like.
Checklist
pre-commit run --all
(first install usingpip install pre-commit
)pytest
passedType of change
Check relevant option(s).