keras提取模型中的某一层_Tensorflow笔记:高级封装——Keras-程序员宅基地

技术标签: keras提取模型中的某一层  

068b778f8d29be8579b91dfa12a569a3.png

前言

之前在《Tensorflow笔记:高级封装——tf.Estimator》中介绍了Tensorflow的一种高级封装,本文介绍另一种高级封装Keras。Keras的特点就是两个字——简单,不用花时间和脑子去研究各种细节问题。

1. 贯序结构

最简单的情况就是贯序模型,就是将网络层一层一层堆叠起来,比如DNN、LeNet等,与之相对的非贯序模型的层和层之间可能存在分叉、合并等复杂结构。下面通过一个LeNet的例子来展示Keras如何实现贯序模型,我们依然采用MNIST数据集举例:

d140d75af7841fdcd31c106286352211.png
LeNet-5模型结构

首先假设我们已经读到了数据,对于MNIST数据可以通过官方API直接获取,如果是其他数据可以自行进行数据预处理,由于数据读取内容不是本篇介绍重点,所以不做介绍。

(train_images, train_labels), (valid_images, valid_labels) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(-1, 28, 28, 1)
valid_images = valid_images.reshape(-1, 28, 28, 1)
train_images, valid_images = train_images / 255.0, valid_images / 255.0

最后数据的格式为 (n, height, width, channel) ,数据和标签的dtype分别为float和int,Keras相比与原生和tf.Estimator相比对于数据type的要求比较友好

print(train_images.shape)
# (60000, 28, 28)
print(train_labels.shape)
# (60000,)
print(train_images.dtype)
# float64 / float32 都可以
print(train_labels.dtype)
# uint8 / int16 / int32 / int64 等都可以

接下来开始构建模型

import tensorflow as tf
from tensorflow.keras.optimizers import SGD

# 构建模型结构
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(6, (5,5), activation='tanh', input_shape=(28,28,1)),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Conv2D(16, (5,5), activation='tanh'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(120, activation='tanh'),
    tf.keras.layers.Dense(84, activation='tanh'),
    tf.keras.layers.Dense(10, activation='softmax')
])

# 模型编译(告诉模型怎么优化)
model.compile(loss='sparse_categorical_crossentropy',    # 损失函数
             optimizer=SGD(lr=0.05, decay=1e-6, momentum=0.9, nesterov=True),     # 优化器
             metrics=['acc'])    # 评估指标

对于贯序模型,只需要调用tf.keras.models.Sequential(),他的参数是一个由tf.keras.layers组成的列表,就可以确定一个模型的结构,然后再简单通过model.compile()就可以确定模型关于“如何优化”方面的信息。很像sklearn的那样简单易用,没有原生tensorflow那种结构和对话的分离,没有必要维护tensor的name。下面看一些怎么开始训练:

history = model.fit(train_images, train_labels, batch_size=32, epochs=1, verbose=1, shuffle=True, validation_data=(valid_images, valid_labels))

就一句fit就解决了!很sklearn。对于evaluate任务也超简单

model.evaluate(test_images, test_labels, verbose=2)
# [0.06203492795133497, 0.9811]

最后对于predict任务,也和sklearn一样

model.predict(test_images)

可见Keras的另一个优势就是,不需要人为的去考虑每一个batch,只需要指定一个batch_size即可,即使是在predict时也可以直接吧全部数据集喂进去。相比之下在原生Tensorflow中要通过一个for循环一个batch一个batch的去sess.run(train_op),就比较麻烦。

2. 复杂结构

贯序模型对于结构复杂的模型,比如层之间出现了分叉、拼接等操作就无法表示了(比如Inception家族)。但是Keras并没有因此放弃,依然是可以很容易的构建复杂结构的网络的。下面来实现一个下图所示的多塔Inception块(该Inception块及其改进是在各种Inception网络的基础结构):

728d2140aff6d56e8465472f5bb57344.png
Inception块结构

假设我们在Previous layer处的输入数据的shape为(256, 256, 3),该结构用Keras这样实现:

import tensorflow as tf

# input数据接口
input_img = tf.keras.layers.Input(shape=(256, 256, 3))

# 分支0
tower0 = tf.keras.layers.Conv2D(64, (1,1), padding="same", activation="relu")(input_img)
# 分支1
tower1 = tf.keras.layers.Conv2D(64, (1,1), padding="same", activation
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_34374684/article/details/112312268

智能推荐

Vlog简介-程序员宅基地

文章浏览阅读995次。https://www.ifanr.com/1138470转载于:https://www.cnblogs.com/pengwang52/p/10683069.html_校园vlog简介怎么写

python稳健回归_【Stata教程】如何用stata做稳健回归-程序员宅基地

文章浏览阅读1.2k次。“社会科学中的数据可视化”第411篇推送导言大量的线性回归模型是基于最小二乘法实现的,但其仍存在一些局限性。比如说,样本点出现许多异常点时,传统的最小二乘法将不再适用,此时则可以使用稳健回归(robust regression)代替最小二乘法。操作下面的稳健回归使用的是犯罪数据,该数据来自Alan Agresti和Barbara Finlay的《社会科学统计方法》。变量包括美国各州编号(sid)、..._margins 贫困

爬虫.requests.exceptions.ConnectionErro-程序员宅基地

文章浏览阅读219次。requests.exceptions.ConnectionError: HTTPConnectionPool(host='jy-qj.com.cn', port=80): Max retries exceeded with url: / (Caused by NewConnectionError('<requests.packages.urllib3.connection.HTTPConn..._requests.exceptions.connectionerror: errno1104 getaddrinfo failed

[C++]欧几里得辗转相除求最大公约数,练习_欧几里得算法c++练习题-程序员宅基地

文章浏览阅读1.1k次。编程实现求解最大公约数的欧几里德算法,用户输入两个任意正整数,程序输出他们的最大公约数。算法如下:拆解步骤如下:步骤1: 如果p < q,则交换p和q。步骤2: 令r是p / q 的余数。步骤3: 如果r = 0,则令g = q并终止;否则令p = q, q = r并转向步骤2#include<iostream>#include<stdio.h>//编程实现求解最大公约数的欧几里德算法,用户输入两..._欧几里得算法c++练习题

ViewPager的notifyDataSetChanged()没有效果?来从源码上解决这个问题_viewpager notifydatasetchanged-程序员宅基地

文章浏览阅读841次。前言最近发现自己有很多颇为基础的内容“不会写”了,就比如今天写的内容:ViewPager。最近有小伙伴,在后台私信一些技术细节,大家真的好勤奋~~因为工作的原因,有些私信回复的不是很及时,多多包涵。996伤不起啊!正文平时我们很容易遇到这样的需求:页面底部很多Tab,可以点击或者活动切换不同的页面…估计话还没有说完,有朋友就会脱口而出:ViewPager+ Fragment实现。说起..._viewpager notifydatasetchanged

unity鼠标右键按住不放_在Windows中如何在不按住鼠标键的情况下突出显示和拖放...-程序员宅基地

文章浏览阅读775次。unity鼠标右键按住不放If you use a touchpad or trackpad, or if you have arthritis or other problems when using a mouse, you may find it difficult to hold the primary mouse button down and move the mouse at the..._untiy3d 鼠标右键一直按着

随便推点

计算机考试网上报名系统-程序员宅基地

文章浏览阅读732次,点赞27次,收藏20次。目 录(一)计算机等级考试发展状况与趋势……………………………………………………1(二)开发系统的意义………………………………………………………………………1(三)用户群及特点…………………………………………………………………………1二、系统分析………………………………………………………………………………………2(一)系统要达到的目的……………………………………………………………………2(二)系统可行性分析………………………………………………………………………2(三)业务流程分析………

linux C应用开发_linux应用开发-程序员宅基地

文章浏览阅读4.4k次,点赞2次,收藏34次。linux应用开发_linux应用开发

iPhone 4 Cydia使用教程!精选Cydia源!cydia怎么添加源!Cydia源使用方法!越狱后使用cydia全攻略!_ihpone4里cydia软件源游戏-程序员宅基地

文章浏览阅读4.3k次。转载自:http://hi.baidu.com/tyc6982/blog/item/7793eb18c9071a1635fa4191.html  2008年11月19日18:40许,iPhone中文网Cydia软件源正式上线(源地址为:iphone.tgbus.com/cydia)。这次Cydia源推出的目的主要是为了给一些WiFi用户提供方便。Cydia源中提供了一些像OpenSSH_ihpone4里cydia软件源游戏

sklearn计算余弦相似度_from sklearn.metrics.pairwise import cosine_simila-程序员宅基地

文章浏览阅读5.4k次,点赞4次,收藏16次。余弦相似度在计算文本相似度等问题中有着广泛的应用,scikit-learn中提供了方便的调用方法第一种,使用cosine_similarity,传入一个变量a时,返回数组的第i行第j列表示a[i]与a[j]的余弦相似度例:from sklearn.metrics.pairwise import cosine_similaritya=[[1,3,2],[2,2,1]]cosine_s..._from sklearn.metrics.pairwise import cosine_similarity

NV21 to NV12(YUV420SP)_nv21tonv12-程序员宅基地

文章浏览阅读1.5w次。setPreviewFormat(ImageFormat.NV21)NV21 颜色空间排列 :YYYYYYYY VUVU在用MediaCodec编码的时候,如果设置颜色空间为YUV420SP,那么则需要转换一下,YUV420SP颜色排列顺序为:YYYYYYY UVUV多说一下,YUV420 是于NV12对应的,但是5.0一下的安卓手机支持这个预览颜色的不多则需要将VU顺序进行转_nv21tonv12

Java并发编程: TransmittableThreadLocal实现父子线程之间值传递_transmittablethreadlocal父子线程数据传递-程序员宅基地

文章浏览阅读1k次,点赞23次,收藏18次。TransmittableThreadLocal 是 Alibaba 开源框架 transmittable-thread-local 中的一个核心类,它扩展了 Java 的标准 ThreadLocal 类。与标准的 ThreadLocal 不同,TransmittableThreadLocal 的值可以在线程之间传递,尤其是在线程池中的线程复用场景下。_transmittablethreadlocal父子线程数据传递

推荐文章

热门文章

相关标签