import torch
import torch.nn as nn
import torch.nn.functional as F
import onnxruntime as ort
import numpy as np
def demo_scaled_dot_product_attention():
query = torch.randn(1,16,1,128)
key = torch.randn(1,16,32,128)
value = torch.randn(1,16,32,128)
attention_mask = torch.ones((1,1,32,32), dtype=torch.bool)
attn_output = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask
).transpose(1, 2)
def demo_scaled_dot_product_attention_model():
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, query, key, value, attention_mask):
attn_output = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask
).transpose(1, 2)
return attn_output
model = MyModel()
query = torch.randn(1,16,1,128)
key = torch.randn(1,16,1,128)
value = torch.randn(1,16,1,128)
attention_mask = torch.ones((1,1,1,1), dtype=torch.bool)
torch.onnx.export(model, (query, key, value, attention_mask), "MyModel.onnx", input_names=["query", "key", "value", "attention_mask"], output_names=["output"])
def test_my_model():
session = ort.InferenceSession("MyModel.onnx")
query = np.random.randn(1,16,1,128).astype(np.float32)
key = np.random.randn(1,16,1,128).astype(np.float32)
value = np.random.randn(1,16,1,128).astype(np.float32)
attention_mask = np.ones((1,1,1,1), dtype=np.bool_)
output = session.run(["output"], {"query":query, "key":key, "value":value, "attention_mask":attention_mask})
print(output[0].shape)
def main():
demo_scaled_dot_product_attention()
# demo_scaled_dot_product_attention_model()
# test_my_model()
if __name__ == "__main__":
main()
---------------------------------------------------------
导出onnx如下: