Pytorch
摸鱼时刻
这个作者很懒,什么都没留下…
展开
-
Pytorch 加载数据加速
利用DataPretcher加速device = torch.device("cuda")class DataPrefetcher(): def __init__(self, loader): self.loader = iter(loader) self.stream = torch.cuda.Stream() self.preloa...原创 2019-11-15 08:48:25 · 769 阅读 · 0 评论 -
如何在 PyTorch 中设定学习率衰减(learning rate decay)
很多时候我们要对学习率(learning rate)进行衰减,下面的代码示范了如何每30个epoch按10%的速率衰减:def adjust_learning_rate(optimizer, epoch): """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" lr = args...转载 2018-10-17 16:24:28 · 23513 阅读 · 1 评论 -
pytorch DataParallel 多GPU训练
当一台服务器有多张GPU时,运行程序默认在一张GPU上运行。通过多GPU训练,可以增大batchsize,加快训练速度。from torch.nn import DataParallelnum_gpu = torch.cuda.device_count()net = DataParallel(net, device_ides=range(num_gpu))device_ides默认情况下...原创 2018-11-01 10:48:43 · 1636 阅读 · 0 评论 -
Pytorch 模型的加载与保存
pytorch的模型和参数是分开的,可以分别保存或加载模型和参数。1、直接保存模型# 保存模型torch.save(model, 'model.pth')# 加载模型model = torch.load('model.pth')2、分别加载模型的结构和参数# 保存模型参数torch.save(model.state_dict(), 'model.pth')# 加载模型参数mo...原创 2019-01-17 13:39:16 · 2067 阅读 · 0 评论 -
Pytorch 计算分类器准确率(总分类及子分类)
分类器平均准确率计算:correct = torch.zeros(1).squeeze().cuda()total = torch.zeros(1).squeeze().cuda()for i, (images, labels) in enumerate(train_loader): images = Variable(images.cuda()) ...原创 2019-01-13 14:53:05 · 24348 阅读 · 6 评论 -
Pytorch 计算模型参数量
转自:https://blog.csdn.net/jdzwanghao/article/details/84196239################### 模型定义# -------------class MyModel(nn.Module): def __init__(self, feat_dim): # input the dim of output fea-ma...转载 2019-05-11 11:44:15 · 6253 阅读 · 0 评论 -
Pytorch DiceLoss MulticlassDiceloss
转自:https://blog.csdn.net/a362682954/article/details/81226427DiceLossclass DiceCoeff(Function): """Dice coeff for individual examples""" def forward(self, input, target): self.save_...转载 2019-05-12 16:42:03 · 2075 阅读 · 0 评论 -
Pytorch 类别标签转换one-hot编码
这里用到了Pytorch的scatter_函数:scatter_(dim, index, src) → TensorWrites all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is sp...原创 2019-05-13 09:08:37 · 21343 阅读 · 3 评论 -
Pytorch 图像分割使用findContours报错
做图像分割时,想在Pytorch的预测结果的基础上进行后处理,用到了cv2.findContours里找标签的连通域,运行时报错:cv2.error: OpenCV(4.0.0) /io/opencv/modules/imgproc/src/contours.cpp:195: error: (-210:Unsupported format or combination of formats) [...原创 2019-06-05 10:47:18 · 6277 阅读 · 3 评论