Improving ObjectQuery.Include

Having spent some time using the sample from my previous post on ObjectQuery.Include, I’ve encountered a bug! It turns out that the code generates the wrong include string for

context.Customers.Include(c => c.Order.SubInclude(o=>o.OrderDetail))

The fix for this is a small change to the BuildString method to recurse up the MemberExpression if necessary. The updated code is below - usual disclaimers apply!

    public static class ObjectQueryExtensions

    {

        public static ObjectQuery<TSource> Include<TSource,

                               TPropType>(this ObjectQuery<TSource> source,

                               Expression<Func<TSource,

                               TPropType>> propertySelector)

        {

            string includeString = BuildString(propertySelector);

            return source.Include(includeString);

        }

        private static string BuildString(Expression propertySelector)

    {

            switch (propertySelector.NodeType)

            {

                case ExpressionType.Lambda:

                    LambdaExpression lambdaExpression = (LambdaExpression)propertySelector;

                    return BuildString(lambdaExpression.Body);

                case ExpressionType.Quote:

                    UnaryExpression unaryExpression = (UnaryExpression)propertySelector;

                    return BuildString(unaryExpression.Operand);

                case ExpressionType.MemberAccess:

                    MemberExpression memberExpression = (MemberExpression)propertySelector;

                    MemberInfo propertyInfo = memberExpression.Member;

                    if (memberExpression.Expression is ParameterExpression)

            {

                        return propertyInfo.Name;

                    }

                    else

                    {

                        // we've got a nested property (e.g. MyType.SomeProperty.SomeNestedProperty)

                        return BuildString(memberExpression.Expression) + "." + propertyInfo.Name;

                    }

                case ExpressionType.Call:

                    MethodCallExpression methodCallExpression =

                            (MethodCallExpression)propertySelector;

               // check that it's a SubInclude cal

            if (IsSubInclude(methodCallExpression.Method))l

                    {

                        // argument 0 is the expression to which the SubInclude is applied (this could

                        // be member access or another SubInclude)

                        // argument 1 is the expression to apply to get the included property

                        // Pass both to BuildString to get the full expression

                        return BuildString(methodCallExpression.Arguments[0]) + "." +

                               BuildString(methodCallExpression.Arguments[1]);

                    }

                    // else drop out and throw

                    break;

            }

            throw new InvalidOperationException("Expression must be a member expression or

                    an SubInclude call: " + propertySelector.ToString());

        }

        private static readonly MethodInfo[] SubIncludeMethods;

        static ObjectQueryExtensions()

        {

            Type type = typeof(ObjectQueryExtensions);

            SubIncludeMethods =

                    type.GetMethods().Where(mi => mi.Name == "SubInclude").ToArray();

        }

        private static bool IsSubInclude(MethodInfo methodInfo)

        {

            if (methodInfo.IsGenericMethod)

            {

          if (!methodInfo.IsGenericMethodDefinition)

                {

                    methodInfo = methodInfo.GetGenericMethodDefinition();

                }

            }

            return SubIncludeMethods.Contains(methodInfo);

        }

        public static TPropType SubInclude<TSource,

                    TPropType>(this EntityCollection<TSource> source,

                    Expression<Func<TSource, TPropType>> propertySelector)

            where TSource : class, IEntityWithRelationships

            where TPropType : class

        {

      throw new InvalidOperationException("This method is only intended for use with

                    ObjectQueryExtensions.Include to generate expressions trees");

                    // no actually using this - just want the expression!

        }

        public static TPropType SubInclude<TSource,

               TPropType>(this TSource source,

               Expression<Func<TSource, TPropType>> propertySelector)

            where TSource : class, IEntityWithRelationships

            where TPropType : class

        {

            throw new InvalidOperationException("This method is only intended for use with

                    ObjectQueryExtensions.Include to generate expressions trees");

                    // no actually using this - just want the expression!

        }

    }

Originally posted by Stuart Leeks on April 24th 2009 here.