Skip to content

Commit a07a1bd

Browse files
authored
Fixes and improvements to StartsWith/EndsWith/Contains (dotnet#31482)
Closes dotnet#30493 Closes dotnet#11881 Closes dotnet#26735
1 parent 3cf064e commit a07a1bd

File tree

69 files changed

+1351
-905
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+1351
-905
lines changed

src/EFCore.Relational/Query/QuerySqlGenerator.cs

+1
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,7 @@ protected virtual void GenerateLike(LikeExpression likeExpression, bool negated)
736736
}
737737

738738
_relationalCommandBuilder.Append(" LIKE ");
739+
739740
Visit(likeExpression.Pattern);
740741

741742
if (likeExpression.EscapeChar != null)

src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1029,7 +1029,7 @@ protected override Expression VisitNewArray(NewArrayExpression newArrayExpressio
10291029
/// <inheritdoc />
10301030
protected override Expression VisitParameter(ParameterExpression parameterExpression)
10311031
=> parameterExpression.Name?.StartsWith(QueryCompilationContext.QueryParameterPrefix, StringComparison.Ordinal) == true
1032-
? new SqlParameterExpression(parameterExpression, null)
1032+
? new SqlParameterExpression(parameterExpression.Name, parameterExpression.Type, null)
10331033
: throw new InvalidOperationException(CoreStrings.TranslationFailed(parameterExpression.Print()));
10341034

10351035
/// <inheritdoc />

src/EFCore.Relational/Query/SqlExpressions/SqlParameterExpression.cs

+17-19
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,30 @@ namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions;
66
/// <summary>
77
/// An expression that represents a parameter in a SQL tree.
88
/// </summary>
9-
/// <remarks>
10-
/// This is a simple wrapper around a <see cref="ParameterExpression" /> in the SQL tree.
11-
/// Instances of this type cannot be constructed by application or database provider code. If this is a problem for your
12-
/// application or provider, then please file an issue at
13-
/// <see href="https://github.com/dotnet/efcore">github.com/dotnet/efcore</see>.
14-
/// </remarks>
159
public sealed class SqlParameterExpression : SqlExpression
1610
{
17-
private readonly ParameterExpression _parameterExpression;
18-
private readonly string _name;
19-
20-
internal SqlParameterExpression(ParameterExpression parameterExpression, RelationalTypeMapping? typeMapping)
21-
: base(parameterExpression.Type.UnwrapNullableType(), typeMapping)
11+
/// <summary>
12+
/// Creates a new instance of the <see cref="SqlParameterExpression" /> class.
13+
/// </summary>
14+
/// <param name="name">The parameter name.</param>
15+
/// <param name="type">The <see cref="Type" /> of the expression.</param>
16+
/// <param name="typeMapping">The <see cref="RelationalTypeMapping" /> associated with the expression.</param>
17+
public SqlParameterExpression(string name, Type type, RelationalTypeMapping? typeMapping)
18+
: this(name, type.UnwrapNullableType(), type.IsNullableType(), typeMapping)
2219
{
23-
Check.DebugAssert(parameterExpression.Name != null, "Parameter must have name.");
20+
}
2421

25-
_parameterExpression = parameterExpression;
26-
_name = parameterExpression.Name;
27-
IsNullable = parameterExpression.Type.IsNullableType();
22+
private SqlParameterExpression(string name, Type type, bool nullable, RelationalTypeMapping? typeMapping)
23+
: base(type, typeMapping)
24+
{
25+
Name = name;
26+
IsNullable = nullable;
2827
}
2928

3029
/// <summary>
3130
/// The name of the parameter.
3231
/// </summary>
33-
public string Name
34-
=> _name;
32+
public string Name { get; }
3533

3634
/// <summary>
3735
/// The bool value indicating if this parameter can have null values.
@@ -44,15 +42,15 @@ public string Name
4442
/// <param name="typeMapping">A relational type mapping to apply.</param>
4543
/// <returns>A new expression which has supplied type mapping.</returns>
4644
public SqlExpression ApplyTypeMapping(RelationalTypeMapping? typeMapping)
47-
=> new SqlParameterExpression(_parameterExpression, typeMapping);
45+
=> new SqlParameterExpression(Name, Type, IsNullable, typeMapping);
4846

4947
/// <inheritdoc />
5048
protected override Expression VisitChildren(ExpressionVisitor visitor)
5149
=> this;
5250

5351
/// <inheritdoc />
5452
protected override void Print(ExpressionPrinter expressionPrinter)
55-
=> expressionPrinter.Append("@" + _parameterExpression.Name);
53+
=> expressionPrinter.Append("@" + Name);
5654

5755
/// <inheritdoc />
5856
public override bool Equals(object? obj)

src/EFCore.Relational/Query/SqlNullabilityProcessor.cs

+96-28
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ protected virtual TableExpressionBase Visit(TableExpressionBase tableExpressionB
161161
var newTable = Visit(innerJoinExpression.Table);
162162
var newJoinPredicate = ProcessJoinPredicate(innerJoinExpression.JoinPredicate);
163163

164-
return TryGetBoolConstantValue(newJoinPredicate) == true
164+
return IsTrue(newJoinPredicate)
165165
? new CrossJoinExpression(newTable)
166166
: innerJoinExpression.Update(newTable, newJoinPredicate);
167167
}
@@ -301,7 +301,7 @@ protected virtual SelectExpression Visit(SelectExpression selectExpression)
301301
var predicate = Visit(selectExpression.Predicate, allowOptimizedExpansion: true, out _);
302302
changed |= predicate != selectExpression.Predicate;
303303

304-
if (TryGetBoolConstantValue(predicate) == true)
304+
if (IsTrue(predicate))
305305
{
306306
predicate = null;
307307
changed = true;
@@ -333,7 +333,7 @@ protected virtual SelectExpression Visit(SelectExpression selectExpression)
333333
var having = Visit(selectExpression.Having, allowOptimizedExpansion: true, out _);
334334
changed |= having != selectExpression.Having;
335335

336-
if (TryGetBoolConstantValue(having) == true)
336+
if (IsTrue(having))
337337
{
338338
having = null;
339339
changed = true;
@@ -519,20 +519,17 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
519519
var test = Visit(
520520
whenClause.Test, allowOptimizedExpansion: testIsCondition, preserveColumnNullabilityInformation: true, out _);
521521

522-
if (TryGetBoolConstantValue(test) is bool testConstantBool)
522+
if (IsTrue(test))
523523
{
524-
if (testConstantBool)
525-
{
526-
testEvaluatesToTrue = true;
527-
}
528-
else
529-
{
530-
// if test evaluates to 'false' we can remove the WhenClause
531-
RestoreNonNullableColumnsList(currentNonNullableColumnsCount);
532-
RestoreNullValueColumnsList(currentNullValueColumnsCount);
524+
testEvaluatesToTrue = true;
525+
}
526+
else if (IsFalse(test))
527+
{
528+
// if test evaluates to 'false' we can remove the WhenClause
529+
RestoreNonNullableColumnsList(currentNonNullableColumnsCount);
530+
RestoreNullValueColumnsList(currentNullValueColumnsCount);
533531

534-
continue;
535-
}
532+
continue;
536533
}
537534

538535
var newResult = Visit(whenClause.Result, out var resultNullable);
@@ -570,7 +567,7 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
570567
// if there is only one When clause and it's test evaluates to 'true' AND there is no else block, simply return the result
571568
return elseResult == null
572569
&& whenClauses.Count == 1
573-
&& TryGetBoolConstantValue(whenClauses[0].Test) == true
570+
&& IsTrue(whenClauses[0].Test)
574571
? whenClauses[0].Result
575572
: caseExpression.Update(operand, whenClauses, elseResult);
576573
}
@@ -635,7 +632,7 @@ protected virtual SqlExpression VisitExists(
635632

636633
// if subquery has predicate which evaluates to false, we can simply return false
637634
// if the exists is negated we need to return true instead
638-
return TryGetBoolConstantValue(subquery.Predicate) == false
635+
return IsFalse(subquery.Predicate)
639636
? _sqlExpressionFactory.Constant(false, existsExpression.TypeMapping)
640637
: existsExpression.Update(subquery);
641638
}
@@ -658,7 +655,7 @@ protected virtual SqlExpression VisitIn(InExpression inExpression, bool allowOpt
658655
var subquery = Visit(inExpression.Subquery);
659656

660657
// a IN (SELECT * FROM table WHERE false) => false
661-
if (TryGetBoolConstantValue(subquery.Predicate) == false)
658+
if (IsFalse(subquery.Predicate))
662659
{
663660
nullable = false;
664661

@@ -967,9 +964,64 @@ protected virtual SqlExpression VisitLike(LikeExpression likeExpression, bool al
967964
var pattern = Visit(likeExpression.Pattern, out var patternNullable);
968965
var escapeChar = Visit(likeExpression.EscapeChar, out var escapeCharNullable);
969966

970-
nullable = matchNullable || patternNullable || escapeCharNullable;
967+
SqlExpression result = likeExpression.Update(match, pattern, escapeChar);
968+
969+
if (UseRelationalNulls)
970+
{
971+
nullable = matchNullable || patternNullable || escapeCharNullable;
972+
973+
return result;
974+
}
975+
976+
nullable = false;
977+
978+
// The null semantics behavior we implement for LIKE is that it only returns true when both sides are non-null and match; any other
979+
// input returns false:
980+
// foo LIKE f% -> true
981+
// foo LIKE null -> false
982+
// null LIKE f% -> false
983+
// null LIKE null -> false
984+
985+
if (IsNull(match) || IsNull(pattern) || IsNull(escapeChar))
986+
{
987+
return _sqlExpressionFactory.Constant(false, likeExpression.TypeMapping);
988+
}
989+
990+
// A constant match-all pattern (%) returns true for all cases, except where the match string is null:
991+
// nullable_foo LIKE % -> foo IS NOT NULL
992+
// non_nullable_foo LIKE % -> true
993+
if (pattern is SqlConstantExpression { Value: "%" })
994+
{
995+
return matchNullable
996+
? _sqlExpressionFactory.IsNotNull(match)
997+
: _sqlExpressionFactory.Constant(true, likeExpression.TypeMapping);
998+
}
971999

972-
return likeExpression.Update(match, pattern, escapeChar);
1000+
if (!allowOptimizedExpansion)
1001+
{
1002+
if (matchNullable)
1003+
{
1004+
result = _sqlExpressionFactory.AndAlso(result, GenerateNotNullCheck(match));
1005+
}
1006+
1007+
if (patternNullable)
1008+
{
1009+
result = _sqlExpressionFactory.AndAlso(result, GenerateNotNullCheck(pattern));
1010+
}
1011+
1012+
if (escapeChar is not null && escapeCharNullable)
1013+
{
1014+
result = _sqlExpressionFactory.AndAlso(result, GenerateNotNullCheck(escapeChar));
1015+
}
1016+
}
1017+
1018+
return result;
1019+
1020+
SqlExpression GenerateNotNullCheck(SqlExpression operand)
1021+
=> OptimizeNonNullableNotExpression(
1022+
_sqlExpressionFactory.Not(
1023+
ProcessNullNotNull(
1024+
_sqlExpressionFactory.IsNull(operand), operandNullable: true)));
9731025
}
9741026

9751027
/// <summary>
@@ -1395,8 +1447,28 @@ protected virtual SqlExpression VisitJsonScalar(
13951447
/// </summary>
13961448
protected virtual bool PreferExistsToComplexIn => false;
13971449

1398-
private static bool? TryGetBoolConstantValue(SqlExpression? expression)
1399-
=> expression is SqlConstantExpression { Value: bool boolValue } ? boolValue : null;
1450+
// Note that we can check parameter values for null since we cache by the parameter nullability; but we cannot do the same for bool.
1451+
private bool IsNull(SqlExpression? expression)
1452+
=> expression is SqlConstantExpression { Value: null }
1453+
|| expression is SqlParameterExpression { Name: string parameterName } && ParameterValues[parameterName] is null;
1454+
1455+
private bool IsTrue(SqlExpression? expression)
1456+
=> expression is SqlConstantExpression { Value: true };
1457+
1458+
private bool IsFalse(SqlExpression? expression)
1459+
=> expression is SqlConstantExpression { Value: false };
1460+
1461+
private bool TryGetBool(SqlExpression? expression, out bool value)
1462+
{
1463+
if (expression is SqlConstantExpression { Value: bool b })
1464+
{
1465+
value = b;
1466+
return true;
1467+
}
1468+
1469+
value = false;
1470+
return false;
1471+
}
14001472

14011473
private void RestoreNonNullableColumnsList(int counter)
14021474
{
@@ -1486,7 +1558,7 @@ private SqlExpression OptimizeComparison(
14861558
return result;
14871559
}
14881560

1489-
if (TryGetBoolConstantValue(right) is bool rightBoolValue
1561+
if (TryGetBool(right, out var rightBoolValue)
14901562
&& !leftNullable
14911563
&& left.TypeMapping!.Converter == null)
14921564
{
@@ -1502,7 +1574,7 @@ private SqlExpression OptimizeComparison(
15021574
: left;
15031575
}
15041576

1505-
if (TryGetBoolConstantValue(left) is bool leftBoolValue
1577+
if (TryGetBool(left, out var leftBoolValue)
15061578
&& !rightNullable
15071579
&& right.TypeMapping!.Converter == null)
15081580
{
@@ -2069,10 +2141,6 @@ private SqlExpression ProcessNullNotNull(SqlUnaryExpression sqlUnaryExpression,
20692141
private static bool IsLogicalNot(SqlUnaryExpression? sqlUnaryExpression)
20702142
=> sqlUnaryExpression is { OperatorType: ExpressionType.Not } && sqlUnaryExpression.Type == typeof(bool);
20712143

2072-
private bool IsNull(SqlExpression expression)
2073-
=> expression is SqlConstantExpression { Value: null }
2074-
|| expression is SqlParameterExpression { Name: string parameterName } && ParameterValues[parameterName] is null;
2075-
20762144
// ?a == ?b -> [(a == b) && (a != null && b != null)] || (a == null && b == null))
20772145
//
20782146
// a | b | F1 = a == b | F2 = (a != null && b != null) | F3 = F1 && F2 |

0 commit comments

Comments
 (0)