智能评卷系统的第二种实现方式具体实现

文本在上文的基础上,对智能评卷系统的第二种实现方式进行代码级别的实现,以展示训练和实现的具体过程。

首先,本次的实现改进了互注意力机制的实现方式,在原来的numpy的实现方式改为了torch的实现,利用torch的softmax函数进行归一化操作,该方式的归一化提高了数据的准确性。

class ShuoFaObj(object):

    def __init__(self):
        self.ref_vector = load_data('./ref.csv')

    def get_ax(self, x: list):
        r = self.ref_vector[0]
        mat_m = torch.mm(torch.tensor(r, dtype=float), torch.tensor(x, dtype=float).t())
        mat_m_r = F.softmax(mat_m, dim=1)
        mat_m_c = F.softmax(mat_m, dim=0)
        row_avg = mat_m_r.mean(dim=0)

        param_ax = torch.mm(mat_m_c, torch.tensor([row_avg.tolist()], dtype=float).t())

        ax = param_ax.mul(torch.tensor(r, dtype=float))

        return ax.tolist()

下一步的训练也是基于torch的实现,进行了一定的改进,提高了训练的效率。

from saveVec import load_data_to_vec_list_together, load_data
from utils import reshape_into_slices, get_padded_vector, label_to_1D
import torch
import torch.nn as nn
from HenYouShuoFa import ShuoFaObj

vec, lable = load_data_to_vec_list_together('vec.csv')

sf = ShuoFaObj()
vec_ax = []
first = []
for i in vec:
    vec_ax.append(sf.get_ax(i))

vec_flat_t = torch.tensor(vec_ax).flatten(start_dim=1)

label_id = []
for li in lable:
    if li == 0:label_id.append(0)
    elif li == 1.5:label_id.append(1)
    else:label_id.append(2)

print(label_id)
label_id_t = torch.tensor(label_id, dtype=torch.long).view(300, -1)

print(vec_flat_t.size())


Model = torch.nn.Sequential(
    nn.Linear(3840, 1500),
    nn.ReLU(True),
    nn.Dropout(0.3),
    nn.Linear(1500, 500),
    nn.ReLU(True),
    nn.Dropout(0.2),
    nn.Linear(500, 3)
)
Model.train()
Model.load_state_dict(torch.load('params.pth'))
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(Model.parameters(), lr=0.001)

for epoch in range(50):
    i = 0
    for flat_t in vec_flat_t.view(300, 10, -1):
        optimizer.zero_grad()
        out = Model(flat_t)
        loss_b = loss(out, label_id_t[i])
        i += 1
        loss_b.backward()
        optimizer.step()
    print('epoch:'+str(epoch)+', loss is:'+str(loss_b.item()))
    if (epoch+1) % 10 == 0:
        torch.save(Model.state_dict(), 'params.pth')

至此,第二种的训练方式介绍完成,同样会在最后一种评阅方式后进行结果的比对。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值