Making Equality Testing Simple


Getting equality correct on a .Net type is a fairly involved process involving adherence to a large set of rules in order to be considered correct.  Including

  • Object.Equals overrides on reference types must return false for null values
  • Object.Equals overrides must return false for incompatible types
  • Excluding null cases x.Equals(y) must be the same as y.Equals(x)
  • Excluding null cases (x.Equals(y) && y.Equals(z)) is true only if x.Equals(z)
  • If operator == or overloaded
    • Both == and != should be overloaded or none
    • Operator == must handle the left side being null
    • Operator == should mimic Object.Equals in all cases where the left side is not null
    • Operator != must handle the left side being null
    • Operator != should mimic !Object.Equals in all cases where the left side is not null
  • If two values are equal according to Object.Equals they must have matching returns for GetHashCode

I’m sure I missed one or two subtle ones but these are the major players.  It gets even more fun when you add in IEquatable<T> to the mix. 

Luckily correctly implementing equality is fairly straight forward and most template code available on the web respects the above rules.  However it’s easy to miss a corner case and add hard to track down bugs. 

I’m not satisfied by simply following a standard template and hoping I got it right.  I only sleep easy if I’ve tested these cases.  Yet testing all of these cases is very tedious and involves quite a bit of code that screams for an abstraction.  As a new type author I simply want to provide a collection of units which associate a value and corresponding equal or not equal values and let the abstraction verify I properly implemented equality semantics.

The first step is defining a type to encapsulate a value and set of equal or not equal values. 

public class EqualityUnit<T> {
    private static ReadOnlyCollection<T> EmptyCollection = new ReadOnlyCollection<T>(new T[] { });

    public readonly T Value;
    public readonly ReadOnlyCollection<T> EqualValues;
    public readonly ReadOnlyCollection<T> NotEqualValues;
    public IEnumerable<T> AllValues {
        get { return Enumerable.Repeat(Value, 1).Concat(EqualValues).Concat(NotEqualValues); }
    }
    public EqualityUnit(T value) {
        Value = value;
        EqualValues = EmptyCollection;
        NotEqualValues = EmptyCollection;
    }
    public EqualityUnit(
        T value,
        ReadOnlyCollection<T> equalValues,
        ReadOnlyCollection<T> notEqualValues) {
        Value = value;
        EqualValues = equalValues;
        NotEqualValues = notEqualValues;
    }
    public EqualityUnit<T> WithEqualValues(params T[] equalValues) {
        return new EqualityUnit<T>(
            Value,
            EqualValues.Concat(equalValues).ToList().AsReadOnly(),
            NotEqualValues);
    }
    public EqualityUnit<T> WithNotEqualValues(params T[] notEqualValues) {
        return new EqualityUnit<T>(
            Value,
            EqualValues,
            NotEqualValues.Concat(notEqualValues).ToList().AsReadOnly());
    }
}

public static class EqualityUnit {
    public static EqualityUnit<T> Create<T>(T value) {
        return new EqualityUnit<T>(value);
    }
}

I chose a fluent interface design here because it makes the usage code very readable.  For example

var unit = EqualityUnit
    .Create(new MyType(42))
    .WithEqualValues(new MyType(42))
    .WithNotEqualValues(new MyType(13));

Now that we have the data defined we need to follow through with the actual test code.  Most of it is very straight forward enforcement of the above said rules.  The only trick part is how to test operator == and !=.   The testing class is necessarily generic but neither == or != can be used against open generic types.  Instead we must use them against the non-generic types. 

This can be solved by having the calling code provide 2 lambda expressions of type Func<T,T,bool> which call the == and != operator.  

EqualityUtil.RunAll(
    (x, y) => x == y,
    (x, y) => x != y,

This is boiler plate code that has to be repeated for every caller but it’s small enough to not be that much of a burden.   Now finally the code.

public sealed class EqualityUtil<T> {
    private readonly ReadOnlyCollection<EqualityUnit<T>> _equalityUnits;
    private readonly Func<T, T, bool> _compareWithEqualityOperator;
    public readonly Func<T, T, bool> _compareWithInequalityOperator;

    public EqualityUtil(
        IEnumerable<EqualityUnit<T>> equalityUnits,
        Func<T, T, bool> compEquality,
        Func<T, T, bool> compInequality) {
        _equalityUnits = equalityUnits.ToList().AsReadOnly();
        _compareWithEqualityOperator = compEquality;
        _compareWithInequalityOperator = compInequality;
    }

    public void RunAll(
        bool skipOperators = false,
        bool skipEquatable = false) {
        if (!skipOperators) {
            EqualityOperator();
            EqualityOperatorCheckNull();
            InEqualityOperator();
            InEqualityOperatorCheckNull();
        }

        if (!skipEquatable) {
            ImplementsIEquatable();
            EquatableEquals();
            EquatableEqualsCheckNull();
        }

        ObjectEquals();
        ObjectEqualsCheckNull();
        ObjectEqualsDifferentType();
        GetHashCodeSemantics();
    }

    private void EqualityOperator() {
        foreach (var unit in _equalityUnits) {
            foreach (var value in unit.EqualValues) {
                Assert.IsTrue(_compareWithEqualityOperator(unit.Value, value));
                Assert.IsTrue(_compareWithEqualityOperator(value, unit.Value));
            }

            foreach (var value in unit.NotEqualValues) {
                Assert.IsFalse(_compareWithEqualityOperator(unit.Value, value));
                Assert.IsFalse(_compareWithEqualityOperator(value, unit.Value));
            }
        }
    }

    private void EqualityOperatorCheckNull() {
        if (typeof(T).IsValueType) {
            return;
        }

        foreach (var value in _equalityUnits.SelectMany(x => x.AllValues)) {
            if (!Object.ReferenceEquals(value, null)) {
                Assert.IsFalse(_compareWithEqualityOperator(default(T), value));
                Assert.IsFalse(_compareWithEqualityOperator(value, default(T)));
            }
        }
    }

    private void InEqualityOperator() {
        foreach (var unit in _equalityUnits) {
            foreach (var value in unit.EqualValues) {
                Assert.IsFalse(_compareWithInequalityOperator(unit.Value, value));
                Assert.IsFalse(_compareWithInequalityOperator(value, unit.Value));
            }

            foreach (var value in unit.NotEqualValues) {
                Assert.IsTrue(_compareWithInequalityOperator(unit.Value, value));
                Assert.IsTrue(_compareWithInequalityOperator(value, unit.Value));
            }
        }
    }

    private void InEqualityOperatorCheckNull() {
        if (typeof(T).IsValueType) {
            return;
        }
        foreach (var value in _equalityUnits.SelectMany(x => x.AllValues)) {
            if (!Object.ReferenceEquals(value, null)) {
                Assert.IsTrue(_compareWithInequalityOperator(default(T), value));
                Assert.IsTrue(_compareWithInequalityOperator(value, default(T)));
            }
        }
    }

    private void ImplementsIEquatable() {
        var type = typeof(T);
        var targetType = typeof(IEquatable<T>);
        Assert.IsTrue(type.GetInterfaces().Contains(targetType));
    }

    private void ObjectEquals() {
        foreach (var unit in _equalityUnits) {
            var unitValue = unit.Value;
            foreach (var value in unit.EqualValues) {
                Assert.IsTrue(unitValue.Equals(value));
                Assert.IsTrue(value.Equals(unitValue));
            }
            foreach (var value in unit.NotEqualValues) {
                Assert.IsFalse(unitValue.Equals(value));
                Assert.IsFalse(value.Equals(unitValue));
            }
        }
    }

    /// <summary>
    /// Comparison with Null should be false for reference types
    /// </summary>
    private void ObjectEqualsCheckNull() {
        if (typeof(T).IsValueType) {
            return;
        }

        var allValues = _equalityUnits.SelectMany(x => x.AllValues);
        foreach (var value in allValues) {
            Assert.IsFalse(value.Equals(null));
        }
    }

    private sealed class NotAccessible { } 

    /// <summary>
    /// Passing a value of a different type should just return false
    /// </summary>
    private void ObjectEqualsDifferentType() {
        var allValues = _equalityUnits.SelectMany(x => x.AllValues);
        foreach (var value in allValues) {
            Assert.IsFalse(value.Equals(new NotAccessible()));
        }
    }

    private void GetHashCodeSemantics() {
        foreach (var unit in _equalityUnits) {
            foreach (var value in unit.EqualValues) {
                Assert.AreEqual(value.GetHashCode(), unit.Value.GetHashCode());
            }
        }
    }

    private void EquatableEquals() {
        foreach (var unit in _equalityUnits) {
            var equatableUnit = (IEquatable<T>)unit.Value;
            foreach (var value in unit.EqualValues) {
                Assert.IsTrue(equatableUnit.Equals(value));
                var equatableValue = (IEquatable<T>)value;
                Assert.IsTrue(equatableValue.Equals(unit.Value));
            }

            foreach (var value in unit.NotEqualValues) {
                Assert.IsFalse(equatableUnit.Equals(value));
                var equatableValue = (IEquatable<T>)value;
                Assert.IsFalse(equatableValue.Equals(unit.Value));
            }
        }
    }

    /// <summary>
    /// If T is a reference type, null should return false in all cases
    /// </summary>
    private void EquatableEqualsCheckNull() {
        if (typeof(T).IsValueType) {
            return;
        }

        foreach (var cur in _equalityUnits.SelectMany(x => x.AllValues)) {
            var value = (IEquatable<T>)cur;
            Assert.IsFalse(value.Equals(null));
        }
    }
}

public static class EqualityUtil {
    public static void RunAll<T>(
        Func<T, T, bool> compEqualsOperator,
        Func<T, T, bool> compNotEqualsOperator,
        bool skipOperators,
        bool skipEquatable,
        params EqualityUnit<T>[] values) {
        var util = new EqualityUtil<T>(values, compEqualsOperator, compNotEqualsOperator);
        util.RunAll(skipOperators: skipOperators, skipEquatable: skipEquatable);
    }

    public static void RunAll<T>(
        Func<T, T, bool> compEqualsOperator,
        Func<T, T, bool> compNotEqualsOperator,
        params EqualityUnit<T>[] values) {
        RunAll(compEqualsOperator, compNotEqualsOperator, skipEquatable: false, skipOperators: false, values: values);
    }
}

And what would any code, including test framework code be without a few test cases?

[TestFixture]
public class EqualityUtilTesting {

    [Test]
    public void EqualityWithIntegers() {
        EqualityUtil.RunAll(
            (x, y) => x == y,
            (x, y) => x != y,
            EqualityUnit.Create(1).WithEqualValues(1).WithNotEqualValues(2),
            EqualityUnit.Create(42).WithNotEqualValues(13));
    }

    [Test]
    public void EqualityWithStrings() {
        EqualityUtil.RunAll(
            (x, y) => x == y,
            (x, y) => x != y,
            EqualityUnit.Create("foo").WithEqualValues("foo").WithNotEqualValues("no"),
            EqualityUnit.Create("FOO").WithNotEqualValues("foo"));
    }
}

Comments (4)

  1. jtenos says:

    Is there some kind of standard that says that Equals should mimic == ?  If so, what’s the point in having both Equals and == in the framework, both override-able?  I’ve done this myself, but it just means that I’m basically doing double-work, since all I’m doing is making it possible for the calling code to take their pick between == and Equals and get the same result, rather than making them use the right one.

    If anything, I would have liked the standard to be "== means reference equality" and "Equals means logical equality", which would be completely different things.  But then you throw value types in the mix, and == can’t be limited to reference, because then (1 == 1) wouldn’t work.

    So I don’t really have a better answer, but I’ve never seen anything official, and haven’t seen any de facto standards either.

  2. Mathematically, .Equals(…) and == should be "equivalence relations".

    http://en.wikipedia.org/wiki/Equivalence_relation#Definition

    You have the symmetric and transitive parts listed:

    # Excluding null cases x.Equals(y) must be the same as y.Equals(x)

    # Excluding null cases (x.Equals(y) && y.Equals(z)) is true only if x.Equals(z)

    I would add the reflexive:

    # Excluding null cases x.Equals(x) must be true

  3. @jtenos

    There is no hard and fast rule about overriding == if you override .Equals.  In general though if .Equals is overridden then .Equals should be as well because quite simply too many people assume this will be true (self included).  This is not consistently done in the BCL (see IPAddress) but newer code follows this standard.

    The one advantage == has over .Equals is that pushes the burden of the null check on the LHS to the operator so you can write "x == y" instead of x != null && x.Equals(y)".  

    Having both == and .Equals is something I’ve definitely found frustrating as well (especially due to the lack of consistency).  I do find C# to be lacking here as it does nothing to distinguish between referential and value equality at an operator level.  

    This is one place where I feel that VB.Net wins over C# because it differs between referential and value equality.  Reference equality is achieved with the "Is" operator and cannot be overriden.  Using the ‘=’ comparison will only compile if the type actually implements operator ‘=’ (this will compile to the same operator as C#’s ==).  

  4. @Maurits

    True that should definitely be in the list.  I’ll add it into the post later.