package com.multitrigger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import com.triggerInfo.InputInfo;
import com.triggerInfo.OutputInfo;
import com.triggerInfo.Protein;
import com.triggerInfo.Segment;
import com.triggerfeature.*;
public class MultiTriggerParsing
{
State state = new State();
CorpusProcessing corpus = new CorpusProcessing();
public void beamSearch(Model model, InputInfo sentInputInfo, OutputInfo sentOutputInfo)
{
List<List<State>> beamList = new ArrayList<List<State>>();
List<String> tokenList = sentInputInfo.sentToken;
//"Anaphora","Entity","Ubiquitination","Protein_modification"
List<String> tag = Arrays.asList("Gene_expression","Transcription","Protein_catabolism","Phosphorylation","Localization","Binding","Regulation","Positive_regulation","Negative_regulation","None");
for(int endTokenId = 0; endTokenId < tokenList.size(); endTokenId++)
{
List<State> endTokenStateList = new ArrayList<State>(); //其实就是保存在Beam Search中的候选值,将K个state添加到Agenda(可能的生成结果)
outer:for(String triggerType : tag)
{
int maxSegLength = 5;
if(endTokenId + 1 > maxSegLength) //此处把 >= 换成了 > ,因为等于时, List<State> candiList为空,仍然要初始化一个State实例
{
for(int segLength = 1; segLength <= maxSegLength; segLength++)
{
List<State> candiList = beamList.get(endTokenId - segLength); //此处不为空
int startTokenId = endTokenId - segLength + 1;
//判断是否是蛋白质,如果为蛋白质则不需要判断是不是trigger
//*******************************************************
for(int proteinId = 0; proteinId < sentInputInfo.proteinList.size(); proteinId++)
{
Protein protein = sentInputInfo.proteinList.get(proteinId);
int proStartIndex = protein.proteinStart;
int proEndIndex = protein.proteinEnd;
if(startTokenId == proStartIndex && endTokenId == proEndIndex)
{
//说明是蛋白质,此时执行一些程序之后跳出
for(int stateId = 0; stateId < candiList.size(); stateId++)
{
State oldState = candiList.get(stateId);
endTokenStateList.add(state.append(startTokenId, endTokenId, "None", oldState, model, sentInputInfo));
}
break outer;
}
}
//*********************************************************
//以上代码用于判断是否是蛋白质,如果是蛋白质怎么样,如果不是怎么样?
for(int stateId = 0; stateId < candiList.size(); stateId++)
{
State oldState = candiList.get(stateId);
endTokenStateList.add(state.append(startTokenId, endTokenId, triggerType, oldState, model, sentInputInfo));
}
}
}else
{
maxSegLength = endTokenId + 1;
for(int segLength = 1; segLength <= maxSegLength; segLength++)
{
int startTokenId = endTokenId - segLength + 1;
//********************************************************
for(int proteinId = 0; proteinId < sentInputInfo.proteinList.size(); proteinId++)
{
Protein protein = sentInputInfo.proteinList.get(proteinId);
int proStartId = protein.proteinStart;
int proEndId = protein.proteinEnd;
if(startTokenId == proStartId && endTokenId == proEndId)
{
//说明是蛋白质,此时执行一些程序之后跳出
if(segLength == maxSegLength) //此时不能从beamList中取candiList:segLength == endTokenId + 1
{
State oldState = new State();
endTokenStateList.add(state.append(startTokenId, endTokenId, "None", oldState, model, sentInputInfo));
}else
{
List<State> candiList = beamList.get(endTokenId - segLength); //注意此处的 candiList
for(int stateId = 0; stateId < candiList.size(); stateId++)
{
State oldState = candiList.get(stateId);
endTokenStateList.add(state.append(startTokenId, endTokenId, "None", oldState, model, sentInputInfo));
}
}
break outer;
}
}
//********************************************************
if(segLength == maxSegLength) //此时不能从beamList中取candiList:segLength == endTokenId + 1
{
State oldState = new State();
endTokenStateList.add(state.append(startTokenId, endTokenId, triggerType, oldState, model, sentInputInfo));
}else
{
List<State> candiList = beamList.get(endTokenId - segLength); //注意此处的 candiList
for(int stateId = 0; stateId < candiList.size(); stateId++)
{
State oldState = candiList.get(stateId);
endTokenStateList.add(state.append(startTokenId, endTokenId, triggerType, oldState, model, sentInputInfo));
}
}
}
}
}
Collections.sort(endTokenStateList,new Comparator<State>(){ public int compare(State sta1, State sta2) { return new Double(sta2.score).compareTo(new Double(sta1.score));} });
corpus.containK(endTokenStateList, 20);
beamList.add(endTokenStateList);
//判断endTokenId 是否是 goldSegment的结尾词,如果不是,不更新,否则更新
boolean mark = false;
List<Segment> goldSegList = sentOutputInfo.segmentList;
List<Segment> partGoldList = new ArrayList<Segment>();
for(int segId = 0; segId < goldSegList.size(); segId++)
{
Segment seg = goldSegList.get(segId);
int segEnd = seg.segmentEnd;
if(segEnd == endTokenId)//此时更新权重
{
mark = true;
partGoldList = corpus.getPartGoldList(goldSegList, segId);
break;
}
}//循环结束的有问题
//***********************************************************
if(!corpus.isContain(endTokenStateList, partGoldList) && mark == true)
{
List<Map<String, Double>> goldFeature = sentOutputInfo.sentGoldFeature;//一句话的所有token的goldFeature 都存储在List<Map>中
Map<String, Double> endTokenGoldFeature = goldFeature.get(endTokenId); //endTokenId对不对?
List<State> staList = beamList.get(beamList.size() - 1); // 检查此处有否合理
State bestCandidate = staList.get(0);
Map<String, Double> bestSegmentPreFeature = bestCandidate.segmentPreFeatures;
//权重更新
Map<String, Double> result = corpus.cutWeight(endTokenGoldFeature, bestSegmentPreFeature);
model.triggerWeight = corpus.addWeight(model.triggerWeight, result);
break;
}
}
}
}
Multi-beam 解析过程
最新推荐文章于 2024-11-01 14:48:20 发布