Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions warp/_src/builtins.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

Expand Down Expand Up @@ -223,6 +235,7 @@ def static_len_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
doc="Compute the cosh of ``x``.",
group="Scalar Math",
)

add_builtin(
"tanh",
input_types={"x": Float},
Expand Down Expand Up @@ -999,6 +1012,7 @@ def get_diag_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str,

# scalar type constructors between all storage / compute types
scalar_types_all = [*scalar_types, bool, int, float]

for t in scalar_types_all:
for u in scalar_types_all:
add_builtin(
Expand Down Expand Up @@ -9784,7 +9798,7 @@ def matrix_ij_dispatch_func(input_types: Mapping[str, type], return_type: Any, a


def vector_assign_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
vec = args["a"].type
vec = strip_reference(args["a"].type)
idx = args["i"].type
value_type = strip_reference(args["value"].type)

Expand Down Expand Up @@ -9862,7 +9876,7 @@ def vector_assign_dispatch_func(input_types: Mapping[str, type], return_type: An


def vector_assign_copy_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
vec_type = arg_types["a"]
vec_type = strip_reference(arg_types["a"])
return vec_type


Expand Down Expand Up @@ -10081,7 +10095,7 @@ def matrix_vector_sametype(arg_types: Mapping[str, Any]):


def matrix_assign_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
mat = args["a"].type
mat = strip_reference(args["a"].type)
value_type = strip_reference(args["value"].type)

idxs = tuple(args[x].type for x in "ij" if args.get(x, None) is not None)
Expand Down Expand Up @@ -10195,7 +10209,7 @@ def matrix_assign_dispatch_func(input_types: Mapping[str, type], return_type: An


def matrix_assign_copy_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
mat_type = arg_types["a"]
mat_type = strip_reference(arg_types["a"])
return mat_type


Expand Down
99 changes: 80 additions & 19 deletions warp/_src/codegen.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

Expand Down Expand Up @@ -3139,7 +3151,7 @@ def strip_indices(indices, count):

return indices

def recurse_subscript(adj, node, indices):
def recurse_subscript(adj, node, indices, allow_partial=True):
if isinstance(node, ast.Name):
target = adj.eval(node)
return target, indices
Expand All @@ -3161,10 +3173,10 @@ def recurse_subscript(adj, node, indices):

indices = [ij, *indices] # prepend

target, indices = adj.recurse_subscript(node.value, indices)
target, indices = adj.recurse_subscript(node.value, indices, allow_partial)

target_type = strip_reference(target.type)
if is_array(target_type):
if is_array(target_type) and allow_partial:
flat_indices = [i for ij in indices for i in ij]
if len(flat_indices) > target_type.ndim:
target = adj.emit_indexing(target, flat_indices[: target_type.ndim])
Expand All @@ -3176,8 +3188,8 @@ def recurse_subscript(adj, node, indices):
return target, indices

# returns the object being indexed, and the list of indices
def eval_subscript(adj, node):
target, indices = adj.recurse_subscript(node, [])
def eval_subscript(adj, node, allow_partial=True):
target, indices = adj.recurse_subscript(node, [], allow_partial)
flat_indices = [i for ij in indices for i in ij]
return target, flat_indices

Expand Down Expand Up @@ -3211,7 +3223,7 @@ def emit_Assign(adj, node):
# more generally in `adj.eval()`.
if isinstance(node.value, ast.List):
raise WarpCodegenError(
"List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small fixed-size collections, or `wp.zeros(shape=N, dtype=...)` for stack-allocated arrays."
"List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
)

lhs = node.targets[0]
Expand Down Expand Up @@ -3273,22 +3285,45 @@ def emit_Assign(adj, node):
adj.add_forward(f"{var.emit()} = {rhs.emit()};")
return

target, indices = adj.eval_subscript(lhs)
target, indices = adj.eval_subscript(lhs, allow_partial=False)

target_type = strip_reference(target.type)
indices = adj.eval_indices(target_type, indices)

if is_array(target_type):
adj.add_builtin_call("array_store", [target, *indices, rhs])
if len(indices) > target_type.ndim:
# Array vector/matrix component assignment
array_indices = indices[: target_type.ndim]
vec_indices = indices[target_type.ndim :]

if warp.config.verify_autograd_array_access:
kernel_name = adj.fun_name
filename = adj.filename
lineno = adj.lineno + adj.fun_lineno
array_indices = adj.eval_indices(target_type, array_indices)

target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
vec_target = adj.emit_indexing(target, array_indices)
vec_target_type = strip_reference(vec_target.type)

vec_indices = adj.eval_indices(vec_target_type, vec_indices)

new_vec = adj.add_builtin_call("assign_copy", [vec_target, *vec_indices, rhs])
adj.add_builtin_call("array_store", [target, *array_indices, new_vec])

if warp.config.verify_autograd_array_access:
kernel_name = adj.fun_name
filename = adj.filename
lineno = adj.lineno + adj.fun_lineno
target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
return
else:
indices = adj.eval_indices(target_type, indices)
adj.add_builtin_call("array_store", [target, *indices, rhs])

elif is_tile(target_type):
if warp.config.verify_autograd_array_access:
kernel_name = adj.fun_name
filename = adj.filename
lineno = adj.lineno + adj.fun_lineno
target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
return

indices = adj.eval_indices(target_type, indices)
if is_tile(target_type):
adj.add_builtin_call("assign", [target, *indices, rhs])

elif (
Expand All @@ -3307,7 +3342,6 @@ def emit_Assign(adj, node):
adj.add_builtin_call("store", [attr, rhs])
return

# TODO: array vec component case
if is_reference(target.type):
attr = adj.add_builtin_call("indexref", [target, *indices])
adj.add_builtin_call("store", [attr, rhs])
Expand Down Expand Up @@ -3502,8 +3536,6 @@ def make_new_assign_statement():
new_node = ast.Assign(targets=[lhs], value=ast.BinOp(lhs, node.op, node.value))
adj.eval(new_node)

rhs = adj.eval(node.value)

if isinstance(lhs, ast.Subscript):
# wp.adjoint[var] appears in custom grad functions, and does not require
# special consideration in the AugAssign case
Expand All @@ -3512,8 +3544,34 @@ def make_new_assign_statement():
return

target, indices = adj.eval_subscript(lhs)

target_type = strip_reference(target.type)

if is_reference(target.type):
if is_array(target_type) and len(indices) > target_type.ndim:
rhs = adj.eval(node.value)

array_indices = indices[: target_type.ndim]
vec_indices = indices[target_type.ndim :]

array_indices = adj.eval_indices(target_type, array_indices)
vec_target = adj.emit_indexing(target, array_indices)
vec_target_type = strip_reference(vec_target.type)

vec_indices = adj.eval_indices(vec_target_type, vec_indices)
old_val = adj.emit_indexing(vec_target, vec_indices)

op_name = builtin_operators[type(node.op)]
new_val = adj.add_builtin_call(op_name, [old_val, rhs])

new_vec = adj.add_builtin_call("assign_copy", [vec_target, *vec_indices, new_val])
adj.add_builtin_call("array_store", [target, *array_indices, new_vec])
return
Comment on lines +3549 to +3568
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 emit_indexing with over-counted indices in AugAssign

At line 3553, adj.emit_indexing(target, indices) is called where target is an array (reference) and indices contains both array-level and vector/component-level indices — more than target_type.ndim. Inside emit_indexing, this hits the else branch (since len(indices) != target_type.ndim) and emits a view builtin call instead of a scalar address load. That is incorrect semantics for the intended "read the current scalar component, compute new value, write back" operation.

The correct approach is to split the indices first (as is done later for the write-back path at lines 3557–3567) and index step by step:

if is_array(target_type) and len(indices) > target_type.ndim:
    rhs = adj.eval(node.value)

    array_indices = indices[:target_type.ndim]
    vec_indices = indices[target_type.ndim:]

    array_indices = adj.eval_indices(target_type, array_indices)
    vec_target = adj.emit_indexing(target, array_indices)
    vec_target_type = strip_reference(vec_target.type)

    vec_indices = adj.eval_indices(vec_target_type, vec_indices)
    old_val = adj.emit_indexing(vec_target, vec_indices)

    op_name = builtin_operators[type(node.op)]
    new_val = adj.add_builtin_call(op_name, [old_val, rhs])

    new_vec = adj.add_builtin_call("assign_copy", [vec_target, *vec_indices, new_val])
    adj.add_builtin_call("array_store", [target, *array_indices, new_vec])
    return

Note: while in practice eval_subscript with allow_partial=True will have already collapsed array indices into a vector reference (making this branch unreachable for typical kernels), the explicit call to emit_indexing with combined indices would be wrong if the branch were ever reached, and the intent clearly should be to read a single component.

else:
make_new_assign_statement()
return

rhs = adj.eval(node.value)

indices = adj.eval_indices(target_type, indices)

if is_array(target_type):
Expand Down Expand Up @@ -3659,6 +3717,9 @@ def emit_Pass(adj, node):
}

def eval(adj, node):
if isinstance(node, Var):
return node

if hasattr(node, "lineno"):
adj.set_lineno(node.lineno - 1)

Expand Down
15 changes: 13 additions & 2 deletions warp/tests/matrix/test_mat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from typing import Any
Expand Down Expand Up @@ -1653,7 +1665,7 @@ def test_mat_array_extract(test, device):
assert_np_equal(x.grad.numpy(), np.array([[[[1.0, 1.0], [2.0, 2.0]]]], dtype=float))


""" TODO: gradient propagation for in-place array assignment
# TODO: gradient propagation for in-place array assignment
@wp.kernel
def mat_array_assign_element(x: wp.array2d(dtype=float), y: wp.array2d(dtype=wp.mat22)):
i, j = wp.tid()
Expand Down Expand Up @@ -1700,7 +1712,6 @@ def test_mat_array_assign(test, device):

assert_np_equal(y.numpy(), np.array([[[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]]], dtype=float))
assert_np_equal(x.grad.numpy(), np.array([[[3.0, 3.0, 3.0]]], dtype=float))
"""


@wp.kernel
Expand Down
15 changes: 13 additions & 2 deletions warp/tests/test_vec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from typing import Any
Expand Down Expand Up @@ -980,8 +992,7 @@ def run(kernel):
tape.backward()

assert_np_equal(y.numpy(), np.array([[[1.0, 2.0, 3.0]]], dtype=float))
# TODO: gradient propagation for in-place array assignment
# assert_np_equal(x.grad.numpy(), np.array([[6.0]], dtype=float))
assert_np_equal(x.grad.numpy(), np.array([[6.0]], dtype=float))

run(vec_array_assign_subscript)
run(vec_array_assign_attribute)
Expand Down