Skip to content

Commit

Permalink
Fix precompiled query handling of non-captured variables flowing in
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Jan 20, 2025
1 parent 3d9cdf9 commit 546845c
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 253 deletions.
20 changes: 10 additions & 10 deletions src/EFCore.Design/Query/Internal/CSharpToLinqTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ INamedTypeSymbol GetTypeSymbolOrThrow(string fullyQualifiedMetadataName)
private readonly Stack<ImmutableDictionary<string, ParameterExpression>> _parameterStack
= new(new[] { ImmutableDictionary<string, ParameterExpression>.Empty });

private readonly Dictionary<ISymbol, MemberExpression?> _capturedVariables = new(SymbolEqualityComparer.Default);
private readonly Dictionary<ISymbol, MemberExpression?> _dataFlowsIn = new(SymbolEqualityComparer.Default);

/// <summary>
/// Translates a Roslyn syntax tree into a LINQ expression tree.
Expand Down Expand Up @@ -100,11 +100,11 @@ public virtual Expression Translate(SyntaxNode node, SemanticModel semanticModel

_semanticModel = semanticModel;

// Perform data flow analysis to detect all captured data (closure parameters)
_capturedVariables.Clear();
foreach (var captured in _semanticModel.AnalyzeDataFlow(node).Captured)
// Perform data flow analysis to detect all variables flowing into the query (e.g. captured variables)
_dataFlowsIn.Clear();
foreach (var flowsIn in _semanticModel.AnalyzeDataFlow(node).DataFlowsIn)
{
_capturedVariables[captured] = null;
_dataFlowsIn[flowsIn] = null;
}

var result = Visit(node);
Expand Down Expand Up @@ -445,13 +445,13 @@ public override Expression VisitIdentifierName(IdentifierNameSyntax identifierNa
return Constant(_userDbContext);
}

// The Translate entry point into the translator uses Roslyn's data flow analysis to locate all captured variables, and populates
// the _capturedVariable dictionary with them (with null values).
if (symbol is ILocalSymbol localSymbol && _capturedVariables.TryGetValue(localSymbol, out var memberExpression))
// The Translate entry point into the translator uses Roslyn's data flow analysis to locate all local variables flowing in
// (e.g. captured variables), and populates the _dataFlowsIn dictionary with them (with null values).
if (symbol is ILocalSymbol localSymbol && _dataFlowsIn.TryGetValue(localSymbol, out var memberExpression))
{
// The first time we see a captured variable, we create MemberExpression for it and cache it in _capturedVariables.
// The first time we see a flowing-in variable, we create MemberExpression for it and cache it in _dataFlowsIn.
return memberExpression
?? (_capturedVariables[localSymbol] =
?? (_dataFlowsIn[localSymbol] =
Field(
Constant(new FakeClosureFrameClass()),
new FakeFieldInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public virtual Task Invoke_no_evaluatability_is_not_supported()
.Where(Expression.Lambda<Func<Blog, bool>>(Expression.Invoke(lambda, parameter), parameter))
.ToListAsync();
""",
errorAsserter: errors => Assert.IsType<InvalidOperationException>(errors.Single().Exception));
errorAsserter: errors => Assert.IsType<ArgumentNullException>(errors.Single().Exception));

[ConditionalFact]
public virtual Task ListInit_no_evaluatability()
Expand Down Expand Up @@ -202,6 +202,58 @@ public virtual Task Unary()

#endregion Expression types

#region Regular operators

[ConditionalFact]
public virtual Task OrderBy()
=> Test("_ = await context.Blogs.OrderBy(b => b.Name).ToListAsync();");

[ConditionalFact]
public virtual Task Skip_with_constant()
=> Test("_ = await context.Blogs.OrderBy(b => b.Name).Skip(1).ToListAsync();");

[ConditionalFact]
public virtual Task Skip_with_parameter()
=> Test(
"""
var toSkip = 1;
_ = await context.Blogs.OrderBy(b => b.Name).Skip(toSkip).ToListAsync();
""");

[ConditionalFact]
public virtual Task Take_with_constant()
=> Test("_ = await context.Blogs.OrderBy(b => b.Name).Take(1).ToListAsync();");

[ConditionalFact]
public virtual Task Take_with_parameter()
=> Test(
"""
var toTake = 1;
_ = await context.Blogs.OrderBy(b => b.Name).Take(toTake).ToListAsync();
""");

[ConditionalFact]
public virtual Task Select_changes_type()
=> Test("_ = await context.Blogs.Select(b => b.Name).ToListAsync();");

[ConditionalFact]
public virtual Task Select_anonymous_object()
=> Test("""_ = await context.Blogs.Select(b => new { Foo = b.Name + "Foo" }).ToListAsync();""");

[ConditionalFact]
public virtual Task Include_single()
=> Test("var blogs = await context.Blogs.Include(b => b.Posts).Where(b => b.Id > 8).ToListAsync();");

[ConditionalFact]
public virtual Task Include_split()
=> Test("var blogs = await context.Blogs.AsSplitQuery().Include(b => b.Posts).ToListAsync();");

[ConditionalFact]
public virtual Task Final_GroupBy()
=> Test("""var blogs = await context.Blogs.GroupBy(b => b.Name).ToListAsync();""");

#endregion Regular operators

#region Terminating operators

[ConditionalFact]
Expand Down Expand Up @@ -772,6 +824,16 @@ public virtual Task Terminating_ExecuteUpdateAsync_without_lambda()
var rowsAffected = await context.Blogs.Where(b => b.Id > 8).ExecuteUpdateAsync(setters => setters.SetProperty(b => b.Name, newValue));
Assert.Equal(1, rowsAffected);
Assert.Equal(1, await context.Blogs.CountAsync(b => b.Id == 9 && b.Name == "NewValue"));
""");

[ConditionalFact] // #35494
public virtual Task Terminating_with_cancellation_token()
=> Test(
"""
CancellationTokenSource source = new CancellationTokenSource();
CancellationToken token = source.Token;
Assert.Equal("Blog1", (await context.Blogs.Where(b => b.Id == 8).FirstOrDefaultAsync(token)).Name);
Assert.Null(await context.Blogs.Where(b => b.Id == 7).FirstOrDefaultAsync(token));
""");

#endregion Reducing terminating operators
Expand Down Expand Up @@ -1010,6 +1072,60 @@ private static PrecompiledQueryContext GetContext()

#endregion Different DbContext expressions

#region Captured variable handling

[ConditionalFact]
public virtual Task Two_captured_variables_in_same_lambda()
=> Test(
"""
var yes = "yes";
var no = "no";
var blogs = await context.Blogs.Select(b => b.Id == 3 ? yes : no).ToListAsync();
""");

[ConditionalFact]
public virtual Task Two_captured_variables_in_different_lambdas()
=> Test(
"""
var starts = "Blog";
var ends = "2";
var blog = await context.Blogs.Where(b => b.Name.StartsWith(starts)).Where(b => b.Name.EndsWith(ends)).SingleAsync();
Assert.Equal(9, blog.Id);
""");

[ConditionalFact]
public virtual Task Same_captured_variable_twice_in_same_lambda()
=> Test(
"""
var foo = "X";
var blogs = await context.Blogs.Where(b => b.Name.StartsWith(foo) && b.Name.EndsWith(foo)).ToListAsync();
""");

[ConditionalFact]
public virtual Task Same_captured_variable_twice_in_different_lambdas()
=> Test(
"""
var foo = "X";
var blogs = await context.Blogs.Where(b => b.Name.StartsWith(foo)).Where(b => b.Name.EndsWith(foo)).ToListAsync();
""");

[ConditionalFact]
public virtual Task Multiple_queries_with_captured_variables()
=> Test(
"""
var id1 = 8;
var id2 = 9;
var blogs = await context.Blogs.Where(b => b.Id == id1 || b.Id == id2).ToListAsync();
var blog1 = await context.Blogs.Where(b => b.Id == id1).SingleAsync();
Assert.Collection(
blogs.OrderBy(b => b.Id),
b => Assert.Equal(8, b.Id),
b => Assert.Equal(9, b.Id));
Assert.Equal("Blog1", blog1.Name);
""");

#endregion Captured variable handling

#region Negative cases

[ConditionalFact]
Expand Down Expand Up @@ -1093,88 +1209,6 @@ where b.Id > 8

#endregion Negative cases

[ConditionalFact]
public virtual Task Select_changes_type()
=> Test("_ = await context.Blogs.Select(b => b.Name).ToListAsync();");

[ConditionalFact]
public virtual Task OrderBy()
=> Test("_ = await context.Blogs.OrderBy(b => b.Name).ToListAsync();");

[ConditionalFact]
public virtual Task Skip()
=> Test("_ = await context.Blogs.OrderBy(b => b.Name).Skip(1).ToListAsync();");

[ConditionalFact]
public virtual Task Take()
=> Test("_ = await context.Blogs.OrderBy(b => b.Name).Take(1).ToListAsync();");

[ConditionalFact]
public virtual Task Project_anonymous_object()
=> Test("""_ = await context.Blogs.Select(b => new { Foo = b.Name + "Foo" }).ToListAsync();""");

[ConditionalFact]
public virtual Task Two_captured_variables_in_same_lambda()
=> Test(
"""
var yes = "yes";
var no = "no";
var blogs = await context.Blogs.Select(b => b.Id == 3 ? yes : no).ToListAsync();
""");

[ConditionalFact]
public virtual Task Two_captured_variables_in_different_lambdas()
=> Test(
"""
var starts = "Blog";
var ends = "2";
var blog = await context.Blogs.Where(b => b.Name.StartsWith(starts)).Where(b => b.Name.EndsWith(ends)).SingleAsync();
Assert.Equal(9, blog.Id);
""");

[ConditionalFact]
public virtual Task Same_captured_variable_twice_in_same_lambda()
=> Test(
"""
var foo = "X";
var blogs = await context.Blogs.Where(b => b.Name.StartsWith(foo) && b.Name.EndsWith(foo)).ToListAsync();
""");

[ConditionalFact]
public virtual Task Same_captured_variable_twice_in_different_lambdas()
=> Test(
"""
var foo = "X";
var blogs = await context.Blogs.Where(b => b.Name.StartsWith(foo)).Where(b => b.Name.EndsWith(foo)).ToListAsync();
""");

[ConditionalFact]
public virtual Task Include_single()
=> Test("var blogs = await context.Blogs.Include(b => b.Posts).Where(b => b.Id > 8).ToListAsync();");

[ConditionalFact]
public virtual Task Include_split()
=> Test("var blogs = await context.Blogs.AsSplitQuery().Include(b => b.Posts).ToListAsync();");

[ConditionalFact]
public virtual Task Final_GroupBy()
=> Test("""var blogs = await context.Blogs.GroupBy(b => b.Name).ToListAsync();""");

[ConditionalFact]
public virtual Task Multiple_queries_with_captured_variables()
=> Test(
"""
var id1 = 8;
var id2 = 9;
var blogs = await context.Blogs.Where(b => b.Id == id1 || b.Id == id2).ToListAsync();
var blog1 = await context.Blogs.Where(b => b.Id == id1).SingleAsync();
Assert.Collection(
blogs.OrderBy(b => b.Id),
b => Assert.Equal(8, b.Id),
b => Assert.Equal(9, b.Id));
Assert.Equal("Blog1", blog1.Name);
""");

[ConditionalFact]
public virtual Task Unsafe_accessor_gets_generated_once_for_multiple_queries()
=> Test(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ public async Task FullSourceTest(
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using System.Text.RegularExpressions;
using Microsoft.EntityFrameworkCore;
Expand Down
Loading

0 comments on commit 546845c

Please sign in to comment.