converting sympy.NumberSymbol to torch.tensor in export_torch.py#726
converting sympy.NumberSymbol to torch.tensor in export_torch.py#726tbuckworth wants to merge 7 commits intoMilesCranmer:masterfrom
Conversation
attempting to address MilesCranmer#656
|
Nice! Do you want to add a unit test for the MWE you described in the issue? |
Pull Request Test Coverage Report for Build 11200885649Details
💛 - Coveralls |
|
Hi I've added a unit test to here, let me know if that is sufficient or if i need to do anything else |
for more information, see https://pre-commit.ci
|
Seems like the test is failing |
|
Apologies, I won't be able to address it until next week |
|
No worries! |
|
I figured out what's going on.
But So i have added a line to make perhaps you would rather fix it at a different level of abstraction? |
MilesCranmer
left a comment
There was a problem hiding this comment.
I wonder if the underlying issue is the use of issubclass instead of isinstance... Maybe try replacing those conditions with the following code?
def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
super().__init__(**kwargs)
self._sympy_func = expr.func
if isinstance(expr, sympy.Float):
self._value = torch.nn.Parameter(torch.tensor(float(expr)))
self._torch_func = lambda: self._value
self._args = ()
elif isinstance(expr, sympy.Rational) and not isinstance(expr, sympy.Integer):
# This is some fraction fixed in the operator.
self._value = float(expr)
self._torch_func = lambda: self._value
self._args = ()
elif isinstance(expr, sympy.UnevaluatedExpr):
if len(expr.args) != 1 or not isinstance(expr.args[0], sympy.Float):
raise ValueError(
"UnevaluatedExpr should only be used to wrap floats."
)
self.register_buffer("_value", torch.tensor(float(expr.args[0])))
self._torch_func = lambda: self._value
self._args = ()
elif isinstance(expr, sympy.Integer):
# Handles Integer special cases like NegativeOne, One, Zero
self._value = int(expr)
self._torch_func = lambda: self._value
self._args = ()
elif isinstance(expr, sympy.NumberSymbol):
# Handles mathematical constants like pi, E
self._value = float(expr)
self._torch_func = lambda: self._value
self._args = ()
elif isinstance(expr, sympy.Symbol):
self._name = expr.name
self._torch_func = lambda value: value
self._args = ((lambda memodict: memodict[expr.name]),)
else:
try:
self._torch_func = _func_lookup[expr.func]
except KeyError:
raise KeyError(
f"Function {expr.func} was not found in Torch function mappings. "
"Please add it to extra_torch_mappings in the format, e.g., "
"{sympy.sqrt: torch.sqrt}."
)
args = []
for arg in expr.args:
try:
arg_ = _memodict[arg]
except KeyError:
arg_ = type(self)(
expr=arg,
_memodict=_memodict,
_func_lookup=_func_lookup,
**kwargs,
)
_memodict[arg] = arg_
args.append(arg_)
self._args = torch.nn.ModuleList(args)|
Thanks! That code gives the following error
note that it is now saying that's due to this code: elif isinstance(expr, sympy.Integer):
# Handles Integer special cases like NegativeOne, One, Zero
self._value = int(expr)
self._torch_func = lambda: self._value
self._args = ()is there a reason that for the would it be ok to just add the following code before the integer case? elif isinstance(expr, sympy.core.numbers.One):
# Handles Integer special cases like NegativeOne, One, Zero
self._value = torch.tensor(int(expr))
self._torch_func = lambda: self._value
self._args = ()or should it be either way, it then passes the test |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #726 +/- ##
=======================================
Coverage 93.96% 93.96%
=======================================
Files 21 21
Lines 1641 1641
=======================================
Hits 1542 1542
Misses 99 99 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Ok I moved this to #1058, sorry for the long delay! |
attempting to address #656