这个题目是创新工场在微博上发布的题目,工场很忙。
程序使用python编写,使用遗传算法实现,在python2.7下运行通过。
File.py文件
def Read():
'''
读取文件得到输入学生面试信息
Returns :
dict 学生面试信息
'''
source={}
print "输入初始数据:"
f=open("iw.in")
try:
lines=f.readlines()
f.close()
finally:
f.close()
for line in lines:
if line[0:3]=='0 0':
break
print line[0:3]
if not source.has_key(line[0]):
source[line[0]]=[]
source[line[0]].append(line[2])
return source
def Write(result={}):
'''
把结果写入文件
:type result : dict
:计算结果
'''
print "输出结果:"
f=open("iw.out",'w')
keys=sorted(result.keys())
try:
for key in keys:
line=""
for value in result[key]:
line+=value
line+=" "
print line
f.write(line+'\n')
finally:
f.close()
f.close()
if __name__=="__main__":
print '这是“工场很忙”题目使用遗传算法的实现中用于读取文件输入数据和写文件输出结果!'
Heredity.py文件
import random
from File import *
dataSource={} #:type dict 保存输入数据,即每个学生的面试数据 param:type key: string 学生名 :type value:list 项目名
dataOut={} #:type dict 保存输出结果 param:type key: string 项目名 :type value:list 面试顺序
projectNames={} #:type dict 保存项目的面试学生信息 param:type key: string 项目名:type value:list 学生名
projectNamesLen=0 #:type integer 项目数
encodeLen=0 #:type integer 编码长度
encodeLenPer=0 #:type integer 每个项目编码长度
encodeSlc=[] #:type list 每次面试可选学生二维列表
initPplLen=50 #:type integer 初始种群大小
maxGent=500 #:type integer 遗传代数
ps=0.1 #:type float 选择概率
pc=0.1 #:type float 交叉概率
pm=0.05 #:type float 变异概率
def GetAllProjects(dataSource={}):
'''
获取每个项目面试的学生信息
:type dataSource : dict
:每个学生参加面试的项目信息
Returns :
dict 每个项目面试的学生信息
integer 项目个数
'''
proNams={}
for key in dataSource:
studPro=dataSource[key]
for value in studPro:
if not proNams.has_key(value):
proNams[value]=[]
proNams[value].append(key)
return proNams,len(proNams.keys())
def GetEncodeLen(dataSource={},proNams={},proNamLen=0):
'''
计算编码长度和每个项目编码长度
:type dataSource : dict
:每个学生参加的面试项目信息
:type proNams : dict
:每个项目面试的学生信息
:type proNamLen : integer
:项目个数
Returns :
integer 编码长度
integer 每个项目编码长度
'''
maxStudLen=0
maxProLen=0
for key in dataSource:
length=len(dataSource[key])
if maxStudLen<length:
maxStudLen=length
for key in proNams:
length=len(proNams[key])
if maxProLen<length:
maxProLen=length
length=max(maxStudLen,maxProLen)
return length*proNamLen,length
def InitEncodeSlc(encodeLen=0,perLen=0,proNams={}):
'''
初始化编码某个位置可选信息
:type encodeLen : integer
:编码长度
:type perLen : integer
:每个项目的编码长度
:type proNams : dict
:每个项目的面试学生
Returns :
list 可选信息
'''
es=[]
pn=proNams.keys()
for i in range(0,encodeLen,1):
es.append(proNams[pn[i/perLen]][:])
es[i].append('0')
return es
def GetOneSelect(index=0,encodeSlc=[]):
'''
得到编码串中某个位置选择的学生
:type index : integer
:编码选择位置索引
:type encodeSlc : list
:可选信息
Returns :
char 选择编码
'''
slc=random.randint(0,len(encodeSlc[index])-1)
return encodeSlc[index][slc]
def CheckEncodeInvalidate(ec=[],encodeLen=0,ecPerLen=0,proNams={}):
'''
对编码进行合法性检查
:type ec : list
:编码信息
:type encodeLen : integer
:编码长度
:type ecPerLen : integer
:每个项目编码长度
:type proNams : dict
:项目面试学生信息
Returns:
boolean Ture合法,False不合法
'''
if len(ec)<encodeLen:
return False
proLen=len(proNams.keys())
for i in range(0,proLen,1):
select={}
for j in range(0,ecPerLen,1):
cutSlc=ec[i*ecPerLen+j]
if cutSlc!='0':
if not select.has_key(cutSlc):
select[cutSlc]=1
else:
return False
for k in range(i+1,proLen,1):
if cutSlc==ec[k*ecPerLen+j]:
return False
if len(select.keys())<len(proNams[proNams.keys()[i]]):
return False
return True
def GetOneEncode():
'''
得到一个编码
'''
global projectNames,encodeLen,encodeLenPer,encodeSlc
ec=[]
while(not CheckEncodeInvalidate(ec,encodeLen,encodeLenPer,projectNames)):
ec=[]
for i in range(0,encodeLen,1):
ec.append(GetOneSelect(i,encodeSlc))
return ec
def Init():
'''
数据初始化
'''
print '初始化数据!'
global projectNames,projectNamesLen,encodeLen,encodeLenPer,encodeSlc
(projectNames,projectNamesLen)=GetAllProjects(dataSource)
(encodeLen,encodeLenPer)=GetEncodeLen(dataSource,projectNames,projectNamesLen)
encodeSlc=InitEncodeSlc(encodeLen,encodeLenPer,projectNames)
def InitPupulation(pplLen=0):
'''
初始化种群
:type pplLen : integer
:初始化种群大小
Returns :
list 初始化种群
'''
population=[]
i=0
print '初始化种群(',pplLen,'):'
while i<pplLen:
population.append(GetOneEncode())
i+=1
print '第',i,'个个体:',population[i-1]
return population
def Sort(population=[],populationFitness=[],pplLen=0):
'''
对种群按适应度排序
:type population : list
:种群编码信息
:type populationFitness : list
:种群适应度
:type pplLen : integer
:种群大小
'''
for i in range(0,pplLen-1,1):
for j in range(0,pplLen-i-1,1):
if populationFitness[j]>populationFitness[j+1]:
tmp=populationFitness[j]
populationFitness[j]=populationFitness[j+1]
populationFitness[j+1]=tmp
tmp=population[j]
population[j]=population[j+1]
population[j+1]=tmp
def Select(population=[],populationFitness=[],pplLen=0):
'''
种群选择操作
:type population : list
:种群编码信息
:type populationFitness : list
:种群适应度
:type pplLen : integer
:种群大小
'''
global ps
Sort(population,populationFitness,pplLen)
count=int(pplLen*ps)
return population[0:pplLen-count]+population[0:count]
def CrossPopulationOfProject(populationLeft=[],populationRight=[],index=0,encodeLen=0,encodeLenPer=0):
'''
对两个个体进行一次交叉运算
:type populationLeft : list
:第一个个体
:type populationRight : list
:第二个个体
:type index : integer
:交换判断索引
:type encodeLen : integer
:编码长度
:type encodeLenPer : integer
:一个项目编码长度
Returns :
list 第一个个体交叉结果
list 第二个个体交叉结果
'''
lLeft=populationLeft[0:index*encodeLenPer]
lRight=populationRight[0:index*encodeLenPer]
tmpLeft=populationLeft[index*encodeLenPer:(index+1)*encodeLenPer]
tmpRight=populationRight[index*encodeLenPer:(index+1)*encodeLenPer]
rLeft=populationLeft[(index+1)*encodeLenPer:encodeLen]
rRight=populationRight[(index+1)*encodeLenPer:encodeLen]
return lLeft+tmpRight+rLeft,lRight+tmpLeft+rRight
def CrossPopulation(populationLeft=[],populationRight=[]):
'''
对两个个体完成交叉运算过程
:type populationLeft : list
:第一个个体
:type populationRight : list
:第二个个体
Returns:
list 第一个个体交叉结果
list 第二个个体交叉结果
'''
global projectNames,projectNamesLen,encodeLen,encodeLenPer
left=populationLeft[:]
right=populationRight[:]
crossIndex=0
(left,right)=CrossPopulationOfProject(left,right,crossIndex,encodeLen,encodeLenPer)
crossIndex+=1
while not (CheckEncodeInvalidate(left,encodeLen,encodeLenPer,projectNames) and CheckEncodeInvalidate(right,encodeLen,encodeLenPer,projectNames)):
if crossIndex>=projectNamesLen:
return populationLeft,populationRight
left=populationLeft[:]
right=populationRight[:]
(left,right)=CrossPopulationOfProject(left,right,crossIndex,encodeLen,encodeLenPer)
crossIndex+=1
return left,right
def Cross(population=[],pplLen=0):
'''
种群交叉运算
:type population : list
:种群信息
:type pplLen : integer
:种群大小
'''
global pc
for i in range(0,pplLen,2):
if pc<random.random():
(population[i],population[i+1])=CrossPopulation(population[i],population[i+1])
def VariationPopulation(population=[]):
'''
对单个个体进行变异运算
:type population : list
:变异个体
Returns :
list 变异结果
'''
global projectNames,projectNamesLen,encodeLen,encodeLenPer,encodeSlc
attept=0
pop=population[:]
index=random.randint(0,encodeLen-1)
pop[index]=GetOneSelect(index,encodeSlc)
while not CheckEncodeInvalidate(pop,encodeLen,encodeLenPer,projectNames):
if attept>=encodeLen:
return population
pop=population[:]
index=random.randint(0,encodeLen-1)
pop[index]=GetOneSelect(index,encodeSlc)
attept+=1
return pop
def Variation(population=[],pplLen=0):
'''
种群变异运算
:type population : list
:种群信息
:type pplLen : integer
:种群大小
'''
global pm
for i in range(0,pplLen,1):
if pm<random.random():
population[i]=VariationPopulation(population[i])
def CalculateStudentTime(encode=[]):
'''
计算一次面试顺序中学生所用时间
:type encode : list
:编码信息
Returns :
integer 学生所用时间
'''
global dataSource,projectNamesLen,encodeLen,encodeLenPer
time=0
for key in dataSource:
studStart=0
studEnd=0
isStart=True
for i in range(studStart,encodeLenPer,1):
for j in range(0,projectNamesLen,1):
if isStart and key==encode[i+j*encodeLenPer]:
studStart=i
studEnd=i+1
isStart=False
elif not isStart and i>studEnd and key==encode[i+j*encodeLenPer]:
studEnd=i+1
time+=studEnd-studStart
return time
def CalculateBossTime(encode=[]):
'''
计算一次面试顺序中老板所用时间
:type encode : list
:编码信息
Returns :
integer 老板所用时间
'''
global projectNamesLen,encodeLenPer
time=0
for i in range(0,projectNamesLen,1):
bossStart=-1
bossEnd=-1
for j in range(0,encodeLenPer,1):
if bossStart<0 and encode[i*encodeLenPer+j]!='0':
bossStart=j
bossEnd=j+1
elif bossStart>=0 and encode[i*encodeLenPer+j]!='0':
bossEnd=j+1
time+=bossEnd-bossStart
return time
def CalculateHRTime(encode=[]):
'''
计算一次面试顺序中HR所用时间
:type encode : list
:编码信息
Returns :
integer HR所用时间
'''
global projectNamesLen,encodeLenPer
HRStart=0
HREnd=0
isStart=True
for i in range(HRStart,encodeLenPer,1):
for j in range(0,projectNamesLen,1):
if isStart and encode[i+j*encodeLenPer]!='0':
HRStart=0
HREnd=i+1
isStart=False
elif not isStart and i>HREnd and encode[i+j*encodeLenPer]!='0':
HREnd=i+1
return HREnd-HRStart
def CalculateFitness(population=[],pplLen=0):
'''
计算种群适应度
:type population : list
:种群编码信息
:type pplLen : integer
:种群大小
Returns :
list 适应度信息
'''
populationFitness=[]
for i in range(0,pplLen,1):
populationFitness.append(CalculateStudentTime(population[i])*4)
populationFitness[i]+=CalculateBossTime(population[i]*2)
populationFitness[i]+=CalculateHRTime(population[i])
return populationFitness
def EncodeToResult(encode=[]):
'''
把计算结果编码转换为字典
:type encode : list
:编码列表
Returns :
dict
'''
global encodeLenPer,projectNames,projectNamesLen
result={}
for i in range(0,projectNamesLen,1):
result[projectNames.keys()[i]]=encode[i*encodeLenPer:(i+1)*encodeLenPer]
return result
def Heredity():
'''
遗传算法实现过程
'''
global initPplLen,maxGent
population=InitPupulation(initPplLen)
populationFitness=[]
print '遗传过程开始(总',initPplLen,'代):'
for i in range(0,maxGent,1):
populationFitness=CalculateFitness(population,initPplLen)
population=Select(population,populationFitness,initPplLen)
print '第',i,"代种群,最好个体:",population[0]
Cross(population,initPplLen)
Variation(population,initPplLen)
print '遗传过程结束!'
Sort(population,populationFitness,initPplLen)
return EncodeToResult(population[0])
def Start():
'''
算法开始入口
'''
global dataSource,dataOut
dataSource=Read()
Init()
dataOut=Heredity()
Write(dataOut)
if __name__=="__main__":
print '这是‘工场很忙’题目使用遗传算法实现过程!'
print '测试数据:'
print '1 1'
print '1 2'
print '1 3'
print '2 1'
print '3 1'
print '3 2'
print '0 0'
dataSource={"1":["1","2","3"],"2":["1"],"3":["1","2"]}
Init()
dataOut=Heredity()
print '计算结果为:'
keys=sorted(dataOut.keys())
for key in keys:
line=""
for value in dataOut[key]:
line+=value+' '
print line