根据信号序列csv文件,写自己的train_dataset函数

仅作为记录,大佬请跳过。

文章目录

程序

新建dataset.py文件

import torch
import numpy as np
import os
import codecs
import csv
import pandas as pd


class pulsedataset(torch.utils.data.Dataset):
    def __init__(self):
        self.data=[]
        self.targets=[]

        f = csv.reader(open(r'E:\generate_data.csv', 'r', encoding='utf-8'))

        temp=[]
        i_f = 0
        for i in f:
            # 去除首行
            if i_f == 0:
                i_f = i_f + 1
                continue
            temp = [float(i[n]) for n in range(len(i) - 1)]
            self.data.append(temp[:len(temp)-1])
            self.targets.append(temp[len(temp)-1])

    def __getitem__(self, index):
        signal=self.data[index]
        target=self.targets[index]
        return signal,target

    def __len__(self):
        return len(self.data)



再新建main.py文件:

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os, sys
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
import pulseclass

dataset=pulseclass.pulsedataset()

dataloader=torch.utils.data.DataLoader(
    dataset,
    batch_size= 8,
    shuffle=True,
    num_workers = 8
)

print('ok')

两个文件放在同一个文件夹下,运行main.py,即可。

(断点运行main.py可以看到变量数据:

在这里插入图片描述

在这里插入图片描述
)

参考

添加链接描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值