引言:
pytorch的官方教程中:
- 例子1是用numpy+Funciton手动写的一个前向传播和反向传播
- 例子2是一个有需要优化参数的例子
- 例子3是用c++写的拓展,1.0之前是用c写的拓展
这里记录用vscode配置c++,然后导入torch的TH.h头文件的过程
ubuntu配置vscode支持c++调试
具体参考网上的教程
Ubuntu16.04下配置VScode的C/C++开发环境
注意:
- tasks.json中需要加上“-std=c++11”这个参数表示支持c++11标准
将pytorch的头文件TH.h等加入到C++的编译路径
1. 软链接的方式
- 找到c++编译器的路径,本人的在
/usr/include/c++/5/
, 可以通过一下命令找到位置:
locate iostream
- 找到pytorch的拓展TH.h所在的位置,命令:
locate TH.h
或者find -name TH.h
,如下是本人的:
(base) chenjun@chenjun-ThinkCentre-M910t-N000:~$ find -name TH.h
find: `./.cache/dconf': 权限不够
./.local/lib/python3.5/site-packages/torch/lib/include/TH/TH.h
./anaconda2/lib/python2.7/site-packages/torch/include/TH/TH.h
./anaconda2/envs/pytorch3/lib/python3.6/site-packages/torch/lib/include/TH/TH.h
./anaconda2/envs/caffe36/lib/python3.6/site-packages/torch/include/TH/TH.h
./anaconda2/envs/chainer/lib/python3.6/site-packages/torch/lib/include/TH/TH.h
./anaconda2/envs/mypytorch/lib/python2.7/site-packages/torch/lib/include/TH/TH.h
./anaconda2/pkgs/pytorch-nightly-1.0.0.dev20190301-py3.6_cuda9.0.176_cudnn7.4.2_0/lib/python3.6/site-packages/torch/include/TH/TH.h
./anaconda2/pkgs/pytorch-nightly-1.0.0.dev20190301-py2.7_cuda10.0.130_cudnn7.4.2_0/lib/python2.7/site-packages/torch/include/TH/TH.h
./anaconda2/pkgs/pytorch-nightly-1.0.0.dev20190301-py2.7_cuda9.0.176_cudnn7.4.2_0/lib/python2.7/site-packages/torch/include/TH/TH.h
./anaconda2/pkgs/pytorch-0.4.1-py27__9.0.176_7.1.2_2/lib/python2.7/site-packages/torch/lib/include/TH/TH.h
./anaconda2/pkgs/pytorch-0.4.1-py36_py35_py27__9.0.176_7.1.2_2/lib/python3.6/site-packages/torch/lib/include/TH/TH.h
这里因为有多个conda环境所以有多个TH.h
- 建立软链接,
sudo ln -s /home/chenjun/anaconda2/envs/mypytorch/lib/python2.7/site-packages/torch/lib/include/ATen /usr/include/c++/5/ # 可能要sudo,不加可能权限不够
2. 在.bashrc中添加环境变量
- 在~/.bashrc中为c++添加环境变量,具体语句如下:
export CPLUS_INCLUDE_PATH=/home/chenjun/anaconda2/envs/mypytorch/lib/python2.7/site-packages/torch/lib/include/:$CPLUS_INCLUDE_PATH # 改为自己的位置就行
添加完成之后,好像一定要重启,不然无法导入。
测试
直接上代码,如果编译能通过则没有问题。
#include <TH/TH.h> // pytorch的头文件
#include <math.h>
#include <iostream>
int main(int argc, char const *argv[])
{
std::cout << "chenjun" << std::endl;
return 0;
}
以上代码只是在普通的c++代码中,导入了pytorch的头文件。