代码复现
代码链接:https://github.com/TuBui/RoSteALS
一定要按照dockerfile,requirements.txt和requirements2.txt配置环境
需要补充的库:
pip安装:omegaconf slack slackclient bchlib (0.14.0版本) einops imagenet-c
conda安装:scikit-image,matplotlib
按照作者git的dependency配置环境的话就不会出现下面的3和4问题
遇见的bug
No module ‘xformers’. Proceeding without it
参考官网Installing xFormers:https://github.com/facebookresearch/xformers
conda install xformers -c xformers
需要python版本>=3.9,这是个加速库,不装也可以运行(我没装)
-
Import Error: MagickWand shared library not found
解决方案:Ubuntu安装imagemagick这个软件,因为sudo apt-get install libmagickwand-dev
一直报依赖的错装不上,我是在官网下载源码编译成功安装的, 此外还要装apt-get install python-wand
官网链接:https://docs.wand-py.org/en/latest/guide/install.html#install-wand-on-debian-ubuntu
ModuleNotFoundError: No module named 'pytorch_lightning.utilities.distributed'
版本问题,pytorch_lightning如果高于1.6.5就会出现,可以降级到1.4.2
解决方案:将from pytorch_lightning.utilities.distributed import rank_zero_only
修改为:from pytorch_lightning.utilities.rank_zero import rank_zero_only
-
版本问题,新版本的pytorch_lightning中TestTubeLogger不再受支持Traceback (most recent call last): File "/root/autodl-tmp/RoSteALS/train.py", line 133, in <module> app(args) File "/root/autodl-tmp/RoSteALS/train.py", line 118, in app trainer_kwargs.update(trainer_settings(config, output)) File "/root/autodl-tmp/RoSteALS/train.py", line 55, in trainer_settings logger = instantiate_from_config(logger) File "/root/autodl-tmp/RoSteALS/ldm/util.py", line 79, in instantiate_from_config return get_obj_from_str(config["target"])(**config.get("params", dict())) File "/root/autodl-tmp/RoSteALS/ldm/util.py", line 87, in get_obj_from_str return getattr(importlib.import_module(module, package=None), cls) AttributeError: module 'pytorch_lightning.loggers' has no attribute 'TestTubeLogger'. Did you mean: 'NeptuneLogger'?
解决:https://github.com/Lightning-AI/lightning/issues/13958
使用CSVLogger
ImportError: cannot import name 'VectorQuantizer2' from 'taming.modules.vqvae.quantize' (/root/miniconda3/envs/RoSte/lib/python3.8/site-packages/taming/modules/vqvae/quantize.py)
解决方案:按 https://github.com/CompVis/stable-diffusion/issues/72 中的方法处理,即用 https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py 新文件,替换错误提示文件中的全部内容即可。
注意bchlib的版本是0.14.0 ,不然运行inference的时候会出错
复现成功!
修改嵌入bit为256时报错
File "/root/autodl-tmp/RoSteALS/RoSteALS-main/cldm/ae.py", line 516, in forward return x + eps, posterior RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0
- 解决方案:修改训练集和测试集中的图片数量,需要是batch_size的整数倍
深度学习中的Tensor 数据格式(N,C,H,W)
1.1、4DTensor格式
4DTensor格式:使用4D张量描述符来定义具有4个字母的2D图像批处理的格式
数据布局由二维图像的四个字母表示:
N:Batch,批处理大小,表示一个batch中的图像数量
C:Channel,通道数,表示一张图像中的通道数
H:Height,高度,表示图像垂直维度的像素数
W:Width,宽度,表示图像水平维度的像素数
常用的4-D tensor 格式为:
NCHW
NHWC
CHWN
NC/32HW32
不同框架的支持
目前的主流ML 框架对 NCHW 和 NHWC 数据格式做了支持,有些框架可以支持两种且用户未作设置时有一个缺省值:
- TensorFlow:默认NHWC,GPU也支持 NCHW
- Caffe:NCHW
- PyTorch:NCHW
要修改的部分
01 VQ4_mir.yaml
VQ4_mir.yaml
中的secret_len
,修改为要嵌入的长度
如果要修改数据集的话也修改train
和test
中的data_dir
和data_list
改完之后可以训练,但是运行inference.py的时候提取准确率不高,发现推理时用了BCH编码,需要做对应修改
02 inference.py
secret ecc = ECC() # 在这里调用的ecc纠错码
secret = ecc.encode_text([args.secret]) # 1, 100
secret = torch.from_numpy(secret).cuda().float() # 1, 100
03 找到ecc.py
文件修改
上图中的BCH_POLYNOMIAL
需要根据要嵌入的比特数进行修改,下面是多项式表格
1024bit的参数修改
image_dataset.py里面的secret_len=1024
yaml文件里面的secret_len: 1024
ecc.py里面的
BCH_POLYNOMIAL = 1033
return 1024
s = s + ’ '*(121-len(s)) # 7 chars
data = np.random.binomial(1, .5, 208)
256bit的参数修改
image_dataset.py里面的secret_len=256
yaml文件里面的secret_len: 256
ecc.py里面的
BCH_POLYNOMIAL = 529
return 256
s = s + ’ '*(26-len(s)) # 7 chars
data = np.random.binomial(1, .5, 968)
BCH编码的参考资料:
https://zhuanlan.zhihu.com/p/481158241?utm_id=0
https://zhuanlan.zhihu.com/p/500425263?utm_id=0 有部分介绍
https://link.springer.com/content/pdf/bbm:978-1-4615-1509-8/1.pdf 这是提供的表格
1.4 重新配置服务器RoSteALS环境记录
- pip和conda按照官方装好,python=3.8
- pip安装
torch pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
下载链接https://pytorch.org/get-started/previous-versions/ - pytorch-lighting报错 https://blog.csdn.net/weixin_39379635/article/details/129159713
- 第三步之后报错 https://blog.csdn.net/qq_60592939/article/details/129177520