#coding=utf-8
import numpy as np
import tensorflow as tf
import os

#获取训练数据和测试数据
def get_data(number):
    list_x = []
    list_y = []
    for i in range(number):
        x = np.random.randn(1)
        #这里构建的数据的分布满足 y=2*x+10
        y = 2*x+10+np.random.randn(1)*0.01
        list_x.append(x)
        list_y.append(y)
    return list_x,list_y

def inference(x):
    global weight,bias
    #此处生命成全局变量,方便我们输出查看
    #计算
    weight = tf.get_variable("weight",[1])
    bias = tf.get_variable("bise",[1])
    y = x*weight+bias
    return y

train_x = tf.placeholder(tf.float32)
train_lable = tf.placeholder(tf.float32)

test_x = tf.placeholder(tf.float32)
test_lable = tf.placeholder(tf.float32)

with tf.variable_scope("inference"):
    train_y = inference(train_x)
    #在此处定义相同名字的变量是共享变量
    #此句之后的tf.get_variable获取的变量是需要根据变量的名字共享前面已经定义的变量
    #如果之前没有相同名字的变量则会报错
    tf.get_variable_scope().reuse_variables()
    test_y = inference(test_x)

train_loss = tf.square(train_y-train_lable)
test_loss = tf.square(test_y-test_lable)
opt = tf.train.GradientDescentOptimizer(0.002)
train_op = opt.minimize(train_loss)

init = tf.global_variables_initializer()

train_data_x,train_data_lable = get_data(1000)
test_data_x,test_data_lable = get_data(1)

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    if os.path.exists("modules/xaddy.ckpt"):
        saver.restore(sess,"modules/xaddy.ckpt")
    for i in range(1000):
        sess.run(train_op,feed_dict={train_x:train_data_x[i],train_lable:train_data_lable[i]})

        if i %10==0:
            test_loss_value = sess.run(test_loss,feed_dict={test_x:test_data_x[0],test_lable:test_data_lable[0]})
            print("step %d eval loss is %.3f"%(i,test_loss_value))
            print(sess.run(weight))
            print(sess.run(bias))
    saver.save(sess, "modules/xaddy.ckpt")