import torch
import torch.nn as nn
import os
from allennlp.nn import util
torch.manual_seed(1)
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
class Test(nn.Module):
def __init__(self):
super().__init__()
pass
def forward(self, mix, offsets):
offsets2d = util.combine_initial_dims(offsets)
# now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
range_vector = util.get_range_vector(
offsets2d.size(0), device=util.get_device_of(mix)
).unsqueeze(1)
# selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
selected_embeddings = mix[range_vector, offsets2d]
return util.uncombine_initial_dims(selected_embeddings, offsets.size())
model = Test()
model.cuda()
mix = torch.randn([2, 5, 3]).cuda()
print(mix)
offsets = torch.tensor([[1, 3, 0], [1, 2, 4]]).cuda()
out = model(mix, offsets)
print(out)
# 转成onnx模型
ONNX_FILE_PATH = "./test.onnx"
torch.onnx.export(model,
(mix, offsets),
ONNX_FILE_PATH, opset_version=12, verbose=True, input_names=["input_ids", "offsets"],
output_names=["output"],
dynamic_axes={
'input_ids': {
0: 'batch_size',
1: 'seq_len',
},
'offsets': {
0: 'batch_size',
1: 'word_len',
}
},
export_params=True)
# 运行图
import onnxruntime as ort
ONNX_FILE_PATH = "./test.onnx"
ort_session = ort.InferenceSession(ONNX_FILE_PATH)
print(ort.get_device())
mix = torch.randn([3, 5, 3]).cuda()
print(mix)
offsets = torch.tensor([[1, 3, 0], [1, 3, 0], [1, 2, 4]]).cuda()
ort_inputs = {
ort_session.get_inputs()[0].name: mix.cpu().numpy(),
ort_session.get_inputs()[1].name: offsets.cpu().numpy(),
}
outputs = ort_session.run(None, ort_inputs)
print(outputs[0])
将一段简单的处理(对张量的第二个维度,按照索引取数,出现错误),存为onnx模型后,运行时出现错误:
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Add node. Name:‘Add_9’ Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:487 void onnxruntime::BroadcastIterator::Append(int64_t, int64_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 2 by 3
替换按照索引取数的操作
def forward(self, mix, offsets):
B, S, D = mix.size()
new_mix = mix.view(-1, D)
_, W = offsets.size()
right_add = torch.arange(0, B).unsqueeze(-1).cuda()
right_add = right_add * S
right_add.expand([B, W])
new_offsets = right_add + offsets
new_offsets = new_offsets.view(-1)
out1 = new_mix.index_select(0, new_offsets).view(B, W, -1)
return out1
问题解决。