pytorch每日一学7(torch.get_default_dtype())

本文介绍了PyTorch中获取默认浮点数类型的方法torch.get_default_dtype()。此方法用于查询系统为未指定数据类型的浮点张量所设定的默认类型。通过示例展示如何使用该方法,并解释了其与torch.set_default_dtype()方法的关系。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

7.第七个方法

torch.get_default_dtype()
  • 如果你看过我的上一个方法torch.set_defautl_dtype()),那么这个方法马上就能理解,这两个方法是一对的,torch.set_defautl_dtype())设置默认浮点类型,而torch.get_defautl_dtype())获取默认浮点类型。
  • 在pytorch中,我们的浮点tensor如果我们不指定其数据类型的话系统会给它一个默认的类型,此方法的作用就是获得这个类型。
import torch
a = torch.tensor([1.,2.])
a.dtype

在这里插入图片描述

  • 可以知道,这个默认类型是torch.float32,我们调用此方法
    在这里插入图片描述
  • 所以此方法返回的就是默认的float类型,如果我们改变默认类型
    在这里插入图片描述
  • 那么对应返回的类型也就改变了。
这是个crossattention模块:class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.scale = dim_head ** -0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) def forward(self, x, context=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) k = self.to_k(context) v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) # force cast to fp32 to avoid overflowing if _ATTN_PRECISION =="fp32": with torch.autocast(enabled=False, device_type = 'cuda'): q, k = q.float(), k.float() sim = einsum('b i d, b j d -> b i j', q, k) * self.scale else: sim = einsum('b i d, b j d -> b i j', q, k) * self.scale del q, k if exists(mask): mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h=h) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of sim = sim.softmax(dim=-1) out = einsum('b i j, b j d -> b i d', sim, v) out = rearrange(out, '(b h) n d -> b n (h d)', h=h) return self.to_out(out) 我如何从中提取各个提示词的注意力热力图并用Gradio可视化?
07-15
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值