Implementing a Code Action using Roslyn

 By Brian Rasmussen

The Roslyn Services API makes it easy to implement extensions that detect and fix code issues directly in Visual Studio. The Roslyn Services API is available as part of the Roslyn CTP.

In this post we implement a Visual Studio extension that identifies calls to the extension method Count() on Enumerable, where the result is compared to greater than zero, e.g. someSequence.Count() > 0. The problem, with that code construct, is that Count() may have to enumerate the entire sequence to evaluate the result. A much better approach in this case is to call Enumerable.Any() instead.

To address this we implement a CodeIssueProvider, which detects the problem and a CodeAction, which replaces the condition with a call to Enumerable.Any() as appropriate. E.g. our code action will change someSequence.Count() > 0 to someSequence.Any().

There are a couple additional scenarios, we want to handle as well: First of all, the expression could be reversed and written like 0 < someSequence.Count(). The other special case is that the comparison could be >= 1 instead of > 0, which is essentially the same comparison as before. We want the extension to be able to handle both of these cases.

Obviously we do not want to change calls to methods called Count() unless they bind to the IEnumerable<T> extension method defined on Enumerable.

Getting Started

The Roslyn CTP ships with a number of templates designed to make it easy to get started using the Roslyn APIs. To get started we create a new project and select Code Issue from the Roslyn templates under the Visual C# section. Let’s call the project ReplaceCountWithAny.

The template generates a small, working code issue provider that will highlight words with the letter “a”. To see the sample in action let’s build and run the project created by the template. This launches a new instance of Visual Studio with the extension enabled. From the newly launched Visual Studio we create a console application and notice how the namespace and class keywords and so on are underlined by our extension.

While the sample may not be very useful as an extension to Visual Studio, it neatly sets up everything we need to get started implementing our own extension. All we have to do is replace the contents of the generated GetIssues method. Notice that there are three overloads for GetIssues. We’ll implement the overload that takes a CommonSyntaxNode as input. The remaining two overloads can be left as they are in this case.

The generated CodeIssueProvider class implements ICodeIssueProvider and is decorated with the ExportSyntaxNodeCodeIssueProvider attribute. This allows Visual Studio to import this type as an extension, which will handle the contract established by the ICodeIssueProvider interface.

Implementing GetIssues

Our GetIssues method will be invoked for every syntax node, so the first thing we need to do is filter out all the nodes we don’t want to handle. Since we are looking for expressions like someSequence.Count() > 0, we’re only interested in nodes of the type BinaryExpressionSyntax. We can instruct Visual Studio to only invoke our provider for nodes of specific type(s) by providing a list of types through the ExportSyntaxNodeCodeIssueProvider attribute. Let’s update the attribute as follows:

 [ExportSyntaxNodeCodeIssueProvider("ReplaceCountWithAny", 
 LanguageNames.CSharp, typeof(BinaryExpressionSyntax))]
 class CodeIssueProvider : ICodeIssueProvider ...

This allows us to safely cast the provided CommonSyntaxNode to BinaryExpressionSyntax in GetIssues.

To identify the cases we want to handle we have to figure out if one part of the expression is a call to Enumerable.Count() and the other part is a relevant comparison. We’ll abstract those checks into a couple helper methods, so with that our implementation of GetIssues looks as follows.

 public IEnumerable<CodeIssue> GetIssues(IDocument document, 
    CommonSyntaxNode node, CancellationToken cancellationToken)
{
    var binaryExpression = (BinaryExpressionSyntax)node;
    var left = binaryExpression.Left;
    var right = binaryExpression.Right;
    var kind = binaryExpression.Kind;
    if (IsCallToEnumerableCount(document, left, cancellationToken) && 
        IsRelevantRightSideComparison(document, right, kind, cancellationToken) ||
        IsCallToEnumerableCount(document, right, cancellationToken) && 
        IsRelevantLeftSideComparison(document, left, kind, cancellationToken))
    {
        yield return new CodeIssue(CodeIssue.Severity.Info, binaryExpression.Span,
            string.Format("Change {0} to use Any() instead of Count() to avoid " +
                          "possible enumeration of entire sequence.", 
                          binaryExpression));
    }
}
  

The instance of CodeIssue we return specifies a severity level, which can be Error, Warning or Info, a span, which is used to highlight the part of the source code the issue applies to, and a text describing the identified issue to the user.

Helper Methods

Let’s turn our attention to the helper methods used in GetIssues. IsCallToEnumerableCount returns true if the part of the expression we’re looking at is in fact a call to Count() on some sequence. Once again we’ll start by filtering the undesired expressions.

First of all the expression must be an invocation. If that’s the case, we’ll get the actual method call from the Expression property of the invocation. So if the expression we’re looking at is someSequence.Count() > 0 we now have the Count() part, but how do we figure out if this binds to the Enumerable type?

To answer questions like that we need to query the semantic model. Fortunately part of the input to GetIssues is an IDocument, which represents a single document in a project and solution. We can get the semantic model via the document and from there we can get the SymbolInfo we need.

With the SymbolInfo present we can check if our call binds to the desired method. Because Count() is an extension method we’ll need to handle it a bit differently. Recall that C# allows extension methods to be called as if they were part of the calling type. The semantic model represents this as a MethodSymbol with a ConstructedFrom property set to the original type. This could possibly be handled slightly better, so look out for changes in the API here.

All that remains is to figure out the declaring type for our constructed extension method. If that matches Enumerable we have found an invocation of Enumerable.Count().

The implementation looks like this: 

 private bool IsCallToEnumerableCount(IDocument document, 
    ExpressionSyntax expression, CancellationToken cancellationToken)
{
    var invocation = expression as InvocationExpressionSyntax;
    if (invocation == null)
    {
        return false;
    }
 
    var call = invocation.Expression as MemberAccessExpressionSyntax;
    if (call == null)
    {
        return false;
    }
 
    var semanticModel = document.GetSemanticModel(cancellationToken);
    var methodSymbol = semanticModel.GetSemanticInfo(call, cancellationToken).Symbol 
        as MethodSymbol;
     if (methodSymbol == null || 
        methodSymbol.Name != "Count" || 
        methodSymbol.ConstructedFrom == null)
    {
        return false;
    }
  
    var enumerable = semanticModel.Compilation.GetTypeByMetadataName(
        typeof(Enumerable).FullName);
 
    if (enumerable == null || 
        !methodSymbol.ConstructedFrom.ContainingType.Equals(enumerable))
    {
        return false;
    }
 
    return true;
} 

With that settled the next thing we need to look for is a relevant comparison on the other side of the binary expression and that’s the job of our IsRelevantRightSideComparison and IsRelevantLeftSideComparison helper methods.

Here are the implementations for those:

 private bool IsRelevantRightSideComparison(IDocument document, 
    ExpressionSyntax expression, SyntaxKind kind, 
    CancellationToken cancellationToken)
{
    var semanticInfo = document.GetSemanticModel(cancellationToken).
        GetSemanticInfo(expression);
 
    int? value;
    if (!semanticInfo.IsCompileTimeConstant || 
        (value = semanticInfo.ConstantValue as int?) == null)
    {
        return false;
    }
 
    if (kind == SyntaxKind.GreaterThanExpression && value == 0 ||
        kind == SyntaxKind.GreaterThanOrEqualExpression && value == 1)
    {
        return true;
    }
 
    return false;
}
 
 private bool IsRelevantLeftSideComparison(IDocument document, 
    ExpressionSyntax expression, SyntaxKind kind, 
    CancellationToken cancellationToken)
{
    var semanticInfo = document.GetSemanticModel(cancellationToken).
        GetSemanticInfo(expression);
 
    int? value;
    if (!semanticInfo.IsCompileTimeConstant ||
        (value = semanticInfo.ConstantValue as int?) == null)
    {
        return false;
    }
 
    if (kind == SyntaxKind.LessThanExpression && value == 0 ||
        kind == SyntaxKind.LessThanOrEqualExpression && value == 1)
    {
        return true;
    }
 
    return false;
}

 They are almost identical, but it is important that we get the both the comparison and the value correct so we don’t highlight something like Count() >= 0.

Testing the CodeIssueProvider

At this point our code issue provider can detect the issues we’re interested in handling. Compile and run the project to launch a new instance of Visual Studio with our extension enabled. Add some code and notice that calls to Enumerable.Count() are correctly underlined while method calls to other methods named Count() are not.

The next step is to provide a code action for our code issue.

CodeAction

To implement a code action we need a type which implements ICodeAction. ICodeAction is a simple interface that defines a description and an icon for the action and a single method called GetEdit, which returns an edit that will transform the current syntax tree. Let’s start by looking at the constructor for our CodeAction class.

 public CodeAction(ICodeActionEditFactory editFactory, 
    IDocument document, BinaryExpressionSyntax binaryExpression)
{
    this.editFactory = editFactory;
    this.document = document;
    this.binaryExpression = binaryExpression;
}

An instance of CodeAction will be created for every code issue identified, so for convenience sake we just pass the arguments we need to implement the change to the constructor itself. We need an ICodeActionEditFactory to create a transformation for our newly created syntax tree. As syntax trees are immutable in Roslyn, returning a new tree is the only way we can make any changes. Fortunately, Roslyn will reuse as much of the original tree as possible and thus avoid creating unnecessary nodes.  

Furthermore we need a document to let us access both the syntax tree and the project and solution for our source code and a reference to the syntax node we want to replace.

This brings us to the GetEdit method. This is where we create a transformation, which will replace the identified binary expression node with a newly created invocation node for the call to Any(). The creation of the new node is handled by a small helper method called GetNewNode. Both are listed below.

 public ICodeActionEdit GetEdit(CancellationToken cancellationToken)
{
    var syntaxTree = (SyntaxTree)document.GetSyntaxTree(cancellationToken);
    var newExpression = GetNewNode(binaryExpression).
        WithLeadingTrivia(binaryExpression.GetLeadingTrivia()).
        WithTrailingTrivia(binaryExpression.GetTrailingTrivia());
    var newRoot = syntaxTree.Root.ReplaceNode(binaryExpression, newExpression);
 
    return editFactory.CreateTreeTransformEdit(
        document.Project.Solution,
        syntaxTree,
        newRoot,
        cancellationToken: cancellationToken);
}
  
private ExpressionSyntax GetNewNode(BinaryExpressionSyntax node)
{
    var invocation = node.DescendentNodes().
        OfType<InvocationExpressionSyntax>().Single();
    var caller = invocation.DescendentNodes().
        OfType<MemberAccessExpressionSyntax>().Single();
    return invocation.Update(
        caller.Update(caller.Expression, 
        caller.OperatorToken, 
        Syntax.IdentifierName("Any")),
        invocation.ArgumentList);
}

The Roslyn syntax tree maintains full fidelity with the original source code, so each node in the tree may have both leading and trailing trivia representing white space and comments. I.e. we need to preserve the trivia from the original node to maintain comments and layout of the code when we exchange the nodes. We call the WithLeadingTrivia and WithTrailingTrivia extension methods to handle that.

Also, notice that GetNewNode preserves the argument list from Count(), so if the extension method was invoked with a lambda to count specific items in the sequence so will Any().

Wrapping Up

To enable our code action we need to update GetIssues in our CodeIssueProvider to return a CodeAction with each CodeIssue. Each code issue may provide a number of code actions to let the user pick between different resolutions. In this case we will just return a single code action as outlined above.

The updated part of GetIssues looks like this:

 yield return new CodeIssue(CodeIssue.Severity.Info, binaryExpression.Span,
    string.Format("Change {0} to use Any() instead of Count() to avoid " +
                  "possible enumeration of entire sequence.", binaryExpression),
    new CodeAction(editFactory, document, binaryExpression));
    

Rebuild and run the project to launch a new instance of Visual Studio with our extension loaded. Notice that the code issue now provides a drop down with the option to invoke our code action to fix the issue.

We have implemented an extension to Visual Studio that will help us improve our code.

About the author

Brian is a Senior SDET at Microsoft working on the C# and VB language services in Roslyn. Before joining Microsoft Brian was a Microsoft MVP for Visual C# for four years. Brian can be found on Twitter (@kodehoved).