Skip to content

Commit 482ec82

Browse files
rojicincuranet
andcommitted
Fix null check optimization for IQueryable/DbSet types
Fixes #35598 Co-Authored-By: cincuranet <[email protected]>
1 parent 6dd2adb commit 482ec82

File tree

5 files changed

+327
-11
lines changed

5 files changed

+327
-11
lines changed

src/EFCore/Query/Internal/ExpressionTreeFuncletizer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -876,8 +876,8 @@ protected override Expression VisitMember(MemberExpression member)
876876
if (_state.IsEvaluatable)
877877
{
878878
// If the query contains a captured variable that's a nested IQueryable, inline it into the main query.
879-
// Otherwise, evaluation of a terminating operator up the call chain will cause us to execute the query and do another
880-
// roundtrip.
879+
// Note that we do this only for IQueryable; evaluation of a terminating operator up the call chain would cause us to execute
880+
// the query and do another roundtrip.
881881
// Note that we only do this when the MemberExpression is typed as IQueryable/IOrderedQueryable; this notably excludes
882882
// DbSet captured variables integrated directly into the query, as that also evaluates e.g. context.Order in
883883
// context.Order.FromSqlInterpolated(), which fails.

src/EFCore/Query/Internal/NullCheckRemovingExpressionVisitor.cs

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
2424
{
2525
var visitedExpression = base.VisitBinary(binaryExpression);
2626

27-
return TryOptimizeConditionalEquality(visitedExpression) ?? visitedExpression;
27+
return TryOptimizeConditionalEquality(visitedExpression)
28+
?? ProcessNullCheck(visitedExpression)
29+
?? visitedExpression;
2830
}
2931

3032
/// <summary>
@@ -39,14 +41,14 @@ protected override Expression VisitConditional(ConditionalExpression conditional
3941

4042
if (test is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binaryTest)
4143
{
42-
var isLeftNullConstant = IsNullConstant(binaryTest.Left);
43-
var isRightNullConstant = IsNullConstant(binaryTest.Right);
44+
var isLeftNullConstant = binaryTest.Left is ConstantExpression { Value: null };
45+
var isRightNullConstant = binaryTest.Right is ConstantExpression { Value: null };
4446

4547
if ((isLeftNullConstant == isRightNullConstant)
4648
|| (binaryTest.NodeType == ExpressionType.Equal
47-
&& !IsNullConstant(conditionalExpression.IfTrue))
49+
&& conditionalExpression.IfTrue is not ConstantExpression { Value: null })
4850
|| (binaryTest.NodeType == ExpressionType.NotEqual
49-
&& !IsNullConstant(conditionalExpression.IfFalse)))
51+
&& conditionalExpression.IfFalse is not ConstantExpression { Value: null }))
5052
{
5153
return conditionalExpression;
5254
}
@@ -75,7 +77,7 @@ protected override Expression VisitConditional(ConditionalExpression conditional
7577
return base.VisitConditional(conditionalExpression);
7678
}
7779

78-
private static Expression? TryOptimizeConditionalEquality(Expression expression)
80+
private static BinaryExpression? TryOptimizeConditionalEquality(Expression expression)
7981
{
8082
// Simplify (a ? b : null) == null => !a || b == null
8183
// Simplify (a ? null : b) == null => a || b == null
@@ -115,6 +117,45 @@ protected override Expression VisitConditional(ConditionalExpression conditional
115117
return null;
116118
}
117119

120+
private static ConstantExpression? ProcessNullCheck(Expression expression)
121+
{
122+
// Optimize IQueryable/DbSet null checks for expressions that are guaranteed to be non-null:
123+
// * queryableMethodCall != null => true
124+
// * queryableMethodCall == null => false
125+
// This applies to method calls and member accesses that produce IQueryable results, which can never be null.
126+
// We do NOT optimize null checks for parameters/variables, as they could legitimately be null.
127+
if (expression is BinaryExpression
128+
{
129+
NodeType: ExpressionType.Equal or ExpressionType.NotEqual
130+
} binaryExpression)
131+
{
132+
var isLeftNull = binaryExpression.Left is ConstantExpression { Value: null };
133+
var isRightNull = binaryExpression.Right is ConstantExpression { Value: null };
134+
135+
if (isLeftNull || isRightNull)
136+
{
137+
// null == null => true
138+
// null != null => false
139+
if (isLeftNull && isRightNull)
140+
{
141+
return Expression.Constant(binaryExpression.NodeType is ExpressionType.Equal);
142+
}
143+
144+
var nonNullExpression = isLeftNull ? binaryExpression.Right : binaryExpression.Left;
145+
146+
// Only optimize if the expression is a query operation that cannot be null
147+
// (method call returning IQueryable, DbSet property access, or QueryRootExpression)
148+
if (nonNullExpression.Type.IsAssignableTo(typeof(IQueryable))
149+
&& (nonNullExpression is MethodCallExpression or QueryRootExpression))
150+
{
151+
return Expression.Constant(binaryExpression.NodeType is ExpressionType.NotEqual);
152+
}
153+
}
154+
}
155+
156+
return null;
157+
}
158+
118159
private sealed class NullSafeAccessVerifyingExpressionVisitor : ExpressionVisitor
119160
{
120161
private readonly ISet<Expression> _nullSafeAccesses = new HashSet<Expression>(ExpressionEqualityComparer.Instance);
@@ -158,7 +199,4 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
158199
return unaryExpression;
159200
}
160201
}
161-
162-
private static bool IsNullConstant(Expression expression)
163-
=> expression is ConstantExpression { Value: null };
164202
}

test/EFCore.Cosmos.FunctionalTests/Query/NorthwindWhereQueryCosmosTest.cs

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1699,6 +1699,112 @@ public override async Task Where_Queryable_AsEnumerable_Contains_negated(bool as
16991699
AssertSql();
17001700
}
17011701

1702+
public override async Task Where_Queryable_conditional_not_null_check_with_Contains(bool async, bool withNull)
1703+
{
1704+
if (withNull)
1705+
{
1706+
await Fixture.NoSyncTest(
1707+
async, async a =>
1708+
{
1709+
await base.Where_Queryable_conditional_not_null_check_with_Contains(a, withNull);
1710+
1711+
AssertSql(
1712+
"""
1713+
SELECT VALUE c
1714+
FROM root c
1715+
WHERE false
1716+
""");
1717+
});
1718+
}
1719+
else
1720+
{
1721+
// Cosmos client evaluation. Issue #17246.
1722+
await AssertTranslationFailed(() => base.Where_Queryable_conditional_not_null_check_with_Contains(async, withNull));
1723+
1724+
AssertSql();
1725+
}
1726+
}
1727+
1728+
public override async Task Where_Queryable_conditional_null_check_with_Contains(bool async, bool withNull)
1729+
{
1730+
if (withNull)
1731+
{
1732+
await Fixture.NoSyncTest(
1733+
async, async a =>
1734+
{
1735+
await base.Where_Queryable_conditional_null_check_with_Contains(a, withNull);
1736+
1737+
AssertSql(
1738+
"""
1739+
SELECT VALUE c
1740+
FROM root c
1741+
""");
1742+
});
1743+
}
1744+
else
1745+
{
1746+
// Cosmos client evaluation. Issue #17246.
1747+
await AssertTranslationFailed(() => base.Where_Queryable_conditional_null_check_with_Contains(async, withNull));
1748+
1749+
AssertSql();
1750+
}
1751+
}
1752+
1753+
public override Task Where_Enumerable_conditional_not_null_check_with_Contains(bool async, bool withNull)
1754+
=> Fixture.NoSyncTest(
1755+
async, async a =>
1756+
{
1757+
await base.Where_Enumerable_conditional_not_null_check_with_Contains(a, withNull);
1758+
1759+
if (withNull)
1760+
{
1761+
AssertSql(
1762+
"""
1763+
SELECT VALUE c
1764+
FROM root c
1765+
WHERE false
1766+
""");
1767+
}
1768+
else
1769+
{
1770+
AssertSql(
1771+
"""
1772+
@ids='["ALFKI","ANATR"]'
1773+
1774+
SELECT VALUE c
1775+
FROM root c
1776+
WHERE ARRAY_CONTAINS(@ids, c["id"])
1777+
""");
1778+
}
1779+
});
1780+
1781+
public override Task Where_Enumerable_conditional_null_check_with_Contains(bool async, bool withNull)
1782+
=> Fixture.NoSyncTest(
1783+
async, async a =>
1784+
{
1785+
await base.Where_Enumerable_conditional_null_check_with_Contains(a, withNull);
1786+
1787+
if (withNull)
1788+
{
1789+
AssertSql(
1790+
"""
1791+
SELECT VALUE c
1792+
FROM root c
1793+
""");
1794+
}
1795+
else
1796+
{
1797+
AssertSql(
1798+
"""
1799+
@ids='["ALFKI","ANATR"]'
1800+
1801+
SELECT VALUE c
1802+
FROM root c
1803+
WHERE NOT(ARRAY_CONTAINS(@ids, c["id"]))
1804+
""");
1805+
}
1806+
});
1807+
17021808
public override Task Where_list_object_contains_over_value_type(bool async)
17031809
=> Fixture.NoSyncTest(
17041810
async, async a =>

test/EFCore.Specification.Tests/Query/NorthwindWhereQueryTestBase.cs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,6 +1438,70 @@ public virtual Task Where_Queryable_ToArray_Length_member(bool async)
14381438
assertOrder: true,
14391439
elementAsserter: (e, a) => AssertCollection(e, a));
14401440

1441+
[ConditionalTheory]
1442+
[InlineData(true, true)]
1443+
[InlineData(true, false)]
1444+
[InlineData(false, true)]
1445+
[InlineData(false, false)]
1446+
public virtual Task Where_Queryable_conditional_not_null_check_with_Contains(bool async, bool withNull)
1447+
=> AssertQuery(
1448+
async,
1449+
ss =>
1450+
{
1451+
var ids = withNull ? null : ss.Set<Customer>().Where(c => c.CustomerID != "ALFKI").Select(c => c.CustomerID);
1452+
return ss.Set<Customer>().Where(c => ids != null && ids.Contains(c.CustomerID));
1453+
},
1454+
assertEmpty: withNull);
1455+
1456+
[ConditionalTheory]
1457+
[InlineData(true, true)]
1458+
[InlineData(true, false)]
1459+
[InlineData(false, true)]
1460+
[InlineData(false, false)]
1461+
public virtual Task Where_Queryable_conditional_null_check_with_Contains(bool async, bool withNull)
1462+
=> AssertQuery(
1463+
async,
1464+
ss =>
1465+
{
1466+
var ids = withNull ? null : ss.Set<Customer>().Where(c => c.CustomerID != "ALFKI").Select(c => c.CustomerID);
1467+
return ss.Set<Customer>().Where(c => ids == null || !ids.Contains(c.CustomerID));
1468+
});
1469+
1470+
[ConditionalTheory]
1471+
[InlineData(true, true)]
1472+
[InlineData(true, false)]
1473+
[InlineData(false, true)]
1474+
[InlineData(false, false)]
1475+
public virtual Task Where_Enumerable_conditional_not_null_check_with_Contains(bool async, bool withNull)
1476+
=> AssertQuery(
1477+
async,
1478+
ss =>
1479+
{
1480+
// Check also with Enumerable here so we don't handle the null check
1481+
// incorrectly because Contains is coverted in
1482+
// QueryableMethodNormalizingExpressionVisitor.TryConvertCollectionContainsToQueryableContains.
1483+
List<string> ids = withNull ? null : ["ALFKI", "ANATR"];
1484+
return ss.Set<Customer>().Where(c => ids != null && ids.Contains(c.CustomerID));
1485+
},
1486+
assertEmpty: withNull);
1487+
1488+
[ConditionalTheory]
1489+
[InlineData(true, true)]
1490+
[InlineData(true, false)]
1491+
[InlineData(false, true)]
1492+
[InlineData(false, false)]
1493+
public virtual Task Where_Enumerable_conditional_null_check_with_Contains(bool async, bool withNull)
1494+
=> AssertQuery(
1495+
async,
1496+
ss =>
1497+
{
1498+
// Check also with Enumerable here so we don't handle the null check
1499+
// incorrectly because Contains is coverted in
1500+
// QueryableMethodNormalizingExpressionVisitor.TryConvertCollectionContainsToQueryableContains.
1501+
List<string> ids = withNull ? null : ["ALFKI", "ANATR"];
1502+
return ss.Set<Customer>().Where(c => ids == null || !ids.Contains(c.CustomerID));
1503+
});
1504+
14411505
[ConditionalTheory, MemberData(nameof(IsAsyncData))]
14421506
public virtual Task Where_collection_navigation_ToList_Count(bool async)
14431507
=> AssertQuery(

0 commit comments

Comments
 (0)