LR-CNN 核心代码

CNN-Based Chinese NER with Lexicon Rethinking
基于LR-CNN的中文命名实体识别
作者: Tao Guil , Ruotian Ma1等
单位:复旦大学、Video++
发表会议及时间: IJCAI 2019

基于LR-CNN的中文命名实体识别模型结构


首先输入字符信息,经过CNN1,CNN2,和Bigram词信息用Attention进行融合,之后经过CNN3,和Trigram词信息进行融合,CNN4类似,最后输出最顶层Attention。这是属于一个模块,然后将这一模块复制一遍同时每层加入X1,也就是刚刚最顶层的Attention值,得到第二个模块的最顶层Attention之后,将每个模块每层的Attention值进行融合,以多尺度特征Attention的方式。融合之后得到最终的输入表示,输入到CRF层中。
上面的过程还用到了一些tricks,比如残差连接,位置编码,在代码中有所体现。

代码结构

在这里插入图片描述

1. Attention机制

在这里插入图片描述

import torch.nn as nn


class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, hidden_dim, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, mask=None):
        attn = torch.bmm(q, k.transpose(1, 2))
        attn = attn / self.temperature

        if mask is not None:
            mask = mask.bool()
            attn = attn.masked_fill(mask, -np.inf)

        attn = self.softmax(attn)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)

        return output, attn
class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''

    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k)
        self.w_ks = nn.Linear(d_model, n_head * d_k)
        self.w_vs = nn.Linear(d_model, n_head * d_v)
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))

        self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5), hidden_dim=self.d_v)
        self.layer_norm = nn.LayerNorm(d_model)

        self.fc = nn.Linear(n_head * d_v, d_model)
        nn.init.xavier_normal_(self.fc.weight)

        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head

        sz_b, len_q, _ = q.size()
        sz_b, len_k, _ = k.size()
        sz_b, len_v, _ = v.size()

        residual = q

        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k)  # (n*b) x lq x dk
        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k)  # (n*b) x lk x dk
        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v)  # (n*b) x lv x dv

        if mask is not None:
            mask = mask.repeat(n_head, 1, 1).bool()  # (n*b) x .. x ..
        output, attn = self.attention(q, k, v, mask=mask)

        output = output.view(n_head, sz_b, len_q, d_v)
        output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1)  # b x lq x (n*dv)

        output = self.dropout(self.fc(output))
        output = self.layer_norm(output + residual)
        return output, attn


class GlobalGate(nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.head = 1
        self.self_attention = MultiHeadAttention(self.head, self.hidden_dim, self.hidden_dim // self.head,
                                                 self.hidden_dim // self.head)
        self.G2Gupdategate = nn.Linear(2 * self.hidden_dim, self.hidden_dim, bias=True)

    def forward(self, layer_output, global_matrix=None):

        if global_matrix is not None:
            layer_output_selfatten, _ = self.self_attention(layer_output, layer_output, layer_output)  # (b,l,h)
            input_cat = torch.cat([layer_output_selfatten, global_matrix], dim=2)  # (b,l,2h)
            update_gate = torch.sigmoid(self.G2Gupdategate(input_cat))
            new_global_matrix = update_gate * layer_output_selfatten + (1 - update_gate) * global_matrix

        else:
            new_global_matrix, _ = self.self_attention(layer_output, layer_output, layer_output)

        return new_global_matrix

2. 特征抽取

在这里插入图片描述

在这里插入图片描述

class LayerGate(nn.Module):

    def __init__(self, hidden_dim, input_dim, use_gaz=True, gpu=False):
        super().__init__()
        self.hidden_dim = hidden_dim  # layer hidden dim
        self.input_dim = input_dim  # input gaz embed dim
        self.use_gaz = use_gaz

        self.index = torch.LongTensor([[i + j * hidden_dim * 4 for i in range(self.hidden_dim * 4)] for j in range(4)])
        # (4,H) [[0:H],[h:2H],[2H:3H],[3H:4H]]  H=4h
        self.index2 = torch.LongTensor([[i + j * hidden_dim * 3 for i in range(self.hidden_dim * 3)] for j in range(4)])
        # (4,H) H=3h
        if gpu:
            self.index = self.index.cuda()
            self.index2 = self.index2.cuda()

        self.cat2gates = nn.Linear(self.hidden_dim * 2 + self.input_dim, self.hidden_dim * 4 * 4)  # para run
        self.exper2gates = nn.Linear(self.hidden_dim, self.hidden_dim * 3 * 4)
        self.reset_paras()

    def reset_paras(self, ):
        for layer in range(4):
            nn.init.constant_(self.cat2gates.bias[self.hidden_dim * 4 * layer:self.hidden_dim * 4 * (layer + 1)].data,
                              val=0)
            nn.init.constant_(self.exper2gates.bias[self.hidden_dim * 3 * layer:self.hidden_dim * 3 * (layer + 1)].data,
                              val=0)
            for i in range(4):
                start = layer * 4 + i
                nn.init.xavier_normal_(self.cat2gates.weight[self.hidden_dim * start:self.hidden_dim * (start + 1), :])
            for i in range(3):
                start = layer * 3 + i
                nn.init.xavier_normal_(
                    self.exper2gates.weight[self.hidden_dim * start:self.hidden_dim * (start + 1), :])

    def forward(self, CNN_output, gaz_input, gaz_input_back, global_matrix, exper_input=None, gaz_mask=None):
        batch_size = global_matrix.size(0)
        seq_len = global_matrix.size(1)

        index = self.index.unsqueeze(1).repeat(1, seq_len, 1)  # (4,l,4h)
        index = index.view(1, -1, self.hidden_dim * 4).repeat(batch_size, 1, 1)  # (b,4l,4h)

        index2 = self.index2.unsqueeze(1).repeat(1, seq_len, 1)  # (4,l,3h)
        index2 = index2.view(1, -1, self.hidden_dim * 3).repeat(batch_size, 1, 1)  # (b,4l,3h)

        if self.use_gaz:
            gaz_input = torch.cat([gaz_input, gaz_input_back, gaz_input, gaz_input_back], dim=1)  # (b,4l,i)

        seq_len_cat = seq_len * 4
        global_matrix = global_matrix.repeat(1, 4, 1)  # (b,4l,h)

        if exper_input is not None:
            exper_input = exper_input.repeat(1, 4, 1)  # (b,4l,h)

        if self.use_gaz:
            cat_input = torch.cat([CNN_output, gaz_input, global_matrix], dim=2)  # (b,4l,2*h+gaz_dim)
        else:
            cat_input = torch.cat([CNN_output, global_matrix], dim=2)  # (b,4l,2*h+gaz_dim)

        cat_gates_ = self.cat2gates(cat_input)  # (b,4l,4h*4)
        cat_gates = torch.gather(cat_gates_, dim=-1, index=index)  # (b,4l,4h)

        new_state = torch.tanh(cat_gates[:, :, :self.hidden_dim])  # (b,4l,h)
        gates = cat_gates[:, :, self.hidden_dim:]  # (b,4l,3h)

        if exper_input is not None:
            exper_gates_ = self.exper2gates(exper_input)  # (b,l,3h)
            exper_gates = torch.gather(exper_gates_, dim=-1, index=index2)

            gates = gates + exper_gates  # (b,4l,3h)

            state_cat = torch.cat([new_state.unsqueeze(2), CNN_output.unsqueeze(2), exper_input.unsqueeze(2)], dim=2)
        else:
            gates = gates[:, :, :self.hidden_dim * 2]  # (b,l,2h)

            state_cat = torch.cat([new_state.unsqueeze(2), CNN_output.unsqueeze(2)], dim=2)

        gates = torch.sigmoid(gates)
        gates = F.softmax(gates.view(batch_size, seq_len_cat, -1, self.hidden_dim), dim=2)

        layer_output = torch.sum(torch.mul(gates, state_cat), dim=2)  # (b,4l,h)

        output = torch.split(layer_output, seq_len, dim=1)

        return output

3. 融合多尺度特征的注意力机制

在这里插入图片描述

class MultiscaleAttention(nn.Module):

    def __init__(self, num_layer, dropout):
        super().__init__()
        self.MLP_layer = nn.Sequential(nn.Dropout(p=dropout), nn.Linear(num_layer, num_layer), nn.Tanh(),
                                       nn.Dropout(p=dropout), nn.Linear(num_layer, num_layer), nn.Tanh())

    def forward(self, X_list):
        seq_len = X_list.size(1)

        X_sum = torch.sum(X_list, dim=-1)  # (batch_size,seq_len,num_layer)
        MLP_output = self.MLP_layer(X_sum)
        weights = F.softmax(MLP_output, dim=-1)  # (b,m,l)
        weights_k = weights.unsqueeze(2)  # (b,m,1,l)

        weights_k = weights_k.view(-1, weights_k.size()[2], weights_k.size()[3])
        X_list_ = X_list.view(-1, X_list.size()[2], X_list.size()[3])
        X_attention = torch.bmm(weights_k, X_list_).squeeze(1)  # (b*m,1,l)*(b*m,l,k) = (b*m,1,k) ->(b*m,k)
        X_attention = X_attention.view(-1, seq_len, X_attention.size()[-1])  # (b,m,k)

        return X_attention

4. Run Model

!python lr-cnn/main.py --train lr-cnn/data/demo.train.char \
--dev lr-cnn/data/demo.dev.char \
--test lr-cnn/data/demo.test.char \
--modelname demo \
--modelpath lr-cnn/saved_model \
--savedset lr-cnn/data/data.dset \
--resultfile lr-cnn/result.txt
CuDNN: True
GPU available: False
Status: train
Seg:  True
Train file: lr-cnn/data/demo.train.char
Dev file: lr-cnn/data/demo.dev.char
Test file: lr-cnn/data/demo.test.char
Raw file: None
Char emb: ./data/gigaword_chn.all.a2b.uni.ite50.vec
Bichar emb: None
Gaz file: ./data/ctb.50d.vec
Loading processed data
DATA SUMMARY START:
     Tag          scheme: BMES
     MAX SENTENCE LENGTH: 250
     MAX   WORD   LENGTH: -1
     Number   normalized: True
     Use          bigram: False
     Word  alphabet size: 2578
     Biword alphabet size: 31749
     Char  alphabet size: 2578
     Gaz   alphabet size: 13652
     Label alphabet size: 18
     Word embedding size: 50
     Biword embedding size: 50
     Char embedding size: 30
     Gaz embedding size: 50
     Norm     word   emb: True
     Norm     biword emb: True
     Norm     gaz    emb: False
     Norm   gaz  dropout: 0.5
     Train instance number: 1147
     Dev   instance number: 113
     Test  instance number: 316
     Raw   instance number: 0
     Hyperpara  iteration: 10
     Hyperpara  batch size: 1
     Hyperpara          lr: 0.0015
     Hyperpara    lr_decay: 0.05
     Hyperpara     HP_clip: 5.0
     Hyperpara    momentum: 0
     Hyperpara  hidden_dim: 128
     Hyperpara     dropout: 0.5
     Hyperpara  lstm_layer: 1
     Hyperpara      bilstm: True
     Hyperpara         GPU: False
     Hyperpara     use_gaz: True
     Hyperpara fix gaz emb: False
     Hyperpara    use_char: False
DATA SUMMARY END.
Training model...
build batched crf...
finish building model.
Epoch: 0/10
 Learning rate is setted as: 0.0015
     Instance: 500; Time: 99.58s; loss: 13690.3076; acc: 18723/21586=0.8674
     Instance: 1000; Time: 103.04s; loss: 11134.5771; acc: 36963/42831=0.8630
     Instance: 1147; Time: 31.86s; loss: 2639.1494; acc: 42880/49471=0.8668
Epoch: 0 training finished. Time: 234.49s, speed: 4.89st/s,  total loss: tensor(27464.0527)
gold_num =  215  pred_num =  77  right_num =  47
Dev: time: 5.78s, speed: 19.56st/s; acc: 0.9058, p: 0.6104, r: 0.2186, f: 0.3219
Exceed previous best f score: -1
gold_num =  340  pred_num =  74  right_num =  14
Test: time: 15.84s, speed: 19.98st/s; acc: 0.9277, p: 0.1892, r: 0.0412, f: 0.0676
Best dev score: p:0.6103896103896104, r:0.2186046511627907, f:0.3219178082191781
Test score: p:0.1891891891891892, r:0.041176470588235294, f:0.06763285024154589
Epoch: 1/10
 Learning rate is setted as: 0.001425
     Instance: 500; Time: 110.16s; loss: 7914.4004; acc: 20198/22563=0.8952
     Instance: 1000; Time: 107.38s; loss: 6434.8618; acc: 38719/43109=0.8982
     Instance: 1147; Time: 31.27s; loss: 1979.6349; acc: 44423/49471=0.8980
Epoch: 1 training finished. Time: 248.81s, speed: 4.61st/s,  total loss: tensor(16328.9062)
gold_num =  215  pred_num =  125  right_num =  82
Dev: time: 5.92s, speed: 19.11st/s; acc: 0.9001, p: 0.6560, r: 0.3814, f: 0.4824
Exceed previous best f score: 0.3219178082191781
gold_num =  340  pred_num =  193  right_num =  61
Test: time: 15.34s, speed: 20.62st/s; acc: 0.9297, p: 0.3161, r: 0.1794, f: 0.2289
Best dev score: p:0.656, r:0.3813953488372093, f:0.4823529411764706
Test score: p:0.3160621761658031, r:0.17941176470588235, f:0.22889305816135083
Epoch: 2/10
 Learning rate is setted as: 0.00135375
     Instance: 500; Time: 107.27s; loss: 5768.9414; acc: 20116/22016=0.9137
     Instance: 1000; Time: 105.81s; loss: 5558.0767; acc: 39348/43234=0.9101
     Instance: 1147; Time: 29.74s; loss: 1212.4352; acc: 45135/49471=0.9124
Epoch: 2 training finished. Time: 242.83s, speed: 4.72st/s,  total loss: tensor(12539.4600)
gold_num =  215  pred_num =  180  right_num =  100
Dev: time: 5.54s, speed: 20.41st/s; acc: 0.9129, p: 0.5556, r: 0.4651, f: 0.5063
Exceed previous best f score: 0.4823529411764706
gold_num =  340  pred_num =  358  right_num =  104
Test: time: 14.66s, speed: 21.58st/s; acc: 0.9229, p: 0.2905, r: 0.3059, f: 0.2980
Best dev score: p:0.5555555555555556, r:0.46511627906976744, f:0.5063291139240507
Test score: p:0.2905027932960894, r:0.3058823529411765, f:0.2979942693409742
Epoch: 3/10
 Learning rate is setted as: 0.0012860624999999999
     Instance: 500; Time: 101.25s; loss: 4788.1846; acc: 19457/21272=0.9147
     Instance: 1000; Time: 102.18s; loss: 4560.3555; acc: 39612/43179=0.9174
     Instance: 1147; Time: 29.92s; loss: 1062.2068; acc: 45496/49471=0.9196
Epoch: 3 training finished. Time: 233.35s, speed: 4.92st/s,  total loss: tensor(10410.7480)
gold_num =  215  pred_num =  134  right_num =  98
Dev: time: 5.51s, speed: 20.53st/s; acc: 0.9260, p: 0.7313, r: 0.4558, f: 0.5616
Exceed previous best f score: 0.5063291139240507
gold_num =  340  pred_num =  242  right_num =  79
Test: time: 14.63s, speed: 21.62st/s; acc: 0.9343, p: 0.3264, r: 0.2324, f: 0.2715
Best dev score: p:0.7313432835820896, r:0.4558139534883721, f:0.5616045845272206
Test score: p:0.32644628099173556, r:0.2323529411764706, f:0.2714776632302406
Epoch: 4/10
 Learning rate is setted as: 0.0012217593749999998
     Instance: 500; Time: 104.86s; loss: 4107.0537; acc: 21269/22931=0.9275
     Instance: 1000; Time: 101.37s; loss: 3652.6150; acc: 40450/43620=0.9273
     Instance: 1147; Time: 29.68s; loss: 1106.2850; acc: 45850/49471=0.9268
Epoch: 4 training finished. Time: 235.91s, speed: 4.86st/s,  total loss: tensor(8865.9590)
gold_num =  215  pred_num =  142  right_num =  92
Dev: time: 5.69s, speed: 19.86st/s; acc: 0.9222, p: 0.6479, r: 0.4279, f: 0.5154
gold_num =  340  pred_num =  232  right_num =  88
Test: time: 15.11s, speed: 20.92st/s; acc: 0.9361, p: 0.3793, r: 0.2588, f: 0.3077
Best dev score: p:0.7313432835820896, r:0.4558139534883721, f:0.5616045845272206
Test score: p:0.32644628099173556, r:0.2323529411764706, f:0.2714776632302406
Epoch: 5/10
 Learning rate is setted as: 0.0011606714062499996
     Instance: 500; Time: 104.10s; loss: 3709.4668; acc: 20218/21847=0.9254
     Instance: 1000; Time: 102.69s; loss: 3248.8923; acc: 40318/43358=0.9299
     Instance: 1147; Time: 30.17s; loss: 897.1221; acc: 46020/49471=0.9302
Epoch: 5 training finished. Time: 236.96s, speed: 4.84st/s,  total loss: tensor(7855.4824)
gold_num =  215  pred_num =  151  right_num =  105
Dev: time: 5.67s, speed: 19.96st/s; acc: 0.9198, p: 0.6954, r: 0.4884, f: 0.5738
Exceed previous best f score: 0.5616045845272206
gold_num =  340  pred_num =  305  right_num =  104
Test: time: 15.02s, speed: 21.05st/s; acc: 0.9325, p: 0.3410, r: 0.3059, f: 0.3225
Best dev score: p:0.695364238410596, r:0.4883720930232558, f:0.5737704918032787
Test score: p:0.34098360655737703, r:0.3058823529411765, f:0.32248062015503876
Epoch: 6/10
 Learning rate is setted as: 0.0011026378359374996
     Instance: 500; Time: 101.53s; loss: 3128.2778; acc: 19688/21069=0.9345
     Instance: 1000; Time: 103.20s; loss: 2775.3899; acc: 40171/42792=0.9388
     Instance: 1147; Time: 31.19s; loss: 953.9064; acc: 46446/49471=0.9389
Epoch: 6 training finished. Time: 235.91s, speed: 4.86st/s,  total loss: tensor(6857.5693)
gold_num =  215  pred_num =  134  right_num =  78
Dev: time: 5.62s, speed: 20.10st/s; acc: 0.9165, p: 0.5821, r: 0.3628, f: 0.4470
gold_num =  340  pred_num =  238  right_num =  84
Test: time: 14.99s, speed: 21.09st/s; acc: 0.9376, p: 0.3529, r: 0.2471, f: 0.2907
Best dev score: p:0.695364238410596, r:0.4883720930232558, f:0.5737704918032787
Test score: p:0.34098360655737703, r:0.3058823529411765, f:0.32248062015503876
Epoch: 7/10
 Learning rate is setted as: 0.0010475059441406246
     Instance: 500; Time: 104.45s; loss: 2734.6147; acc: 20487/21803=0.9396
     Instance: 1000; Time: 104.03s; loss: 2777.1926; acc: 40927/43515=0.9405
     Instance: 1147; Time: 29.69s; loss: 615.7896; acc: 46602/49471=0.9420
Epoch: 7 training finished. Time: 238.16s, speed: 4.82st/s,  total loss: tensor(6127.5962)
gold_num =  215  pred_num =  134  right_num =  103
Dev: time: 5.65s, speed: 20.03st/s; acc: 0.9313, p: 0.7687, r: 0.4791, f: 0.5903
Exceed previous best f score: 0.5737704918032787
gold_num =  340  pred_num =  202  right_num =  78
Test: time: 15.07s, speed: 20.98st/s; acc: 0.9377, p: 0.3861, r: 0.2294, f: 0.2878
Best dev score: p:0.7686567164179104, r:0.4790697674418605, f:0.5902578796561604
Test score: p:0.38613861386138615, r:0.22941176470588234, f:0.2878228782287823
Epoch: 8/10
 Learning rate is setted as: 0.0009951306469335934
     Instance: 500; Time: 103.95s; loss: 2349.3535; acc: 20543/21679=0.9476
     Instance: 1000; Time: 103.65s; loss: 2417.0288; acc: 41031/43297=0.9477
     Instance: 1147; Time: 29.46s; loss: 657.4135; acc: 46879/49471=0.9476
Epoch: 8 training finished. Time: 237.07s, speed: 4.84st/s,  total loss: tensor(5423.7979)
gold_num =  215  pred_num =  140  right_num =  106
Dev: time: 5.46s, speed: 20.70st/s; acc: 0.9287, p: 0.7571, r: 0.4930, f: 0.5972
Exceed previous best f score: 0.5902578796561604
gold_num =  340  pred_num =  253  right_num =  125
Test: time: 14.60s, speed: 21.67st/s; acc: 0.9452, p: 0.4941, r: 0.3676, f: 0.4216
Best dev score: p:0.7571428571428571, r:0.4930232558139535, f:0.5971830985915493
Test score: p:0.49407114624505927, r:0.36764705882352944, f:0.42158516020236086
Epoch: 9/10
 Learning rate is setted as: 0.0009453741145869137
     Instance: 500; Time: 100.88s; loss: 2217.9604; acc: 20239/21321=0.9493
     Instance: 1000; Time: 102.44s; loss: 2388.3457; acc: 40928/43212=0.9471
     Instance: 1147; Time: 29.88s; loss: 562.2151; acc: 46906/49471=0.9482
Epoch: 9 training finished. Time: 233.20s, speed: 4.92st/s,  total loss: tensor(5168.5195)
gold_num =  215  pred_num =  140  right_num =  107
Dev: time: 5.49s, speed: 20.60st/s; acc: 0.9295, p: 0.7643, r: 0.4977, f: 0.6028
Exceed previous best f score: 0.5971830985915493
gold_num =  340  pred_num =  255  right_num =  141
Test: time: 15.01s, speed: 21.07st/s; acc: 0.9470, p: 0.5529, r: 0.4147, f: 0.4739
Best dev score: p:0.7642857142857142, r:0.49767441860465117, f:0.6028169014084508
Test score: p:0.5529411764705883, r:0.4147058823529412, f:0.47394957983193275
  • 3
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
以下是使用PyCharm实现Faster R-CNN的Python代码示例: ```python # 导入必要的库 import torch import torchvision from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.datasets import CocoDetection from torchvision.transforms import ToTensor from torch.utils.data import DataLoader # 加载COCO数据集 train_dataset = CocoDetection(root="./data/train2017", annFile="./data/annotations/instances_train2017.json", transform=ToTensor()) test_dataset = CocoDetection(root="./data/val2017", annFile="./data/annotations/instances_val2017.json", transform=ToTensor()) # 创建数据加载器 train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=4) test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False, num_workers=4) # 加载预训练的Faster R-CNN模型 model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) # 替换分类器 num_classes = 91 # 用于COCO数据集的类别数 in_features = model.roi_heads.box_predictor.cls_score.in_features model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) # 训练和测试模型 device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') model.to(device) params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1) num_epochs = 10 for epoch in range(num_epochs): model.train() for images, targets in train_loader: images = list(image.to(device) for image in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) optimizer.zero_grad() losses.backward() optimizer.step() lr_scheduler.step() model.eval() test_loss = 0.0 for images, targets in test_loader: images = list(image.to(device) for image in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] with torch.no_grad(): loss_dict = model(images, targets) test_loss += sum(loss for loss in loss_dict.values()).item() print(f"Epoch {epoch}: train_loss = {losses.item()}, test_loss = {test_loss / len(test_loader)}") ``` 需要注意的是,此处的代码仅提供了一个基本框架,具体实现需要根据实际需求进行相应的修改和调整。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值