访问模型参数
import torch
from torch. utils import data
from torch import nn
net = nn. Sequential( nn. Linear( 2 , 4 ) , nn. ReLU( ) , nn. Linear( 4 , 1 ) )
print ( ( net[ 0 ] . state_dict( ) ) )
print ( net[ 0 ] . bias)
print ( net[ 0 ] . bias. data)
rderedDict( [ ( 'weight' , tensor( [ [ 0.2167 , - 0.4856 ] ,
[ - 0.6567 , - 0.4602 ] ,
[ 0.1185 , 0.1153 ] ,
[ - 0.0756 , - 0.6647 ] ] ) ) , ( 'bias' , tensor( [ 0.4804 , 0.4804 , 0.3834 , 0.7041 ] ) ) ] )
Parameter containing:
tensor( [ 0.4804 , 0.4804 , 0.3834 , 0.7041 ] , requires_grad= True )
tensor( [ 0.4804 , 0.4804 , 0.3834 , 0.7041 ] )
保存Tensor数据
>> > x = torch. arange( 4 )
>> > torch. save( x, 'x-file' )
>> > x2 = torch. load( 'x-file' )
>> > x2
tensor( [ 0 , 1 , 2 , 3 ] )
保存模型参数
class MLP ( nn. Module) :
def __init__ ( self) :
super ( ) . __init__( )
self. hidden = nn. Linear( 1 , 2 )
self. output = nn. Linear( 2 , 2 )
def forward ( self, x) :
return self. output( F. relu( self. hidden( x) ) )
net = MLP( )
torch. save( net. state_dict( ) , 'mlp.params' )
clone = MLP( )
clone. load_state_dict( torch. load( 'mlp.params' ) )
clone. eval ( )
net. state_dict( ) , clone. state_dict( )
( OrderedDict( [ ( 'hidden.weight' , tensor( [ [ - 0.6865 ] ,
[ - 0.2476 ] ] ) ) ,
( 'hidden.bias' , tensor( [ 0.3468 , - 0.3090 ] ) ) ,
( 'output.weight' , tensor( [ [ 0.1645 , - 0.1768 ] ,
[ 0.7000 , - 0.6238 ] ] ) ) ,
( 'output.bias' , tensor( [ - 0.0137 , - 0.4359 ] ) ) ] ) ,
OrderedDict( [ ( 'hidden.weight' , tensor( [ [ - 0.6865 ] ,
[ - 0.2476 ] ] ) ) ,
( 'hidden.bias' , tensor( [ 0.3468 , - 0.3090 ] ) ) ,
( 'output.weight' , tensor( [ [ 0.1645 , - 0.1768 ] ,
[ 0.7000 , - 0.6238 ] ] ) ) ,
( 'output.bias' , tensor( [ - 0.0137 , - 0.4359 ] ) ) ] ) )