// BitCastLoad函数是为了对bitcast + Load语句做等价替换
// %1 = bitcast (%struct.point* getelementptr inbounds (%struct.str, %struct.str* @global, i64 0, i32 3) to i32*)
/ /等价替换为
// %0 = getelementptr inbounds (%struct.str, %struct.str* @global, i64 0, i32 3)
// %1 = get (%0, i64 0,i32 0); %2 = load %1; %3 = zext %2 to i32; %4 = shl 0,%3;
// %5 = get (%0, i64 0,i32 1); %6 = load%5; %7 = zext %6 to i32; %8 = shl 16,%7
// %9 = xor (%4, %8); %10 = load %9;
void RemoveStruct::BitCastLoad(Instruction* &instr) {
Value* val = instr->getOperand(0);
Type* val_ty = val->getType()->getPointerElementType();
// Get the list of value's Type in memory
vector<int> val_list;
map<vector<int>, Type*> val_map;
// **align是对齐size,通过遍历struct中变量最大size得到alignment**
int64_t align = 0;
// **BitCastStruct函数,通过该函数可以得到每个变量的内存排布顺序**。
// **val_list 是变量在struct的相对位置地址,举个栗子:**
// **struct_1 { struct_2 {struct_3{x, y}, p}, q}**
//**(0 0 1 就是x),(0 1是p),(1是q)**
// **val_ty是struct的type**
// **val_map是map,key是位置地址,value是位置地址对应变量的type**
BitCastStruct(val_list, val_ty, val_map, align);
map<vector<int>, Type*>::iterator iter;
//两个for循环可以打印map的所有信息
for (iter = val_map.begin(); iter != val_map.end(); iter++) {
vector<int> temp = iter->first ;
for (std::vector<int>::iterator it = temp.begin(); it != temp.end(); ++it)
{
int a = *it;
outs() << a << " ";
}
outs() << "\n";
outs() << "Type list: " << *iter->second << "\n";
//outs() <<"Type IntegerBitWidth " <<iter->second->getIntegerBitWidth()<<"\n";
}
outs() << "align: " << align <<"\n";
// change BitCast Instruction by the Type list and bitcast_bitwidth
// **bitcast_bitwidth是需要拼接value的位宽**
int bitcast_bitwidth = instr->getType()->getPointerElementType()->getIntegerBitWidth();
Type* type_zextI = Type::getIntNTy(module_->getContext(), bitcast_bitwidth);
// 当前value的位宽
int val_width = 0;
// step_width是步长大小
int step_width = 1;
// 已经用掉的位宽长度
int use_width = 0;
Instruction* Inst_temp;
// 首先拿到val_map.begin()->first 第一个list 的size大小,保证在处理过程中size大小不会改变,否者align不对齐现象,因为每个struct直接需要对齐,此时没处理对齐。下面有assert
vector<int> begin = val_map.begin()->first;
int64_t vec_temp_size = begin.size();
outs()<< "vec_temp_size: " << vec_temp_size << "\n";
for (iter = val_map.begin(); iter != val_map.end(); iter++) {
val_width = iter->second->getIntegerBitWidth();
while (use_width % val_width != 0) {
use_width += step_width;
}
use_width += val_width;
assert(use_width <= bitcast_bitwidth);
vector<Value *> val_list;
val_list.push_back(ConstantInt::get(Type::getInt64Ty(module_->getContext()), 0));
vector<int> vec_temp = iter->first;
//出现跨层,报错,进行添加align处理
assert(vec_temp_size == vec_temp.size());
for (vector<int>::iterator it = vec_temp.begin(); it != vec_temp.end(); ++it) {
val_list.push_back( ConstantInt::get(Type::getInt32Ty(module_->getContext()), *it));
}
ArrayRef<Value*> list(val_list);
GetElementPtrInst* ptrInst_new =GetElementPtrInst::CreateInBounds(
dyn_cast<PointerType> (instr->getOperand(0)->getType())->getElementType(),instr->getOperand(0), list, "", instr);
llvm::LoadInst* load_GEP = new LoadInst(ptrInst_new, "", instr);
Instruction* zextI = CastInst::Create(Instruction::ZExt, load_GEP, type_zextI, "", instr);
Instruction* shlI = BinaryOperator::Create(llvm::Instruction::Shl, zextI,
ConstantInt::get(type_zextI, use_width-val_width), "", instr);
if (use_width == val_width) { Inst_temp = shlI; }
else if (use_width < bitcast_bitwidth) {
Instruction* xorI = BinaryOperator::Create(llvm::Instruction::Xor, shlI, Inst_temp,"",instr);
Inst_temp = xorI;
}
if (use_width == bitcast_bitwidth) {
Instruction* xorI = BinaryOperator::Create(llvm::Instruction::Xor, shlI, Inst_temp,"");
ReplaceInstWithInst(instr->getNextNode(),xorI);
// 当满足bitWidth大小则跳出循环
break;
}
}
}
BitCastStruct()和BitCastArray()两个函数时为了遍历struct生成map
void RemoveStruct::BitCastStruct(vector<int> val_list, Type* val_ty, map<vector<int>, Type*> &val_map, int64_t &align) {
int struct_val_elements = val_ty->getStructNumElements();
for (int i = 0; i < struct_val_elements; i++) {
if (val_ty->getStructElementType(i)->isIntegerTy()) {
// int
val_list.push_back(i);
val_map.insert(map<vector<int>, Type*>::value_type(val_list, val_ty->getStructElementType(i)));
if (val_ty->getStructElementType(i)->getIntegerBitWidth() > align) { align = val_ty->getStructElementType(i)->getIntegerBitWidth(); }
//**val_list.pop_back();是为了防止位置错乱,没有的话struct{x,y} x:0,y:01,导致错误**
val_list.pop_back();
} else if (val_ty->getStructElementType(i)->isArrayTy()) {
// array
val_list.push_back(i);
BitCastArray(val_list, val_ty->getStructElementType(i), val_map, align);
val_list.pop_back();
} else if (val_ty->getStructElementType(i)->isStructTy()) {
// struct
val_list.push_back(i);
BitCastStruct(val_list, val_ty->getStructElementType(i), val_map, align);
val_list.pop_back();
} else {
outs() << "This Type is not considered! " << val_ty->getStructElementType(i) << "\n";
assert(0);
}
}
}
void RemoveStruct::BitCastArray(vector<int> val_list, Type* val_ty, map<vector<int>, Type*> &val_map, int64_t &align) {
outs() << "Array Num: " << val_ty->getArrayNumElements() << "\n";
if (val_ty->getArrayElementType()->isIntegerTy()) {
// int
for (int i = 0; i < val_ty->getArrayNumElements(); ++i) {
val_list.push_back(i);
val_map.insert(map<vector<int>, Type*>::value_type(val_list, val_ty->getArrayElementType()));
if (val_ty->getArrayElementType()->getIntegerBitWidth() > align) { align = val_ty->getArrayElementType()->getIntegerBitWidth(); }
val_list.pop_back();
}
} else if (val_ty->getArrayElementType()->isArrayTy()) {
// array
assert(0);//TODO
for (int i = 0; i < val_ty->getArrayNumElements(); ++i) {
val_list.push_back(i);
BitCastArray(val_list, val_ty->getArrayElementType(), val_map, align);
val_list.pop_back();
}
} else if (val_ty->getArrayElementType()->isStructTy()) {
// struct
assert(0);//TODO
for (int i = 0; i < val_ty->getArrayNumElements(); ++i) {
val_list.push_back(i);
BitCastStruct(val_list, val_ty->getArrayElementType(), val_map, align);
val_list.pop_back();
}
} else {
outs() << "This Type is not considered! " << val_ty->getArrayElementType() << "\n";
assert(0);
}
}