参考https://devtalk.nvidia.com/default/topic/1002826/question-about-cudnnsetconvolution2ddescriptor/
在cudnn.hpp中查看cudnnSetConvolution2dDescriptor的正确的调用方法:
template <typename Dtype>
inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
cudnnTensorDescriptor_t bottom, cudnnFilterDescriptor_t filter,
int pad_h, int pad_w, int stride_h, int stride_w) {
#if CUDNN_VERSION_MIN(6, 0, 0)
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv,
pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION,
dataType<Dtype>::type));
#else
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv,
pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION));
#endif
}
发现cudnn6.0还需要dataType::type类型的参数,所以在缺失的地方补上对应的CUDNN_DATA_FLOAT或CUDNN_DATA_DOUBLE即可。