![](https://img-blog.csdnimg.cn/20190918140012416.png?x-oss-process=image/resize,m_fixed,h_224,w_224)
pytorch
文章平均质量分 62
pytorch实战项目分析
MusicDancing
这个作者很懒,什么都没留下…
展开
-
Tenserflow 情感分类
1. 背景说明 在Pytorch 实现情感分类版本基础上进行tensorflow实现。2. 加载数据1. 获取数据集def load_data(data_path): df = pd.read_csv(data_path, engine='python', header=None, encoding='utf-8') df.columns = ['sentiment', 'id', 'date', 'query', 'user_id', 'text'] ...原创 2021-12-30 08:16:04 · 895 阅读 · 0 评论 -
Pytorch 情感分类进阶
1. 背景说明 在Pytorch 实现情感分类版本基础上进行优化实现,使用到文本文件reviews.txt和标签文件labels.txt两个数据文件。1.1 数据集预览1. reviews.txt此文件是一个包含25001条句子的长文本。print(len(text)) # 33678267 个字符2. labels.txt2. 显示前5条数据:3. 数据分布(各个类别占比)print(dataset[0].value_cou...原创 2021-12-29 15:24:40 · 853 阅读 · 0 评论 -
Pytorch 实现情感分类
1. 背景说明 本实验使用Twitter上发布的数据集Sentiment140,包含160W条记录,其中{0: 负面;2: 中性;4: 正面}原创 2021-12-27 16:00:44 · 993 阅读 · 0 评论 -
文本生成(字符级)
1122原创 2021-12-24 16:14:09 · 323 阅读 · 0 评论 -
x.data()与x.detach()
1. 相同点:1. 都和x共享同一块数据;2. 都和x的计算历史无关;3. requires_grad = False。2. 不同点:x.detach()更安全原创 2021-12-24 12:41:34 · 775 阅读 · 0 评论 -
交通指示牌识别
涉及单分类,多分类问题,共计5W多样本,43个类,包括一个注释的标签数据源:German Traffic Sign Benchmarks原创 2021-12-23 15:10:03 · 2038 阅读 · 0 评论 -
数据可视化
1. 绘制各种曲线1. 损失函数曲线def plot_loss_cure(epoch_list, loss_list): plt.plot(epoch_list, loss_list) plt.xlabel('Number of Epoch') plt.ylabel('Loss') plt.title('RNN') plt.show()epoch_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]loss_list = [0.1原创 2021-12-23 14:51:19 · 3432 阅读 · 1 评论 -
model.train() 与model.eval() 的区别
1. model.train() train模式下,dropout层会按照设定的参数p设置保留激活单元的概率; BN层会继续计算数据的mean和var等参数并更新。2. model.eval() 在validation时,会使用model.eval()切换到测试模式,通知dropout层和BN层在train和val模式间切换。 dropout层会让所有的激活单元都通过,而BN层会停止计算和更新mean和var,直接使用在训练阶段学到的me...原创 2021-12-22 12:46:45 · 1464 阅读 · 0 评论 -
CIFAR10识别
1. 数据集简介 CIFAR10数据集共有6W张彩色图像,图像大小是32*32*3的,共计10个类,每类6K张图片。 其中训练集5W张,构成了5个训练批,每一批1W张,但一个训练批中的各类图像并不一定数量相同,总的来看训练集,每一类都有5K张;测试集1W单独构成一批,其来自10个分类,每类随机取1K张。...原创 2021-12-21 09:55:29 · 920 阅读 · 0 评论 -
一些常用的函数技巧
1. 计算模型训练时间import timestart = time.time()# 输出训练所耗总时间time_all = time.time() - startprint('Training complete in {:.2f}m | {:.2f}s'.format(time_all // 60, time_all % 60))原创 2021-12-17 21:42:05 · 663 阅读 · 0 评论 -
Pytorch 使用SummaryWriter记录日志信息
11import timefrom torch.utils.tensorboard.writer import SummaryWriter# 获取Tensorboard的writerdef tb_writer(): timestr = time.strftime('%Y%m%d_%H') writer = SummaryWriter('../logs/' + timestr) return writerif __name__ == "__main__": w原创 2021-12-17 21:30:02 · 8304 阅读 · 2 评论 -
Tensor语法进阶
1. 常用操作1.1 GPU判断# 判断是否有GPU,若有则进行部署if torch.cuda.is_available(): device = torch.device('cuda')else: device = torch.device('cpu')x = torch.tensor([1, 2, 3], device=device)2. 分布2.1 指定分布的tensor# 值在0-1之间的均匀分布y = torch.rand(5)# tensor...原创 2021-12-17 10:09:25 · 1259 阅读 · 0 评论 -
肺部感染识别
1. 数据集简介下载地址:https://www.kaggle.com/paultimothymooney/chest-xray-pneumonia/download数据集目录展示: 数据集大小: NORMAL PNEUMONIA train 1342 3876 val 9 9 test 235 391 样本展示:2. 主要流程...原创 2021-12-12 19:57:18 · 2314 阅读 · 0 评论 -
PyTorch手写数字识别
1. MNIST数据集简介 包含了6W训练样本和1W条测试样本,共10类(0-9),每张图片都做了尺寸归一化,都是28*28大小的灰度图。2. 模型训练2.1 超参数定义import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimfrom torchvision import datasets, transforms# 超参数定义EPOC...原创 2021-12-11 09:17:57 · 1075 阅读 · 0 评论 -
PyTorch安装与Tensor基本语法
2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch,它是一个开源的Python机器学习库,用于NLP、CV等应用程序。优点:1. 相当简洁且高效快速的框架,入门简单;2. 动态计算图:提供了一个可以在运行时构建计算图(甚至在运行时可更改)的一个框架;3. 多gpu支持,自定义数据加载器和简化的预处理器。安装地址:https://pytorch.org...原创 2021-11-29 11:25:41 · 960 阅读 · 0 评论