Zero.写作动机
对给定参数区间内部进行搜索,寻找到最优参数近似解的方法有很多。比如网格搜索。但是网格搜索太过暴力,往往花销过大。这里介绍一种新的参数寻优方法——蒙特卡洛树搜索。
网络上关于蒙特卡洛方法几乎清一色都是在介绍Buffon实验并以此估计某个量。这里,我们介绍蒙特卡洛树用于参数寻优。
一、模型原理
下面推荐几个博客,这些文章已经介绍得很好了:
①https://blog.csdn.net/ljyt2/article/details/78332802
②https://www.jianshu.com/p/a34f06885ef8
二、编程实现
Version one: Python
https://www.jianshu.com/p/a34f06885ef8
Version Two: Matlab
鉴于实际需求,笔者在Python版本的基础上实现了matlab版本,涉及到matlab的面向对象编程。读者诸君按需获取即可
state.m文件
classdef State < handle
properties
value
round
choices
PATH
x2 %因为假定现在只用MCST找到第三步迭代的最优参数
y
sigma
im
end
methods
function self= State(x2,y,sigma,im)
%在这里进行初始化
self.value = 0;
self.round = 0;
self.choices = [];
self.PATH = [0.1:0.2:3];
self.x2 = x2;
self.y = y;
self.sigma = sigma;
self.im = im;
end
function state = new_state(self)
choice = randperm(numel(self.PATH));
choice = self.PATH(choice(1));%从一维数组中进行随机采样
state = State(self.x2,self.y,self.sigma,self.im);
%对于辣椒的彩色图片,第三步迭代的默认两个参数是0.7, 0.8
value_ = 0;
if numel(self.choices) == 1 %当前在选择第二个参数
%计算潜在的value
x3 = step(self.x2, self.y, self.sigma^2, 15, 7, self.choices(1), choice);
value_ = - (sum(sum((x3 - self.im).^2)) / numel(x3)); %反向来
elseif numel(self.choices) == 0 %当前在选择第一个参数
%计算潜在的value
x3 = step(self.x2, self.y, self.sigma^2, 15, 7, choice,0.8);
value_ = - (sum(sum((x3 - self.im).^2)) / numel(x3)); %反向来
else
value_ = 0;
end
%得到一个参数的选择结果
state.value = self.value + value_; %价值计算函数需要更改
state.round = self.round+1;
state.choices = [ self.choices,choice ];%扩充当前的选择
end
function display(self)
fprintf(1,'class State:\n');%表示在终端上进行输出
fprintf(1,'value = %f\n',self.value);
fprintf(1,'round = %d\n',self.round);
fprintf(1,'ready to show the choice array:\n');
for i = 1:numel(self.choices)
if i == 1
fprintf(1,'[');
end
fprintf(1,'%d,',self.choices(i));
if i == numel(self.choices)
fprintf(1,']');
end
end
end
end
end
Node.m文件
classdef Node < handle
properties
parent
children
quality
visit
state
MAX_DEPTH = 2
MAX_CHOICE = numel([0.1:0.2:3]) %其实代表的是children数组的长度的上限
end
methods
function self= Node()
self.quality = 0.0;
self.visit = 0;
%剩下的变量没有定义
end
function add_child(self,node)
fprintf(1,'printing node in function add_child\n');
node
self.children = [self.children,node];
node.parent = self;
end
function display(self)
fprintf(1,'class Node:\n');%表示在终端上进行输出
fprintf(1,'quality = %f\n',self.quality);
fprintf(1,'visit = %d\n',self.visit);
end
function child_node = expand(cnt_node)
%随机选择一个之前没有扩展过的——也就是不在children列表中的一个子节点进行扩展,随机性在new_state的时候的随机函数中体现出来
%返回当前结点扩展出的子节点
fprintf(1,'printing node in function EXPAND\n');
cnt_node
fprintf(1,'printing value of the ori_state in function EXPAND:%f\n\n',cnt_node.state.value);
cnt_node.state.choices
state = new_state(cnt_node.state);
%拿到当前结点的children列表中的子节点的状态
sub_state_value_list = [];
for i = 1:numel(cnt_node.children)
sub_state_value_list(i) = cnt_node.children(i).state.value;
end
fprintf(1,'printing value of the new_state in function EXPAND:%f\n\n',state.value);
state.choices
while ismember(state.value,sub_state_value_list)
fprintf(1,'printing value of the new_state in function EXPAND:%f\n\n',state.value);
state.choices
state = new_state(cnt_node.state);
end
child_node = Node();
child_node.state = state;
add_child(cnt_node,child_node);
fprintf(1,'printing value of the end_child_state in function EXPAND\n');
for i = 1:numel(cnt_node.children)
fprintf(1,'printing value of the end_child_state in function EXPAND:%f\n\n',cnt_node.children(i).state.value);
cnt_node.children(i).state.choices
end
end
function best = best_child(node)
%返回当前结点的children列表中最适合作为扩展结点的子节点
fprintf(1,'printing node in function BEST_CHILD\n');
node
best_score = -100000000; %代表负无穷
best = -1 ;%初始化
for i= 1:numel(node.children)
C = 1/sqrt(2.0);
sub_node = node.children(i);
left = sub_node.quality / sub_node.visit; %分母是被访问的次数
right = 2.0*log(node.visit)/sub_node.visit;
score = left+C*sqrt(right);
if score >best_score
best = sub_node;
best_score = score;
end
end
end
function node = tree_policy(node)
fprintf(1,'printing node in function TREE_POLICY\n');
node
%选择+expand扩展
%调用逻辑:如果当前结点还有子节点没有被添加到children列表——也就是还没有expand过,那么就从还没有扩展过的子节点中随机选择一个进行扩展,并返回该被需选中的子节点
%调用逻辑:如果当前结点是叶子结点,直接返回该结点
%调用逻辑:如果当前结点的所有子节点都已经被加入到了children列表,那么就从中选择一个收益最高的结点进行扩展,并且返回该结点
%选择是否是叶子结点
count = 0;
while node.state.round < node.MAX_DEPTH
fprintf(1,'running while-end with count:%d in Node.m/line73\n',count);
count = count +1;
if numel(node.children) < node.MAX_CHOICE
node = expand(node);
return
else
node = best_child(node);
end
end
end
function expanded_value = default_policy(node)
fprintf(1,'printing node in function DEFAULT_POLICY\n');
node
%模拟
%算一次从当前结点随机走到叶节点的收益
now_state = node.state;
count= 0;
while now_state.round < node.MAX_DEPTH
fprintf(1,'running while-end with count:%d in Node.m/line90\n',count);
count = count +1;
now_state = new_state(now_state);
end
expanded_value = now_state.value;
end
function backup(node,reward)
fprintf(1,'printing node in function BACKUP\n');
node
%从当前结点带着reward回溯到根节点,并且增加路径上的每个结点的visit次数和quality
while ~isempty(node)
fprintf(1,'not empty\n');
node.visit = node.visit +1;
node.quality = node.quality+reward;
node = node.parent;
end
end
function best = mcts(node)
%似乎是多次尝试扩展,选择当前扩展到children列表中的子节点中的收益最好的一个子结点进行扩展,并且返回该被选中的子节点
% times = 5 ;%为什么是5?
times = 20;
for i = 1:times
expand = tree_policy(node);%当前结点向下选择扩展一个结点
reward = default_policy(expand);%计算从该扩展结点走到叶子结点的随机一条路径的一种收益情况
backup(expand,reward);
end
best = best_child(node);
end
function main(self)
init_state = State();
init_node = Node();
init_node.state = init_state;
cnt_node = init_node;
for i = 1:self.MAX_DEPTH
cnt_node = mcts(cnt_node);
end
end
end
end
Notice.
在matlab的实现版本中,注意两种不同的类的写法。classdef name < handle是引用类型,这样的类可以作为另外一个类的属性存在。classdef name是按value类型,这样的类如果想要使用自己的实例对象作为类的一个属性会报错。
上面的Node类和State类都属于引用类型。