part1

初始化配置

  def __init__(self, 
               sess, 
               image_size=33,#输入图片是33
               label_size=21, #输出图片是21
               batch_size=64,#一个batch是64
               c_dim=1, 
               checkpoint_dir=None, 
               sample_dir=None):
    self.sess = sess
    self.is_grayscale = (c_dim == 1)
    self.image_size = image_size
    self.label_size = label_size
    self.batch_size = batch_size
    self.c_dim = c_dim
    self.checkpoint_dir = checkpoint_dir
    self.sample_dir = sample_dir
    self.build_model()

搭建网络

def build_model(self):
    #输入图片和输入图片
 self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, self.c_dim], name='images')
    self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, self.c_dim], name='labels')

 #三层的过滤器 
    self.weights = {
      'w1': tf.Variable(tf.random_normal([9, 9, 1, 64], stddev=1e-3), name='w1'),
      'w2': tf.Variable(tf.random_normal([1, 1, 64, 32], stddev=1e-3), name='w2'),
      'w3': tf.Variable(tf.random_normal([5, 5, 32, 1], stddev=1e-3), name='w3')
    }
    #偏置
    self.biases = {
      'b1': tf.Variable(tf.zeros([64]), name='b1'),
      'b2': tf.Variable(tf.zeros([32]), name='b2'),
      'b3': tf.Variable(tf.zeros([1]), name='b3')
    }

    self.pred = self.model()
    # 以MSE作为损耗函数
    self.loss = tf.reduce_mean(tf.square(self.labels - self.pred))
    self.saver = tf.train.Saver()

下面就是三层模型

  def model(self):
    conv1 = tf.nn.relu(tf.nn.conv2d(self.images, self.weights['w1'], strides=[1,1,1,1], padding='VALID') + self.biases['b1'])
    conv2 = tf.nn.relu(tf.nn.conv2d(conv1, self.weights['w2'], strides=[1,1,1,1], padding='VALID') + self.biases['b2'])
    conv3 = tf.nn.conv2d(conv2, self.weights['w3'], strides=[1,1,1,1], padding='VALID') + self.biases['b3']
    return conv3

保存和加载模型

 def save(self, checkpoint_dir, step):
    model_name = "SRCNN.model"
    model_dir = "%s_%s" % ("srcnn", self.label_size)
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)  #再一次确定路径为 checkpoint->srcnn_21下
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    self.saver.save(self.sess,
                    os.path.join(checkpoint_dir, model_name),  #文件名为SRCNN.model-迭代次数
                    global_step=step)
def load(self, checkpoint_dir):
    print(" [*] Reading checkpoints...")
    model_dir = "%s_%s" % ("srcnn", self.label_size)
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)  #路径为checkpoint->srcnn_labelsize(21)
    #加载路径下的模型(.meta文件保存当前图的结构; .index文件保存当前参数名; .data文件保存当前参数值)
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)  
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))  #saver.restore()函数给出model.-n路径后会自动寻找参数名-值文件进行加载
        return True
    else:
        return False

part2

  def train(self, config):
    if config.is_train:  #判断是否为训练(main传入)
      input_setup(self.sess, config)
    else:
      nx, ny = input_setup(self.sess, config)  
    #训练为checkpoint下train.h5
    #测试为checkpoint下test.h5
    if config.is_train:     
      data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "train.h5")
    else:
      data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "test.h5")
    train_data, train_label = read_data(data_dir)#读取.h5文件(由测试和训练决定)
    global_step=tf.Variable(0)#定义global_step 它会自动+1
    #通过exponential_decay函数生成学习率
    learning_rate_exp=tf.train.exponential_decay(config.learning_rate , global_step , 1480 , 0.98 , staircase=True)  #每1个Epoch 学习率*0.98   
    #标准反向传播的随机梯度下降
    #self.train_op = tf.train.GradientDescentOptimizer(config.learning_rate).minimize(self.loss)#学习率learning rate  使self.loss有最小值
    self.train_op = tf.train.GradientDescentOptimizer(learning_rate_exp).minimize(self.loss , global_step=global_step)
    #Adam  替换上面的连续4行
    #self.train_op = tf.train.AdamOptimizer(config.learning_rate).minimize(self.loss, global_step=global_step)

    #出现warning : initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
    #tf.initialize_all_variables().run()
    tf.global_variables_initializer().run() #替换掉上句
    counter = 0
    start_time = time.time()
    if self.load(self.checkpoint_dir):
      print(" [*] Load SUCCESS")
    else:
      print(" [!] Load failed...")
    #训练
    if config.is_train:     
      print("Training...")
      for ep in range(config.epoch): #迭代次数的循环
        # 以batch为单元
        batch_idxs = len(train_data) // config.batch_size
        for idx in range(0, batch_idxs):
          batch_images = train_data[idx*config.batch_size : (idx+1)*config.batch_size]
          batch_labels = train_label[idx*config.batch_size : (idx+1)*config.batch_size]
          counter += 1
          _, err = self.sess.run([self.train_op, self.loss], feed_dict={self.images: batch_images, self.labels: batch_labels})
          if counter % 10 == 0:  #10的倍数step显示
            print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \
              % ((ep+1), counter, time.time()-start_time, err))
          if counter % 500 == 0:  #500的倍数step存储
            self.save(config.checkpoint_dir, counter)
    #测试
    else:   
      print("Testing...")
      result = self.pred.eval({self.images: train_data, self.labels: train_label}) # 从test.h中来 
      result = merge(result, [nx, ny])
      result = result.squeeze()#除去size为1的维度
      #result= exposure.adjust_gamma(result, 1.07)#调暗一些
      image_path = os.path.join(os.getcwd(), config.sample_dir)
      image_path = os.path.join(image_path, "MySRCNN.bmp")
      imsave(result, image_path)