1. 读取.txt打乱顺序保存到test.csv
def create_csv(txt_path, csv_path):
lists = pd.read_csv(txt_path, sep=r"\t", header=None)
lists = lists.sample(frac=1)
lists.to_csv(csv_path, index=None)
print("Finish save csv")
2. 准确率 & 损失结果可视化
# 解决中文乱码问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
def plot_loss(train_loss, val_loss):
plt.plot(train_loss, label='train_loss')
plt.plot(val_loss, label='val_loss')
plt.legend(loc='best')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.title("训练集和验证集loss值得对比图")
plt.savefig('results/loss.png')
plt.show()
def plot_acc(train_acc, val_acc):
plt.plot(train_acc, label='train_acc')
plt.plot(val_acc, label='val_acc')
plt.legend(loc='best')
plt.ylabel('acc')
plt.xlabel('epoch')
plt.title("训练集和验证集acc值得对比图")
plt.savefig('results/acc.png')
plt.show()
def plot_results(epochs, train_acc, train_loss, test_acc, test_loss):
x = np.arange(epochs)
plt.plot(x, train_acc, label='train_acc')
plt.plot(x, train_loss, label='train_loss')
plt.plot(x, test_acc, label='test_acc')
plt.plot(x, test_loss, label='test_loss')
plt.title("Results", fontsize=15)
plt.xlabel("X", fontsize=13)
plt.ylabel("Y", fontsize=13)
plt.legend()
plt.savefig('results/result.png')
plt.show()
3. 保存文件路径到txt
from glob import glob
def save_txt(data_path, save_path):
files = glob(data_path + '/*.jpg')
lists = sorted(files) # 所有图片文件路径
a = open(save_path, "w", encoding='UTF-8')
for i in range(len(lists)):
a.write(lists[i])
a.write('\n')
a.close()
4. 读入txt
def ReadTxt(rootdir='test.txt'):
lines = []
with open(rootdir, 'r') as file_to_read:
while True:
line = file_to_read.readline()
if not line:
break
line = line.strip('\n')
lines.append(line)
return lines
5. 获取文件名写入txt
path = '/data/'
files = []
a = open("test.txt", "w", encoding='UTF-8')
for item in os.listdir(path):
if item.split('.')[-1].lower() in ['jpg', 'jpeg', 'png']:
a.write(str(item.split('.')[:-1]) + '\n')
a.close()
6. 其它常用函数
- permute()
permute()可以对某个张量的任意维度进行调换(对象是tensor)。类似transpose(对象是numpy)
import torch
x = torch.randn(2, 5, 7)
y = x.permute(1, 2, 0) # 参数对应x的维度
print(f'shape of x: {x.shape}\nshape of y: {y.shape}')
输出:
shape of x: torch.Size([2, 5, 7])
shape of y: torch.Size([5, 7, 2])
- transpose()
import numpy as np
x = np.random.randn(2, 5, 7)
y = x.transpose(1, 2, 0)
print(f'shape of x: {x.shape}\nshape of y: {y.shape}')
输出:
shape of x: (2, 5, 7)
shape of y: (5, 7, 2)
# 获取当前路径和父路径
```python
import os
cur_path = os.getcwd()
root_path = os.path.dirname(cur_path)
.strip()
删除任何前导/尾随空格。
enumerate(iteration, start)
参数:
- iteration参数为需要遍历的参数,比如字典、列表、元组等
- start参数为开始的参数,默认为0(不写start那就是从0开始
返回值:
- 第一个返回值为从start参数开始的数(序号)
- 第二个参数为iteration参数中的值。
e.g.
X = ['a', 'b', 'c', 'd', 'e', 'f']
for i, x in enumerate(X):
print(i, x)
输出:
0 a
1 b
2 c
3 d
4 e
5 f
torch.cat()
作用:在给定维度上对输入的张量(tensor)进行连接(拼接)操作。
参数:
- inputs : 待连接的张量序列
- dim : 选择的扩维, 沿着此维连接张量序列。
注:输入tensor的shape必须相同
e.g.
import torch
x = torch.randint(0, 10, (1, 3, 4))
y = torch.randint(0, 10, (1, 3, 4))
z_0 = torch.cat((x, y), dim=0)
z_1 = torch.cat((x, y), dim=1)
print(f'x: {x}\n\ny: {y}\n\nz_0: {z_0}\n\nz_1: {z_1}')
输出:
x: tensor([[[3, 3, 4, 7],
[3, 5, 7, 7],
[7, 5, 5, 9]]])
y: tensor([[[1, 6, 4, 0],
[9, 8, 5, 9],
[5, 6, 7, 8]]])
z_0: tensor([[[3, 3, 4, 7],
[3, 5, 7, 7],
[7, 5, 5, 9]],
[[1, 6, 4, 0],
[9, 8, 5, 9],
[5, 6, 7, 8]]])
z_1: tensor([[[3, 3, 4, 7],
[3, 5, 7, 7],
[7, 5, 5, 9],
[1, 6, 4, 0],
[9, 8, 5, 9],
[5, 6, 7, 8]]])