最近一直在学pytorch,copy了几个经典的入门问题。现在作一下总结。
首先,做的小项目主要有
分类问题:Mnist手写体识别、FashionMnist识别、猫狗大战
语义分割:Unet分割肝脏图像、遥感图像
先把语义分割的心得总结一下,目前只是一部分,以后还会随着学习的深入慢慢往里面加新的感悟。
1)对于二分类问题
1. Unet输出channel:对于二分类问题,类别数为2,channel为1,用uint8的单通道灰度图像表示类别就行(0/1)。
2. label是单通道灰度图像,直接传给损失函数。
3. 损失函数:nn.sigmoid + nn.BCELoss / nn.BCEWithLogitsLoss,此时计算loss的ouput和label维度应该保持一致。batchsize*1*h*w
2)对于多分类问题
1. Unet输出channel: 输出channel是类别数。网络的输入是img,网络的输出是one hot编码的多通道图像。
2. Label是单通道灰度图像,不同的灰度级表示不同的类别。用于传给损失函数,计算Loss。
具体操作方面,第一步有人说先将Label进行one hot编码(即转换成多通道图,一个通道一个类别),这样才能用交叉熵计算损失;也有人说不需要one hot编码,直接把单通道Label作为损失函数的Label。
其实这两个人说的都不错,但第一个人并没有用Pytorch做,而第二个人是用Pytorch和nn.CrossEntropyLoss计算损失的。
在多分类问题中&#x