只要有一定的内存order保证,不通过比较并交换(CAS)那些需要硬件支持的原子操作,能不能实现一个互斥的临界区?答案是:能。
计算机先驱 Edsger Wybe Dijkstra,50多年前的这篇经典paper中就提出了解决方案。并且自这以后开启了如何通过一般编程语言实现并发控制的 研究。
这里的假设我们有N个线程,代表序号分别为1-N,一个公共变量k用于辅助指示当前占有临界区的线程。临界区是critical section,并且内存模型是先执行的操作对后面可见,对同一个内存位置的访问是一个接着另一个。
初始数组b[N],c[N]完全都为true。k的初始值任意(1-N)。这里的i变量代表当前的执行逻辑单元(线程)。
对于每个线程i,b[i]和c[i]都代表了线程i的参与竞争临界区的意愿,b[i]==false代表线程i将要参与临界区的争夺,线程c[i]==false代表线程i正在争竞临界区。线程退出临界区时,会而将b[i]、c[i]都置为true。从而其他线程能够通过查看当前的b[k]和c[k]来判断线程是否仍然占据临界区,这里的判断是一个大概的判断,由于各个线程执行顺序的不确定。
存在多个线程查看b[k],从而将k设置为自身的id,从而都进入了临界区前的位置,但即使这样,由于进临界区前先要查看其他线程的c[j]值,所以这里至多只有一个线程进入临界区,其他线程都退回到Li1的逻辑。存在这种情况,这里一个线程都无法获取临界区,从而全部回到Li1,下一次继续竞争。
注意:paper中的Li2,c[i] := true这一句会导致许多重复的无意义操作(因为c[i]本来就是true),这里针对的情况仅仅是从Li4里面goto Li1的时候,所以我们将c[i]:=true放到goto Li1之前就能保持程序语义,并且减少了无用功。
我们用JAVA来实现一遍这个方案试试,并且用10个线程,每个进入临界区1千万次,每次+1来验证它,可执行代码如下:
package com.psly.testatomic;
import sun.misc.Unsafe;
public class TestVolatile {
//用于内存保证:putXXVolatile/getXXVolatile
private static final Unsafe _unsafe = UtilUnsafe.getUnsafe();
private static final int _Obase = _unsafe.arrayBaseOffset(int[].class);
private static final int _Oscale = _unsafe.arrayIndexScale(int[].class);
//N:线程数,TIMES每个线程需要进入临界区的次数。
private final static int N = 10;
private final static int TIMES = 10000000;
private final static int[] B = new int[N+1];
private final static int[] C = new int[N+1];
//每个线程进入临界区++count,最终count == N * TIMES
private static long count;
//countObj:获取count字段所属于的对象(其实就是地址),
private final static Object countObj;
//countOffset:获取count字段处于所在对象地址的偏移量
private final static long countOffset;
//k与上面的count字段类似
private static int k = 1;
private final static Object kObj;
private final static long kOffset;
static{
for(int i = 1; i <= N; ++i){
B[i] = 1;
C[i] = 1;
}
try {
countObj = _unsafe.staticFieldBase(TestVolatile.class.getDeclaredField("count"));
countOffset = _unsafe.staticFieldOffset(TestVolatile.class.getDeclaredField("count"));
kObj = _unsafe.staticFieldBase(TestVolatile.class.getDeclaredField("k"));
kOffset = _unsafe.staticFieldOffset(TestVolatile.class.getDeclaredField("k"));
} catch (Exception e) {
throw new Error(e);
}
}
final static void dijkstrasConcurMethod(int pM){
int times = TIMES;
int i = pM;
L0: for(;;){
B[i] = 0;
L1: for(;;){
if( k != i ) {
//C[i] = 1;
if(B[_unsafe.getIntVolatile(kObj, kOffset)] == 1)
_unsafe.putIntVolatile(kObj, kOffset, i);//k = i;
continue L1;
} else{
_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 0);//C[i] = 0;
//这里必定会看到更新的C[i],从而根本上保证了互斥,临界区最多一个线程。
for(int j = 1; j <= N; ++j )
if(j != i && _unsafe.getIntVolatile(C, _Obase + j * _Oscale) == 0){
//将C[i]的值更新回去,写这里效率更高
_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);
continue L1;
}
}
break L1;
}
//临界区开始
long val = _unsafe.getLongVolatile(countObj, countOffset);
_unsafe.putLongVolatile(countObj, countOffset, val + 1);
//临界区结束
_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);
B[i]=1;
if( --times != 0){
continue L0; //goto L0;
}
return;
}
}
public static void main(String[] args) throws InterruptedException
{
//开始时间
long start = System.currentTimeMillis();
//打印累加器初始值
System.out.println( count + " initial\n");
Thread handle[] = new Thread[N+1];
//创建线程
for (int i = 1; i <= N; ++i){
int j = i;
handle[i] = new Thread(new Runnable(){
@Override
public void run(){
dijkstrasConcurMethod(j);
}
});
}
//线程开始执行
for (int i = 1; i <= N; ++i)
handle[i].start();
//主线程等待子线程结束
for (int i = 1; i <= N; ++i)
handle[i].join();
//打印累加值,== N * TIMES
System.out.println(_unsafe.getLongVolatile(countObj, countOffset));
//打印程序执行时间
System.out.println((System.currentTimeMillis() - start) / 1000.0 + " seconds");
}
}
执行一遍,输出为:
0 initial
100000000
12.936 seconds
10个线程,每个进入临界区1千万次,总共累加为1亿。费时12.936秒。所以这个示例,至少看起来是正确的
我们接着,
重点关注dijkstrasConcurMethod这个方法:
final static void dijkstrasConcurMethod(int pM){
int times = TIMES;
int i = pM;
L0: for(;;){
B[i] = 0;
L1: for(;;){
if( k != i ) {
//C[i] = 1;
if(B[_unsafe.getIntVolatile(kObj, kOffset)] == 1)
_unsafe.putIntVolatile(kObj, kOffset, i);//k = i;
continue L1;
} else{
_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 0);//C[i] = 0;
//这里必定会看到更新的C[i],从而根本上保证了互斥,临界区最多一个线程。
for(int j = 1; j <= N; ++j )
if(j != i && _unsafe.getIntVolatile(C, _Obase + j * _Oscale) == 0){
//将C[i]的值更新回去,写这里效率更高
_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);
continue L1;
}
}
break L1;
}
//临界区开始
long val = _unsafe.getLongVolatile(countObj, countOffset);
_unsafe.putLongVolatile(countObj, countOffset, val + 1);
//临界区结束
_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);
B[i]=1;
if( --times != 0){
continue L0; //goto L0;
}
return;
}
}
我们将paper中的ture/false用1/0来代替。由于JAVA中没有goto语句,所以我们有了一个带表情的循环for(;;)来实现一样的功能。这里的 pM代表了线程本身的下标,TIMES为需要执行临界区的次数。
其实从严格意义上来说这里的程序并不完全等同于Dijkstra上面paper中的示例,paper中的共享内存要求是强一致的,也就是说任何的一个写入操作B[i],C[i],k立刻能够被其他线程看到。
paper发表时是1965年,那个时候对于内存模型以及硬件能力的设想可能是这样的。但是随着现代的计算机体系结构的发展,为了提高程序执行的熟读,尤其是多层缓存以及指令乱序执行的引入,使得大部分程序设计语言的模型已经不符合上面的假设了。
然而尽管如此,我们的JAVA程序加入volatile语义的操作之后,我们这个程序依然是对的。因为保证了两点
保证C上面更新的值在开始探测整个C数组之前被看到。_unsafe.putIntVolatile(C, _Obase + i * _Oscale, 0);//C[i] = 0; //这里必定会看到更新的C[i],从而根本上保证了互斥,临界区最多一个线程。 for(int j = 1; j <= N; ++j ) if(j != i && _unsafe.getIntVolatile(C, _Obase + j * _Oscale) == 0){ //将C[i]的值更新回去,写这里效率更高 _unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1); continue L1; }
保证离开临界区之后才将C[i]更新回1,从而防止这个1过早泄露出来,从而导致前面循环探测的失误。//临界区开始 long val = _unsafe.getLongVolatile(countObj, countOffset); _unsafe.putLongVolatile(countObj, countOffset, val + 1); //临界区结束 _unsafe.putIntVolatile(C, _Obase + i * _Oscale, 1);
- 首先从k开始遍历到自己的id(i),假如发现一个control(j)!=0,说明前面已经有线程在竞争了,所以我们goto返回。否则从k到前一个id的control都为0,那么我们就进入第二步。
- 第二步首先将contrl值设置为2,说明已经进一步竞争了,此时依然可能有多个线程到达此处,所以接下来,我们采用与Dijkstra类似的探测排除方法,最多可以得到一个进入下一步的线程。
- 第三步,将k的值设置为当前id,进入临界区。
- 第四部,从临界区出来之后,将k值设置为当前id右边→_→的一个id,如此一来很可能形成环形的执行顺序。最后将control[i]设置为0。
- 最后返回。 注意, 这里的k设置是没有竞争的 k:=if i = 1 then N else i -1;是为了尽量让右边一个线程执行,但是极端情况下依然可能被其他线程获取锁。所以还是得有L3: k := i; 这一行。
package com.psly.testatomic;
import java.util.Random;
import com.psly.locksupprot.LockSupport;
import sun.misc.Unsafe;
public class TestVolatileKnuthMethod {
private final static Random random = new Random();
//用于内存保证:putXXVolatile/getXXVolatile
private static final Unsafe _unsafe = UtilUnsafe.getUnsafe();
private static final int _Obase = _unsafe.arrayBaseOffset(int[].class);
private static final int _Oscale = _unsafe.arrayIndexScale(int[].class);
//N:线程数,TIMES每个线程需要进入临界区的次数。
private final static int N = 5;
private final static int TIMES = 1000;
private final static int[] B = new int[N+1];
private final static int[] C = new int[N+1];
//knuth's method
private final static int[] control = new int[N+1];
//每个线程进入临界区++count,最终count == N * TIMES
private static long count;
//countObj:获取count字段所属于的对象(其实就是地址),
private final static Object countObj;
//countOffset:获取count字段处于所在对象地址的偏移量
private final static long countOffset;
//k与上面的count字段类似
private static int k = 1;
private final static Object kObj;
private final static long kOffset;
static{
for(int i = 1; i <= N; ++i){
B[i] = 1;
C[i] = 1;
}
try {
countObj = _unsafe.staticFieldBase(TestVolatileKnuthMethod.class.getDeclaredField("count"));
countOffset = _unsafe.staticFieldOffset(TestVolatileKnuthMethod.class.getDeclaredField("count"));
kObj = _unsafe.staticFieldBase(TestVolatileKnuthMethod.class.getDeclaredField("k"));
kOffset = _unsafe.staticFieldOffset(TestVolatileKnuthMethod.class.getDeclaredField("k"));
} catch (Exception e) {
throw new Error(e);
}
}
private static Object obj = new Object();
final static void knuthConcurMethod(int pM){
int times = TIMES;
int i = pM;
L0: for(;;){
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 1);
L1: for(;;){
for(int j = _unsafe.getIntVolatile(kObj, kOffset); j >= 1; --j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
continue L1;
}
}
for(int j = N; j >= 1; --j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
continue L1;
}
}
}
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;
for(int j = N; j >= 1; --j){
if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){
continue L0;
}
}
_unsafe.putIntVolatile(kObj, kOffset, i);
long val = _unsafe.getLongVolatile(countObj, countOffset);
_unsafe.putLongVolatile(countObj, countOffset, val + 1);
_unsafe.putIntVolatile(kObj, kOffset, (i == 1)? N : i -1);
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
if( --times != 0)
continue L0;
return ;
}
}
private static Thread[] handle = new Thread[N+1];
public static void main(String[] args) throws InterruptedException
{
//开始时间
long start = System.currentTimeMillis();
//打印累加器初始值
System.out.println( count + " initial\n");
//创建线程
for (int i = 1; i <= N; ++i){
int j = i;
handle[i] = new Thread(new Runnable(){
@Override
public void run(){
knuthConcurMethod(j);
}
});
}
//线程开始执行
for (int i = 1; i <= N; ++i)
handle[i].start();
//主线程等待子线程结束
for (int i = 1; i <= N; ++i)
handle[i].join();
//打印累加值,== N * TIMES
System.out.println(_unsafe.getLongVolatile(countObj, countOffset));
//打印程序执行时间
System.out.println((System.currentTimeMillis() - start) / 1000.0 + " seconds");
}
}
输出如下:
0 initial
5000
7.464 seconds
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
int j = (i == 1)? N : i -1;
for(int m = 0; m < N - 1; ++m){
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)
break;
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){
LockSupport.unpark(handle[j]);
break;
}
j = (j == 1)? N : j -1;
}
if( --times != 0)
休眠:
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
LockSupport.park(obj);
continue L1;
}
package com.psly.testatomic;
import com.psly.locksupprot.LockSupport;
import sun.misc.Unsafe;
public class TestVolatileKnuthMethod {
//用于内存保证:putXXVolatile/getXXVolatile
private static final Unsafe _unsafe = UtilUnsafe.getUnsafe();
private static final int _Obase = _unsafe.arrayBaseOffset(int[].class);
private static final int _Oscale = _unsafe.arrayIndexScale(int[].class);
//N:线程数,TIMES每个线程需要进入临界区的次数。
private final static int N = 5;
private final static int TIMES = 1000;
private final static int[] B = new int[N+1];
private final static int[] C = new int[N+1];
//knuth's method
private final static int[] control = new int[N+1];
//每个线程进入临界区++count,最终count == N * TIMES
private static long count;
//countObj:获取count字段所属于的对象(其实就是地址),
private final static Object countObj;
//countOffset:获取count字段处于所在对象地址的偏移量
private final static long countOffset;
//k与上面的count字段类似
private static int k = 1;
private final static Object kObj;
private final static long kOffset;
static{
for(int i = 1; i <= N; ++i){
B[i] = 1;
C[i] = 1;
}
try {
countObj = _unsafe.staticFieldBase(TestVolatileKnuthMethod.class.getDeclaredField("count"));
countOffset = _unsafe.staticFieldOffset(TestVolatileKnuthMethod.class.getDeclaredField("count"));
kObj = _unsafe.staticFieldBase(TestVolatileKnuthMethod.class.getDeclaredField("k"));
kOffset = _unsafe.staticFieldOffset(TestVolatileKnuthMethod.class.getDeclaredField("k"));
} catch (Exception e) {
throw new Error(e);
}
}
private static Object obj = new Object();
final static void knuthConcurMethod(int pM){
int times = TIMES;
int i = pM;
L0: for(;;){
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 1);
L1: for(;;){
for(int j = _unsafe.getIntVolatile(kObj, kOffset); j >= 1; --j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
LockSupport.park(obj);
continue L1;
}
}
for(int j = N; j >= 1; --j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
LockSupport.park(obj);
continue L1;
}
}
}
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;
for(int j = N; j >= 1; --j){
if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){
continue L0;
}
}
//临界区开始
_unsafe.putIntVolatile(kObj, kOffset, i);
long val = _unsafe.getLongVolatile(countObj, countOffset);
_unsafe.putLongVolatile(countObj, countOffset, val + 1);
_unsafe.putIntVolatile(kObj, kOffset, (i == 1)? N : i -1); //临界区结束
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
int j = (i == 1)? N : i -1;
for(int m = 0; m < N - 1; ++m){
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)
break;
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){
LockSupport.unpark(handle[j]);
break;
}
j = (j == 1)? N : j -1;
}
if( --times != 0)
continue L0;
return ;
}
}
private static Thread[] handle = new Thread[N+1];
public static void main(String[] args) throws InterruptedException
{
//开始时间
long start = System.currentTimeMillis();
//打印累加器初始值
System.out.println( count + " initial\n");
//创建线程
for (int i = 1; i <= N; ++i){
int j = i;
handle[i] = new Thread(new Runnable(){
@Override
public void run(){
knuthConcurMethod(j);
}
});
}
//线程开始执行
for (int i = 1; i <= N; ++i)
handle[i].start();
//主线程等待子线程结束
for (int i = 1; i <= N; ++i)
handle[i].join();
//打印累加值,== N * TIMES
System.out.println(_unsafe.getLongVolatile(countObj, countOffset));
//打印程序执行时间
System.out.println((System.currentTimeMillis() - start) / 1000.0 + " milliseconds");
}
}
0 initial
5000
0.043 milliseconds
0 initial
500000
2.938 seconds
100个线程,每个进入临界区5000次,总共2.938秒,这比轮询的版本好多啦。
final static void knuthConcurMethod(int pM){
int times = TIMES;
int i = pM;
L0: for(;;){
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 1);
L1: for(;;){
for(int j = _unsafe.getIntVolatile(kObj, kOffset); j >= 1; --j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
LockSupport.park(obj);
continue L1;
}
}
for(int j = N; j >= 1; --j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
LockSupport.park(obj);
continue L1;
}
}
}
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;
for(int j = N; j >= 1; --j){
if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){
continue L0;
}
}
_unsafe.putIntVolatile(kObj, kOffset, i);
long val = _unsafe.getLongVolatile(countObj, countOffset);
_unsafe.putLongVolatile(countObj, countOffset, val + 1);
_unsafe.putIntVolatile(kObj, kOffset, (i == 1)? N : i -1);
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
int j = (i == 1)? N : i -1;
for(int m = 0; m < N - 1; ++m){
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)
break;
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){
LockSupport.unpark(handle[j]);
break;
}
j = (j == 1)? N : j -1;
}
if( --times != 0)
continue L0;
return ;
}
}
int j = (i == 1)? N : i -1;
for(int m = 0; m < N - 1; ++m){
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)
break;
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){
LockSupport.unpark(handle[j]);
break;
}
j = (j == 1)? N : j -1;
}
package com.psly.testatomic;
import java.text.SimpleDateFormat;
import java.util.Date;
import com.psly.locksupprot.LockSupport;
import sun.misc.Unsafe;
public class TestVolatileBruijnMethod {
//用于内存保证:putXXVolatile/getXXVolatile
private static final Unsafe _unsafe = UtilUnsafe.getUnsafe();
private static final int _Obase = _unsafe.arrayBaseOffset(int[].class);
private static final int _Oscale = _unsafe.arrayIndexScale(int[].class);
//N:线程数,TIMES每个线程需要进入临界区的次数。
private final static int N = 100;
private final static int TIMES = 5000;
private final static int[] B = new int[N+1];
private final static int[] C = new int[N+1];
//knuth's method
private final static int[] control = new int[N+1];
//每个线程进入临界区++count,最终count == N * TIMES
private static long count;
//countObj:获取count字段所属于的对象(其实就是地址),
private final static Object countObj;
//countOffset:获取count字段处于所在对象地址的偏移量
private final static long countOffset;
//k与上面的count字段类似
private static int k = 1;
private final static Object kObj;
private final static long kOffset;
static{
for(int i = 1; i <= N; ++i){
B[i] = 1;
C[i] = 1;
}
try {
countObj = _unsafe.staticFieldBase(TestVolatileBruijnMethod.class.getDeclaredField("count"));
countOffset = _unsafe.staticFieldOffset(TestVolatileBruijnMethod.class.getDeclaredField("count"));
kObj = _unsafe.staticFieldBase(TestVolatileBruijnMethod.class.getDeclaredField("k"));
kOffset = _unsafe.staticFieldOffset(TestVolatileBruijnMethod.class.getDeclaredField("k"));
} catch (Exception e) {
throw new Error(e);
}
}
private static Object obj = new Object();
final static void knuthConcurMethod(int pM){
int times = TIMES;
int i = pM;
L0: for(;;){
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 1);
L1: for(;;){
for(int j = _unsafe.getIntVolatile(kObj, kOffset); j >= 1; --j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
LockSupport.park(obj);
continue L1;
}
}
for(int j = N; j >= 1; --j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
LockSupport.park(obj);
continue L1;
}
}
}
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;
for(int j = N; j >= 1; --j){
if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){
continue L0;
}
}
// _unsafe.putIntVolatile(kObj, kOffset, i);
int kLocal = _unsafe.getIntVolatile(kObj, kOffset);
int kNew = kLocal;
long val = _unsafe.getLongVolatile(countObj, countOffset);
_unsafe.putLongVolatile(countObj, countOffset, val + 1);
if(_unsafe.getIntVolatile(control, _Obase + kLocal * _Oscale) == 0 || kLocal == i)
_unsafe.putIntVolatile(kObj, kOffset, kNew = ((kLocal == 1)? N: kLocal - 1));
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
int j = kNew;
for(int m = 0; m < N; ++m){
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)
break;
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){
LockSupport.unpark(handle[j]);
break;
}
j = (j == 1)? N : j -1;
}
if( --times != 0)
continue L0;
return ;
}
}
private static Thread[] handle = new Thread[N+1];
public static void main(String[] args) throws InterruptedException
{
System.out.println(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
.format(new Date()));
//开始时间
long start = System.currentTimeMillis();
//打印累加器初始值
System.out.println( count + " initial\n");
//创建线程
for (int i = 1; i <= N; ++i){
int j = i;
handle[i] = new Thread(new Runnable(){
@Override
public void run(){
knuthConcurMethod(j);
}
});
}
//线程开始执行
for (int i = 1; i <= N; ++i)
handle[i].start();
//主线程等待子线程结束
for (int i = 1; i <= N; ++i)
handle[i].join();
//打印累加值,== N * TIMES
System.out.println(_unsafe.getLongVolatile(countObj, countOffset));
//打印程序执行时间
System.out.println((System.currentTimeMillis() - start) / 1000.0 + " seconds");
}
}
package com.psly.testatomic;
import java.text.SimpleDateFormat;
import java.util.Date;
import com.psly.locksupprot.LockSupport;
import sun.misc.Unsafe;
public class TestVolatileEisenbergMethod {
//用于内存保证:putXXVolatile/getXXVolatile
private static final Unsafe _unsafe = UtilUnsafe.getUnsafe();
private static final int _Obase = _unsafe.arrayBaseOffset(int[].class);
private static final int _Oscale = _unsafe.arrayIndexScale(int[].class);
//N:线程数,TIMES每个线程需要进入临界区的次数。
private final static int N = 100;
private final static int TIMES = 5000;
private final static int[] B = new int[N+1];
private final static int[] C = new int[N+1];
//knuth's method
private final static int[] control = new int[N+1];
//每个线程进入临界区++count,最终count == N * TIMES
private static long count;
//countObj:获取count字段所属于的对象(其实就是地址),
private final static Object countObj;
//countOffset:获取count字段处于所在对象地址的偏移量
private final static long countOffset;
//k与上面的count字段类似
private static int k = 1;
private final static Object kObj;
private final static long kOffset;
static{
for(int i = 1; i <= N; ++i){
B[i] = 1;
C[i] = 1;
}
try {
countObj = _unsafe.staticFieldBase(TestVolatileEisenbergMethod.class.getDeclaredField("count"));
countOffset = _unsafe.staticFieldOffset(TestVolatileEisenbergMethod.class.getDeclaredField("count"));
kObj = _unsafe.staticFieldBase(TestVolatileEisenbergMethod.class.getDeclaredField("k"));
kOffset = _unsafe.staticFieldOffset(TestVolatileEisenbergMethod.class.getDeclaredField("k"));
} catch (Exception e) {
throw new Error(e);
}
}
private static Object obj = new Object();
final static void EisenbergConcurMethod(int pM){
int times = TIMES;
int i = pM;
L0: for(;;){
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 1);
L1: for(;;){
int kLocal;
for(int j = (kLocal = _unsafe.getIntVolatile(kObj, kOffset)); j <= N; ++j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
LockSupport.park(obj);
continue L1;
}
}
for(int j = 1; j <= kLocal - 1; ++j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
LockSupport.park(obj);
continue L1;
}
}
}
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;
for(int j = 1; j <= N; ++j){
if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){
continue L0;
}
}
int kLocal;
if(_unsafe.getIntVolatile(control, _Obase + (kLocal = _unsafe.getIntVolatile(kObj, kOffset)) *_Oscale ) != 0
&& kLocal != i)
continue L0;
_unsafe.putIntVolatile(kObj, kOffset, i);
long val = _unsafe.getLongVolatile(countObj, countOffset);
_unsafe.putLongVolatile(countObj, countOffset, val + 1);
// System.out.println(Thread.currentThread().getName());
int kNew = i;
L2: for(;;){
for(int j = i + 1; j <= N; ++j){
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
_unsafe.putIntVolatile(kObj, kOffset, j);
// LockSupport.unpark(handle[j]);
kNew = j;
break L2;
}
}
for(int j = 1; j <= i - 1; ++j){
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
_unsafe.putIntVolatile(kObj, kOffset, j);
// LockSupport.unpark(handle[j]);
kNew = j;
break L2;
}
}
break;
}
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
int j = kNew;
for(int m = 0; m < N; ++m){
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2)
break;
if(_unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 1){
LockSupport.unpark(handle[j]);
break;
}
j = (j == N)? 1 : j + 1;
}
if( --times != 0)
continue L0;
return ;
}
}
private static Thread[] handle = new Thread[N+1];
public static void main(String[] args) throws InterruptedException
{
System.out.println(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
.format(new Date()));
//开始时间
long start = System.currentTimeMillis();
//打印累加器初始值
System.out.println( count + " initial\n");
//创建线程
for (int i = 1; i <= N; ++i){
int j = i;
handle[i] = new Thread(new Runnable(){
@Override
public void run(){
EisenbergConcurMethod(j);
}
});
}
//线程开始执行
for (int i = 1; i <= N; ++i)
handle[i].start();
//主线程等待子线程结束
for (int i = 1; i <= N; ++i)
handle[i].join();
//打印累加值,== N * TIMES
System.out.println(_unsafe.getLongVolatile(countObj, countOffset));
//打印程序执行时间
System.out.println((System.currentTimeMillis() - start) / 1000.0 + " seconds");
}
}
L1: for(;;){ //以下两个循环的代码判断当前线程是否适合竞争临界区
for(int j = _unsafe.getIntVolatile(kObj, kOffset); j >= 1; --j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
// LockSupport.park(obj);
continue L1;
}
}
for(int j = N; j >= 1; --j){
if(j == i)
break L1;
if( _unsafe.getIntVolatile(control, _Obase + j * _Oscale) != 0){
// LockSupport.park(obj);
continue L1;
}
} //以上两个循环的代码判断当前线程是否适合竞争临界区
}
//以下代码保证最多一个线程进去临界区
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 2);//control[i] = 2;
for(int j = N; j >= 1; --j){
if(j != i && _unsafe.getIntVolatile(control, _Obase + j * _Oscale) == 2/*control[j] ==2*/){
continue L0;
}
}
//以上代码保证最多一个线程进入临界区
_unsafe.putIntVolatile(kObj, kOffset, i);
//临界区start
long val = _unsafe.getLongVolatile(countObj, countOffset);
_unsafe.putLongVolatile(countObj, countOffset, val + 1);
_unsafe.putIntVolatile(kObj, kOffset, (i == 1)? N : i - 1);
//临界区end
_unsafe.putIntVolatile(control, _Obase + i * _Oscale, 0);//control[i] = 0;
- 先通过两个循环来判断当前线程是否适合竞争锁,适合跳出L1,否则继续循环
- 接着第二个循环通过探测其他线程的control值,假如发现都不为0则结束循环,获得锁,否则跳回L0,继续前面的循环判断。注意这里的语义确保最多只有一个线程进入临界区,存在全部线程都无法获得锁,跳回L0的极端情况。
- 临界区结尾处将0给control[i],替换掉了它的2值,从而之后,让其他线程有机会获得锁(根据竞争判断的语义,假如一个线程看到其他的某个为2是无法获取锁的)。
附上:
package com.psly.testatomic;
import java.lang.reflect.Field;
import sun.misc.Unsafe;
public class UtilUnsafe {
private UtilUnsafe() { } // dummy private constructor
/** Fetch the Unsafe. Use With Caution. */
public static Unsafe getUnsafe() {
// Not on bootclasspath
if( UtilUnsafe.class.getClassLoader() == null )
return Unsafe.getUnsafe();
try {
final Field fld = Unsafe.class.getDeclaredField("theUnsafe");
fld.setAccessible(true);
return (Unsafe) fld.get(UtilUnsafe.class);
} catch (Exception e) {
throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e);
}
}
}
package com.psly.locksupprot;
import com.psly.testatomic.UtilUnsafe;
public class LockSupport {
private LockSupport() {} // Cannot be instantiated.
private static void setBlocker(Thread t, Object arg) {
// Even though volatile, hotspot doesn't need a write barrier here.
UNSAFE.putObject(t, parkBlockerOffset, arg);
}
/**
* Makes available the permit for the given thread, if it
* was not already available. If the thread was blocked on
* {@code park} then it will unblock. Otherwise, its next call
* to {@code park} is guaranteed not to block. This operation
* is not guaranteed to have any effect at all if the given
* thread has not been started.
*
* @param thread the thread to unpark, or {@code null}, in which case
* this operation has no effect
*/
public static void unpark(Thread thread) {
if (thread != null)
UNSAFE.unpark(thread);
}
/**
* Disables the current thread for thread scheduling purposes unless the
* permit is available.
*
* <p>If the permit is available then it is consumed and the call returns
* immediately; otherwise
* the current thread becomes disabled for thread scheduling
* purposes and lies dormant until one of three things happens:
*
* <ul>
* <li>Some other thread invokes {@link #unpark unpark} with the
* current thread as the target; or
*
* <li>Some other thread {@linkplain Thread#interrupt interrupts}
* the current thread; or
*
* <li>The call spuriously (that is, for no reason) returns.
* </ul>
*
* <p>This method does <em>not</em> report which of these caused the
* method to return. Callers should re-check the conditions which caused
* the thread to park in the first place. Callers may also determine,
* for example, the interrupt status of the thread upon return.
*
* @param blocker the synchronization object responsible for this
* thread parking
* @since 1.6
*/
public static void park(Object blocker) {
Thread t = Thread.currentThread();
setBlocker(t, blocker);
UNSAFE.park(false, 0L);
setBlocker(t, null);
}
/**
* Disables the current thread for thread scheduling purposes, for up to
* the specified waiting time, unless the permit is available.
*
* <p>If the permit is available then it is consumed and the call
* returns immediately; otherwise the current thread becomes disabled
* for thread scheduling purposes and lies dormant until one of four
* things happens:
*
* <ul>
* <li>Some other thread invokes {@link #unpark unpark} with the
* current thread as the target; or
*
* <li>Some other thread {@linkplain Thread#interrupt interrupts}
* the current thread; or
*
* <li>The specified waiting time elapses; or
*
* <li>The call spuriously (that is, for no reason) returns.
* </ul>
*
* <p>This method does <em>not</em> report which of these caused the
* method to return. Callers should re-check the conditions which caused
* the thread to park in the first place. Callers may also determine,
* for example, the interrupt status of the thread, or the elapsed time
* upon return.
*
* @param blocker the synchronization object responsible for this
* thread parking
* @param nanos the maximum number of nanoseconds to wait
* @since 1.6
*/
public static void parkNanos(Object blocker, long nanos) {
if (nanos > 0) {
Thread t = Thread.currentThread();
setBlocker(t, blocker);
UNSAFE.park(false, nanos);
setBlocker(t, null);
}
}
/**
* Disables the current thread for thread scheduling purposes, until
* the specified deadline, unless the permit is available.
*
* <p>If the permit is available then it is consumed and the call
* returns immediately; otherwise the current thread becomes disabled
* for thread scheduling purposes and lies dormant until one of four
* things happens:
*
* <ul>
* <li>Some other thread invokes {@link #unpark unpark} with the
* current thread as the target; or
*
* <li>Some other thread {@linkplain Thread#interrupt interrupts} the
* current thread; or
*
* <li>The specified deadline passes; or
*
* <li>The call spuriously (that is, for no reason) returns.
* </ul>
*
* <p>This method does <em>not</em> report which of these caused the
* method to return. Callers should re-check the conditions which caused
* the thread to park in the first place. Callers may also determine,
* for example, the interrupt status of the thread, or the current time
* upon return.
*
* @param blocker the synchronization object responsible for this
* thread parking
* @param deadline the absolute time, in milliseconds from the Epoch,
* to wait until
* @since 1.6
*/
public static void parkUntil(Object blocker, long deadline) {
Thread t = Thread.currentThread();
setBlocker(t, blocker);
UNSAFE.park(true, deadline);
setBlocker(t, null);
}
/**
* Returns the blocker object supplied to the most recent
* invocation of a park method that has not yet unblocked, or null
* if not blocked. The value returned is just a momentary
* snapshot -- the thread may have since unblocked or blocked on a
* different blocker object.
*
* @param t the thread
* @return the blocker
* @throws NullPointerException if argument is null
* @since 1.6
*/
public static Object getBlocker(Thread t) {
if (t == null)
throw new NullPointerException();
return UNSAFE.getObjectVolatile(t, parkBlockerOffset);
}
/**
* Disables the current thread for thread scheduling purposes unless the
* permit is available.
*
* <p>If the permit is available then it is consumed and the call
* returns immediately; otherwise the current thread becomes disabled
* for thread scheduling purposes and lies dormant until one of three
* things happens:
*
* <ul>
*
* <li>Some other thread invokes {@link #unpark unpark} with the
* current thread as the target; or
*
* <li>Some other thread {@linkplain Thread#interrupt interrupts}
* the current thread; or
*
* <li>The call spuriously (that is, for no reason) returns.
* </ul>
*
* <p>This method does <em>not</em> report which of these caused the
* method to return. Callers should re-check the conditions which caused
* the thread to park in the first place. Callers may also determine,
* for example, the interrupt status of the thread upon return.
*/
public static void park() {
UNSAFE.park(false, 0L);
}
/**
* Disables the current thread for thread scheduling purposes, for up to
* the specified waiting time, unless the permit is available.
*
* <p>If the permit is available then it is consumed and the call
* returns immediately; otherwise the current thread becomes disabled
* for thread scheduling purposes and lies dormant until one of four
* things happens:
*
* <ul>
* <li>Some other thread invokes {@link #unpark unpark} with the
* current thread as the target; or
*
* <li>Some other thread {@linkplain Thread#interrupt interrupts}
* the current thread; or
*
* <li>The specified waiting time elapses; or
*
* <li>The call spuriously (that is, for no reason) returns.
* </ul>
*
* <p>This method does <em>not</em> report which of these caused the
* method to return. Callers should re-check the conditions which caused
* the thread to park in the first place. Callers may also determine,
* for example, the interrupt status of the thread, or the elapsed time
* upon return.
*
* @param nanos the maximum number of nanoseconds to wait
*/
public static void parkNanos(long nanos) {
if (nanos > 0)
UNSAFE.park(false, nanos);
}
/**
* Disables the current thread for thread scheduling purposes, until
* the specified deadline, unless the permit is available.
*
* <p>If the permit is available then it is consumed and the call
* returns immediately; otherwise the current thread becomes disabled
* for thread scheduling purposes and lies dormant until one of four
* things happens:
*
* <ul>
* <li>Some other thread invokes {@link #unpark unpark} with the
* current thread as the target; or
*
* <li>Some other thread {@linkplain Thread#interrupt interrupts}
* the current thread; or
*
* <li>The specified deadline passes; or
*
* <li>The call spuriously (that is, for no reason) returns.
* </ul>
*
* <p>This method does <em>not</em> report which of these caused the
* method to return. Callers should re-check the conditions which caused
* the thread to park in the first place. Callers may also determine,
* for example, the interrupt status of the thread, or the current time
* upon return.
*
* @param deadline the absolute time, in milliseconds from the Epoch,
* to wait until
*/
public static void parkUntil(long deadline) {
UNSAFE.park(true, deadline);
}
/**
* Returns the pseudo-randomly initialized or updated secondary seed.
* Copied from ThreadLocalRandom due to package access restrictions.
*/
static final int nextSecondarySeed() {
int r;
Thread t = Thread.currentThread();
if ((r = UNSAFE.getInt(t, SECONDARY)) != 0) {
r ^= r << 13; // xorshift
r ^= r >>> 17;
r ^= r << 5;
}
else if ((r = java.util.concurrent.ThreadLocalRandom.current().nextInt()) == 0)
r = 1; // avoid zero
UNSAFE.putInt(t, SECONDARY, r);
return r;
}
// Hotspot implementation via intrinsics API
private static final sun.misc.Unsafe UNSAFE;
private static final long parkBlockerOffset;
private static final long SEED;
private static final long PROBE;
private static final long SECONDARY;
static {
try {
UNSAFE = UtilUnsafe.getUnsafe();
Class<?> tk = Thread.class;
parkBlockerOffset = UNSAFE.objectFieldOffset
(tk.getDeclaredField("parkBlocker"));
SEED = UNSAFE.objectFieldOffset
(tk.getDeclaredField("threadLocalRandomSeed"));
PROBE = UNSAFE.objectFieldOffset
(tk.getDeclaredField("threadLocalRandomProbe"));
SECONDARY = UNSAFE.objectFieldOffset
(tk.getDeclaredField("threadLocalRandomSecondarySeed"));
} catch (Exception ex) { throw new Error(ex); }
}
}
最近发现,针对这个独占互斥区的并发控制,2013年图灵奖得主Leslie Lamport在1974年也提出过另一种算法,paper截图如下:
证明过程:
这个算法的特点是,没有中心控制。
我们用JAVA代码实现下:
package com.psly.testatomic;
import sun.misc.Unsafe;
public class TestVolatile {
//用于内存保证:putXXVolatile/getXXVolatile
private static final Unsafe _unsafe = UtilUnsafe.getUnsafe();
private static final int _Obase = _unsafe.arrayBaseOffset(long[].class);
private static final int _Oscale = _unsafe.arrayIndexScale(long[].class);
//N:线程数,TIMES每个线程需要进入临界区的次数。
private final static int N = 2000;
private final static int TIMES = 1000;
private final static long[] choosing = new long[N+1];
private final static long[] number = new long[N+1];
//每个线程进入临界区++count,最终count == N * TIMES
private static long count;
//countObj:获取count字段所属于的对象(其实就是地址),
private final static Object mainObj;
//countOffset:获取count字段处于所在对象地址的偏移量
private final static long countOffset;
private static Object obj = new Object();
// private static Queue<Thread> queues = new ConcurrentLinkedQueue();
static{
for(int i = 1; i <= N; ++i){
choosing[i] = 0;
number[i] = 0;
}
try {
mainObj = _unsafe.staticFieldBase(TestVolatile.class.getDeclaredField("count"));
countOffset = _unsafe.staticFieldOffset(TestVolatile.class.getDeclaredField("count"));
// waitersOffset = _unsafe.staticFieldOffset(TestVolatile.class.getDeclaredField("waiters"));
} catch (Exception e) {
throw new Error(e);
}
}
final static void dijkstrasConcurMethod(int pM){
int times = TIMES;
int i = pM;
L0: for(;;){
_unsafe.putLongVolatile(choosing, _Obase + i * _Oscale, 1);
//获取最大的number并+1。
long maxNum = _unsafe.getLongVolatile(number, _Obase + _Oscale), midNum;
for(int j = 2; j <= N; ++j)
if(maxNum < (midNum = _unsafe.getLongVolatile(number, _Obase + j * _Oscale)))
maxNum = midNum;
_unsafe.putLongVolatile(number, _Obase + i * _Oscale, 1 + maxNum);
_unsafe.putLongVolatile(choosing, _Obase + i * _Oscale, 0);
/* for(int j = 1; j <i; ++j)
LockSupport.unpark(handle[j]);
for(int j = i+1; j <= N; ++j)
LockSupport.unpark(handle[j]);*/
long jNumber, iNumber;
for(int j = 1; j <= N; ++j){
L1: for(;;){
for(int k = 0 ; k < 100; ++k)
if(!(_unsafe.getLongVolatile(choosing, _Obase + j * _Oscale) != 0))
break L1;
// LockSupport.park(obj);
}
L2: for(;;){
for(int k = 0; k < 1000; ++k)
if(!(_unsafe.getLongVolatile(number, _Obase + j * _Oscale) != 0
&& ((jNumber=_unsafe.getLongVolatile(number, _Obase + j * _Oscale))
< (iNumber=_unsafe.getLongVolatile(number, _Obase + i * _Oscale))
|| (jNumber == iNumber && j < i))))
break L2;
LockSupport.park(obj);
}
}
//critical section
//临界区开始
long val = _unsafe.getLongVolatile(mainObj, countOffset);
_unsafe.putLongVolatile(mainObj, countOffset, val + 1);
//临界区结束
//设置标识
_unsafe.putLongVolatile(number, _Obase + i * _Oscale, 0);
//唤醒需要的线程
Thread target = handle[i];
long numMax = Long.MAX_VALUE, arg;
for(int j = 1; j <i; ++j)
if((arg = _unsafe.getLongVolatile(number, _Obase + j * _Oscale)) != 0 && arg < numMax)
{ target = handle[j]; numMax = arg;}
for(int j = i+1; j <= N; ++j)
if((arg = _unsafe.getLongVolatile(number, _Obase + j * _Oscale)) != 0 && arg < numMax)
{ target = handle[j]; numMax = arg;}
LockSupport.unpark(target);
/*for(int j = 1; j <= N; ++j)
LockSupport.unpark(handle[j]);*/
//计算次数
if( --times != 0){
continue L0; //goto L0;
}
return;
}
}
private static Thread[] handle = new Thread[N+1];
public static void main(String[] args) throws InterruptedException
{
//开始时间
long start = System.currentTimeMillis();
//打印累加器初始值
System.out.println( count + " initial\n");
// Thread handle[] = new Thread[N+1];
//创建线程
for (int i = 1; i <= N; ++i){
int j = i;
handle[i] = new Thread(new Runnable(){
@Override
public void run(){
dijkstrasConcurMethod(j);
}
});
}
//线程开始执行
for (int i = 1; i <= N; ++i)
handle[i].start();
//主线程等待子线程结束
for (int i = 1; i <= N; ++i)
handle[i].join();
//打印累加值,== N * TIMES
System.out.println(_unsafe.getLongVolatile(mainObj, countOffset));
//打印程序执行时间
System.out.println((System.currentTimeMillis() - start) / 1000.0 + " seconds");
}
}