EfficientNet网络详解_太阳花的小绿豆的博客-程序员秘密

技术标签: 深度学习  EfficientNet  分类网络  

原论文名称:EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks
论文下载地址:https://arxiv.org/abs/1905.11946
原论文提供代码:https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
自己使用Pytorch实现的代码: pytorch_classification/Test9_efficientNet
自己使用Tensorflow实现的代码: tensorflow_classification/Test9_efficientNet
不想看文章的可以看下我在bilibili上录制的视频:https://www.bilibili.com/video/BV1XK4y1U7PX



0 前言

在之前的一些手工设计网络中(AlexNet,VGG,ResNet等等)经常有人问,为什么输入图像分辨率要固定为224,为什么卷积的个数要设置为这个值,为什么网络的深度设为这么深?这些问题你要问设计作者的话,估计回复就四个字——工程经验。而这篇论文主要是用NAS(Neural Architecture Search)技术来搜索网络的图像输入分辨率 r r r,网络的深度 d e p t h depth depth以及channel的宽度 w i d t h width width三个参数的合理化配置。在之前的一些论文中,基本都是通过改变上述3个参数中的一个来提升网络的性能,而这篇论文就是同时来探索这三个参数的影响。在论文中提到,本文提出的EfficientNet-B7Imagenet top-1上达到了当年最高准确率84.3%,与之前准确率最高的GPipe相比,参数数量(Params)仅为其1/8.4,推理速度提升了6.1倍(看上去又快又轻量,但个人实际使用起来发现很吃显存)。下图是EfficientNet与其他网络的对比(注意,参数数量少并不意味推理速度就快)。

acc


1 论文思想

在之前的一些论文中,有的会通过增加网络的width即增加卷积核的个数(增加特征矩阵的channels)来提升网络的性能如图(b)所示,有的会通过增加网络的深度即使用更多的层结构来提升网络的性能如图(c)所示,有的会通过增加输入网络的分辨率来提升网络的性能如图(d)所示。而在本篇论文中会同时增加网络的width、网络的深度以及输入网络的分辨率来提升网络的性能如图(e)所示:

modelscaling

  • 根据以往的经验,增加网络的深度depth能够得到更加丰富、复杂的特征并且能够很好的应用到其它任务中。但网络的深度过深会面临梯度消失,训练困难的问题。

The intuition is that deeper ConvNet can capture richer and more complex features, and generalize well on new tasks. However, deeper networks are also more difficult to train due to the vanishing gradient problem

  • 增加网络的width能够获得更高细粒度的特征并且也更容易训练,但对于width很大而深度较浅的网络往往很难学习到更深层次的特征。

wider networks tend to be able to capture more fine-grained features and are easier to train. However, extremely wide but shallow networks tend to have difficulties in capturing higher level features.

  • 增加输入网络的图像分辨率能够潜在得获得更高细粒度的特征模板,但对于非常高的输入分辨率,准确率的增益也会减小。并且大分辨率图像会增加计算量。

With higher resolution input images, ConvNets can potentially capture more fine-grained patterns. but the accuracy gain diminishes for very high resolutions.

下图展示了在基准EfficientNetB-0上分别增加widthdepth以及resolution后得到的统计结果。通过下图可以看出大概在Accuracy达到80%时就趋于饱和了。
scalingup
接着作者又做了一个实验,采用不同的 d , r d, r d,r组合,然后不断改变网络的width就得到了如下图所示的4条曲线,通过分析可以发现在相同的FLOPs下,同时增加 d d d r r r的效果最好。
figure4
为了方便后续理解,我们先看下论文中通过 NAS(Neural Architecture Search) 技术搜索得到的EfficientNetB0的结构,如下图所示,整个网络框架由一系列Stage组成, F ^ i \widehat{F}_i F i表示对应Stage的运算操作, L ^ i \widehat{L}_i L i表示在该Stage中重复 F ^ i \widehat{F}_i F i的次数:
EfficientNetb0
作者在论文中对整个网络的运算进行抽象:
N ( d , w , r ) = ⊙ i = 1... s F i L i ( X ⟨ H i , W i , C i ⟩ ) N(d,w,r)=\underset{i=1...s}{\odot} {F}_i^{L_i}(X_{\left\langle{ {H}_i, {W}_i, {C}_i } \right\rangle}) N(d,w,r)=i=1...sFiLi(XHi,Wi,Ci)
其中:

  • ⊙ i = 1... s \underset{i=1...s}{\odot} i=1...s表示连乘运算。
  • F i {F}_i Fi表示一个运算操作(如上图中的Operator),那么 F i L i {F}_i^{L_i} FiLi表示在 S t a g e i {\rm Stage}i Stagei F i {F}_i Fi运算被重复执行 L i L_i Li次。
  • X X X表示输入 S t a g e i {\rm Stage}i Stagei的特征矩阵(input tensor)。
  • ⟨ H i , W i , C i ⟩ {\left\langle{ {H}_i, {W}_i, {C}_i } \right\rangle} Hi,Wi,Ci表示 X X X的高度,宽度,以及Channels(shape)。

为了探究 d , r , w d, r, w d,r,w这三个因子对最终准确率的影响,则将 d , r , w d, r, w d,r,w加入到公式中,我们可以得到抽象化后的优化问题(在指定资源限制下),其中 s . t . s.t. s.t.代表限制条件:

Our target is to maximize the model accuracy for any given resource constraints, which can be formulated as an optimization problem:

m a x d , w , r       A c c u r a c y ( N ( d , w , r ) ) s . t .      N ( d , w , r ) = ⊙ i = 1... s F ^ i d ⋅ L ^ i ( X ⟨ r ⋅ H ^ i ,   r ⋅ W ^ i ,   w ⋅ C ^ i ⟩ ) M e m o r y ( N ) ≤ t a r g e t _ m e m o r y            F L O P s ( N ) ≤ t a r g e t _ f l o p s          ( 2 ) \underset{d, w, r}{max} \ \ \ \ \ Accuracy(N(d, w, r)) \\ s.t. \ \ \ \ N(d,w,r)=\underset{i=1...s}{\odot} \widehat{F}_i^{d \cdot \widehat{L}_i}(X_{\left\langle{r \cdot \widehat{H}_i, \ r \cdot \widehat{W}_i, \ w \cdot \widehat{C}_i } \right\rangle}) \\ Memory(N) \leq {\rm target\_memory} \\ \ \ \ \ \ \ \ \ \ \ FLOPs(N) \leq {\rm target\_flops} \ \ \ \ \ \ \ \ (2) d,w,rmax     Accuracy(N(d,w,r))s.t.    N(d,w,r)=i=1...sF idL i(XrH i, rW i, wC i)Memory(N)target_memory          FLOPs(N)target_flops        (2)
其中:

  • d d d用来缩放深度 L ^ i \widehat{L}_i L i
  • r r r用来缩放分辨率即影响 H ^ i \widehat{H}_i H i W ^ i \widehat{W}_i W i
  • w w w就是用来缩放特征矩阵的channel C ^ i \widehat{C}_i C i
  • target_memorymemory限制
  • target_flops为FLOPs限制

接着作者又提出了一个混合缩放方法 ( compound scaling method) 在这个方法中使用了一个混合因子 ϕ \phi ϕ去统一的缩放width,depth,resolution参数,具体的计算公式如下,其中 s . t . s.t. s.t.代表限制条件:
d e p t h : d = α ϕ w i d t h : w = β ϕ        r e s o l u t i o n : r = γ ϕ            ( 3 ) s . t .         α ⋅ β 2 ⋅ γ 2 ≈ 2 α ≥ 1 , β ≥ 1 , γ ≥ 1        depth: d={\alpha}^{\phi} \\ width: w={\beta}^{\phi} \\ \ \ \ \ \ \ resolution: r={\gamma}^{\phi} \ \ \ \ \ \ \ \ \ \ (3) \\ s.t. \ \ \ \ \ \ \ {\alpha} \cdot {\beta}^{2} \cdot {\gamma}^{2} \approx 2 \\ \alpha \geq 1, \beta \geq 1, \gamma \geq 1 \ \ \ \ \ \ depth:d=αϕwidth:w=βϕ      resolution:r=γϕ          (3)s.t.       αβ2γ22α1,β1,γ1      

注意:

  • FLOPs(理论计算量)与depth的关系是:当depth翻倍,FLOPs也翻倍。
  • FLOPs与width的关系是:当width翻倍(即channal翻倍),FLOPs会翻4倍,因为卷积层的FLOPs约等于 f e a t u r e w × f e a t u r e h × f e a t u r e c × k e r n e l w × k e r n e l h × k e r n e l n u m b e r feature_w \times feature_h \times feature_c \times kernel_w \times kernel_h \times kernel_{number} featurew×featureh×featurec×kernelw×kernelh×kernelnumber(假设输入输出特征矩阵的高宽不变),当width翻倍,输入特征矩阵的channels( f e a t u r e c feature_c featurec)和输出特征矩阵的channels或卷积核的个数( k e r n e l n u m b e r kernel_{number} kernelnumber)都会翻倍,所以FLOPs会翻4倍
  • FLOPs与resolution的关系是:当resolution翻倍,FLOPs也会翻4倍,和上面类似因为特征矩阵的宽度 f e a t u r e w feature_w featurew和特征矩阵的高度 f e a t u r e h feature_h featureh都会翻倍。

所以总的FLOPs倍率可以用近似用 ( α ⋅ β 2 ⋅ γ 2 ) ϕ (\alpha \cdot \beta^{2} \cdot \gamma^{2})^{\phi} (αβ2γ2)ϕ来表示,当限制 α ⋅ β 2 ⋅ γ 2 ≈ 2 \alpha \cdot \beta^{2} \cdot \gamma^{2} \approx 2 αβ2γ22时,对于任意一个 ϕ \phi ϕ而言FLOPs相当增加了 2 ϕ 2^{\phi} 2ϕ倍。

接下来作者在基准网络EfficientNetB-0(在后面的网络详细结构章节会详细讲)上使用NAS来搜索 α , β , γ \alpha, \beta, \gamma α,β,γ这三个参数。

  • (step1)首先固定 ϕ = 1 \phi=1 ϕ=1,并基于上面给出的公式(2)和(3)进行搜索,作者发现对于EfficientNetB-0最佳参数为 α = 1.2 , β = 1.1 , γ = 1.15 \alpha=1.2, \beta=1.1, \gamma=1.15 α=1.2,β=1.1,γ=1.15
  • (step2)接着固定 α = 1.2 , β = 1.1 , γ = 1.15 \alpha=1.2, \beta=1.1, \gamma=1.15 α=1.2,β=1.1,γ=1.15,在EfficientNetB-0的基础上使用不同的 ϕ \phi ϕ分别得到EfficientNetB-1至EfficientNetB-7(在后面的EfficientNet(B0-B7)参数章节有给出详细参数)

需要注意的是,对于不同的基准网络搜索出的 α , β , γ \alpha, \beta, \gamma α,β,γ也不定相同。还需要注意的是,在原论文中,作者也说了,如果直接在大模型上去搜索 α , β , γ \alpha, \beta, \gamma α,β,γ可能获得更好的结果,但是在较大的模型中搜索成本太大(Google大厂居然说这种话),所以这篇文章就在比较小的EfficientNetB-0模型上进行搜索的。

Notably, it is possible to achieve even better performance by searching for α, β, γ directly around a large model, but the search cost becomes prohibitively more expensive on larger models. Our method solves this issue by only doing search once on the small baseline network (step 1), and then use the same scaling coefficients for all other models (step 2).


2 网络详细结构

下表为EfficientNet-B0的网络框架(B1-B7就是在B0的基础上修改ResolutionChannels以及Layers),可以看出网络总共分成了9个Stage,第一个Stage就是一个卷积核大小为3x3步距为2的普通卷积层(包含BN和激活函数Swish),Stage2~Stage8都是在重复堆叠MBConv结构(最后一列的Layers表示该Stage重复MBConv结构多少次),而Stage9由一个普通的1x1的卷积层(包含BN和激活函数Swish)一个平均池化层和一个全连接层组成。表格中每个MBConv后会跟一个数字1或6,这里的1或6就是倍率因子nMBConv中第一个1x1的卷积层会将输入特征矩阵的channels扩充为n倍,其中k3x3k5x5表示MBConvDepthwise Conv所采用的卷积核大小。Channels表示通过该Stage后输出特征矩阵的Channels
EfficientNetb0


2.1 MBConv结构

MBConv其实就是MobileNetV3网络中的InvertedResidualBlock,但也有些许区别。一个是采用的激活函数不一样(EfficientNet的MBConv中使用的都是Swish激活函数),另一个是在每个MBConv中都加入了SE(Squeeze-and-Excitation)模块。下图是我自己绘制的MBConv结构。

mbblock

如图所示,MBConv结构主要由一个1x1的普通卷积(升维作用,包含BN和Swish),一个kxkDepthwise Conv卷积(包含BN和Swish)k的具体值可看EfficientNet-B0的网络框架主要有3x35x5两种情况,一个SE模块,一个1x1的普通卷积(降维作用,包含BN),一个Droupout层构成。搭建过程中还需要注意几点:

  • 第一个升维的1x1卷积层,它的卷积核个数是输入特征矩阵channel n n n倍, n ∈ { 1 , 6 } n \in \left\{1, 6\right\} n{ 1,6}
  • n = 1 n=1 n=1时,不要第一个升维的1x1卷积层,即Stage2中的MBConv结构都没有第一个升维的1x1卷积层(这和MobileNetV3网络类似)。
  • 关于shortcut连接,仅当输入MBConv结构的特征矩阵与输出的特征矩阵shape相同时才存在(代码中可通过stride==1 and inputc_channels==output_channels条件来判断)。
  • SE模块如下所示,由一个全局平均池化,两个全连接层组成。第一个全连接层的节点个数是输入该MBConv特征矩阵channels 1 4 \frac{1}{4} 41,且使用Swish激活函数。第二个全连接层的节点个数等于Depthwise Conv层输出的特征矩阵channels,且使用Sigmoid激活函数。
  • Dropout层的dropout_rate在tensorflow的keras源码中对应的是drop_connect_rate后面会细讲(注意,在源码实现中只有使用shortcut的时候才有Dropout层)。

semodule


2.2 EfficientNet(B0-B7)参数

还是先给出EfficientNetB0的网络结构,方便后面理解。
EfficientNetb0
通过上面的内容,我们是可以搭建出EfficientNetB0网络的,其他版本的详细参数可见下表:

Model input_size width_coefficient depth_coefficient drop_connect_rate dropout_rate
EfficientNetB0 224x224 1.0 1.0 0.2 0.2
EfficientNetB1 240x240 1.0 1.1 0.2 0.2
EfficientNetB2 260x260 1.1 1.2 0.2 0.3
EfficientNetB3 300x300 1.2 1.4 0.2 0.3
EfficientNetB4 380x380 1.4 1.8 0.2 0.4
EfficientNetB5 456x456 1.6 2.2 0.2 0.4
EfficientNetB6 528x528 1.8 2.6 0.2 0.5
EfficientNetB7 600x600 2.0 3.1 0.2 0.5
  • input_size代表训练网络时输入网络的图像大小
  • width_coefficient代表channel维度上的倍率因子,比如在 EfficientNetB0中Stage13x3卷积层所使用的卷积核个数是32,那么在B6中就是 32 × 1.8 = 57.6 32 \times 1.8=57.6 32×1.8=57.6接着取整到离它最近的8的整数倍即56,其它Stage同理。
  • depth_coefficient代表depth维度上的倍率因子(仅针对Stage2Stage8),比如在EfficientNetB0中Stage7 L ^ i = 4 {\widehat L}_i=4 L i=4,那么在B6中就是 4 × 2.6 = 10.4 4 \times 2.6=10.4 4×2.6=10.4接着向上取整即11.
  • drop_connect_rate是在MBConv结构中dropout层使用的drop_rate,在官方keras模块的实现中MBConv结构的drop_rate是从0递增到drop_connect_rate的(具体实现可以看下官方源码注意,在源码实现中只有使用shortcut的时候才有Dropout层)。还需要注意的是,这里的Dropout层是Stochastic Depth,即会随机丢掉整个block的主分支(只剩捷径分支,相当于直接跳过了这个block)也可以理解为减少了网络的深度。具体可参考Deep Networks with Stochastic Depth这篇文章。
  • dropout_rate是最后一个全连接层前的dropout层(在stage9的Pooling与FC之间)的dropout_rate

最后给出原论文中关于EfficientNet与当时主流网络的性能参数对比:

EfficientNetvsothers

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_37541097/article/details/114434046

智能推荐

送干货,实用内联gulp插件——gulp-embed_dcof99817的博客-程序员秘密

现在npm上有很多gulp内联工具,用于把脚本和样式内嵌到HTML页面上,之前搞项目我也在这些插件中寻觅许久,但均不满足公司项目的一个需求—— HTML上同时插入了开发(dev版,src文件夹下,比如 src/index.html)和gulp处理后(build版,dest文件夹下,比如 dest/index.html)的两种版本的脚本,要求运行src文件夹下的该页面时,能忽略掉引入的bu...

IDA基本使用_ida pro函数偏移的彩色条有什么用_KOLOXIYA的博客-程序员秘密

1.可以被IDA解析的文件包括.exe、.so、.o等格式。——打开方式:1.选择文件(代码如下),点ok。(十进制转二进制,八进制代码)`#include<stdio.h>#include<math.h>int main(){int a,b,n;while(scanf("%d&a

新谈:为什么你觉得FPGA难学?如何入门?_FPGA技术江湖的博客-程序员秘密

新谈:为什么你觉得FPGA难学?如何入门?今天给大侠带来新谈:为什么你觉得FPGA难学?如何入门?以前发过一篇,但是也是很多年前了,大体上还是可以参考,随着技术的发展革新,有些内容还是要与时俱进一下的,可以聊一聊个人的最新看法,仅供参考学习,话不多说,上货。各位大侠会发现,FPGA技术江湖一直都在推送各种设关于FPGA的设计实例或者项目研发案例,会把设计方法和设计思想都阐述清楚,甚至有些会分享源码供各位大侠参考学习,很少去搞一些纯粹噱头性的文章,这就是因为现在社会的大环境差点意思,噱头性浮夸或者

android8.0 Launcher源码 ---Launcher的整体概述之桌面结构_安卓兼职framework应用工程师的博客-程序员秘密

概述说道Launcher,想必大家也都不陌生,很多人感觉很深奥的一个东西,其实他就是一个,launcher其实就是一个app,从功能上说,是对手机上其他app的一个管理和启动,从代码上说比其他app多了一个属性,就是在AndroidManifest.xml文件中多了一个“”属性 和,考虑的方面比较多,逻辑处理和代码规范性比较强,安卓各方面知识的应用比较多。如果系统只安装了一个launcher,就...

联通软件研究院笔试题1_联通软研杯题库_shizi599的博客-程序员秘密

试题描述:小明同学想将自己的零花钱存起来,捐献给贫困地区的同龄人。为了方便记录自己存钱的总数,于是,当他存的钱满10元,他就去换取一张10元的纸币;当满100元,他就换取一张100元的纸币,当满1000元,他就将10张100元纸币放在一起...。为了方便统计,小明构建了一个由正整数组成的数组,数组中每个元素都只存储当个数字。小明存储的零花钱的最高为放在数组的首位,最低位放在数组的末尾。现在他刚获得...

vue + ArcGIS 地图应用系列三:添加常规的地图组件_vue arcgis measurement_LuckRain7的博客-程序员秘密

为了页面的美观,这里我们使用的UI库为: Ant Design Vue项目源码仓库地址:https://github.com/LuckRain7/arcgis-api-for-javascript-vue 1. 首先创建工具菜单组件创建文件 src\components\ToolBar.vue并通过组件通信写好对应接口<template> <div class="toolbar"> <!-- 使用按钮组 --> <a-button-.

随便推点

在Spring中优雅关闭Pulsar消息消费者?_Java架构狮的博客-程序员秘密

这个github创建的示例应用程序以演示如何使用 Spring Boot 在 Java 中正确实现 Apache Pulsar 队列消费者的正常关闭。队列消费者实施强大的优雅关闭策略:我们是立即停止处理飞行中的队列消息,还是等待它们完成? 我们是否停止接受新的队列消息? 我们该如何处理本地排队的消息?想象一下,您的应用程序是一组汽车(容器的部署)愉快地行驶——灯是绿色的。现在,您需要停止应用程序,以便部署新版本。你可以告诉应用程序它需要立即停止(立即红灯)——但是,就像汽车接近十字路口一样,这

Maven问题-访问servlet报错 cannot be cast to javax.servlet__我的天哪的博客-程序员秘密

原因:jar包冲突tomcat 启动后先将tomcat/lib目录下的jar包全部读入内存,如果webapps目录里的应用程序中WEB-INF/lib目录下有相同的包,将无法加载,不同版本的包之间也会造成类似问题解决这个问题的方法就是对于servlet-ap.jar 使用 <scope>标签,编译的时候用到servlet-api,但在打包的时候不用这个依赖,配置成provided<dependency> <groupId>javax.servlet<..

电子邮箱格式是什么?电子邮箱怎么申请注册?_ZYX郑的博客-程序员秘密

电子邮箱的格式为:用户名、@符号、域名,每个人的邮箱账号都是独一无二的,如果想要申请注册电子邮箱,可以在Tom企业邮箱的官网中查看邮箱介绍、申请注册邮箱。

oracle如何导出导出(转)_weixin_30916125的博客-程序员秘密

问题:1、 使用oracle dump方式导出的数据在导入的时候不能随意选择一张表的数据进行导入;一般业务的表的数量都有1000左右,在出现由于某张表数据异常导致的故障时,用dump文件进行恢复基本不可行或者很费事;2、 表存在外键约束和触发器,使用oracle dump import方式导入表的时候不能非常简单的禁止触发器和外键,导致大量错误产生,导入的表的数据不全,部...

centos8安装Docker出现 package docker-ce-3:19.03.13-3.el7.x86_64 requires containerd.io >= 1.2.2-3_java后端指南的博客-程序员秘密

环境centos8过程今天安装centos8安装Docker的时候出现一个问题原因是:centos8默认使用podman代替docker,所以需要containerd.io,那我们就安装一下就好了解决办法是:安装containerd.io即可yum install https://download.docker.com/linux/fedora/30/x86_64/stable/Packages/containerd.io-1.2.6-3.3.fc30.x86_64.rpm然

推荐文章

热门文章

相关标签