onnx 模型转换及推理时间对比_onnxruntime 推理时间_CV-deeplearning的博客-程序员宅基地

技术标签: 人工智能  pytorch  模型部署  onnx  

目录

1. 环境准备

2. 测试过程

3. 测试结果与分析


1. 环境准备

       对比时间,和模型训练的环境相同,可能额外要安装的包是onnxruntime.

pip install onnxruntime      # for cpu
pip install onnxruntime-gpu  # for gpu

2. 测试过程

    直接上代码吧,代码就是最好的解释。

import cv2
import time
import torch
import numpy as np
from torch.nn import DataParallel
from MobileNetV2 import mobilenet_v2
from collections import OrderedDict
from torchvision import transforms as T
import onnxruntime as rt


def torch2onnx(model, save_path):
    model.eval()
    data = torch.rand(1,3,256,256)
    input_names = ['input']
    output_names = ['out']
    torch.onnx.export(model,
                      data,
                      save_path,
                      export_params=True,
                      opset_version=11,
                      input_names=input_names,
                      output_names=output_names)
    print("torch2onnx finish")


def img_process(img_path):
    normalize = T.Normalize(mean = [0.5, 0.5, 0.5],
                            std = [0.5, 0.5, 0.5])
    transforms = T.Compose([T.ToTensor(),
                            normalize])
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (256, 256))
    img = transforms(img)
    img = img.unsqueeze(0)
    return img


def onnx_runtime(img):
    sess = rt.InferenceSession("mobilenet-v2.onnx")
    input_name = sess.get_inputs()[0].name
    output_name = sess.get_outputs()[0].name
    t0 = time.time()
    for i in range(1000):
        pred_onnx = sess.run([output_name], {input_name:np.array(img)})
    t1 = time.time()
    print("用onnx完成1000次推理消耗的时间:%s" % (t1-t0))
    print("用onnx推理的结果如下:")
    print(pred_onnx[0].tolist())


def model_load(model_pth):
    state_dict = torch.load(model_pth, map_location=device)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k
        if name.startswith("module."):
            name = name[7:]
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    model.eval()
    return model


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = mobilenet_v2().to(device)
    model_pth = "./mobilenet-v2.pth"
    model = model_load(model_pth)

    img = img_process("test.jpg")
    t0 = time.time()
    for i in range(1000):
        outputs = model(img)
    t1 = time.time()
    print("用pytorch完成1000次推理消耗的时间:%s" % (t1-t0))
    print("用pytorch推理的结果如下:")
    print(outputs[0].detach().tolist())
    print()

    torch2onnx(model, "mobilenet-v2.onnx")

    onnx_runtime(img)

3. 测试结果与分析

       运行上面代码,输出如下(我是用cpu跑的结果):

      可以看到用pytorch和onnx的推理结果几乎相同,完全可以接受。然而,用onnx的推理速度是pytorch的好几倍。

 

OK,就是这么简单~

想要完整代码和模型,请联系博主,下面是微信二维码。

 

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/Guo_Python/article/details/116276618

智能推荐

if-else嵌套太深?教你一个新手都能掌握的设计模式搞定!-程序员宅基地

△Hollis, 一个对Coding有着独特追求的人△这是Hollis的第259篇原创分享作者 l 南山狮来源 l Hollis(ID:hollischuang)我也不用设计模式很多...

乘法逆元的几种求法总结-程序员宅基地

乘法逆元对于缩系中的元素,每个数a均有唯一的与之对应的乘法逆元x,使得ax≡1(mod n) 一个数有逆元的充分必要条件是gcd(a,n)=1,此时逆元唯一存在 逆元的含义:模n意义下,1个数a如果有逆元x,那么除以a相当于乘以x。下面给出求逆元的几种方法1 循环找解法给定模m和需要求逆的数x,直接暴力枚举1~m-1 检查是否有x*i=1(mod m)

目标检测深度学习方法综述(一)_深度学习的目标检测最新方法-程序员宅基地

0.前言从去年九月份以来,我断断续续的接触并了解了深度学习中目标检测方面的知识。读了几篇论文,也尝试着跑了几个代码,对目标检测领域的深度学习方法有了大致的了解,一直准备写一篇综述性的学习报告,来总结我所学到,看到的知识点。但由于所学太过于零碎,不成体系,一直没有动手整理。直到前两天我导让我将18年之后比较典型的神经网络模型总结一哈,我这下定决心,准备将我这大半年所学到的关于深度学习目标检测的知识..._深度学习的目标检测最新方法

python绘制等值线图_python,matplotlib_python菜鸟求助,使用matplotlib 绘制contour等高线图,z为2D数组?,python,matplotlib - php...-程序员宅基地

python菜鸟求助,使用matplotlib 绘制contour等高线图,z为2D数组?有一系列点坐标如下所示:x,y,z74,781,51373,731,111321,1791,280,1787,41049,2127,121647,2728,62883,3617,152383,3692,72708,2295,222933,1767,74233,895,64043,1895,14想通过conto..._contourf的input z must be2d

解决CentOS添加新网卡后找不到网卡配置文件_etc下找不到网卡配置文件-程序员宅基地

进入CentOS7系统后,使用ip addr 查看状态如下:发现ens33和ens77均有IP地址,且可正常使用,ens33使用的是手动配置IP,ens37使用的是dhcp自动获取的IP地址,但是/etc/sysconfig/network-scripts/目录下找不到ifcfg-ens37配置文件解决方案:1.使用nmcli con show命令,查看网卡的UUID信息,记下UUID值2.使用ip addr命令查看网卡信息,记下ens37网卡的MAC地址3.将 /etc/s_etc下找不到网卡配置文件

回文词(Palindromes, UVa 401)_回文词ac-程序员宅基地

输入一个字符串,判断它是否为回文串以及镜像串。输入字符串保证不含数字0。所谓回文串,就是反转以后与原串相同,如abba和madam。所谓镜像串,就是左右镜像之后和原串相同,如2S和3AIAE。注意,并不是每个字符在镜像之后都能得到一个合法字符。(空白项表示该字符镜像后不能得到一个合法字符。)Character Reverse Character Reverse Character Reverse_回文词ac

随便推点

CountVectorizer参数-程序员宅基地

https://zhuanlan.zhihu.com/p/37644086

编写高性能的Java代码需要注意的4个问题-程序员宅基地

一、并发无法创建新的本机线程…问题1:Java的中创建一个线程消耗多少内存?每个线程有独自的栈内存,共享堆内存问题2:一台机器可以创建多少线程?CPU,内存,操作系统,JVM,应用服务器我们编写一段示例代码,来验证下线程池与非线程池的区别://线程池和非线程池的区别public class ThreadPool {public static int times = 100;//1...

impdp\expdp 指定表,导入到不同用户不同表空间里_impdp导入不同用户不同表空间命令-程序员宅基地

A用户下面的数据导入到B用户下面,前提B用户下面不存在A用户所存在的表--tables 指定的具体表名“,”隔开 directory:文件路径 DUMPFILE:指定导出数据文件名 version :指定版本REUSE_DUMPFILES 覆盖执行的时候,一定要把执行语句放在同一行,不然复制到cmd窗口会少导出:expdp userName01/password01 ..._impdp导入不同用户不同表空间命令

部门工资前三高所有员工_公司id部门的工资有多高-程序员宅基地

Employee 表包含所有员工信息,每个员工有其对应的工号 Id,姓名 Name,工资 Salary 和部门编号 DepartmentId 。±—±------±-------±-------------+| Id | Name | Salary | DepartmentId |±—±------±-------±-------------+| 1 | Joe | 85000 ..._公司id部门的工资有多高

超分辨率 EDSR开源项目_edsr超分辨率继续训练-程序员宅基地

对于常用框架EDSR的一点点理解_edsr超分辨率继续训练

wmsys.wmconcat mysql_oracle 12C wmsys.wm_concat()函数-程序员宅基地

对于一些业务,需要连接函数把内容拼接文本文件的时候,借助合适的函数,非常重要,减少很多工作。目前常用的连接函数有wmsys.wm_concat()和LISTAGG()函数,当然还有看拼接内容的长度来选。oracle数据库中,还有一个根据版本选择。最新的两个版本中,11G中,自带有两个函数,但在12C中,oracle不再自带wmsys.wm_concat(),如果实际业务中需要到,需要自己创建上。当..._wmsys.wm_conca