diff --git a/firedrake/external_operators/__init__.py b/firedrake/external_operators/__init__.py new file mode 100644 index 0000000000..d37df34173 --- /dev/null +++ b/firedrake/external_operators/__init__.py @@ -0,0 +1 @@ +from firedrake.external_operators.neural_networks import * \ No newline at end of file diff --git a/firedrake/external_operators/neural_networks/__init__.py b/firedrake/external_operators/neural_networks/__init__.py new file mode 100644 index 0000000000..6be941f2dd --- /dev/null +++ b/firedrake/external_operators/neural_networks/__init__.py @@ -0,0 +1,2 @@ +from .backends import get_backend +from .ml_backend_coupling import HybridOperator, torch_operator \ No newline at end of file diff --git a/firedrake/external_operators/neural_networks/ml_backend_coupling.py b/firedrake/external_operators/neural_networks/ml_backend_coupling.py index c1c5463d13..bc5fff7b28 100644 --- a/firedrake/external_operators/neural_networks/ml_backend_coupling.py +++ b/firedrake/external_operators/neural_networks/ml_backend_coupling.py @@ -31,5 +31,5 @@ def __call__(self, *ω): return φ(*ω) -def torch_op(*args, **kwargs): +def torch_operator(*args, **kwargs): return HybridOperator(*args, **kwargs)