2. Tensorflow进阶

本节我们将学习以下知识点:

  • 张量
  • 变量
  • 名称域
  • 会话

2.1. 张量的阶和数据类型

TensorFlow用张量这种数据结构来表示所有的数据.你可以把一个张量想象成一个n维的数组或列表.一个张量有一个静态类型和动态类型的维数.张量可以在图中的节点之间流通.其实张量更代表的就是一种多位数组。

在TensorFlow系统中,张量的维数来被描述为阶.但是张量的阶和矩阵的阶并不是同一个概念.张量的阶(有时是关于如顺序或度数或者是n维)是张量维数的一个数量描述.比如,下面的张量(使用Python中list定义的)就是2阶.

t = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]

你可以认为一个二阶张量就是我们平常所说的矩阵,一阶张量可以认为是一个向量.

数学实例 Python 例子
0 纯量 (只有大小) s = 483
1 向量 (大小和方向) v = [1.1, 2.2, 3.3]
2 矩阵 (数据表) m = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
3 3阶张量 (数据立体) t = [[[2], [4], [6]], [[8], [10], [12]], [[14], [16], [18]]]
n n阶 (自己想想看)

数据类型

Tensors有一个数据类型属性.你可以为一个张量指定下列数据类型中的任意一个类型:

数据类型 Python 类型 描述
DT_FLOAT tf.float32 32 位浮点数.
DT_DOUBLE tf.float64 64 位浮点数.
DT_INT64 tf.int64 64 位有符号整型.
DT_INT32 tf.int32 32 位有符号整型.
DT_INT16 tf.int16 16 位有符号整型.
DT_INT8 tf.int8 8 位有符号整型.
DT_UINT8 tf.uint8 8 位无符号整型.
DT_STRING tf.string 可变长度的字节数组.每一个张量元素都是一个字节数组.
DT_BOOL tf.bool 布尔型.
DT_COMPLEX64 tf.complex64 由两个32位浮点数组成的复数:实数和虚数.
DT_QINT32 tf.qint32 用于量化Ops的32位有符号整型.
DT_QINT8 tf.qint8 用于量化Ops的8位有符号整型.
DT_QUINT8 tf.quint8 用于量化Ops的8位无符号整型.

2.2. 张量操作

在tensorflow中,有很多操作张量的函数,有生成张量、创建随机张量、张量类型与形状变换和张量的切片与运算

生成张量

固定值张量

tf.zeros(shape, dtype=tf.float32, name=None)

创建所有元素设置为零的张量。此操作返回一个dtype具有形状shape和所有元素设置为零的类型的张量。

tf.zeros_like(tensor, dtype=None, name=None)

给tensor定单张量(),此操作返回tensor与所有元素设置为零相同的类型和形状的张量。

tf.ones(shape, dtype=tf.float32, name=None)

创建一个所有元素设置为1的张量。此操作返回一个类型的张量,dtype形状shape和所有元素设置为1。

tf.ones_like(tensor, dtype=None, name=None)

给tensor定单张量(),此操作返回tensor与所有元素设置为1 相同的类型和形状的张量。

tf.fill(dims, value, name=None)

创建一个填充了标量值的张量。此操作创建一个张量的形状dims并填充它value。

tf.constant(value, dtype=None, shape=None, name=‘Const’)

创建一个常数张量。

用常数张量作为例子

t1 = tf.constant([1, 2, 3, 4, 5, 6, 7])

t2 = tf.constant(-1.0, shape=[2, 3])

print(t1,t2)

我们可以看到在没有运行的时候,输出值为:

(<tf.Tensor 'Const:0' shape=(7,) dtype=int32>, <tf.Tensor 'Const_1:0' shape=(2, 3) dtype=float32>)

一个张量包含了一下几个信息

  • 一个名字,它用于键值对的存储,用于后续的检索:Const: 0
  • 一个形状描述, 描述数据的每一维度的元素个数:(2,3)
  • 数据类型,比如int32,float32

创建随机张量

一般我们经常使用的随机数函数 Math.random() 产生的是服从均匀分布的随机数,能够模拟等概率出现的情况,例如 扔一个骰子,1到6点的概率应该相等,但现实生活中更多的随机现象是符合正态分布的,例如20岁成年人的体重分布等。

假如我们在制作一个游戏,要随机设定许许多多 NPC 的身高,如果还用Math.random(),生成从140 到 220 之间的数字,就会发现每个身高段的人数是一样多的,这是比较无趣的,这样的世界也与我们习惯不同,现实应该是特别高和特别矮的都很少,处于中间的人数最多,这就要求随机函数符合正态分布。

tf.truncated_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)

从截断的正态分布中输出随机值,和 tf.random_normal() 一样,但是所有数字都不超过两个标准差

tf.random_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)

从正态分布中输出随机值,由随机正态分布的数字组成的矩阵

# 正态分布的 4X4X4 三维矩阵,平均值 0, 标准差 1
normal = tf.truncated_normal([4, 4, 4], mean=0.0, stddev=1.0)

a = tf.Variable(tf.random_normal([2,2],seed=1))
b = tf.Variable(tf.truncated_normal([2,2],seed=2))
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print(sess.run(a))
    print(sess.run(b))

输出:
[[-0.81131822  1.48459876]
 [ 0.06532937 -2.44270396]]
[[-0.85811085 -0.19662298]
 [ 0.13895047 -1.22127688]]

tf.random_uniform(shape, minval=0.0, maxval=1.0, dtype=tf.float32, seed=None, name=None)

从均匀分布输出随机值。生成的值遵循该范围内的均匀分布 [minval, maxval)。下限minval包含在范围内,而maxval排除上限。

a = tf.random_uniform([2,3],1,10)

with tf.Session() as sess:
  print(sess.run(a))

tf.random_shuffle(value, seed=None, name=None)

沿其第一维度随机打乱

tf.set_random_seed(seed)

设置图级随机种子

要跨会话生成不同的序列,既不设置图级别也不设置op级别的种子:

a = tf.random_uniform([1])
b = tf.random_normal([1])

print "Session 1"
with tf.Session() as sess1:
  print sess1.run(a)  
  print sess1.run(a)  
  print sess1.run(b)
  print sess1.run(b)  

print "Session 2"
with tf.Session() as sess2:
  print sess2.run(a)
  print sess2.run(a)
  print sess2.run(b)
  print sess2.run(b)

要为跨会话生成一个可操作的序列,请为op设置种子:

a = tf.random_uniform([1], seed=1)
b = tf.random_normal([1])


print "Session 1"
with tf.Session() as sess1:
  print sess1.run(a)
  print sess1.run(a)
  print sess1.run(b)
  print sess1.run(b)

print "Session 2"
with tf.Session() as sess2:
  print sess2.run(a)
  print sess2.run(a)  
  print sess2.run(b)
  print sess2.run(b)

为了使所有op产生的随机序列在会话之间是可重复的,设置一个图级别的种子:

tf.set_random_seed(1234)
a = tf.random_uniform([1])
b = tf.random_normal([1])


print "Session 1"
with tf.Session() as sess1:
  print sess1.run(a)
  print sess1.run(a)
  print sess1.run(b)
  print sess1.run(b)

print "Session 2"
with tf.Session() as sess2:
  print sess2.run(a)
  print sess2.run(a)
  print sess2.run(b)
  print sess2.run(b)

我们可以看到结果

张量变换

TensorFlow提供了几种操作,您可以使用它们在图形中改变张量数据类型。

改变类型

提供了如下一些改变张量中数值类型的函数

  • tf.string_to_number(string_tensor, out_type=None, name=None)
  • tf.to_double(x, name=‘ToDouble’)
  • tf.to_float(x, name=‘ToFloat’)
  • tf.to_bfloat16(x, name=‘ToBFloat16’)
  • tf.to_int32(x, name=‘ToInt32’)
  • tf.to_int64(x, name=‘ToInt64’)
  • tf.cast(x, dtype, name=None)

我们用一个其中一个举例子

tf.string_to_number(string_tensor, out_type=None, name=None)

将输入Tensor中的每个字符串转换为指定的数字类型。注意,int32溢出导致错误,而浮点溢出导致舍入值

n1 = tf.constant(["1234","6789"])
n2 = tf.string_to_number(n1,out_type=tf.types.float32)

sess = tf.Session()

result = sess.run(n2)
print result

sess.close()

形状和变换

可用于确定张量的形状并更改张量的形状

  • tf.shape(input, name=None)
  • tf.size(input, name=None)
  • tf.rank(input, name=None)
  • tf.reshape(tensor, shape, name=None)
  • tf.squeeze(input, squeeze_dims=None, name=None)
  • tf.expand_dims(input, dim, name=None)

tf.shape(input, name=None)

返回张量的形状。

t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
shape(t) -> [2, 2, 3]

静态形状与动态形状

静态维度 是指当你在创建一个张量或者由操作推导出一个张量时,这个张量的维度是确定的。它是一个元祖或者列表。TensorFlow将尽最大努力去猜测不同张量的形状(在不同操作之间),但是它不会总是能够做到这一点。特别是如果您开始用未知维度定义的占位符执行操作。tf.Tensor.get_shape方法读取静态形状

t = tf.placeholder(tf.float32,[None,2])
print(t.get_shape())

结果

动态形状 当你在运行你的图时,动态形状才是真正用到的。这种形状是一种描述原始张量在执行过程中的一种张量。如果你定义了一个没有标明具体维度的占位符,即用None表示维度,那么当你将值输入到占位符时,这些无维度就是一个具体的值,并且任何一个依赖这个占位符的变量,都将使用这个值。tf.shape来描述动态形状

t = tf.placeholder(tf.float32,[None,2])
print(tf.shape(t))

tf.squeeze(input, squeeze_dims=None, name=None)

这个函数的作用是将input中维度是1的那一维去掉。但是如果你不想把维度是1的全部去掉,那么你可以使用squeeze_dims参数,来指定需要去掉的位置。

import tensorflow as tf

sess = tf.Session()
data = tf.constant([[1, 2, 1], [3, 1, 1]])
print sess.run(tf.shape(data))
d_1 = tf.expand_dims(data, 0)
d_1 = tf.expand_dims(d_1, 2)
d_1 = tf.expand_dims(d_1, -1)
d_1 = tf.expand_dims(d_1, -1)
print sess.run(tf.shape(d_1))
d_2 = d_1
print sess.run(tf.shape(tf.squeeze(d_1)))
print sess.run(tf.shape(tf.squeeze(d_2, [2, 4])))

tf.expand_dims(input, dim, name=None)

该函数作用与squeeze相反,添加一个指定维度

import tensorflow as tf
import numpy as np

sess = tf.Session()
data = tf.constant([[1, 2, 1], [3, 1, 1]])
print sess.run(tf.shape(data))
d_1 = tf.expand_dims(data, 0)
print sess.run(tf.shape(d_1))
d_1 = tf.expand_dims(d_1, 2)
print sess.run(tf.shape(d_1))
d_1 = tf.expand_dims(d_1, -1)
print sess.run(tf.shape(d_1))

切片与扩展

TensorFlow提供了几个操作来切片或提取张量的部分,或者将多个张量加在一起

  • tf.slice(input_, begin, size, name=None)
  • tf.split(split_dim, num_split, value, name=‘split’)
  • tf.tile(input, multiples, name=None)
  • tf.pad(input, paddings, name=None)
  • tf.concat(concat_dim, values, name=‘concat’)
  • tf.pack(values, name=‘pack’)
  • tf.unpack(value, num=None, name=‘unpack’)
  • tf.reverse_sequence(input, seq_lengths, seq_dim, name=None)
  • tf.reverse(tensor, dims, name=None)
  • tf.transpose(a, perm=None, name=‘transpose’)
  • tf.gather(params, indices, name=None)
  • tf.dynamic_partition(data, partitions, num_partitions, name=None)
  • tf.dynamic_stitch(indices, data, name=None)

其它一些张量运算(了解查阅)

张量复制与组合

  • tf.identity(input, name=None)
  • tf.tuple(tensors, name=None, control_inputs=None)
  • tf.group(*inputs, **kwargs)
  • tf.no_op(name=None)
  • tf.count_up_to(ref, limit, name=None)

逻辑运算符

  • tf.logical_and(x, y, name=None)
  • tf.logical_not(x, name=None)
  • tf.logical_or(x, y, name=None)
  • tf.logical_xor(x, y, name=‘LogicalXor’)

比较运算符

  • tf.equal(x, y, name=None)
  • tf.not_equal(x, y, name=None)
  • tf.less(x, y, name=None)
  • tf.less_equal(x, y, name=None)
  • tf.greater(x, y, name=None)
  • tf.greater_equal(x, y, name=None)
  • tf.select(condition, t, e, name=None)
  • tf.where(input, name=None)

判断检查

  • tf.is_finite(x, name=None)
  • tf.is_inf(x, name=None)
  • tf.is_nan(x, name=None)
  • tf.verify_tensor_all_finite(t, msg, name=None) 断言张量不包含任何NaN或Inf
  • tf.check_numerics(tensor, message, name=None)
  • tf.add_check_numerics_ops()
  • tf.Assert(condition, data, summarize=None, name=None)
  • tf.Print(input_, data, message=None, first_n=None, summarize=None, name=None)

2.3. 变量的的创建、初始化

其实变量的作用在语言中相当,都有存储一些临时值的作用或者长久存储。在Tensorflow中当训练模型时,用变量来存储和更新参数。变量包含张量(Tensor)存放于内存的缓存区。建模时它们需要被明确地初始化,模型训练后它们必须被存储到磁盘。值可在之后模型训练和分析是被加载。

Variable类

tf.Variable.init(initial_value, trainable=True, collections=None, validate_shape=True, name=None)

创建一个带值的新变量initial_value

  • initial_value:A Tensor或Python对象可转换为a Tensor.变量的初始值.必须具有指定的形状,除非 validate_shape设置为False.
  • trainable:如果True,默认值也将该变量添加到图形集合GraphKeys.TRAINABLE_VARIABLES,该集合用作Optimizer类要使用的变量的默认列表
  • collections:图表集合键列表,新变量添加到这些集合中.默认为[GraphKeys.VARIABLES]
  • validate_shape:如果False允许使用未知形状的值初始化变量,如果True,默认形状initial_value必须提供.
  • name:变量的可选名称,默认’Variable’并自动获取

变量的创建

创建当一个变量时,将你一个张量作为初始值传入构造函数Variable().TensorFlow提供了一系列操作符来初始化张量,值初始的英文常量或是随机值。像任何一样Tensor,创建的变量Variable()可以用作图中其他操作的输入。此外,为Tensor该类重载的所有运算符都被转载到变量中,因此您也可以通过对变量进行算术来将节点添加到图形中。

x = tf.Variable(5.0,name="x")
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),name="weights")
biases = tf.Variable(tf.zeros([200]), name="biases")

调用tf.Variable()向图中添加了几个操作:

  • 一个variable op保存变量值。
  • 初始化器op将变量设置为其初始值。这实际上是一个tf.assign操作。
  • 初始值的ops,例如 示例中biases变量的zeros op 也被添加到图中。

变量的初始化

  • 变量的初始化必须在模型的其它操作运行之前先明确地完成。最简单的方法就是添加一个给所有变量初始化的操作,并在使用模型之前首先运行那个操作。最常见的初始化模式是使用便利函数 initialize_all_variables()将Op添加到初始化所有变量的图形中。
init_op = tf.global_variables_initializer()

with tf.Session() as sess:
  sess.run(init_op)
  • 还可以通过运行其初始化函数op来初始化变量,从保存文件还原变量,或者简单地运行assign向变量分配值的Op。实际上,变量初始化器op只是一个assignOp,它将变量的初始值赋给变量本身。assign是一个方法,后面方法的时候会提到
with tf.Session() as sess:
    sess.run(w.initializer)

通过另一个变量赋值

你有时候会需要用另一个变量的初始化值给当前变量初始化,由于tf.global_variables_initializer()初始化所有变量,所以需要注意这个方法的使用。

就是将已初始化的变量的值赋值给另一个新变量!

weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),name="weights")

w2 = tf.Variable(weights.initialized_value(), name="w2")

w_twice = tf.Variable(weights.initialized_value() * 0.2, name="w_twice")

所有变量都会自动收集到创建它们的图形中。默认情况下,构造函数将新变量添加到图形集合GraphKeys.GLOBAL_VARIABLES。方便函数 global_variables()返回该集合的内容。

属性

name

返回变量的名字

weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),name="weights")
print(weights.name)

op

返回op操作

weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35))
print(weights.op)

方法

assign

为变量分配一个新值。

x = tf.Variable(5.0,name="x")
w.assign(w + 1.0)

eval

在会话中,计算并返回此变量的值。这不是一个图形构造方法,它不会向图形添加操作。方便打印结果

v = tf.Variable([1, 2])
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)

    # 指定会话
    print(v.eval(sess))
    # 使用默认会话
    print(v.eval())

变量的静态形状与动态形状

TensorFlow中,张量具有静态(推测)形状和动态(真实)形状

  • 静态形状:

创建一个张量或者由操作推导出一个张量时,初始状态的形状

  • tf.Tensor.get_shape:获取静态形状

  • tf.Tensor.set_shape():更新Tensor对象的静态形状,通常用于在不能直接推断的情况下

  • 动态形状:

一种描述原始张量在执行过程中的一种形状

  • tf.shape(tf.Tensor):如果在运行的时候想知道None到底是多少,只能通过tf.shape(tensor)[0]这种方式来获得
  • tf.reshape:创建一个具有不同动态形状的新张量

要点

1、转换静态形状的时候,1-D到1-D,2-D到2-D,不能跨阶数改变形状

2、 对于已经固定或者设置静态形状的张量/变量,不能再次设置静态形状

3、tf.reshape()动态创建新张量时,元素个数不能不匹配

4、运行时候,动态获取张量的形状值,只能通过tf.shape(tensor)[]

管理图中收集的变量

tf.global_variables()

返回图中收集的所有变量

weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35))

print(tf.global_variables())

2.4. 名称域与共享变量

tensorflow提供了变量作用域和共享变量这样的概念,有几个重要的作用。

  • 让模型代码更加清晰,作用分明

变量作用域域

通过tf.variable_scope()创建指定名字的变量作用域

with tf.variable_scope("itcast") as scope:
  print("----")

加上with语句就可以在整个itcast变量作用域下就行操作。

嵌套使用

变量作用域可以嵌套使用

with tf.variable_scope("itcast") as itcast:
    with tf.variable_scope("python") as python:
      print("----")

变量作用域下的变量

在同一个变量作用域下,如果定义了两个相同名称的变量(这里先用tf.Variable())会怎么样呢?

with tf.variable_scope("itcast") as scope:
    a = tf.Variable([1.0,2.0],name="a")
    b = tf.Variable([2.0,3.0],name="a")

我们通过tensoflow提供的计算图界面观察

我们发现取了同样的名字,其实tensorflow并没有当作同一个,而是另外又增加了一个a_1,来表示b的图

变量范围

当每次在一个变量作用域中创建变量的时候,会在变量的name前面加上变量作用域的名称

with tf.variable_scope("itcast"):
    a = tf.Variable(1.0,name="a")
    b = tf.get_variable("b", [1])
    print(a.name,b.name)

得道结果

(u'itcast/a:0', u'itcast/b:0')

对于嵌套的变量作用域来说

with tf.variable_scope("itcast"):
    with tf.variable_scope("python"):
        python3 = tf.get_variable("python3", [1])
assert python3.name == "itcast/python/python3:0"
var2 = tf.get_variable("var",[3,4],initializer=tf.constant_initializer(0.0))

```

2.5 图与会话

tf.Graph

TensorFlow计算,表示为数据流图。一个图包含一组表示 tf.Operation计算单位的对象和tf.Tensor表示操作之间流动的数据单元的对象。默认Graph值始终注册,并可通过调用访问 tf.get_default_graph。

a = tf.constant(1.0)
assert c.graph is tf.get_default_graph()

我们可以发现这两个图是一样的。那么如何创建一个图呢,通过tf.Graph()

g1= tf.Graph()
g2= tf.Graph()

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    print(g1,g2,tf.get_default_graph())

图的其它属性和方法

作为一个图的类,自然会有一些图的属性和方法。

as_default()

返回一个上下文管理器,使其成为Graph默认图形。

如果要在同一过程中创建多个图形,则应使用此方法。为了方便起见,提供了一个全局默认图形,如果不明确地创建一个新的图形,所有操作都将添加到此图形中。使用该with关键字的方法来指定在块的范围内创建的操作应添加到此图形中。

g = tf.Graph()
with g.as_default():
  a = tf.constant(1.0)
  assert c.graph is g

会话

tf.Session

运行TensorFlow操作图的类,一个包含ops执行和tensor被评估

a = tf.constant(5.0)
b = tf.constant(6.0)
c = a * b

sess = tf.Session()

print(sess.run(c))

在开启会话的时候指定图

with tf.Session(graph=g) as sess:

资源释放

会话可能拥有很多资源,如 tf.Variable,tf.QueueBase和tf.ReaderBase。在不再需要这些资源时,重要的是释放这些资源。要做到这一点,既可以调用tf.Session.close会话中的方法,也可以使用会话作为上下文管理器。以下两个例子是等效的:

# 使用close手动关闭
sess = tf.Session()
sess.run(...)
sess.close()

# 使用上下文管理器
with tf.Session() as sess:
  sess.run(...)

run方法介绍

run(fetches, feed_dict=None, options=None, run_metadata=None)

运行ops和计算tensor

  • fetches 可以是单个图形元素,或任意嵌套列表,元组,namedtuple,dict或OrderedDict
  • feed_dict 允许调用者覆盖图中指定张量的值

如果a,b是其它的类型,比如tensor,同样可以覆盖原先的值

a = tf.placeholder(tf.float32, shape=[])
b = tf.placeholder(tf.float32, shape=[])
c = tf.constant([1,2,3])

with tf.Session() as sess:
    a,b,c = sess.run([a,b,c],feed_dict={
   a: 1, b: 2,c:[4,5,6]})
    print(a,b,c)

错误

  • RuntimeError:如果它Session处于无效状态(例如已关闭)。
  • TypeError:如果fetches或feed_dict键是不合适的类型。
  • ValueError:如果fetches或feed_dict键无效或引用 Tensor不存在。

其它属性和方法

graph

返回本次会话中的图

as_default()

返回使此对象成为默认会话的上下文管理器。

获取当前的默认会话,请使用 tf.get_default_session

c = tf.constant(..)
sess = tf.Session()

with sess.as_default():
  assert tf.get_default_session() is sess
  print(c.eval())

注意: 使用这个上下文管理器并不会在退出的时候关闭会话,还需要手动的去关闭

c = tf.constant(...)
sess = tf.Session()
with sess.as_default():
  print(c.eval())
# ...
with sess.as_default():
  print(c.eval())

sess.close()

2.6 模型保存与恢复、自定义命令行参数

在我们训练或者测试过程中,总会遇到需要保存训练完成的模型,然后从中恢复继续我们的测试或者其它使用。模型的保存和恢复也是通过tf.train.Saver类去实现,它主要通过将Saver类添加OPS保存和恢复变量到checkpoint。它还提供了运行这些操作的便利方法。

tf.train.Saver(var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, write_version=tf.SaverDef.V2, pad_step_number=False)

  • var_list:指定将要保存和还原的变量。它可以作为一个dict或一个列表传递.
  • max_to_keep:指示要保留的最近检查点文件的最大数量。创建新文件时,会删除较旧的文件。如果无或0,则保留所有检查点文件。默认为5(即保留最新的5个检查点文件。)
  • keep_checkpoint_every_n_hours:多久生成一个新的检查点文件。默认为10,000小时

保存

保存我们的模型需要调用Saver.save()方法。save(sess, save_path, global_step=None),checkpoint是专有格式的二进制文件,将变量名称映射到张量值。

import tensorflow as tf

a = tf.Variable([[1.0,2.0]],name="a")
b = tf.Variable([[3.0],[4.0]],name="b")
c = tf.matmul(a,b)

saver=tf.train.Saver()
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    print(sess.run(c))
    saver.save(sess, '/tmp/ckpt/test/matmul')

我们可以看保存了什么文件

在多次训练的时候可以指定多少间隔生成检查点文件

saver.save(sess, '/tmp/ckpt/test/matmu', global_step=0) ==> filename: 'matmu-0'

saver.save(sess, '/tmp/ckpt/test/matmu', global_step=1000) ==> filename: 'matmu-1000'

恢复

恢复模型的方法是restore(sess, save_path),save_path是以前保存参数的路径,我们可以使用tf.train.latest_checkpoint来获取最近的检查点文件(也恶意直接写文件目录)

import tensorflow as tf

a = tf.Variable([[1.0,2.0]],name="a")
b = tf.Variable([[3.0],[4.0]],name="b")
c = tf.matmul(a,b)

saver=tf.train.Saver(max_to_keep=1)
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    print(sess.run(c))
    saver.save(sess, '/tmp/ckpt/test/matmul')

    # 恢复模型
    model_file = tf.train.latest_checkpoint('/tmp/ckpt/test/')
    saver.restore(sess, model_file)
    print(sess.run([c], feed_dict={
   a: [[5.0,6.0]], b: [[7.0],[8.0]]}))

自定义命令行参数

tf.app.run(),默认调用main()函数,运行程序。main(argv)必须传一个参数。

tf.app.flags,它支持应用从命令行接受参数,可以用来指定集群配置等。在tf.app.flags下面有各种定义参数的类型

  • DEFINE_string(flag_name, default_value, docstring)
  • DEFINE_integer(flag_name, default_value, docstring)
  • DEFINE_boolean(flag_name, default_value, docstring)
  • DEFINE_float(flag_name, default_value, docstring)

第一个也就是参数的名字,路径、大小等等。第二个参数提供具体的值。第三个参数是说明文档

tf.app.flags.FLAGS,在flags有一个FLAGS标志,它在程序中可以调用到我们前面具体定义的flag_name.

import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('data_dir', '/tmp/tensorflow/mnist/input_data',
                           """数据集目录""")
tf.app.flags.DEFINE_integer('max_steps', 2000,
                            """训练次数""")
tf.app.flags.DEFINE_string('summary_dir', '/tmp/summary/mnist/convtrain',
                           """事件文件目录""")


def main(argv):
    print(FLAGS.data_dir)
    print(FLAGS.max_steps)
    print(FLAGS.summary_dir)
    print(argv)


if __name__=="__main__":
    tf.app.run()