adjust_param
是一个辅助函数。在的返回行中断言TypeError
optimize_theta
.
def adjust_param(R, delta, i, theta):
thetaplus = theta.copy()
thetaminus = theta.copy()
thetaplus[i*2] += delta
thetaplus[i*2+1] += delta
thetaminus[i*2] -= delta
thetaminus[i*2+1] -= delta
y = Remp(q_data, labels, R, num_samples, theta)
yplus = Remp(q_data, labels, R, num_samples, thetaplus)
yminus = Remp(q_data, labels, R, num_samples, thetaminus)
if (yplus < y and yplus < yminus and yplus != -1):
return thetaplus, yplus
elif (yminus < y and yminus < yplus and yminus != -1):
return thetaminus, yminus
else:
return theta, y
def optimize_theta(N, R, delta, i, theta, risk):
if N == 0:
print("Theta : " + str(type(theta)))
print("= " + str(theta))
print()
print("Risk : " + str(type(risk)))
print("= " + str(risk))
return theta, risk
else:
theta_new, risk_new = adjust_param(R, delta, i, theta)
if i == (len(theta)/2)-1:
#print("N = " + str(N-1))
#print("theta = " + str(theta))
risk_copy = risk.copy()
risk_copy.append(risk_new)
optimize_theta(N-1, R, delta, 0, theta_new, risk_copy)
else:
optimize_theta(N, R, delta, i+1, theta_new, risk)
输出:
Theta :
= [0.85885111 0.86066499 0.47482528 0.13555158 0.87249245 0.02604654
0.2906744 0.34618303]
Risk :
= [0.6273510217403618]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
in
----> 1 theta, risk = optimize_theta(N, R, delta, 0, theta0, [])
TypeError: cannot unpack non-iterable NoneType object
任何洞察力都将不胜感激。谢谢您!