Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ref switch expressions #69575

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 181 additions & 5 deletions src/Compilers/CSharp/Portable/Binder/Binder.ValueChecks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.CodeAnalysis.CSharp.CodeGen;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Text;
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis.CSharp
Expand Down Expand Up @@ -493,6 +494,17 @@ private BoundExpression CheckValue(BoundExpression expr, BindValueKind valueKind
}
}

if (expr.Kind == BoundKind.UnconvertedSwitchExpression &&
expr.Type is not null &&
valueKind is BindValueKind.RValue or BindValueKind.Assignable)
{
var switchExpression = (BoundUnconvertedSwitchExpression)expr;
if (switchExpression.IsRef)
{
expr = ConvertSwitchExpression(switchExpression, expr.Type, null, diagnostics);
}
}

if (!hasResolutionErrors && CheckValueKind(expr.Syntax, expr, valueKind, checkingReceiver: false, diagnostics: diagnostics) ||
expr.HasAnyErrors && valueKind == BindValueKind.RValueOrMethodGroup)
{
Expand Down Expand Up @@ -550,6 +562,14 @@ internal bool CheckValueKind(SyntaxNode node, BoundExpression expr, BindValueKin

case BoundKind.EventAccess:
return CheckEventValueKind((BoundEventAccess)expr, valueKind, diagnostics);

case BoundKind.UnconvertedSwitchExpression:
case BoundKind.ConvertedSwitchExpression:
bool check = CheckSwitchExpressionValueKind((BoundSwitchExpression)expr, valueKind, diagnostics);
if (!check)
return false;

break;
}

// easy out for a very common RValue case.
Expand All @@ -566,6 +586,8 @@ internal bool CheckValueKind(SyntaxNode node, BoundExpression expr, BindValueKin
return false;
}

var errorSpan = node.Span;

switch (expr.Kind)
{
case BoundKind.NamespaceExpression:
Expand Down Expand Up @@ -789,13 +811,52 @@ internal bool CheckValueKind(SyntaxNode node, BoundExpression expr, BindValueKin
// Strict RValue
break;

case BoundKind.UnconvertedSwitchExpression:
case BoundKind.ConvertedSwitchExpression:
var switchExpression = (BoundSwitchExpression)expr;

var switchExpressionNode = (CSharp.Syntax.SwitchExpressionSyntax)switchExpression.Syntax;
errorSpan = TextSpan.FromBounds(switchExpressionNode.SpanStart, switchExpressionNode.SwitchKeyword.Span.End);

if (switchExpression.IsRef)
{
Debug.Assert(switchExpression.SwitchArms.Length > 0, "By-ref switch expressions must always have at least one switch arm");

// defer check to the switch arms' values
bool check = true;
foreach (var arm in switchExpression.SwitchArms)
{
// Specially handle throw expressions in arms because they are not treated the same elsewhere
if (arm.Value is BoundThrowExpression or BoundConversion { Operand: BoundThrowExpression })
continue;

check &= CheckValueKind(arm.Value.Syntax, arm.Value, valueKind, checkingReceiver: false, diagnostics: diagnostics);
}
if (check)
return true;
break;
}
else
{
if (RequiresReferenceToLocation(valueKind))
{
// We use a different error to better hint the user about the feature
var switchExpressionErrorLocation = Location.Create(node.SyntaxTree, errorSpan);
Error(diagnostics, ErrorCode.ERR_RefOnNonRefSwitchExpression, switchExpressionErrorLocation);
return false;
}
}

return true;

default:
Debug.Assert(expr is not BoundValuePlaceholderBase, $"Placeholder kind {expr.Kind} should be explicitly handled");
break;
}

// At this point we should have covered all the possible cases for anything that is not a strict RValue.
Error(diagnostics, GetStandardLvalueError(valueKind), node);
var errorLocation = Location.Create(node.SyntaxTree, errorSpan);
Error(diagnostics, GetStandardLvalueError(valueKind), errorLocation);
return false;

bool checkArrayAccessValueKind(SyntaxNode node, BindValueKind valueKind, ImmutableArray<BoundExpression> indices, BindingDiagnosticBag diagnostics)
Expand Down Expand Up @@ -1128,14 +1189,14 @@ private bool CheckParameterRefEscape(SyntaxNode node, BoundExpression parameter,
{
(checkingReceiver: true, isRefScoped: true, inUnsafeRegion: false, _) => (ErrorCode.ERR_RefReturnScopedParameter2, parameter.Syntax),
(checkingReceiver: true, isRefScoped: true, inUnsafeRegion: true, _) => (ErrorCode.WRN_RefReturnScopedParameter2, parameter.Syntax),
(checkingReceiver: true, isRefScoped: false, inUnsafeRegion: false, ReturnOnlyScope) => (ErrorCode.ERR_RefReturnOnlyParameter2, parameter.Syntax),
(checkingReceiver: true, isRefScoped: false, inUnsafeRegion: true, ReturnOnlyScope) => (ErrorCode.WRN_RefReturnOnlyParameter2, parameter.Syntax),
(checkingReceiver: true, isRefScoped: false, inUnsafeRegion: false, ReturnOnlyScope) => (ErrorCode.ERR_RefReturnOnlyParameter2, parameter.Syntax),
(checkingReceiver: true, isRefScoped: false, inUnsafeRegion: true, ReturnOnlyScope) => (ErrorCode.WRN_RefReturnOnlyParameter2, parameter.Syntax),
(checkingReceiver: true, isRefScoped: false, inUnsafeRegion: false, _) => (ErrorCode.ERR_RefReturnParameter2, parameter.Syntax),
(checkingReceiver: true, isRefScoped: false, inUnsafeRegion: true, _) => (ErrorCode.WRN_RefReturnParameter2, parameter.Syntax),
(checkingReceiver: false, isRefScoped: true, inUnsafeRegion: false, _) => (ErrorCode.ERR_RefReturnScopedParameter, node),
(checkingReceiver: false, isRefScoped: true, inUnsafeRegion: true, _) => (ErrorCode.WRN_RefReturnScopedParameter, node),
(checkingReceiver: false, isRefScoped: false, inUnsafeRegion: false, ReturnOnlyScope) => (ErrorCode.ERR_RefReturnOnlyParameter, node),
(checkingReceiver: false, isRefScoped: false, inUnsafeRegion: true, ReturnOnlyScope) => (ErrorCode.WRN_RefReturnOnlyParameter, node),
(checkingReceiver: false, isRefScoped: false, inUnsafeRegion: false, ReturnOnlyScope) => (ErrorCode.ERR_RefReturnOnlyParameter, node),
(checkingReceiver: false, isRefScoped: false, inUnsafeRegion: true, ReturnOnlyScope) => (ErrorCode.WRN_RefReturnOnlyParameter, node),
(checkingReceiver: false, isRefScoped: false, inUnsafeRegion: false, _) => (ErrorCode.ERR_RefReturnParameter, node),
(checkingReceiver: false, isRefScoped: false, inUnsafeRegion: true, _) => (ErrorCode.WRN_RefReturnParameter, node)
};
Expand Down Expand Up @@ -1353,6 +1414,61 @@ private bool CheckFieldLikeEventRefEscape(SyntaxNode node, BoundEventAccess even
}
}

partial class RefSafetyAnalysis
{
private void ValidateRefSwitchExpression(SyntaxNode node, ImmutableArray<BoundSwitchExpressionArm> arms, BindingDiagnosticBag diagnostics)
{
var currentScope = _localScopeDepth;

var expressionEscapes = PooledHashSet<(BoundExpression expression, uint escape)>.GetInstance();

bool hasSameEscapes = true;
uint minEscape = uint.MaxValue;

// val-escape must agree on all arms
foreach (var arm in arms)
{
var expression = arm.Value;

if (expression is BoundConversion conversion)
{
Debug.Assert(conversion is { Operand: BoundThrowExpression });
continue;
}

uint expressionEscape = GetValEscape(expression, currentScope);
if (expressionEscapes.Count > 0)
{
if (minEscape != expressionEscape)
{
hasSameEscapes = false;
}
}
minEscape = Math.Min(minEscape, expressionEscape);
expressionEscapes.Add((expression, expressionEscape));
}

if (!hasSameEscapes)
{
// pass through all the expressions whose value escape was calculated
// ask the ones with narrower escape, for the wider
foreach (var expressionEscape in expressionEscapes)
{
var (expression, escape) = expressionEscape;
if (escape != minEscape)
{
Debug.Assert(escape > minEscape);
CheckValEscape(expression.Syntax, expression, currentScope, minEscape, checkingReceiver: false, diagnostics: diagnostics);
}
}

diagnostics.Add(_inUnsafeRegion ? ErrorCode.WRN_MismatchedRefEscapeInSwitchExpression : ErrorCode.ERR_MismatchedRefEscapeInSwitchExpression, node.Location);
}

expressionEscapes.Free();
}
}

internal partial class Binder
{
private bool CheckEventValueKind(BoundEventAccess boundEvent, BindValueKind valueKind, BindingDiagnosticBag diagnostics)
Expand Down Expand Up @@ -1495,6 +1611,31 @@ protected bool CheckMethodReturnValueKind(

}

private bool CheckSwitchExpressionValueKind(BoundSwitchExpression expression, BindValueKind valueKind, BindingDiagnosticBag diagnostics)
{
Debug.Assert(expression is not null);

switch (valueKind)
{
case BindValueKind.CompoundAssignment:
if (!expression.IsRef)
{
Error(diagnostics, ErrorCode.ERR_RequiresRefReturningSwitchExpression, expression.Syntax);
return false;
}
return true;
case BindValueKind.RValue:
if (expression.IsRef)
{
Error(diagnostics, ErrorCode.ERR_UnusedSwitchExpressionRef, expression.Syntax);
return false;
}
return true;
}

return true;
}

private bool CheckPropertyValueKind(SyntaxNode node, BoundExpression expr, BindValueKind valueKind, bool checkingReceiver, BindingDiagnosticBag diagnostics)
{
// SPEC: If the left operand is a property or indexer access, the property or indexer must
Expand Down Expand Up @@ -3037,6 +3178,30 @@ internal uint GetRefEscape(BoundExpression expr, uint scopeOfTheContainingExpres
// otherwise it is an RValue
break;

case BoundKind.UnconvertedSwitchExpression:
throw ExceptionUtilities.UnexpectedValue(expr.Kind);

case BoundKind.ConvertedSwitchExpression:
var switchExpression = (BoundConvertedSwitchExpression)expr;

if (switchExpression.IsRef)
{
uint maxScope = uint.MinValue;
foreach (var arm in switchExpression.SwitchArms)
{
if (arm.Value is BoundConversion boundConversion)
{
Debug.Assert(boundConversion is { Operand: BoundThrowExpression });
continue;
}

maxScope = Math.Max(GetRefEscape(arm.Value, scopeOfTheContainingExpression), maxScope);
}
return maxScope;
}

break;

case BoundKind.FieldAccess:
return GetFieldRefEscape((BoundFieldAccess)expr, scopeOfTheContainingExpression);

Expand Down Expand Up @@ -3549,6 +3714,17 @@ internal bool CheckRefEscape(SyntaxNode node, BoundExpression expr, uint escapeF

case BoundKind.ThrowExpression:
return true;

case BoundKind.ConvertedSwitchExpression:
case BoundKind.UnconvertedSwitchExpression:
var switchExpression = (BoundSwitchExpression)expr;
foreach (var arm in switchExpression.SwitchArms)
{
bool canEscape = CheckRefEscape(node, arm.Value, escapeFrom, escapeTo, checkingReceiver: false, diagnostics);
if (!canEscape)
return false;
}
return true;
}

// At this point we should have covered all the possible cases for anything that is not a strict RValue.
Expand Down
26 changes: 16 additions & 10 deletions src/Compilers/CSharp/Portable/Binder/Binder_Conversions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -951,23 +951,29 @@ private BoundExpression ConvertSwitchExpression(BoundUnconvertedSwitchExpression
Debug.Assert(targetTyped || destination.IsErrorType() || destination.Equals(source.Type, TypeCompareKind.ConsiderEverything));
ImmutableArray<Conversion> underlyingConversions = conversion.UnderlyingConversions;
var builder = ArrayBuilder<BoundSwitchExpressionArm>.GetInstance(source.SwitchArms.Length);
bool allowConversion = !source.IsRef;
for (int i = 0, n = source.SwitchArms.Length; i < n; i++)
{
var oldCase = source.SwitchArms[i];
var oldValue = oldCase.Value;
var newValue =
targetTyped
? CreateConversion(oldValue.Syntax, oldValue, underlyingConversions[i], isCast: false, conversionGroupOpt: null, destination, diagnostics)
: GenerateConversionForAssignment(destination, oldValue, diagnostics);
var newCase = (oldValue == newValue) ? oldCase :
new BoundSwitchExpressionArm(oldCase.Syntax, oldCase.Locals, oldCase.Pattern, oldCase.WhenClause, newValue, oldCase.Label, oldCase.HasErrors);
builder.Add(newCase);
var oldArm = source.SwitchArms[i];
var oldValue = oldArm.Value;
var newValue = oldValue;

bool requiresConversion = oldValue.Type is null;
if (allowConversion || requiresConversion)
{
newValue = targetTyped
? CreateConversion(oldValue.Syntax, oldValue, underlyingConversions[i], isCast: false, conversionGroupOpt: null, destination, diagnostics)
: GenerateConversionForAssignment(destination, oldValue, diagnostics);
}
var newArm = (oldValue == newValue) ? oldArm :
new BoundSwitchExpressionArm(oldArm.Syntax, oldArm.Locals, oldArm.Pattern, oldArm.WhenClause, newValue, oldArm.Label, oldArm.RefKind, oldArm.HasErrors);
builder.Add(newArm);
}

var newSwitchArms = builder.ToImmutableAndFree();
return new BoundConvertedSwitchExpression(
source.Syntax, source.Type, targetTyped, source.Expression, newSwitchArms, source.ReachabilityDecisionDag,
source.DefaultLabel, source.ReportedNotExhaustive, destination, hasErrors || source.HasErrors).WithSuppression(source.IsSuppressed);
source.DefaultLabel, source.ReportedNotExhaustive, source.RefKind, destination, hasErrors || source.HasErrors).WithSuppression(source.IsSuppressed);
}

private BoundExpression CreateUserDefinedConversion(
Expand Down
14 changes: 14 additions & 0 deletions src/Compilers/CSharp/Portable/Binder/Binder_Operators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,20 @@ private BoundExpression BindCompoundAssignment(AssignmentExpressionSyntax node,
leftPlaceholder: null, leftConversion: null, finalPlaceholder: null, finalConversion: null, LookupResultKind.NotAVariable, CreateErrorType(), hasErrors: true);
}

if (left.Kind == BoundKind.UnconvertedSwitchExpression)
{
var switchExpression = (BoundUnconvertedSwitchExpression)left;
var switchExpressionDiagnostics = diagnostics;
if (!switchExpression.IsRef)
{
Error(diagnostics, ErrorCode.ERR_RequiresRefReturningSwitchExpression, node.OperatorToken);
// Ignore further binding errors, potentially including unavailable conversions
switchExpressionDiagnostics = BindingDiagnosticBag.Discarded;
}

left = this.ConvertSwitchExpression(switchExpression, destination: left.Type, null, switchExpressionDiagnostics);
}

// A compound operator, say, x |= y, is bound as x = (X)( ((T)x) | ((T)y) ). We must determine
// the binary operator kind, the type conversions from each side to the types expected by
// the operator, and the type conversion from the return type of the operand to the left hand side.
Expand Down
15 changes: 14 additions & 1 deletion src/Compilers/CSharp/Portable/Binder/Binder_Statements.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1941,7 +1941,7 @@ internal BoundExpression GenerateConversionForAssignment(TypeSymbol targetType,
}
else
{
return expression;
return ConvertIdentityRefExpression(expression, targetType, diagnostics);
}
}
else if (!conversion.IsValid ||
Expand All @@ -1965,6 +1965,19 @@ internal BoundExpression GenerateConversionForAssignment(TypeSymbol targetType,
return CreateConversion(expression.Syntax, expression, conversion, isCast: false, conversionGroupOpt: null, targetType, diagnostics);
}

private BoundExpression ConvertIdentityRefExpression(BoundExpression expression, TypeSymbol destination, BindingDiagnosticBag diagnostics)
{
switch (expression.Kind)
{
case BoundKind.UnconvertedSwitchExpression:
var switchExpression = (BoundUnconvertedSwitchExpression)expression;
return ConvertSwitchExpression(switchExpression, destination, null, diagnostics);

default:
return expression;
}
}

#nullable enable
private static Location GetAnonymousFunctionLocation(SyntaxNode node)
=> node switch
Expand Down
4 changes: 4 additions & 0 deletions src/Compilers/CSharp/Portable/Binder/RefSafetyAnalysis.cs
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,10 @@ private void AssertVisited(BoundExpression expr)
this.Visit(node.Expression);
using var _ = new PatternInput(this, GetValEscape(node.Expression, _localScopeDepth));
this.VisitList(node.SwitchArms);
if (node.IsRef)
{
this.ValidateRefSwitchExpression(node.Syntax, node.SwitchArms, this._diagnostics);
}
return null;
}

Expand Down
Loading