在之前的文章里,我的多线程装饰器虽然在我们的RF框架上用上了,但那是因为我们装饰的方法是静态方法。如果装饰的方法是写在类中,由于第一个参数是self,这个参数在被装饰时并不会传入,只有调用时才会在第一个位置传入,导致解析参数时报错。
定位发现问题是在
def multiThread(func): # 实现多线程
def infunc(*args):
print "诊断", args # 这里加了这一行,然后发现后面的每个参数都是两两成对的元组,即每个线程的(*args,**kwargs),而第一个参数是孤零零的实例object,即self
global resset
resset = [] # 每次调用前,清除结果集
tli = [] # 线程对象列表
tid = 0 # 初始线程编号
定位后,解决这个问题不难,把代码改变为如下便可:
在迭代参数开始前先把第一个参数(self)拿走并保存
tli = [] # 线程对象列表
tid = 0 # 初始线程编号
# 在这里新增一个判断 #########
if type(args[0])!=tuple: # 说明原来的函数是类里面的,第一个参数是self
sarg=args[0]
args=tuple(args[1:])
else:
sarg=None
###########################
for arg in args: # 处理传入的参数列表
if len(arg) == 1:
接着在迭代结束后,执行前加上参数,也就是说原来的targ = tuple(list(targ) + [tid])改为额外判断,如果原来的类里有self传入,就把self拼在*args的最开头,变成tuple([sarg] + list(targ) + [tid])
except:
raise ValueError, 'arguement format error'
targ = tuple(list(targ) + [tid]) if not sarg else tuple([sarg] + list(targ) + [tid]) # 在位置参数的最后一个位置,加上线程编号传进去。getResult里会把这个参数pop掉,并记录到结果里
tli.append(Thread(target=func, args=targ, kwargs=darg)) # 线程列表填充
tid = tid + 1 # 线程编号递增
for th in tli:
th.setDaemon(True)
这里是完整的脚本和测试代码
# coding: utf-8
from threading import Thread
def getResult(func): # 重定向函数返回的结果
def infunc(*args, **kwargs):
args = list(args)
tid = args.pop() # 去掉最后一个代表线程编号的参数
args = tuple(args)
try:
resset.append((func(*args, **kwargs), tid)) # 结果是一个由二元元组组成的列表,每个线程一个元组,元组内第一个元素是线程运行的结果,第二个是线程编号
except Exception, e:
resset.append((e, tid)) # 如果线程内报错,就返回错误对象
return infunc
def multiThread(func): # 实现多线程
def infunc(*args):
# print "诊断", args
global resset
resset = [] # 每次调用前,清除结果集
tli = [] # 线程对象列表
tid = 0 # 初始线程编号
if type(args[0])!=tuple: # 说明原来的函数是类里面的,第一个参数是self
sarg=args[0]
args=tuple(args[1:])
else:
sarg=None
for arg in args: # 处理传入的参数列表
if len(arg) == 1:
if type(arg[0]) == type({}):
targ = () # 设置位置参数
darg = arg[0] # 设置命名参数
else:
targ = arg[0]
darg = {}
elif len(arg) == 0 or arg == None:
targ = ()
darg = {}
else:
targ = arg[0] # (args,kwargs) [0]
darg = arg[1] # (args,kwargs) [1]
try:
assert type(targ) == type(()) and type(darg) == type({}) # 确定传来的参数格式正确
except:
raise ValueError, 'arguement format error'
targ = tuple(list(targ) + [tid]) if not sarg else tuple([sarg] + list(targ) + [tid]) # 在位置参数的最后一个位置,加上线程编号传进去。getResult里会把这个参数pop掉,并记录到结果里
tli.append(Thread(target=func, args=targ, kwargs=darg)) # 线程列表填充
tid = tid + 1 # 线程编号递增
for th in tli:
th.setDaemon(True)
th.start()
for th in tli:
th.join() # 大家都开始了才join阻塞
return {x[1]: x[0] for x in resset} # 结果字典,resset的结果是二元元组,0号元素是函数返回值,1号元素是线程编号tid
return infunc
def parseArg(*args, **kwargs): # 打包单个线程的参数
return (args, kwargs)
def multipleArg(gtuple, times): # 复制某个线程的参数n次,形成参数元组
return (gtuple,) * times
if __name__=="__main__":
############################################## 演示常规用法 #####################################
from selenium import webdriver
import time, re, random
@multiThread
@getResult
def baidusearch(statement):
driver = webdriver.Chrome()
driver.get(r'http://www.baidu.com')
driver.maximize_window()
driver.implicitly_wait(10)
kw = driver.find_element_by_id('kw')
kw.clear()
kw.send_keys(statement)
driver.find_element_by_id('su').click()
time.sleep(2)
resNumber = driver.find_element_by_xpath("//span[@class='nums_text']").text
resNumber = re.findall('[0-9,]+', resNumber)[0]
driver.close()
return resNumber
li1 = parseArg('mi')
li2 = parseArg('wo')
li3 = parseArg(u'故宫喵')
print(baidusearch(li1, li2, li3))
lis = multipleArg(parseArg('python'), 3) # 或者可以像下面这样,连续生成多个同样的关键字
print(baidusearch(*lis)) # 这种方式不要忘记解包
lit = tuple([parseArg('html' + str(x)) for x in range(1, 6)]) # 使用推导式,分别搜索html1 html2 ... html5
print(baidusearch(*lit))
################################# 可以使用路由的方式,让不同的函数并发执行 #############################
def plus(x, y):
return x + y
def minus(x, y):
return abs(x - y)
@multiThread
@getResult
def union(func_name, *args, **kwargs): # 可以用这种方法来实现路由,让不同的函数同时执行在一个线程里
return eval(func_name)(*args, **kwargs)
li1 = parseArg('plus', 1, 2)
li2 = parseArg('minus', 3, y=5)
li3 = parseArg('plus', x=4, y=6)
print(union(li1, li2, li3))
#################################### 类里面的函数也支持这样处理了 ################################################
class testLibrary(object):
@multiThread
@getResult
def printer(self, tno):
'''
:param unicode tno: 线程编号
:return:
'''
startTime = time.time()
for i in range(50):
print tno, i
time.sleep(random.random())
return time.time() - startTime
@staticmethod
def parseArg(*args, **kwargs): # 打包单个线程的参数
return (args, kwargs)
@staticmethod
def multipleArg(gtuple, times): # 复制某个线程的参数n次,形成参数元组
return (gtuple,) * times
t=[parseArg(x,) for x in range(4)]
lib=testLibrary()
print lib.printer(*t)