Improving ObjectQuery.Include – Updated

Having spent some time using the sample from my previous post on ObjectQuery.Include, I’ve encountered a bug! It turns out that the code generates the wrong include string for

 context.Customers.Include(c => c.Order.SubInclude(o=>o.OrderDetail))

The fix for this is a small change to the BuildString method to recurse up the MemberExpression if necessary. The updated code is below  - usual disclaimers apply!

     public static class ObjectQueryExtensions
    {
        public static ObjectQuery<TSource> Include<TSource, TPropType>(this ObjectQuery<TSource> source, Expression<Func<TSource, TPropType>> propertySelector)
        {
            string includeString = BuildString(propertySelector);
            return source.Include(includeString);
        }
        private static string BuildString(Expression propertySelector)
        {
            switch (propertySelector.NodeType)
            {
                case ExpressionType.Lambda:
                    LambdaExpression lambdaExpression = (LambdaExpression)propertySelector;
                    return BuildString(lambdaExpression.Body);

                case ExpressionType.Quote:
                    UnaryExpression unaryExpression = (UnaryExpression)propertySelector;
                    return BuildString(unaryExpression.Operand);

                case ExpressionType.MemberAccess:

                    MemberExpression memberExpression = (MemberExpression)propertySelector;
                    MemberInfo propertyInfo = memberExpression.Member;

                    if (memberExpression.Expression is ParameterExpression)
                    {
                        return propertyInfo.Name;
                    }
                    else
                    {
                        // we've got a nested property (e.g. MyType.SomeProperty.SomeNestedProperty)
                        return BuildString(memberExpression.Expression) + "." + propertyInfo.Name;
                    }

                case ExpressionType.Call:
                    MethodCallExpression methodCallExpression = (MethodCallExpression)propertySelector;
                    if (IsSubInclude(methodCallExpression.Method)) // check that it's a SubInclude call
                    {
                        // argument 0 is the expression to which the SubInclude is applied (this could be member access or another SubInclude)
                        // argument 1 is the expression to apply to get the included property
                        // Pass both to BuildString to get the full expression
                        return BuildString(methodCallExpression.Arguments[0]) + "." +
                               BuildString(methodCallExpression.Arguments[1]);
                    }
                    // else drop out and throw
                    break;
            }
            throw new InvalidOperationException("Expression must be a member expression or an SubInclude call: " + propertySelector.ToString());

        }

        private static readonly MethodInfo[] SubIncludeMethods;
        static ObjectQueryExtensions()
        {
            Type type = typeof(ObjectQueryExtensions);
            SubIncludeMethods = type.GetMethods().Where(mi => mi.Name == "SubInclude").ToArray();
        }
        private static bool IsSubInclude(MethodInfo methodInfo)
        {
            if (methodInfo.IsGenericMethod)
            {
                if (!methodInfo.IsGenericMethodDefinition)
                {
                    methodInfo = methodInfo.GetGenericMethodDefinition();
                }
            }
            return SubIncludeMethods.Contains(methodInfo);
        }

        public static TPropType SubInclude<TSource, TPropType>(this EntityCollection<TSource> source, Expression<Func<TSource, TPropType>> propertySelector)
            where TSource : class, IEntityWithRelationships
            where TPropType : class
        {
            throw new InvalidOperationException("This method is only intended for use with ObjectQueryExtensions.Include to generate expressions trees"); // no actually using this - just want the expression!
        }
        public static TPropType SubInclude<TSource, TPropType>(this TSource source, Expression<Func<TSource, TPropType>> propertySelector)
            where TSource : class, IEntityWithRelationships
            where TPropType : class
        {
            throw new InvalidOperationException("This method is only intended for use with ObjectQueryExtensions.Include to generate expressions trees"); // no actually using this - just want the expression!
        }
    }

 

UPDATE: Alex James has a great post on Eager Loading Strategies that I'd recommend reading.