diff --git a/analyzers/src/SonarAnalyzer.CSharp/Rules/RedundantCast.cs b/analyzers/src/SonarAnalyzer.CSharp/Rules/RedundantCast.cs index a23a168971f..9c24d53abc8 100644 --- a/analyzers/src/SonarAnalyzer.CSharp/Rules/RedundantCast.cs +++ b/analyzers/src/SonarAnalyzer.CSharp/Rules/RedundantCast.cs @@ -18,161 +18,137 @@ * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. */ -namespace SonarAnalyzer.Rules.CSharp -{ - [DiagnosticAnalyzer(LanguageNames.CSharp)] - public sealed class RedundantCast : SonarDiagnosticAnalyzer - { - internal const string DiagnosticId = "S1905"; - private const string MessageFormat = "Remove this unnecessary cast to '{0}'."; - - private static readonly DiagnosticDescriptor rule = - DescriptorFactory.Create(DiagnosticId, MessageFormat); - - public override ImmutableArray SupportedDiagnostics { get; } = ImmutableArray.Create(rule); +namespace SonarAnalyzer.Rules.CSharp; - private static readonly ISet CastIEnumerableMethods = new HashSet { "Cast", "OfType" }; - - protected override void Initialize(SonarAnalysisContext context) - { - context.RegisterNodeAction( - c => - { - var castExpression = (CastExpressionSyntax)c.Node; - CheckCastExpression(c, castExpression.Expression, castExpression.Type, castExpression.Type.GetLocation()); - }, - SyntaxKind.CastExpression); +[DiagnosticAnalyzer(LanguageNames.CSharp)] +public sealed class RedundantCast : SonarDiagnosticAnalyzer +{ + internal const string DiagnosticId = "S1905"; + private const string MessageFormat = "Remove this unnecessary cast to '{0}'."; - context.RegisterNodeAction( - c => - { - var castExpression = (BinaryExpressionSyntax)c.Node; - CheckCastExpression(c, castExpression.Left, castExpression.Right, - castExpression.OperatorToken.CreateLocation(castExpression.Right)); - }, - SyntaxKind.AsExpression); - - context.RegisterNodeAction( - CheckExtensionMethodInvocation, - SyntaxKind.InvocationExpression); - } + private static readonly DiagnosticDescriptor Rule = + DescriptorFactory.Create(DiagnosticId, MessageFormat); - private static void CheckCastExpression(SonarSyntaxNodeReportingContext context, ExpressionSyntax expression, ExpressionSyntax type, Location location) - { - if (expression.IsKind(SyntaxKindEx.DefaultLiteralExpression)) - { - return; - } + public override ImmutableArray SupportedDiagnostics { get; } = ImmutableArray.Create(Rule); - var expressionType = context.SemanticModel.GetTypeInfo(expression).Type; - if (expressionType == null) - { - return; - } + private static readonly ISet CastIEnumerableMethods = new HashSet { "Cast", "OfType" }; - var castType = context.SemanticModel.GetTypeInfo(type).Type; - if (castType == null) + protected override void Initialize(SonarAnalysisContext context) + { + context.RegisterNodeAction( + c => { - return; - } + var castExpression = (CastExpressionSyntax)c.Node; + CheckCastExpression(c, castExpression.Expression, castExpression.Type, castExpression.Type.GetLocation()); + }, + SyntaxKind.CastExpression); - if (expressionType.Equals(castType)) + context.RegisterNodeAction( + c => { - context.ReportIssue(Diagnostic.Create(rule, location, - castType.ToMinimalDisplayString(context.SemanticModel, expression.SpanStart))); - } - } + var castExpression = (BinaryExpressionSyntax)c.Node; + CheckCastExpression(c, castExpression.Left, castExpression.Right, + castExpression.OperatorToken.CreateLocation(castExpression.Right)); + }, + SyntaxKind.AsExpression); + + context.RegisterNodeAction( + CheckExtensionMethodInvocation, + SyntaxKind.InvocationExpression); + } - private static void CheckExtensionMethodInvocation(SonarSyntaxNodeReportingContext context) + private static void CheckCastExpression(SonarSyntaxNodeReportingContext context, ExpressionSyntax expression, ExpressionSyntax type, Location location) + { + if (!expression.IsKind(SyntaxKindEx.DefaultLiteralExpression) + && context.SemanticModel.GetTypeInfo(expression) is { Type: { } expressionType } expressionTypeInfo + && context.SemanticModel.GetTypeInfo(type) is { Type: { } castType } + && expressionType.Equals(castType) + && FlowStateEquals(expressionTypeInfo, type)) { - var invocation = (InvocationExpressionSyntax)context.Node; - if (GetEnumerableExtensionSymbol(invocation, context.SemanticModel) is { } methodSymbol) - { - var returnType = methodSymbol.ReturnType; - if (GetGenericTypeArgument(returnType) is { } castType) - { - if (methodSymbol.Name == "OfType" && CanHaveNullValue(castType)) - { - // OfType() filters 'null' values from enumerables - return; - } - - var elementType = GetElementType(invocation, methodSymbol, context.SemanticModel); - // Generic types {T} and {T?} are equal and there is no way to access NullableAnnotation field right now - // See https://github.com/SonarSource/sonar-dotnet/issues/3273 - if (elementType != null && elementType.Equals(castType) && string.Equals(elementType.ToString(), castType.ToString(), System.StringComparison.Ordinal)) - { - var methodCalledAsStatic = methodSymbol.MethodKind == MethodKind.Ordinary; - context.ReportIssue(Diagnostic.Create(rule, GetReportLocation(invocation, methodCalledAsStatic), - returnType.ToMinimalDisplayString(context.SemanticModel, invocation.SpanStart))); - } - } - } + ReportIssue(context, expression, castType, location); } + } - /// If the invocation one of the extensions, returns the method symbol. - private static IMethodSymbol GetEnumerableExtensionSymbol(InvocationExpressionSyntax invocation, SemanticModel semanticModel) => - invocation.GetMethodCallIdentifier() is { } methodName - && CastIEnumerableMethods.Contains(methodName.ValueText) - && semanticModel.GetSymbolInfo(invocation).Symbol is IMethodSymbol methodSymbol - && methodSymbol.IsExtensionOn(KnownType.System_Collections_IEnumerable) - ? methodSymbol - : null; - - private static ITypeSymbol GetGenericTypeArgument(ITypeSymbol type) => - type is INamedTypeSymbol returnType && returnType.Is(KnownType.System_Collections_Generic_IEnumerable_T) - ? returnType.TypeArguments.Single() - : null; - - private static bool CanHaveNullValue(ITypeSymbol type) => type.IsReferenceType || type.Name == "Nullable"; - - private static Location GetReportLocation(InvocationExpressionSyntax invocation, bool methodCalledAsStatic) + private static bool FlowStateEquals(TypeInfo expressionTypeInfo, ExpressionSyntax type) + { + var castingToNullable = type.IsKind(SyntaxKind.NullableType); + return expressionTypeInfo.Nullability().FlowState switch { - if (!(invocation.Expression is MemberAccessExpressionSyntax memberAccess)) - { - return invocation.Expression.GetLocation(); - } - - return methodCalledAsStatic - ? memberAccess.GetLocation() - : memberAccess.OperatorToken.CreateLocation(invocation); - } + NullableFlowState.None => true, + NullableFlowState.MaybeNull => castingToNullable, + NullableFlowState.NotNull => !castingToNullable, + _ => true, + }; + } - private static ITypeSymbol GetElementType(InvocationExpressionSyntax invocation, IMethodSymbol methodSymbol, - SemanticModel semanticModel) + private static void CheckExtensionMethodInvocation(SonarSyntaxNodeReportingContext context) + { + var invocation = (InvocationExpressionSyntax)context.Node; + if (GetEnumerableExtensionSymbol(invocation, context.SemanticModel) is { } methodSymbol) { - ExpressionSyntax collection; - if (methodSymbol.MethodKind == MethodKind.Ordinary) + var returnType = methodSymbol.ReturnType; + if (GetGenericTypeArgument(returnType) is { } castType) { - if (!invocation.ArgumentList.Arguments.Any()) + if (methodSymbol.Name == "OfType" && CanHaveNullValue(castType)) { - return null; + // OfType() filters 'null' values from enumerables + return; } - collection = invocation.ArgumentList.Arguments.First().Expression; - } - else - { - if (!(invocation.Expression is MemberAccessExpressionSyntax memberAccess)) + + var elementType = GetElementType(invocation, methodSymbol, context.SemanticModel); + // Generic types {T} and {T?} are equal and there is no way to access NullableAnnotation field right now + // See https://github.com/SonarSource/sonar-dotnet/issues/3273 + if (elementType != null && elementType.Equals(castType) && string.Equals(elementType.ToString(), castType.ToString(), StringComparison.Ordinal)) { - return null; + var methodCalledAsStatic = methodSymbol.MethodKind == MethodKind.Ordinary; + ReportIssue(context, invocation, returnType, GetReportLocation(invocation, methodCalledAsStatic)); } - collection = memberAccess.Expression; } + } + } - var typeInfo = semanticModel.GetTypeInfo(collection); - if (typeInfo.Type is INamedTypeSymbol collectionType && - collectionType.TypeArguments.Length == 1) - { - return collectionType.TypeArguments.First(); - } + private static void ReportIssue(SonarSyntaxNodeReportingContext context, ExpressionSyntax expression, ITypeSymbol castType, Location location) => + context.ReportIssue(Diagnostic.Create(Rule, location, castType.ToMinimalDisplayString(context.SemanticModel, expression.SpanStart))); - if (typeInfo.Type is IArrayTypeSymbol arrayType && - arrayType.Rank == 1) // casting is necessary for multidimensional arrays - { - return arrayType.ElementType; - } + /// If the invocation one of the extensions, returns the method symbol. + private static IMethodSymbol GetEnumerableExtensionSymbol(InvocationExpressionSyntax invocation, SemanticModel semanticModel) => + invocation.GetMethodCallIdentifier() is { } methodName + && CastIEnumerableMethods.Contains(methodName.ValueText) + && semanticModel.GetSymbolInfo(invocation).Symbol is IMethodSymbol methodSymbol + && methodSymbol.IsExtensionOn(KnownType.System_Collections_IEnumerable) + ? methodSymbol + : null; - return null; - } + private static ITypeSymbol GetGenericTypeArgument(ITypeSymbol type) => + type is INamedTypeSymbol returnType && returnType.Is(KnownType.System_Collections_Generic_IEnumerable_T) + ? returnType.TypeArguments.Single() + : null; + + private static bool CanHaveNullValue(ITypeSymbol type) => + type.IsReferenceType || type.Is(KnownType.System_Nullable_T); + + private static Location GetReportLocation(InvocationExpressionSyntax invocation, bool methodCalledAsStatic) => + methodCalledAsStatic is false && invocation.Expression is MemberAccessExpressionSyntax memberAccess + ? memberAccess.OperatorToken.CreateLocation(invocation) + : invocation.Expression.GetLocation(); + + private static ITypeSymbol GetElementType(InvocationExpressionSyntax invocation, IMethodSymbol methodSymbol, SemanticModel semanticModel) + { + return semanticModel.GetTypeInfo(CollectionExpression(invocation, methodSymbol)).Type switch + { + INamedTypeSymbol { TypeArguments: { Length: 1 } typeArguments } => typeArguments.First(), + IArrayTypeSymbol { Rank: 1 } arrayType => arrayType.ElementType, // casting is necessary for multidimensional arrays + _ => null + }; + + static ExpressionSyntax CollectionExpression(InvocationExpressionSyntax invocation, IMethodSymbol methodSymbol) => + methodSymbol.MethodKind is MethodKind.ReducedExtension + ? ReducedExtensionExpression(invocation) + : invocation.ArgumentList.Arguments.FirstOrDefault()?.Expression; + + static ExpressionSyntax ReducedExtensionExpression(InvocationExpressionSyntax invocation) => + invocation.Expression is MemberAccessExpressionSyntax { Expression: { } memberAccessExpression } + ? memberAccessExpression + : invocation.GetParentConditionalAccessExpression()?.Expression; } } diff --git a/analyzers/tests/SonarAnalyzer.UnitTest/Rules/RedundantCastTest.cs b/analyzers/tests/SonarAnalyzer.UnitTest/Rules/RedundantCastTest.cs index a183818ceb7..7e707e05d17 100644 --- a/analyzers/tests/SonarAnalyzer.UnitTest/Rules/RedundantCastTest.cs +++ b/analyzers/tests/SonarAnalyzer.UnitTest/Rules/RedundantCastTest.cs @@ -21,45 +21,103 @@ using Microsoft.CodeAnalysis.CSharp; using SonarAnalyzer.Rules.CSharp; -namespace SonarAnalyzer.UnitTest.Rules +namespace SonarAnalyzer.UnitTest.Rules; + +[TestClass] +public class RedundantCastTest { - [TestClass] - public class RedundantCastTest - { - private readonly VerifierBuilder builder = new VerifierBuilder(); + private readonly VerifierBuilder builder = new VerifierBuilder(); - [TestMethod] - public void RedundantCast() => - builder.AddPaths("RedundantCast.cs").Verify(); + [TestMethod] + public void RedundantCast() => + builder.AddPaths("RedundantCast.cs").Verify(); - [TestMethod] - public void RedundantCast_CSharp8() => - builder.AddPaths("RedundantCast.CSharp8.cs").WithOptions(ParseOptionsHelper.FromCSharp8).Verify(); + [TestMethod] + public void RedundantCast_CSharp8() => + builder.AddPaths("RedundantCast.CSharp8.cs").WithOptions(ParseOptionsHelper.FromCSharp8).Verify(); #if NET - [TestMethod] - public void RedundantCast_CSharp9() => - builder.AddPaths("RedundantCast.CSharp9.cs").WithOptions(ParseOptionsHelper.FromCSharp9).Verify(); + [TestMethod] + public void RedundantCast_CSharp9() => + builder.AddPaths("RedundantCast.CSharp9.cs").WithOptions(ParseOptionsHelper.FromCSharp9).Verify(); #endif - [TestMethod] - public void RedundantCast_CodeFix() => - builder.AddPaths("RedundantCast.cs").WithCodeFix().WithCodeFixedPaths("RedundantCast.Fixed.cs").VerifyCodeFix(); + [TestMethod] + public void RedundantCast_CodeFix() => + builder.AddPaths("RedundantCast.cs").WithCodeFix().WithCodeFixedPaths("RedundantCast.Fixed.cs").VerifyCodeFix(); - [TestMethod] - public void RedundantCast_DefaultLiteral() => - builder.AddSnippet(@" -using System; -public static class MyClass -{ - public static void RunAction(Action action) + [TestMethod] + public void RedundantCast_DefaultLiteral() => + builder.AddSnippet(""" + using System; + public static class MyClass + { + public static void RunAction(Action action) + { + bool myBool = (bool)default; // FN - the cast is unneeded + RunFunc(() => { action(); return default; }, (bool)default); // should not raise because of the generic the cast is mandatory + RunFunc(() => { action(); return default; }, (bool)default); // FN - the cast is unneeded + } + + public static T RunFunc(Func func, T returnValue = default) => returnValue; + } + """).WithLanguageVersion(LanguageVersion.CSharp7_1).Verify(); + + [TestMethod] + [DynamicData(nameof(NullableTestDataWithFlowState))] + public void RedundantCast_NullableEnabled(string snippet, bool compliant) + => VerifyNullableTests(snippet, "enable", compliant); + + [TestMethod] + [DynamicData(nameof(NullableTestDataWithFlowState))] + public void RedundantCast_NullableWarnings(string snippet, bool compliant) + => VerifyNullableTests(snippet, "enable warnings", compliant); + + [TestMethod] + [DynamicData(nameof(NullableTestDataWithoutFlowState))] + public void RedundantCast_NullableDisabled(string snippet, bool compliant) + => VerifyNullableTests(snippet, "disable", compliant); + + [TestMethod] + [DynamicData(nameof(NullableTestDataWithoutFlowState))] + public void RedundantCast_NullableAnnotations(string snippet, bool compliant) + => VerifyNullableTests(snippet, "enable annotations", compliant); + + private void VerifyNullableTests(string snippet, string nullableContext, bool compliant) { - bool myBool = (bool)default; // FN - the cast is unneeded - RunFunc(() => { action(); return default; }, (bool)default); // should not raise because of the generic the cast is mandatory - RunFunc(() => { action(); return default; }, (bool)default); // FN - the cast is unneeded + var code = $$""" + #nullable {{nullableContext}} + void Test(string nonNullable, string? nullable) + { + {{snippet}} // {{compliant switch { true => "Compliant", false => "Noncompliant" }}} + } + """; + builder.AddSnippet(code).WithTopLevelStatements().Verify(); } - public static T RunFunc(Func func, T returnValue = default) => returnValue; -}").WithLanguageVersion(LanguageVersion.CSharp7_1).Verify(); - } + private static IEnumerable<(string Snippet, bool CompliantWithFlowState, bool CompliantWithoutFlowState)> NullableTestData => new[] + { + ("""_ = (string)"Test";""", false, false), + ("""_ = (string?)"Test";""", true, false), + ("""_ = (string)null;""", true, true), + ("""_ = (string?)null;""", true, true), + ("""_ = (string)nullable;""", true, false), + ("""_ = (string?)nullable;""", false, false), + ("""_ = (string)nonNullable;""", false, false), + ("""_ = (string?)nonNullable;""", true, false), + ("""_ = nullable as string;""", true, false), + ("""_ = nonNullable as string;""", false, false), + ("""if (nullable != null) _ = (string)nullable;""", false, false), + ("""if (nullable != null) _ = (string?)nullable;""", true, false), + ("""if (nullable != null) _ = nullable as string;""", false, false), + ("""if (nonNullable == null) _ = (string)nonNullable;""", true, false), + ("""if (nonNullable == null) _ = (string?)nonNullable;""", false, false), + ("""if (nonNullable == null) _ = nonNullable as string;""", true, false), + }; + + private static IEnumerable NullableTestDataWithFlowState => + NullableTestData.Select(x => new object[] { x.Snippet, x.CompliantWithFlowState }); + + private static IEnumerable NullableTestDataWithoutFlowState => + NullableTestData.Select(x => new object[] { x.Snippet, x.CompliantWithoutFlowState }); } diff --git a/analyzers/tests/SonarAnalyzer.UnitTest/TestCases/RedundantCast.CSharp8.cs b/analyzers/tests/SonarAnalyzer.UnitTest/TestCases/RedundantCast.CSharp8.cs index adab201f62f..1e623abffc3 100644 --- a/analyzers/tests/SonarAnalyzer.UnitTest/TestCases/RedundantCast.CSharp8.cs +++ b/analyzers/tests/SonarAnalyzer.UnitTest/TestCases/RedundantCast.CSharp8.cs @@ -2,15 +2,99 @@ using System.Collections.Generic; using System.Linq; +#nullable enable namespace Tests.Diagnostics { + public class InvocationTests + { + public static void Invocations() + { + var ints = new int[1]; + var objects= new object[1]; + var moreInts = new int[1][]; + var moreObjects = new object[1][]; + ints?.Cast(); // Noncompliant + objects?.Cast(); // Compliant + moreInts[0].Cast(); // Noncompliant + moreObjects[0].Cast(); // Compliant + moreInts[0]?.Cast(); // Noncompliant + moreObjects[0]?.Cast(); // Compliant + GetInts().Cast(); // Noncompliant + GetObjects().Cast(); // Compliant + GetInts()?.Cast(); // Noncompliant + GetObjects()?.Cast(); // Compliant + Enumerable.Cast(); // Error - overload resolution failure + } + + public static int[] GetInts() => null; + public static object[] GetObjects() => null; + } + // https://github.com/SonarSource/sonar-dotnet/issues/3273 public class CastOnNullable { - public static IEnumerable UsefulCast() + public static IEnumerable Array() { var nullableStrings = new string?[] { "one", "two", null, "three" }; return nullableStrings.OfType(); // Compliant - filters out the null } - } + + public void Tuple() + { + _ = (a: (string?)"", b: ""); // Compliant + } + + public void ValueTypes(int nonNullable, int? nullable) + { + _ = (int?)nonNullable; // Compliant + _ = (int?)nullable; // Noncompliant + _ = (int)nonNullable; // Noncompliant + _ = (int)nullable; // Compliant + } + } + + // https://github.com/SonarSource/sonar-dotnet/issues/6438 + public class AnonTypes + { + public void Simple(string nonNullable, string? nullable) + { + _ = new { X = (string?)nonNullable }; // Compliant + _ = new { X = (string?)nullable }; // Noncompliant + _ = new { X = (string)nonNullable }; // Noncompliant + _ = new { X = (string)nullable }; // Compliant + } + + public void Array(string nonNullable, string? nullable) + { + _ = new[] { new { X = (string?)nonNullable }, new { X = (string?)null } }; // Compliant + _ = new[] { new { X = (string?)nullable }, new { X = (string?)null } }; // Noncompliant + _ = new[] { new { X = (string?)nonNullable } }; // Compliant + _ = new[] { new { X = (string?)nullable } }; // Noncompliant + _ = new[] { new HoldsObject(new { X = (string?)nonNullable }) }; // Compliant + _ = new[] { new HoldsObject(new { X = (string?)nullable }) }; // Noncompliant + } + + public void SwitchExpression(string nonNullable, string? nullable) + { + _ = true switch + { + true => new { X = (string?)nonNullable }, // Compliant + false => new { X = (string?)null } // Compliant + }; + _ = true switch + { + true => new { X = (string?)nullable }, // Noncompliant + false => new { X = (string?)null } // Compliant + }; + } + } + + internal class HoldsObject + { + object O { get; } + public HoldsObject(object o) + { + O = o; + } + } }