仅作为记录,大佬请跳过。
程序
新建dataset.py
文件
import torch
import numpy as np
import os
import codecs
import csv
import pandas as pd
class pulsedataset(torch.utils.data.Dataset):
def __init__(self):
self.data=[]
self.targets=[]
f = csv.reader(open(r'E:\generate_data.csv', 'r', encoding='utf-8'))
temp=[]
i_f = 0
for i in f:
# 去除首行
if i_f == 0:
i_f = i_f + 1
continue
temp = [float(i[n]) for n in range(len(i) - 1)]
self.data.append(temp[:len(temp)-1])
self.targets.append(temp[len(temp)-1])
def __getitem__(self, index):
signal=self.data[index]
target=self.targets[index]
return signal,target
def __len__(self):
return len(self.data)
再新建main.py
文件:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os, sys
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
import pulseclass
dataset=pulseclass.pulsedataset()
dataloader=torch.utils.data.DataLoader(
dataset,
batch_size= 8,
shuffle=True,
num_workers = 8
)
print('ok')
两个文件放在同一个文件夹下,运行main.py,即可。
(断点运行main.py可以看到变量数据:
)