常用功能函数

1. 读取.txt打乱顺序保存到test.csv

def create_csv(txt_path, csv_path):
    lists = pd.read_csv(txt_path, sep=r"\t", header=None)
    lists = lists.sample(frac=1)
    lists.to_csv(csv_path, index=None)
    print("Finish save csv")

2. 准确率 & 损失结果可视化

# 解决中文乱码问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

def plot_loss(train_loss, val_loss):
    plt.plot(train_loss, label='train_loss')
    plt.plot(val_loss, label='val_loss')
    plt.legend(loc='best')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.title("训练集和验证集loss值得对比图")
    plt.savefig('results/loss.png')
    plt.show()


def plot_acc(train_acc, val_acc):
    plt.plot(train_acc, label='train_acc')
    plt.plot(val_acc, label='val_acc')
    plt.legend(loc='best')
    plt.ylabel('acc')
    plt.xlabel('epoch')
    plt.title("训练集和验证集acc值得对比图")
    plt.savefig('results/acc.png')
    plt.show()


def plot_results(epochs, train_acc, train_loss, test_acc, test_loss):
    x = np.arange(epochs)

    plt.plot(x, train_acc, label='train_acc')
    plt.plot(x, train_loss, label='train_loss')
    plt.plot(x, test_acc, label='test_acc')
    plt.plot(x, test_loss, label='test_loss')

    plt.title("Results", fontsize=15)
    plt.xlabel("X", fontsize=13)
    plt.ylabel("Y", fontsize=13)
    plt.legend()
    plt.savefig('results/result.png')
    plt.show()

3. 保存文件路径到txt

from glob import glob
def save_txt(data_path, save_path):
    files = glob(data_path + '/*.jpg')
    lists = sorted(files)  # 所有图片文件路径

    a = open(save_path, "w", encoding='UTF-8')
    for i in range(len(lists)):
        a.write(lists[i])
        a.write('\n')
    a.close()

4. 读入txt

def ReadTxt(rootdir='test.txt'):
    lines = []
    with open(rootdir, 'r') as file_to_read:
        while True:
            line = file_to_read.readline()
            if not line:
                break
            line = line.strip('\n')
            lines.append(line)
    return lines

5. 获取文件名写入txt

path = '/data/'
files = []
a = open("test.txt", "w", encoding='UTF-8')
for item in os.listdir(path):
    if item.split('.')[-1].lower() in ['jpg', 'jpeg', 'png']:
        a.write(str(item.split('.')[:-1]) + '\n')
a.close()

6. 其它常用函数

  • permute()
    permute()可以对某个张量的任意维度进行调换(对象是tensor)。类似transpose(对象是numpy)
import torch
x = torch.randn(2, 5, 7)
y = x.permute(1, 2, 0)  # 参数对应x的维度
print(f'shape of x: {x.shape}\nshape of y: {y.shape}')

输出:
shape of x: torch.Size([2, 5, 7])
shape of y: torch.Size([5, 7, 2])
  • transpose()
import numpy as np
x = np.random.randn(2, 5, 7)
y = x.transpose(1, 2, 0)
print(f'shape of x: {x.shape}\nshape of y: {y.shape}')

输出:
shape of x: (2, 5, 7)
shape of y: (5, 7, 2)

# 获取当前路径和父路径
```python
import os
cur_path = os.getcwd()
root_path = os.path.dirname(cur_path)

.strip()

删除任何前导/尾随空格。

enumerate(iteration, start)

参数:

  • iteration参数为需要遍历的参数,比如字典、列表、元组等
  • start参数为开始的参数,默认为0(不写start那就是从0开始

返回值:

  • 第一个返回值为从start参数开始的数(序号)
  • 第二个参数为iteration参数中的值。

e.g.

X = ['a', 'b', 'c', 'd', 'e', 'f']
for i, x in enumerate(X):
    print(i, x)
输出:
0 a
1 b
2 c
3 d
4 e
5 f

torch.cat()

作用:在给定维度上对输入的张量(tensor)进行连接(拼接)操作。
参数:

  • inputs : 待连接的张量序列
  • dim : 选择的扩维, 沿着此维连接张量序列。

注:输入tensor的shape必须相同

e.g.

import torch
x = torch.randint(0, 10, (1, 3, 4))
y = torch.randint(0, 10, (1, 3, 4))
z_0 = torch.cat((x, y), dim=0)
z_1 = torch.cat((x, y), dim=1)

print(f'x: {x}\n\ny: {y}\n\nz_0: {z_0}\n\nz_1: {z_1}')
输出:
x: tensor([[[3, 3, 4, 7],
         [3, 5, 7, 7],
         [7, 5, 5, 9]]])

y: tensor([[[1, 6, 4, 0],
         [9, 8, 5, 9],
         [5, 6, 7, 8]]])

z_0: tensor([[[3, 3, 4, 7],
         [3, 5, 7, 7],
         [7, 5, 5, 9]],

        [[1, 6, 4, 0],
         [9, 8, 5, 9],
         [5, 6, 7, 8]]])

z_1: tensor([[[3, 3, 4, 7],
         [3, 5, 7, 7],
         [7, 5, 5, 9],
         [1, 6, 4, 0],
         [9, 8, 5, 9],
         [5, 6, 7, 8]]])

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

An_37

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值