先贴源码地址
https://github.com/dotnet/corefx/tree/master/src/System.Linq/src
.NET CORE很大一个好处就是代码的开源,你可以详细的查看你使用类的源代码,并学习微软的写法和实现思路。我们这个系列熟悉基本类库是一个目的,另一个目的就是学习微软的实现思路和编程方法。
今天我们就单独讨论的问题是linq中的distinct方法是如何实现。最后还会有我们实际编程时候对distinct方法的扩展。
System.Linq
linq中Distinct方法在Enumerable类中
Enumerable
public static partial class Enumerable
内部去重方法实现有2个重载
1
public static IEnumerable<TSource> Distinct<TSource>(this IEnumerable<TSource> source) => Distinct(source, null);
2
public static IEnumerable<TSource> Distinct<TSource>(this IEnumerable<TSource> source, IEqualityComparer<TSource> comparer) { if (source == null) { throw Error.ArgumentNull(nameof(source)); } return new DistinctIterator<TSource>(source, comparer); }
去重迭代器DistinctIterator
private sealed class DistinctIterator
去重迭代器,先把元素都加到Set<TSource> _set;中,然后用set的UnionWith去重
这里的set是内部实现的一个轻量级的hash set 具体代码下一部分介绍
/// <summary> /// An iterator that yields the distinct values in an <see cref="IEnumerable{TSource}"/>. /// </summary> /// <typeparam name="TSource">The type of the source enumerable.</typeparam> private sealed class DistinctIterator<TSource> : Iterator<TSource>, IIListProvider<TSource> { private readonly IEnumerable<TSource> _source; private readonly IEqualityComparer<TSource> _comparer; private Set<TSource> _set; private IEnumerator<TSource> _enumerator; public DistinctIterator(IEnumerable<TSource> source, IEqualityComparer<TSource> comparer) { Debug.Assert(source != null); _source = source; _comparer = comparer; } public override Iterator<TSource> Clone() => new DistinctIterator<TSource>(_source, _comparer); public override bool MoveNext() { switch (_state) { case 1: _enumerator = _source.GetEnumerator(); if (!_enumerator.MoveNext()) { Dispose(); return false; } TSource element = _enumerator.Current; _set = new Set<TSource>(_comparer); _set.Add(element); _current = element; _state = 2; return true; case 2: while (_enumerator.MoveNext()) { element = _enumerator.Current; if (_set.Add(element)) { _current = element; return true; } } break; } Dispose(); return false; } public override void Dispose() { if (_enumerator != null) { _enumerator.Dispose(); _enumerator = null; _set = null; } base.Dispose(); } private Set<TSource> FillSet() { Set<TSource> set = new Set<TSource>(_comparer); set.UnionWith(_source); return set; } public TSource[] ToArray() => FillSet().ToArray(); public List<TSource> ToList() => FillSet().ToList(); public int GetCount(bool onlyIfCheap) => onlyIfCheap ? -1 : FillSet().Count; }
Set的UnionWith
这部分其实是distinct实现的重点,所以内容较多。
/// <summary> /// A lightweight hash set. /// </summary> /// <typeparam name="TElement">The type of the set's items.</typeparam> internal sealed class Set<TElement> { 变量 /// <summary> /// The comparer used to hash and compare items in the set. /// </summary> private readonly IEqualityComparer<TElement> _comparer; /// <summary> /// The hash buckets, which are used to index into the slots. /// </summary> private int[] _buckets; /// <summary> /// The slots, each of which store an item and its hash code. /// </summary> private Slot[] _slots; /// <summary> /// An entry in the hash set. /// </summary> private struct Slot { /// <summary> /// The hash code of the item. /// </summary> internal int _hashCode; /// <summary> /// In the case of a hash collision(碰撞), the index of the next slot to probe(查看). /// </summary> internal int _next; /// <summary> /// The item held by this slot. /// </summary> internal TElement _value; } /// <summary> /// The number of items in this set. /// </summary> private int _count; 构造函数 /// <summary> /// Constructs a set that compares items with the specified comparer. /// </summary> /// <param name="comparer"> /// The comparer. If this is <c>null</c>, it defaults to <see cref="EqualityComparer{TElement}.Default"/>. /// </param> public Set(IEqualityComparer<TElement> comparer) { _comparer = comparer ?? EqualityComparer<TElement>.Default; _buckets = new int[7]; _slots = new Slot[7]; } 新增方法 /// <summary> /// Attempts to add an item to this set. /// </summary> /// <param name="value">The item to add.</param> /// <returns> /// <c>true</c> if the item was not in the set; otherwise, <c>false</c>. /// </returns> public bool Add(TElement value) { //根据值获取哈希值 最终调用的是_comparer.GetHashCode(value) int hashCode = InternalGetHashCode(value); //遍历对比值 直接找到对应的桶,遍历桶中的元素 _slots[i]._next最后一个值会是-1,所以会跳出循环 for (int i = _buckets[hashCode % _buckets.Length] - 1; i >= 0; i = _slots[i]._next) { if (_slots[i]._hashCode == hashCode && _comparer.Equals(_slots[i]._value, value)) { return false; } } //如果超出长度,就扩展 乘以2加1 if (_count == _slots.Length) { Resize(); } int index = _count; _count++; int bucket = hashCode % _buckets.Length;//这里具体桶的位置需要除以总体长度,这样空间利用率更好 _slots[index]._hashCode = hashCode; _slots[index]._value = value; _slots[index]._next = _buckets[bucket] - 1;//桶中前一个元素的位置索引 _buckets[bucket] = index + 1; return true; } 去除方法 /// <summary> /// Attempts to remove an item from this set. /// </summary> /// <param name="value">The item to remove.</param> /// <returns> /// <c>true</c> if the item was in the set; otherwise, <c>false</c>. /// </returns> public bool Remove(TElement value) { int hashCode = InternalGetHashCode(value); int bucket = hashCode % _buckets.Length; int last = -1; for (int i = _buckets[bucket] - 1; i >= 0; last = i, i = _slots[i]._next) { if (_slots[i]._hashCode == hashCode && _comparer.Equals(_slots[i]._value, value)) { if (last < 0) { _buckets[bucket] = _slots[i]._next + 1; } else { _slots[last]._next = _slots[i]._next; } _slots[i]._hashCode = -1; _slots[i]._value = default(TElement); _slots[i]._next = -1; return true; } } return false; } 扩展set /// <summary> /// Expands the capacity of this set to double the current capacity, plus one. /// </summary> private void Resize() { int newSize = checked((_count * 2) + 1);//这个要检测是否超出int长度限制 int[] newBuckets = new int[newSize]; Slot[] newSlots = new Slot[newSize]; Array.Copy(_slots, 0, newSlots, 0, _count);//赋值newSlots数组 for (int i = 0; i < _count; i++) { int bucket = newSlots[i]._hashCode % newSize; newSlots[i]._next = newBuckets[bucket] - 1;//重新记录桶位置 newBuckets[bucket] = i + 1; } _buckets = newBuckets; _slots = newSlots; } /// <summary> /// Creates an array from the items in this set. /// </summary> /// <returns>An array of the items in this set.</returns> public TElement[] ToArray() { TElement[] array = new TElement[_count]; for (int i = 0; i != array.Length; ++i) { array[i] = _slots[i]._value; } return array; } /// <summary> /// Creates a list from the items in this set. /// </summary> /// <returns>A list of the items in this set.</returns> public List<TElement> ToList() { int count = _count; List<TElement> list = new List<TElement>(count); for (int i = 0; i != count; ++i) { list.Add(_slots[i]._value); } return list; } UnionWith方法,实际是执行add /// <summary> /// The number of items in this set. /// </summary> public int Count => _count; /// <summary> /// Unions this set with an enumerable. /// </summary> /// <param name="other">The enumerable.</param> public void UnionWith(IEnumerable<TElement> other) { Debug.Assert(other != null); foreach (TElement item in other) { Add(item); } } 内部哈希方法 /// <summary> /// Gets the hash code of the provided value with its sign bit zeroed out, so that modulo has a positive result. /// </summary> /// <param name="value">The value to hash.</param> /// <returns>The lower 31 bits of the value's hash code.</returns> private int InternalGetHashCode(TElement value) => value == null ? 0 : _comparer.GetHashCode(value) & 0x7FFFFFFF; }
扩展distinct的关键
实现IEqualityComparer接口
public interface IEqualityComparer<in T> { // true if the specified objects are equal; otherwise, false. bool Equals(T x, T y); // Returns a hash code for the specified object. // 异常: // T:System.ArgumentNullException: // The type of obj is a reference type and obj is null. int GetHashCode(T obj); }
distinct扩展方法
使用params,支持多字段。
public static class ComparerHelper { /// <summary> /// 自定义Distinct扩展方法 /// </summary> /// <typeparam name="T">要去重的对象类</typeparam> /// <param name="source">要去重的对象</param> /// <param name="getfield">获取自定义去重字段的委托</param> /// <returns></returns> public static IEnumerable<T> DistinctEx<T>(this IEnumerable<T> source, params Func<T, object>[] getfield) { return source.Distinct(new CompareEntityFields<T>(getfield)); } } public class CompareEntityFields<T> : IEqualityComparer<T> { private readonly Func<T, object>[] _compareFields; /// <summary> /// 可以根据字段比对数据 /// </summary> /// <param name="compareFields">比对字段引用</param> public CompareEntityFields(params Func<T, object>[] compareFields) { _compareFields = compareFields; } /// <summary>Determines whether the specified objects are equal.</summary> /// <param name="x">The first object of type T to compare.</param> /// <param name="y">The second object of type T to compare.</param> /// <returns>true if the specified objects are equal; otherwise, false.</returns> bool IEqualityComparer<T>.Equals(T x, T y) { if (_compareFields == null || _compareFields.Length <= 0) { return EqualityComparer<T>.Default.Equals(x, y); } bool result = true; foreach (var func in _compareFields) { var xv = func(x); var yv = func(y); result = xv == null && yv == null || Equals(xv, yv); if (!result) break; } return result; } /// <summary>Returns a hash code for the specified object.</summary> /// <param name="obj">The <see cref="T:System.Object"></see> for which a hash code is to be returned.</param> /// <returns>A hash code for the specified object.</returns> /// <exception cref="T:System.ArgumentNullException"> /// The type of <paramref name="obj">obj</paramref> is a reference type /// and <paramref name="obj">obj</paramref> is null. /// </exception> int IEqualityComparer<T>.GetHashCode(T obj) { return obj.ToString().GetHashCode(); } }