Dealing with Linq’s Immutable Expression Trees

 

Jomo Fisher --I recently got a question via my blog that dovetailed nicely with something I’ve been working on:

I know that expression trees are (or at least appear to be) immutable - which requires that you rewrite the entire tree if you want a tree that is similar, but not the same as, an existing tree.

My questions is really whether there is anything in the near term pipeline from Microsoft that might address this gap.

Linq’s expression trees are indeed immutable. I’ve worked with mutable trees—LINQ to SQL has an internal tree structure that’s mutable—and immutable trees like expression trees from the System.Linq.Expression namespace. I’ve really come to believe that, maybe counter-intuitively, immutable trees are far easier to work with. See this blog article for a discussion of the sorts of problems that can arise with mutable classes.

The problems are actually magnified when a tree structure is mutable because they’re intrinsically more complex and because it’s easy to unintentionally turn the tree into a graph. This latter problem is relieved in immutable trees because a node must exist before it can be assigned to be the child of another node.

Another advantage of immutable trees is that you can effectively duplicate them just by reference assignment. Here’s what I mean:

Expression<Func<int, int>> f = (n => 1 - n);

Expression<Func<int, int>> g = f;

f = MyRewriteFunction(f);

Here, the variable ‘g’ keeps the original expression tree value while ‘f’ is rewritten with a new value. Notice that this is a very fast reference copy. Without the immutability guarantee you would have to physically copy the tree to get the same guarantee.

Changing a Mutable Tree without Complete Rewrite

This original concern, though, is about the need to completely rewrite the tree in order to change it. It turns out, though, that you only need to rewrite the slice of the tree from the point of the change along with parents all the way to the root expression. While you do need to recreate the parents you can reuse the unchanged siblings. This is actually a very powerful concept:

You can create a forest of similar expression trees in which only the differences between the trees occupy extra memory.

For example, if you have a system that rewrites an expression tree in many phases then you can keep all of the intermediate rewrite steps with little effort and without consuming unnecessary memory.

Another example might be a Genetic Programming example in which the mutations are kept in a forest of expression trees.

Yet another example might be a theorem prover in which you need to search for a particular tree but you need to be able to backtrack if you reach a dead end.

I should back this up with an example. My example is based on the chain-of-responsibility code that I posted here. You’ll need that code to make this work. What we need is a visitor over Linq Expression trees. This visitor needs to recognize when a child node has changed and rewrite parent nodes all the way to the root.

 

(Note: This is is post-Beta1 code and won't compile with Beta1 and earlier of Orcas. At the bottom, I've attached the equivalent Beta1 ExprOp class and sample code).

public static class ExprOp {

    static public Func<Expression, Expression> Visit = FuncOp.Create<Expression, Expression>(

        (self, expr) => {

            switch (expr.NodeType) {

                case ExpressionType.Coalesce:

                    var c = (BinaryExpression)expr;

                    var left = self(c.Left);

                    var right = self(c.Right);

                    var conv = self(c.Conversion);

                    return (left == c.Left && right == c.Right && conv == c.Conversion) ? expr : Expression.Coalesce(left, right, (LambdaExpression)conv);

                case ExpressionType.TypeIs:

                    var tbe = (TypeBinaryExpression)expr;

                    var tbex = self(tbe.Expression);

                    return (tbe.Expression == tbex) ? expr : Expression.TypeIs(tbex, tbe.TypeOperand);

                case ExpressionType.Conditional:

                    var ce = (ConditionalExpression)expr;

                    var t = self(ce.Test);

                    var it = self(ce.IfTrue);

                    var @if = self(ce.IfFalse);

                    return (t == ce.Test && it == ce.IfTrue && @if == ce.IfFalse) ? expr : Expression.Condition(t, it, @if);

                case ExpressionType.MemberAccess:

                    var ma = (MemberExpression)expr;

                    var maex = self(ma.Expression);

                    return (maex == ma.Expression) ? expr : Expression.MakeMemberAccess(maex, ma.Member);

                case ExpressionType.Call:

                    var mce = (MethodCallExpression)expr;

                    var o = self(mce.Object);

                    var ca = self.VisitExpressionList(mce.Arguments);

                    return (o == mce.Object && ca == mce.Arguments) ? expr : Expression.Call(o, mce.Method, ca);

                case ExpressionType.Lambda:

                    var le = (LambdaExpression)expr;

                    var b = self(le.Body);

                    return (b == le.Body) ? expr : Expression.Lambda(le.Type, b, le.Parameters);

                case ExpressionType.New:

                    var ne = (NewExpression)expr;

                    var nar = self.VisitExpressionList(ne.Arguments);

                    if (nar != ne.Arguments)

                        return (ne.Members != null) ? Expression.New(ne.Constructor, nar, ne.Members)

                            : Expression.New(ne.Constructor, nar);

                    return expr;

                case ExpressionType.NewArrayInit:

                case ExpressionType.NewArrayBounds:

                    var na = (NewArrayExpression)expr;

                    var inits = self.VisitExpressionList(na.Expressions);

                    if (inits != na.Expressions)

                        return (na.NodeType == ExpressionType.NewArrayInit) ?

                            Expression.NewArrayInit(na.Type.GetElementType(), inits) :

                            Expression.NewArrayBounds(na.Type.GetElementType(), inits);

                    return expr;

                case ExpressionType.Invoke:

    var inv = (InvocationExpression)expr;

                    var args = self.VisitExpressionList(inv.Arguments);

                    var ie = self(inv.Expression);

                    return (args == inv.Arguments && ie == inv.Expression) ? expr : Expression.Invoke(ie, args);

                case ExpressionType.MemberInit:

                    var mi = (MemberInitExpression)expr;

                    var n = (NewExpression)self(mi.NewExpression);

                    var bindings = self.VisitBindingList(mi.Bindings);

                    return (n == mi.NewExpression && bindings == mi.Bindings) ? expr : Expression.MemberInit(n, bindings);

                case ExpressionType.ListInit:

                    var li = (ListInitExpression)expr;

                    var lin = (NewExpression)self(li.NewExpression);

                    var lii = VisitElementInitializerList(self, li.Initializers);

                    return (lin == li.NewExpression || lii == li.Initializers) ? expr : Expression.ListInit(lin, lii);

            }

            if (expr.IsBinary()) {

                var b = (BinaryExpression)expr;

                var left = self(b.Left);

                var right = self(b.Right);

                return (left == b.Left && right == b.Right) ? expr : Expression.MakeBinary(expr.NodeType, left, right);

            }

            else if (expr.IsUnary()) {

                var u = (UnaryExpression)expr;

                var op = self(u.Operand);

                return (u.Operand == op) ? expr : Expression.MakeUnary(u.NodeType, op, expr.Type);

            }

            return expr;

        }

    );

    public static bool IsBinary(this Expression expr) {

        return expr is BinaryExpression;

    }

    public static bool IsUnary(this Expression expr) {

        return expr is UnaryExpression;

    }

    public static MemberBinding VisitBinding(this Func<Expression, Expression> self, MemberBinding b) {

        switch (b.BindingType) {

            case MemberBindingType.Assignment:

                return self.VisitMemberAssignment((MemberAssignment)b);

            case MemberBindingType.MemberBinding: return self.VisitMemberMemberBinding((MemberMemberBinding)b);

        }

        return self.VisitMemberListBinding((MemberListBinding)b);

    }

    public static MemberAssignment VisitMemberAssignment(this Func<Expression, Expression> self, MemberAssignment assignment) {

        var e = self(assignment.Expression);

        return (e == assignment.Expression) ? assignment : Expression.Bind(assignment.Member, e);

    }

    public static MemberMemberBinding VisitMemberMemberBinding(this Func<Expression, Expression> self, MemberMemberBinding binding) {

        var bindings = self.VisitBindingList(binding.Bindings);

        return (bindings == binding.Bindings) ? binding : Expression.MemberBind(binding.Member, bindings);

    }

    public static MemberListBinding VisitMemberListBinding(this Func<Expression, Expression> self, MemberListBinding binding) {

        var initializers = self.VisitElementInitializerList(binding.Initializers);

        return (initializers == binding.Initializers) ? binding : Expression.ListBind(binding.Member, initializers);

    }

    public static ElementInit VisitElementInitializer(this Func<Expression, Expression> self, ElementInit initializer) {

        var arguments = self.VisitExpressionList(initializer.Arguments);

        return (arguments == initializer.Arguments) ? initializer : Expression.ElementInit(initializer.AddMethod, arguments);

    }

    public static ReadOnlyCollection<Expression> VisitExpressionList(this Func<Expression, Expression> self, ReadOnlyCollection<Expression> original) {

        return VisitList(original, e => self(e));

    }

    public static ReadOnlyCollection<MemberBinding> VisitBindingList(this Func<Expression, Expression> self, ReadOnlyCollection<MemberBinding> original) {

        return VisitList(original, e => self.VisitBinding(e));

    }

    public static ReadOnlyCollection<ElementInit> VisitElementInitializerList(this Func<Expression, Expression> self, ReadOnlyCollection<ElementInit> original) {

        return VisitList(original, e => VisitElementInitializer(self, e));

    }

    private static ReadOnlyCollection<T> VisitList<T>(ReadOnlyCollection<T> original, Func<T, T> op) {

        List<T> @new = null;

        for (int i = 0, n = original.Count; i < n; i++) {

            T init = op(original[i]);

            if (@new != null)

                @new.Add(init);

            else if (!ReferenceEquals(init, original[i])) {

                @new = new List<T>(n);

                for (int j = 0; j < i; j++)

                    @new.Add(original[j]);

                @new.Add(init);

            }

        }

        return (@new == null) ? original : @new.AsReadOnly();

    }

}

 

Once I pasted this in, it looks like a lot of code (~150 lines). But keep in mind that this is complete, reusable mutating visitor for Linq expression trees.

Ok, now to see it in action. This code chains to the visitor above and creates an expression tree rewriter that replaces all adds with subtracts:

    var RewriteSubtractToAdd = ExprOp.Visit.Chain(

        (self, last, expr) => {

            switch (expr.NodeType) {

                case ExpressionType.Add:

                    var b = (BinaryExpression)expr;

                    return Expression.Subtract(b.Left, b.Right);

                default:

                    return last(expr);

            }

        }

    );

Here's the how to use the rewriter:

    Expression<Func<int, int>> SubtractOneExpr = (n => n - 1);

    var AddOneExpr = (LambdaExpression)RewriteSubtractToAdd(SubtractOneExpr);

    var AddOne = (Func<int,int>)AddOneExpr.Compile();

    Console.WriteLine(AddOne(5));

In this example, I take an expression that subtracts one and rewrite it as an expression that adds one. Then I compile the resulting expression into a delegate and invoke it to prove it work. The chain-of-responsibility expression visitor does the heavy lifting.

(Follow-up 5-25-2007: Aaron's got a different utility class for dealing with Expression trees. Check out ExpressionBuilder: https://blog.magenic.com/blogs/aarone/archive/2007/05/24/Announcing-MetaLinq-_2D00_-Linq-to-Expressions.aspx )  

(Follow-up 7-31-2007: Matt Warren (they guy in the office next to mine has posted a different expression visitor. His is pretty close to the one we use internally: https://blogs.msdn.com/mattwar/archive/2007/07/31/linq-building-an-iqueryable-provider-part-ii.aspx )  

This posting is provided "AS IS" with no warranties, and confers no rights.

 

kick it on DotNetKicks.com

ExprOpSample-Beta1.cs