How to rewrite Single(OrDefault) using LINQ Aggregate Catamorpshim?

178 Views Asked by At

I was reading the article written by Bart de Smet quite a while ago: http://community.bartdesmet.net/blogs/bart/archive/2009/11/08/jumping-the-trampoline-in-c-stack-friendly-recursion.aspx

It's left as an exercise to the reader to define all other catamorphic operators in LINQ in terms of the Aggregate operator:
- Simple: (Long)Count, Sum, Average, Min, Max
- A bit harder: All, Any, Contains
- More thinking: First(OrDefault), Last(OrDefault), Single(OrDefault)

I managed to use the Aggregate catamorphism for almost everything except for the Single() extension method. Also, in a couple of cases like for the Contains() or First(), there are not properly re-written cause I don't really see how the Aggregate catamorphism can be stopped (whenever an item matching some conditions has been found) instead of iterating through the rest of the whole given sequence.

To sum-up:
- How to rewrite the Single() extension method using Aggregate?
- How to stop the LINQ Aggregate catamorphism in order to return an item when an item has been found without continuing to iterate the rest of a given sequence (e.g. AggregateContains() should return true for the first item matching the value and stop iterating).

public static class Program
{
    private static void Main()
    {
        var n = 42;
        var numbers = Enumerable.Range(1, n).ToArray();

        Console.WriteLine(numbers.Count() == numbers.AggregateCount());
        Console.WriteLine(numbers.Count(IsEven) == numbers.AggregateCount(IsEven));

        Console.WriteLine(numbers.LongCount() == numbers.AggregateLongCount());
        Console.WriteLine(numbers.LongCount(IsEven) == numbers.AggregateLongCount(IsEven));

        Console.WriteLine(numbers.Sum() == numbers.AggregateSum());
        Console.WriteLine(numbers.Sum(i => i.IsEven() ? i : 0) == numbers.AggregateSum(i => i.IsEven() ? i : 0));

        Console.WriteLine(numbers.Average() == numbers.AggregateAverage());
        Console.WriteLine(numbers.Average(i => i * 2) == numbers.AggregateAverage(i => i * 2));

        Console.WriteLine(numbers.Min() == numbers.AggregateMin());
        Console.WriteLine(numbers.Max() == numbers.AggregateMax());

        Console.WriteLine(numbers.All(i => i < n + 1) == numbers.AggregateAll(i => i < n + 1));
        Console.WriteLine(numbers.Any(IsEven) == numbers.AggregateAny(IsEven));

        Console.WriteLine(numbers.Last() == numbers.AggregateLast());
        Console.WriteLine(numbers.LastOrDefault() == numbers.AggregateLastOrDefault());
        Console.WriteLine(numbers.Last(IsEven) == numbers.AggregateLast(IsEven));
        Console.WriteLine(numbers.LastOrDefault(IsEven) == numbers.AggregateLastOrDefault(IsEven));

        Console.WriteLine(numbers.First() == numbers.AggregateFirst());
        Console.WriteLine(numbers.FirstOrDefault() == numbers.AggregateFirstOrDefault());
        Console.WriteLine(numbers.First(IsEven) == numbers.AggregateFirst(IsEven));
        Console.WriteLine(numbers.FirstOrDefault(IsEven) == numbers.AggregateFirstOrDefault(IsEven));

        Console.ReadKey();
    }

    private static bool IsEven(this int number)
    {
        return number % 2 == 0;
    }
}


public static class EnumerableAggregateExtensions
{
    public static int AggregateCount<T>(this IEnumerable<T> source)
    {
        return source.Aggregate(0, (r, i) => r + 1);
    }

    public static int AggregateCount<T>(this IEnumerable<T> source, Func<T, bool> predicate)
    {
        return source.Aggregate(0, (r, i) => predicate(i) ? r + 1 : r);
    }

    public static long AggregateLongCount<T>(this IEnumerable<T> source)
    {
        return source.Aggregate(0L, (r, i) => r + 1);
    }

    public static long AggregateLongCount<T>(this IEnumerable<T> source, Func<T, bool> predicate)
    {
        return source.Aggregate(0L, (r, i) => predicate(i) ? r + 1 : r);
    }

    public static int AggregateSum(this IEnumerable<int> source)
    {
        return source.Aggregate(0, (r, i) => r + i);
    }

    public static int AggregateSum<TSource>(this IEnumerable<TSource> source, Func<TSource, int> selector)
    {
        return source.Aggregate(0, (r, i) => r + selector(i));
    }

    public static double AggregateAverage(this IEnumerable<int> source)
    {
        return source.Aggregate(new { Sum = 0, Count = 0 }, (r, i) => new { Sum = r.Sum + i, Count = r.Count + 1 }, r => r.Sum / (double)r.Count);
    }

    // Could use a nominable mutable type definition to avoid the readonly anonymous type re-allocation:
    //public class Int32SumCountHolder
    //{
    //    public int Sum { get; private set; }
    //    public int Count { get; private set; }
    //    public double? Average => Count > 0 ? Sum / (double)Count : new double?();
    //    public Int32SumCountHolder(int sum = 0, int count = 0)
    //    {
    //        Sum = sum;
    //        Count = count;
    //    }
    //    public Int32SumCountHolder Increase(int value)
    //    {
    //        Sum += value;
    //        Count ++;
    //        return this;
    //    }
    //}
    public static double AggregateAverage<TSource>(this IEnumerable<TSource> source, Func<TSource, int> selector)
    {
        return source.Aggregate(new { Sum = 0, Count = 0 }, (r, i) => new { Sum = r.Sum + selector(i), Count = r.Count + 1 }, r => r.Sum / (double)r.Count);
    }

    public static TSource AggregateMin<TSource>(this IEnumerable<TSource> source)
    {
        var comparer = Comparer<TSource>.Default;
        return source.Aggregate((r, i) => comparer.Compare(i, r) < 0 ? i : r);
    }

    public static TResult AggregateMin<TSource, TResult>(this IEnumerable<TSource> source, Func<TSource, TResult> selector)
    {
        var comparer = Comparer<TResult>.Default;
        return source.Select(selector).Aggregate((r, i) => comparer.Compare(i, r) < 0 ? i : r);
    }

    public static TSource AggregateMax<TSource>(this IEnumerable<TSource> source)
    {
        var comparer = Comparer<TSource>.Default;
        return source.Aggregate((r, i) => comparer.Compare(i, r) > 0 ? i : r);
    }

    public static TResult AggregateMax<TSource, TResult>(this IEnumerable<TSource> source, Func<TSource, TResult> selector)
    {
        var comparer = Comparer<TResult>.Default;
        return source.Select(selector).Aggregate((r, i) => comparer.Compare(i, r) > 0 ? i : r);
    }

    public static bool AggregateAll<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate)
    {
        return source.Aggregate(true, (r, i) => r && predicate(i));
    }

    public static bool AggregateAny<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate)
    {
        // How to make the Aggregate call stop whenever the predicate is okay rather than looping over the whole source?
        return source.Aggregate(false, (r, i) => r || predicate(i));
    }

    public static bool AggregateContains<TSource>(this IEnumerable<TSource> source, TSource value)
    {
        return source.AggregateContains(value, EqualityComparer<TSource>.Default);
    }

    public static bool AggregateContains<TSource>(this IEnumerable<TSource> source, TSource value, IEqualityComparer<TSource> comparer)
    {
        // Same as for AggregateAny... how to make it stop earlier?
        return source.Aggregate(false, (r, i) => r || comparer.Equals(i, value));
    }

    public static TSource AggregateFirst<TSource>(this IEnumerable<TSource> source)
    {
        // Same as for AggregateContains... how to make it stop earlier?
        return source.Aggregate(default(Tuple<TSource>), (r, i) => r ?? new Tuple<TSource>(i)).Item1;
    }

    public static TSource AggregateFirstOrDefault<TSource>(this IEnumerable<TSource> source)
    {
        var result = source.Aggregate(default(Tuple<TSource>), (r, i) => r ?? new Tuple<TSource>(i));

        return result == null ? default(TSource) : result.Item1;
    }

    public static TSource AggregateFirst<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate)
    {
        // Same as for above...
        return source.Aggregate(default(Tuple<TSource>), (r, i) => predicate(i) && r == null ? new Tuple<TSource>(i) : r).Item1;
    }

    public static TSource AggregateFirstOrDefault<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate)
    {
        // Same as for above...
        var result = source.Aggregate(default(Tuple<TSource>), (r, i) => predicate(i) && r == null ? new Tuple<TSource>(i) : r);

        return result == null ? default(TSource) : result.Item1;
    }

    public static TSource AggregateLast<TSource>(this IEnumerable<TSource> source)
    {
        return source.Aggregate((r, i) => i);
    }

    public static TSource AggregateLast<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate)
    {
        return source.Aggregate(default(Tuple<TSource>), (r, i) => predicate(i) ? new Tuple<TSource>(i) : r).Item1;
    }

    public static TSource AggregateLastOrDefault<TSource>(this IEnumerable<TSource> source)
    {
        return source.Aggregate(default(TSource), (r, i) => i);
    }

    public static TSource AggregateLastOrDefault<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate)
    {
        return source.Aggregate(default(TSource), (r, i) => predicate(i) ? i : r);
    }

}
0

There are 0 best solutions below