dnn应用的python代码_Python cudnn.descriptor方法代码示例

# 需要导入模块: from torch.backends import cudnn [as 别名]

# 或者: from torch.backends.cudnn import descriptor [as 别名]

def flatten_parameters(self):

"""Resets parameter data pointer so that they can use faster code paths.

Right now, this works only if the module is on the GPU and cuDNN is enabled.

Otherwise, it's a no-op.

"""

any_param = next(self.parameters()).data

if not any_param.is_cuda or not torch.backends.cudnn.is_acceptable(any_param):

self._data_ptrs = []

return

with torch.cuda.device_of(any_param):

# This is quite ugly, but it allows us to reuse the cuDNN code without larger

# modifications. It's really a low-level API that doesn't belong in here, but

# let's make this exception.

from torch.backends.cudnn import rnn

from torch.backends import cudnn

from torch.nn._functions.rnn import CudnnRNN

handle = cudnn.get_handle()

with warnings.catch_warnings(record=True):

fn = CudnnRNN(

self.mode,

self.input_size,

self.hidden_size,

num_layers=self.num_layers,

batch_first=self.batch_first,

dropout=self.dropout,

train=self.training,

bidirectional=self.bidirectional,

dropout_state=self.dropout_state,

)

# Initialize descriptors

fn.datatype = cudnn._typemap[any_param.type()]

fn.x_descs = cudnn.descriptor(any_param.new(1, self.input_size), 1)

fn.rnn_desc = rnn.init_rnn_descriptor(fn, handle)

# Allocate buffer to hold the weights

self._param_buf_size = rnn.get_num_weights(handle, fn.rnn_desc, fn.x_descs[0], fn.datatype)

fn.weight_buf = any_param.new(self._param_buf_size).zero_()

fn.w_desc = rnn.init_weight_descriptor(fn, fn.weight_buf)

# Slice off views into weight_buf

params = rnn.get_parameters(fn, handle, fn.weight_buf)

all_weights = [[p.data for p in l] for l in self.all_weights]

# Copy weights and update their storage

rnn._copyParams(all_weights, params)

for orig_layer_param, new_layer_param in zip(all_weights, params):

for orig_param, new_param in zip(orig_layer_param, new_layer_param):

orig_param.set_(new_param.view_as(orig_param))

self._data_ptrs = list(p.data.data_ptr() for p in self.parameters())

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值