集团站切换校区

验证码已发送,请查收短信

图标

学习文章

当前位置:首页 > >学习文章 > >

{HCNA-AI TensorFlow编程基础}之模型保存与使用

发布时间: 2019-01-03 23:52:17

8.1 实验介绍8.1.1 关于本实验本实验主要介绍如何保存模型和使用保存的模型,一般而言,训练好的模型都需要保存。8.1.2 实验目的理解如何保存模型。

理解如何载入模型,并使用。8.1.3 实验介绍本实验主要是基于前面的实验添加模型的保存及载入功能。在文件中生成模拟数据之后,加入对图变量的充值,在 session 创建之前定义 saver 及保存路径,在 session 种训练结束后,保存模型。8.1.4 实验步骤步骤 1  登陆华为云。

步骤 2 点击右上方的控制台。

步骤 3 选择弹性云服务器,网页中会显示该弹性云的可进行的操作,选择远程登录。即登录到弹性云服务器。

步骤 4 输入指令 ll,查看当前目录下的文件。

步骤 5 输入命令 vi mnist_train.py,创建新的 Python 脚本。

步骤 6 输入命令 i,进入编辑模式开始编辑,输入脚本内容。

步骤 7 输入命令 :wq!,保存并退出。

步骤 8 输入命令 cat mnist_train.py 查看代码。

步骤 9 运行测试。输入命令 python3 mnist_train.py。




8.2 实验过程8.2.1 导入数据集# -*- coding: utf-8 -*- #!/usr/bin/env python

# 导入 mnist 数据库

from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)


import tensorflow as tf import os8.2.2 定义变量# 定义输入变量

x = tf.placeholder(tf.float32, [None, 784])


# 定义参数

W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10]))

# 定义激励函数

y = tf.nn.softmax(tf.matmul(x, W) + b)


# 定义输出变量

y_ = tf.placeholder(tf.float32, [None, 10])


# 定义成本函数

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))


# 定义优化函数

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)8.2.3 初始化# 初始化变量

init = tf.global_variables_initializer()


# 定义会话

sess = tf.Session()8.2.4 运行 session# 运行初始化

sess.run(init)




# 定义模型保存对象

saver = tf.train.Saver()


# 循环训练 1000 次

for i in range(1000):

batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_:batch_ys})

print("训练完成!")8.2.5 创建模型保存目录# 创建模型保存目录

model_dir = "mnist_model" model_name = "ckp"

if not os.path.exists(model_dir): os.mkdir(model_dir)8.2.6 保存模型# 保存模型

saver.save(sess, os.path.join(model_dir, model_name))


print("保存模型成功!")8.2.7 实验结果训练完成!

保存模型成功!


8.3 使用模型8.3.1 导入数据集# -*- coding: utf-8 -*- #!/usr/bin/env python

# 导入 mnist 数据库

from tensorflow.examples.tutorials.mnist import input_data mnist=input_data.read_data_sets("MNIST_data",one_hot=True)


import tensorflow as tf8.3.2 创建回话# 创建会话

sess = tf.Session()



8.3.3 定义变量# 定义输入变量

x = tf.placeholder(tf.float32, [None, 784])


# 定义参数

W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10]))

# 定义模型和激励函数

y = tf.nn.softmax(tf.matmul(x, W) + b)


# 定义模型保存对象

saver = tf.train.Saver([W, b])8.3.4 恢复模型# 恢复模型

saver.restore(sess, "mnist/ckp")


print("恢复模型成功!")

# 取出一个测试图片

idx=0

img = mnist.test.images[idx]8.3.5 计算结果# 根据模型计算结果

ret = sess.run(y, feed_dict = {x : img.reshape(1, 784)})


print("计算模型结果成功!")

# 显示测试结果

print("预测结果:%d"%(ret.argmax()))

print("实际结果:%d"%(mnist.test.labels[idx].argmax()))8.3.6 实验结果恢复模型成功!

计算模型结果成功! 预测结果:7

实际结果:7


8.4 实例描述

本实验主要是保存以后模型,方便后续的模型载入与使用,这里输出的是中间状态 cost 损

上一篇: {HTML5}DOM节点操作-第一节

下一篇: {springboot}工程文件介绍

十年老品牌
QQ咨询:450959328 微信咨询:togogozhong 咨询电话:400-885-2225 咨询网站客服:在线客服

相关课程推荐

在线咨询 ×

您好,请问有什么可以帮您?我们将竭诚提供最优质服务!