pytorch 轉(zhuǎn)tensorflow注意

最近在作pytorch 模型轉(zhuǎn)tensorflow,通過onnx 中間轉(zhuǎn)換和容易,但是再轉(zhuǎn)換時(shí)有一個(gè)注意事項(xiàng),即如何處理batch

pytorch 模型轉(zhuǎn)tensorflow: http://m.itdecent.cn/p/3e5623696a8e

通過 onnx 手動(dòng)修改batch 為動(dòng)態(tài)值: model_onnx.graph.input[0].type.tensor_type.shape.dim[0].dim_param ='?'

這下就可以使用batch 了。

最后別忘了用transform_graph 壓縮下模型大?。?http://m.itdecent.cn/p/d2637646cda1

self.model.load_state_dict(model_dict)

example = torch.ones(1,1,112,112).cuda() #限定好的tensor 輸入大小

# traced_script_module = torch.jit.trace(self.model, example)

# traced_script_module.save('./lt_model.pt')

torch.onnx.export(self.model, example,'./model_simple.onnx',input_names=['input'],

output_names=['output'])

model_onnx = onnx.load('./model_simple.onnx')

model_onnx.graph.input[0].type.tensor_type.shape.dim[0].dim_param ='?'

tf_rep = prepare(model_onnx)

print(tf_rep.tensor_dict)

tf_rep.export_graph('./lt_model.pb')

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容