package cn.test;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.Set;
/**
* @Description : 手写HashSet的实现
* @Author : houshuiqiang@163.com, 2016年6月9日 上午8:24:23
* @Modified :houshuiqiang@163.com, 2016年6月10日
*/
public class MyHashSet<E> implements Set<E> {
private int modCount;
private Node<E>[] table;
private int size;
private int capacity;
private static final int DEFAULT_CAPACITY = 16;
private static final int MAX_CAPACITY = 2^16;
private float loadFactor;
private static final float DEFAULT_LOAD_FACTOR = 0.75F;
private static final int INCREMENT_FACTOR = 2;
private int capacitySize; // 当前set最多允许存放多少个元素,超过则需要扩容
public MyHashSet(){
this(DEFAULT_CAPACITY,DEFAULT_LOAD_FACTOR);
}
public MyHashSet(int capacity){
this(capacity,DEFAULT_LOAD_FACTOR);
}
@SuppressWarnings("unchecked")
public MyHashSet(int capacity, float loadFactor){
vertifyConsturctorCapacity(capacity);
vertiryConsturctorLoadFactor(loadFactor);
table = new Node[this.capacity];
capacitySize =(int) (this.capacity * this.loadFactor);
}
@Override
public int size() {
return size;
}
@Override
public boolean isEmpty() {
return size == 0;
}
@Override
public boolean contains(Object o) {
boolean contains = false;
Iterator<E> iterator = this.iterator();
for (E e = iterator.next(); iterator.hasNext() && !contains;) {
if (e == null) {
contains = o == null;
}else {
contains = e.equals(o);
}
}
return contains;
}
@Override
public Iterator<E> iterator() {
return new MyIterator();
}
@Override
public Object[] toArray() {
Object[] array = new Object[size];
int index = 0;
Iterator<E> iterator = this.iterator();
for (E e = iterator.next(); iterator.hasNext(); ){
array[index++] = e;
}
return array;
}
@SuppressWarnings("unchecked")
@Override
public <T> T[] toArray(T[] a) {
int index = 0;
Iterator<E> iterator = this.iterator();
for (E e = iterator.next(); iterator.hasNext(); ){
if (index < a.length) {
a[index++] = (T)e;
}else{
break;
}
}
return a;
}
@Override
public boolean add(E e) {
ifIncrement();
if (null == e) {
Node<E> node = table[0];
boolean alreadyHasNull = false;
while(node != null && !alreadyHasNull) {
alreadyHasNull = node.e == null;
node = node.nextNode;
}
if (!alreadyHasNull) {
table[0] = new Node<E>(null, 0, table[0]);
size++;
modCount++;
}
return !alreadyHasNull;
}
int hash = hash(e.hashCode());
int index = indexOfTable(hash);
Node<E> node = table[index];
boolean alreadyHas = false;
while (node != null && !alreadyHas) {
alreadyHas = node.e != null && hash == node.hash && e.equals(node.e);
node = node.nextNode;
}
if (!alreadyHas) {
table[index] = new Node<E>(e, hash, table[index]);
size++;
modCount++;
}
return !alreadyHas;
}
@Override
public boolean remove(Object o) {
for (int i = 0; i < table.length; i++) {
Node<E> node = table[i];
if (node == null) {
continue;
}else{
// 链表的第一个元素
if (o == null && node.e == null || o != null && node.e != null
&& o.hashCode() == node.e.hashCode() && o.equals(node.e)) {
table[i] = node.nextNode;
size--;
modCount++;
return true;
}else{ // 链表的非第一个元素
Node<E> preNode = node;
Node<E> tempNode = node.nextNode;
while (tempNode != null) {
if (o == null && tempNode.e == null || o != null && tempNode.e != null
&& o.hashCode() == tempNode.e.hashCode() && o.equals(tempNode.e)) {
preNode.nextNode = tempNode.nextNode;
size--;
modCount++;
return true;
}
preNode = tempNode;
tempNode = tempNode.nextNode;
}
}
}
}
return false;
}
@SuppressWarnings("unchecked")
@Override
public boolean containsAll(Collection<?> c) {
for (Object object : c.toArray()) {
if (! contains((E)object)) {
return false;
}
}
return true;
}
@SuppressWarnings("unchecked")
@Override
public boolean addAll(Collection<? extends E> c) {
boolean changed = false;
for (Object obj : c.toArray()) {
changed |= add((E)obj);
}
return changed;
}
@SuppressWarnings("unchecked")
@Override
public boolean retainAll(Collection<?> c) {
boolean changed = false;
for (Object obj : c.toArray()) {
E e = (E)obj;
changed |= contains(e) ? remove(e) : false;
}
return changed;
}
@SuppressWarnings("unchecked")
@Override
public boolean removeAll(Collection<?> c) {
boolean changed = false;
for (Object object : c.toArray()) {
changed |= remove((E)object);
}
return changed;
}
@Override
public void clear() {
Arrays.fill(table, null);
size = 0;
}
private void vertifyConsturctorCapacity(int capacity){
if (capacity < 0) {
throw new IllegalArgumentException("初始容量不合法");
}
if (capacity > MAX_CAPACITY) {
this.capacity = MAX_CAPACITY;
return;
}
// TODO 2的n次方 -- 直接copy的hashMap的代码,这个有点困难。
this.capacity = tableSizeFor(capacity);
}
// 这个方法有点难实现,直接copy吧。
/** {@link java.lang.HashMap#tableSizeFor(int)} */
private int tableSizeFor(int capacity) {
int n = capacity - 1;
n |= n >>> 1;
n |= n >>> 2;
n |= n >>> 4;
n |= n >>> 8;
n |= n >>> 16;
return (n < 0) ? 1 : (n >= MAX_CAPACITY) ? MAX_CAPACITY : n + 1;
}
private void vertiryConsturctorLoadFactor(float factor){
if (factor <= 0 || Float.isNaN(factor)) {
throw new IllegalArgumentException("增长因子不合法");
}
this.loadFactor = factor;
}
@SuppressWarnings({ "rawtypes", "unchecked" })
private void ifIncrement(){
if (capacitySize <= size) {
// 扩容
Node<E>[] newTable = new Node[table.length * INCREMENT_FACTOR];
// 遍历数组
for(Node<E> node : table){
if (node != null) {
// 遍历链表
for (Node<E> tempNode = node; tempNode != null; tempNode = tempNode.nextNode) {
int indexOfTable = indexOfTable(tempNode.hash);
Node<E> firstNode = newTable[indexOfTable];
if (firstNode == null ) {
newTable[indexOfTable] = tempNode;
}else{
newTable[indexOfTable] = new Node(tempNode.e, tempNode.hash, firstNode);
}
}
}
}
table = newTable;
capacitySize = (int)(table.length * loadFactor);
}
}
// 获得更好的hash散列
private int hash(int h){
// TODO
return h;
}
// hash值对应的数组的位置
private int indexOfTable(int hash){
return hash & (table.length -1);
}
@SuppressWarnings("hiding")
private class Node<E> {
E e;
Node<E> nextNode;
int hash; // 为了更好的hash散列
public Node(E e, int hash, Node<E> nextNode) {
super();
this.e = e;
this.hash = hash;
this.nextNode = nextNode;
}
}
public class MyIterator implements Iterator<E> {
private int curreantIndexOfTable = 0;
private Node<E> nextReturnNode = table[0];
private Node<E> lastReturnedNode;
private int itrModCount = modCount;
@Override
public boolean hasNext() {
return linkHasNext(nextReturnNode);
}
@Override
public E next() {
checkModCount();
lastReturnedNode = nextReturnNode;
nextReturnNode = nextReturnNode.nextNode;
compareNext();
return lastReturnedNode.e;
}
@Override
public void remove(){
MyHashSet.this.remove(lastReturnedNode.e);
itrModCount = modCount;
}
private boolean linkHasNext(Node<E> nextReturnNode){
return nextReturnNode != null;
}
private void checkModCount(){
if (this.itrModCount != MyHashSet.this.modCount) throw new RuntimeException("遍历期间不能删元素,如果需要,请使用iterator.remove()");
}
private void compareNext(){
if (! linkHasNext(nextReturnNode)) {
boolean hashNext = false;
while (!hashNext && ++curreantIndexOfTable < table.length) {
nextReturnNode = table[curreantIndexOfTable];
hashNext = linkHasNext(nextReturnNode);
}
}
}
}
}
写完ArrayList熟悉数组结构,写完LinkedList熟悉链表结构之后,再写HashSet就比较容易。就是数组+链表结构而已。遍历的时候,如果当前链表没有下一个节点了,就得去数组的下一个index去找了,直到找到有值的节点。如果找不到节点了,数组里也没有元素可遍历了,则hashNext()返回false。