前情描述
这次要讲的问题很是有含金量,一是打破了之前固有的不要改库函数的思想,二是提供了一个分析问题的解决途径。
问题描述
这次的问题在之前几篇有意无意的提起过,但因为当时没有很好的解决办法,所以这块就没有展开去说,问题是什么呢?在用配置文件noise.yaml、content.yaml、unrelated.yaml、frontier_stitching.yaml来进行训练模型加水印时报错RuntimeError: "nll_loss_forward_reduce_cuda_kernel_2d_index" not implemented for 'Float',这个在网上并没有找到解决办法,有的也只是RuntimeError: "nll_loss_forward_reduce_cuda_kernel_2d_index" not implemented for 'int',也就是说之前因为int类型不兼容出了问题,那这个float报错的原因是什么呢?
原因分析及解决
仔细分析会发现,在RuntimeError: "nll_loss_forward_reduce_cuda_kernel_2d_index" not implemented for 'int'的解决办法中提到的是应该输入float或者double而不是int,所以在输入int的时候会出现这个问题,但是呢我这里边输入的就是float啊,为什么会报这种错误呢?这个原因想了很久我都没有想明白。经过各种打印输出也很能确定的是:我的输入类型就是float:
那就奇了怪了,他需要的是float,我输入的也是float,但是却一直报错RuntimeError: "nll_loss_forward_reduce_cuda_kernel_2d_index" not implemented for 'Float'。
于是不得不求助大佬,上来就说是类型的问题,当然这个我是知道的,就是犯迷糊,那为什么我给了他想要的,程序不给我我想要的呢?
只见大佬一个操作,解决了我的疑问:
原来导致问题的一直是问题本身,因为这本来就是一个误解的问题,我没想到啊万万没想到啊,输入的另外一个参数时需要long而不是float,这不是大水冲了龙王庙吗?代码本身是个多参数,能看到的形式只有一个包装后的
所以这个封装的就很深,能想到的一个解决办法就是直接在torch包里进行修改target.dtype,这是个很大胆的行为,因为之前还真没改过torch包本身的东西,修改的话可能会导致其他引用的包出现问题,这是不可避免的。所以需要谨慎一些。
知道了这个原因只需要把torch包的给改掉就行了:也就是进行强制类型转换
target=target.type(torch.long)
这个时候再运行的话就会报另外一个错误:no module named torch,因为这里平白引入了这个torch嘛,所以在最开头或者是引用的位置要执行下面操作:
import torch
不用担心是在torch的包下还import torch可不可行,答案是肯定的,因为python模块的引用嘛。做完这些,程序就可以运行了,很瓦塞。
延伸阅读
使用pyhon时必然会用到各种各样的库,特别是有些流行的库,不得不用,但又不能很好的满足我们的需求,这时就需要对库进行修改,那么如何修改呢?
完结撒花
当然除了这个之外还有其他的问题,我们在后面的博客中继续讲。