python cnn 实例_python实现简单的卷积神经网络CNN案例1:定义CNN网络结构_weixin_39883374的博客-程序员秘密

技术标签: python cnn 实例  

本案例中定义的CNN网络模型如下:

cnn.py文件中的__init__()函数主要作用是对卷积神经网络的参数w1,w2,w3等进行初始化,下面是该函数的代码:

def __init__(self, input_dim=(3, 32, 32), num_filters=32, filter_size=7,

hidden_dim=100, num_classes=10, weight_scale=1e-3, reg=0.0,

dtype=np.float32):

self.params = {}

self.reg = reg

self.dtype = dtype

# Initialize weights and biases

C, H, W = input_dim

self.params['W1'] = weight_scale * np.random.randn(num_filters, C, filter_size, filter_size)

self.params['b1'] = np.zeros(num_filters)

self.params['W2'] = weight_scale * np.random.randn(int(num_filters*H*W/4), hidden_dim)

self.params['b2'] = np.zeros(hidden_dim)

self.params['W3'] = weight_scale * np.random.randn(hidden_dim, num_classes)

self.params['b3'] = np.zeros(num_classes)

for k, v in self.params.items():

self.params[k] = v.astype(dtype)

无论是caffe还是tensorflow类型都会把数据类型转为float32类型,所以__init__()函数最后一个参数定义为dtype=np.float32。

cnn.py文件中还有另外一个函数 loss()函数,这两个函数起到了定义CNN卷积网络结构的作用。下面是loss()函数的代码:

def loss(self, X, y=None):

W1, b1 = self.params['W1'], self.params['b1']

W2, b2 = self.params['W2'], self.params['b2']

W3, b3 = self.params['W3'], self.params['b3']

# pass conv_param to the forward pass for the convolutional layer

filter_size = W1.shape[2]

conv_param = {'stride': 1, 'pad': (int)((filter_size - 1) / 2)}

# pass pool_param to the forward pass for the max-pooling layer

pool_param = {'pool_height': 2, 'pool_width': 2, 'stride': 2}

# compute the forward pass

a1, cache1 = conv_relu_pool_forward(X, W1, b1, conv_param, pool_param)

a2, cache2 = affine_relu_forward(a1, W2, b2)

scores, cache3 = affine_forward(a2, W3, b3)

if y is None:

return scores

# compute the backward pass

data_loss, dscores = softmax_loss(scores, y)

da2, dW3, db3 = affine_backward(dscores, cache3)

da1, dW2, db2 = affine_relu_backward(da2, cache2)

dX, dW1, db1 = conv_relu_pool_backward(da1, cache1)

# Add regularization

dW1 += self.reg * W1

dW2 += self.reg * W2

dW3 += self.reg * W3

reg_loss = 0.5 * self.reg * sum(np.sum(W * W) for W in [W1, W2, W3])

loss = data_loss + reg_loss

grads = {'W1': dW1, 'b1': db1, 'W2': dW2, 'b2': db2, 'W3': dW3, 'b3': db3}

return loss, grads

1.参数设置

从loss()函数我们可以看到卷积层的参数为: conv_param = {'stride': 1, 'pad': (int)((filter_size - 1) / 2)}

filter_size在_init_()函数中已经定义为7,所以pad值等于3,根据公式1:

conv卷积层的输出高度和宽度为(32-7+2*3)/1+1=32,所以conv卷积层的输出为32*32*32,32分别代表过滤器个数、高度和宽度。

loss()函数中pool池化层的参数定义为: pool_param = {'pool_height': 2, 'pool_width': 2, 'stride': 2}

同理根据公式1计算池化层输出的高度和宽度为:(32-2+2*0)/2+1=16,所以pool池化层的输出为:32*16*16,也就是说经过池化层之后高度和宽度分别减半了。

所以_init()函数中初始化池化层与FC全连接层间的参数w2时,定义为: self.params['W2'] = weight_scale * np.random.randn(int(num_filters*H*W/4), hidden_dim),num_filters*H*W/4就是pool池化层的输出为:32*16*16。

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_39883374/article/details/110885335

智能推荐

勒让德函数(Legendre多项式)_yw_233的博客-程序员秘密

文章目录勒让德函数定义勒让德多项式公式Associated Legendre Function勒让德函数定义勒让德函数指以下勒让德微分方程的解:(1−x2)d2P(x)dx2−2xdP(x)dx+n(n+1)P(x)=0(1-x^2)\frac{d^2P(x)}{dx^2} -2x\frac{dP(x)}{dx}+n(n+1)P(x)=0 (1−x2)dx2d2P(x)​−2xdxdP(x...

现代大学英语精读第二版(第四册)学习笔记(原文及全文翻译)——4A - Lions and Tigers and Bears(狮子、老虎和熊)_预见未来to50的博客-程序员秘密

Unit 4A -Lions and Tigers and BearsLions and Tigers and BearsBill BufordSo I thought I'd spend the night in Central Park, and, having stuffed my small rucksack with a sleeping bag, a big bottle of mineral water, a map, and a toothbrush, I arrived on.

关于Platinum库的MediaRender具体C++代码实现探讨_蓝斯的博客-程序员秘密

接上篇博文 NDK下 将Platinum SDK 编译成so库 (android - upnp)讲述了如何利用该代码库编译给android程序调用的so库,其中也提到了,在使用sample-upnp工程来测试生成的so库是无效的大家比对一下Platinum开发库的Platinum\Source\Platform\Android\module\platinum\jni\platinum-jn

SpringBoot实践之@ControllerAdvice_springboot @controlleradvice_DuanQingCI的博客-程序员秘密

在spring 3.2中,新增了@ControllerAdvice 注解,并且配套有三个注解@ExceptionHandler、@InitBinder、@ModelAttribute,以此来对@RequestMapping注解下的方法进行“切面”环绕。参考:@ControllerAdvice 文档全局异常处理@ExceptionHandler 全局数据绑定@ModelAttribute 全局数据预处理@InitBinder一、最常用的全局异常处理1、自定义异常类package c

wps斜杠日期格式_wps表格,怎样将输入的日期间隔斜线改为横线?_思索bike的博客-程序员秘密

wps表格,怎样将输入的日期间隔斜线改为横线?234游戏网友 提出于 2019-07-23 15:20:35请问:wps表格,怎样将输入的日期间隔斜线改为横线?具体步骤:1、打开WPS后,设置合适的行高和列宽。2、选中进行操作的单元格,然后点击右键,在出来的任务栏里设置单元格格式。工具:计算机、Excel2016方法:1、鼠标左键双击计算机桌面Excel2016程序图标,将其打开运行。...

Mysql_mysql sum之后 包含.0_大肥鹅啊的博客-程序员秘密

Mysql聚合函数COUNT(*)计算表中总的行数,不管某列是否有数值或者为空值。COUNT(字段名)计算指定列下总的行数,计算时将忽略空值的行。AVG()函数通过计算返回的行数和每一行数据的和,求得指定列数据的平均值。AVG()函数()AVG()函数可以与GROUP BY一起使用,来计算每个分组的平均值。SUM()是一个求总和的函数,返回指定列值的总和。SUM()可以与GROUP BY一起使用,来计算每个分组的总和。MAX()返回指定列中的最大值。MAX()也可以和GRO

随便推点

新手如何学习Java三大框架?_coffee801的博客-程序员秘密

Java是世界第一编程语言,这已经达成共识,是毋庸置疑的真理。框架是程序员们必学的知识点,而且是十分重要的应用,Spring、Struts、Hibernate也是经典中的经典,最常用的框架类型。作为Java新手应该如何去学习呢?小编搜集了很多网友的建议,现在为大家总结如下:有同学建议:对于Spring来说,最应该学习的就是Spring的IOC原理,这在使用过程中是必须要

java api中的异常处理_赤丶的博客-程序员秘密

首先我们要知道java中异常分为编译时异常和运行时异常。这里的编译时异常可以理解为错误,即一些基本的语法错误。无法通过编译的,这里我们一般不做深究。我们的主角还是--运行时异常。首先我们要知道最基本的就是如何来处理异常。一异常的捕获。try catch finally便是我们处理异常的方法。我们将可能出现异常的代码放在try块。将出现异常后处理的代码放在catch块。finally则是无论是否有异常产生,都会执行的。我们现在可以来讨论一下它们的执行顺序。当try块中的代码块有异常时会即刻停.

STM32F103单片机ADC功能使用_stm32f103单片机带有12位精度模拟数字转换器,有多达16个外部通道和2个内部信_嵌入式@hxydj的博客-程序员秘密

  stm32f103系列单片机内部ADC为12位ADC。12位ADC是一种逐次逼近型模拟数字转换器。它有多达18个通道,可测量16个外部和2个内部信号源。各通道的A/D转换可以单次、连续、扫描或间断模式执行。ADC的结果可以左对齐或右对齐方式存储在16位数据寄存器中。模拟看门狗特性允许应用程序检测输入电压是否超出用户定义的高/低阀值。ADC的输入时钟不得超过14MHz,它是由PCLK2经分频产生。ADC 主要特征● 12位分辨率● 转换结束、注入转换结束和发生模拟看门狗事件时产生中断●

subprocess.check_output shell参数问题_sky0Lan的博客-程序员秘密

的时候, timeout=5 超时以后不会强制结束,在超时结束后,点击pycharm中的 stop 无法停止掉启动 sub_loop.py程序。时候,设置的 timeout=5 能够正常生效,并结束。

九、Mysql数据备份与恢复_liuqi66的博客-程序员秘密

1、为什么要备份 备份:能够防止由于机器故障以及人为误操作带来的数据丢失,例如将数据库文件保存在了其它地方。 冗余: 数据有多份冗余,但不等备份,只能防止机器故障带来的数据丢失,例如主从模式、数据库集群。 2.MySQL数据备份需要重视的内容备份内容 databases Binlog my.cnf所有备份数据都应放在非数据库本地,而且建议有多份副本。测试环境中做日常恢复演练,恢复较备份更为重要。备份过程中必须考虑因素:1. 数据的一致性2. 服务的可用性3.

逗比同时学习java和Python?是真的吗?_suntsods的博客-程序员秘密

java和Python的比较针对当前最热门的编程语言,无论是Java还是Python都是学习的首选。下面让我们开始学习吧。今天学习的两本书是《learn Python3 the hard way 》and 《head first java》。两本书各有特殊,《learn Python3 the...

推荐文章

热门文章

相关标签