关于 c10::Half 类型和float不匹配

相关错误

 # error-1 ; (all-no-half) self-attn RuntimeError: expected m1 and m2 to have the same dtype, but got: float != c10::Half
 # error-2 : (embed-half) self-attn RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
 # error-3 : model(half) embed(no-half) conv RuntimeError: Input type (float) and bias type (c10::Half) should be the same
 # error-4 : model(half) embed(half) conv RuntimeError: Input type (float) and bias type (c10::Half) should be the same
 # cuda + half all : RuntimeError: Input type (float) and bias type (c10::Half) should be the same

我在跑大模型推理的时候,遇到了上面的错误。

首先有一个问题需要考虑:

  1. 我希望模型以半精度的方式推理,所以在from_pretrained的时候,是以float16的方式加载的
self.llama_model = LlamaForCausalLM.from_pretrained(
                args.llama_model,  torch_dtype=torch.float16, ) 
  1. 我希望模型可以在gpu上面推理,但是我默认了模型会自动加载到gpu上面。。。

解决方法

  1. 检查llama模型是不是正确加载到gpu,一半出现 c10:Half 这个类型,模型很大概率是加载到CPU上面去推理的,所以只要修改到gpu上就不会报错了

  2. 模型推理的时候,记得加上autocase

with torch.cuda.amp.autocast():
....

最后代码

因为是修改R2genGPT的,所以代码如下:

class Generator:
    def __init__(self):
        pass

    def generate(self, input_conv, img_list):
        raise NotImplementedError
    
    
class R2genGPT_shallow(Generator):
    def __init__(self):
        super().__init__()
        args = parser.parse_args()
        # args.precision = "fp16"
        args.delta_file = "../checkpoints/R2genGPT/shallow_checkpoint_step14102.pth"
        args.vision_model = "microsoft/swin-base-patch4-window7-224"
        args.llama_model = "../checkpoints/Llama-2-7b-chat-hf"
        self.filed_parser = FieldParser(args)
        self.model = R2GenGPT(args)
        self.model.eval()  
        self.model.cuda()
        print("device : ", self.model.device)

    def adapt(self, query):
        query = query.replace("<image>", " ")
        return query

    def get_image_tensor(self, img_file):
        with Image.open(img_file) as pil:
            array = np.array(pil, dtype=np.uint8)
            if array.shape[-1] != 3 or len(array.shape) != 3:
                array = np.array(pil.convert("RGB"), dtype=np.uint8)
            image = self.filed_parser._parse_image(array)
            image = image.to(self.model.device)
        return image
    
    def generate(self, query, img_list):
        self.model.llama_tokenizer.padding_side = "right"

        images = []
        for img_file in img_list:
            image = self.get_image_tensor(img_file)
            images.append(image.unsqueeze(0))
        
        self.model.prompt = self.adapt(query)
        img_embeds, atts_img = self.model.encode_img(images)
        img_embeds = self.model.layer_norm(img_embeds)
        img_embeds, atts_img = self.model.prompt_wrap(img_embeds, atts_img)

        batch_size = img_embeds.shape[0]
        bos = torch.ones([batch_size, 1],
                         dtype=atts_img.dtype,
                         device=atts_img.device) * self.model.llama_tokenizer.bos_token_id
        bos_embeds = self.model.embed_tokens(bos)
        atts_bos = atts_img[:, :1]

        inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1)
        attention_mask = torch.cat([atts_bos, atts_img], dim=1)

        with torch.inference_mode():
            with torch.cuda.amp.autocast():
                outputs = self.model.llama_model.generate(
                    inputs_embeds=inputs_embeds,
                    num_beams=self.model.hparams.beam_size,
                    do_sample=self.model.hparams.do_sample,
                    min_new_tokens=self.model.hparams.min_new_tokens,
                    max_new_tokens=self.model.hparams.max_new_tokens,
                    repetition_penalty=self.model.hparams.repetition_penalty,
                    length_penalty=self.model.hparams.length_penalty,
                    temperature=self.model.hparams.temperature,
                )
            
        answer = self.model.decode(outputs[0])
        return answer
    
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值