builtins自定义_Pytorch学习 (二十一) ------自定义C++/ATen扩展

总说

没办法, 出来混总是要还的, 不会写点底层代码没法混啊. 废话不多说, 简单来说, 有时候我们需要写一些自定义的操作, 这些操作如果用python写会很慢, 我们需要用CUDA写, 然后这些操作与python绑定, 以供python端调用.

主要是简略拿出 https://pytorch.org/tutorials/advanced/cpp_extension.html 的东西, 根据实践, 补充了一些东西(否则, 直接看官方文档可能会有一些地方需要花点实践), 没毛病. 看这个博客, 可以的.

示例程序

class LLTM(torch.nn.Module):

def __init__(self, input_features, state_size):

super(LLTM, self).__init__()

self.input_features = input_features

self.state_size = state_size

# 3 * state_size for input gate, output gate and candidate cell gate.

# input_features + state_size because we will multiply with [input, h].

self.weights = torch.nn.Parameter(

torch.empty(3 * state_size, input_features + state_size))

self.bias = torch.nn.Parameter(torch.empty(3 * state_size))

self.reset_parameters()

def reset_parameters(self):

stdv = 1.0 / math.sqrt(self.state_size)

for weight in self.parameters():

weight.data.uniform_(-stdv, +stdv)

def forward(self, input, state):

old_h, old_cell = state

X = torch.cat([old_h, input], dim=1)

# 自定义C++扩展, 可以让这些操作, 变成一个fused的版本

# Compute the input, output and candidate cell gates with one MM.

gate_weights = F.linear(X, self.weights, self.bias)

# Split the combined gate weight matrix into its components.

gates = gate_weights.chunk(3, dim=1)

input_gate = F.sigmoid(gates[0])

output_gate = F.sigmoid(gates[1])

# Here we use an ELU instead of the usual tanh.

candidate_cell = F.elu(gates[2])

# Compute the new cell state.

new_cell = old_cell + candidate_cell * input_gate

# Compute the new hidden state and output.

new_h = F.tanh(new_cell) * output_gate

return new_h, new_cell

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

import torch

X = torch.randn(batch_size, input_features)

h = torch.randn(batch_size, state_size)

C = torch.randn(batch_size, state_size)

rnn = LLTM(input_features, state_size)

new_h, new_C = rnn(X, (h, C))

1

2

3

4

5

6

7

8

9

C++扩展

简单来说, 我们写完C++程序后, python要用这些程序, 可以用pybind11. 然而, 安装pybind11需要用到pytest, 而pytest貌似只能在python3.5以上才能运行. 所以我们先弄个基于python3的Anaconda. 装好pytorch之后(随便你咋装上的). 然后再进行后续操作.

pybind11安装

git clone https://github.com/pybind/pybind11.git

pip install pytest

1

2

注意一下, 这里的pip -V最好是显示anaconda3中的pip, 从而确保下载的pytest是python3版本.

cd pybind11

mkdir build

cd build

cmake ..

make check -j 4

1

2

3

4

5

编译好的动态库是test目录下的so文件.

pytorch的相关事宜

我们要写pytorch扩展, 得下载pytorch源代码.

git clone --recursive https:

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值