在(一)中,是对图片的处理,在二中就开始看训练过程了

size_input = 33
size_label = 21
images = tf.placeholder(tf.float32, [None, size_input, size_input, 1], name='images')
labels = tf.placeholder(tf.float32, [None, size_label, size_label, 1], name='labels')
learning_rate = 1e-4
num_epoch = 15000
batch_size = 128
num_training = 21712
num_testing = 1113
train_path = os.path.join('./{}'.format('checkpoint'), "train.h5")
test_path = os.path.join('./{}'.format('checkpoint'), "test.h5")
ckpt_dir = './checkpoint/'

读取h5文件

def load_data(path):
    with h5py.File(path, 'r') as hf:
        data = np.array(hf.get('data'))
        label = np.array(hf.get('label'))
        return data, label

存储和加载配置文件

def save_ckpt(sess, step, saver):
    model_name = 'SRCNN.model'
    model_dir = "%s_%s" % ("srcnn", size_label)
    checkpoint_dir = os.path.join(ckpt_dir, model_dir)

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    saver.save(sess, os.path.join(checkpoint_dir, model_name), global_step=step)# 存储
def load_ckpt(sess, checkpoint_dir, saver):
    print(" [*] Reading checkpoints...")# 从其中读取参数
    model_dir = "%s_%s" % ("srcnn", size_label)# srcnn_21
    checkpoint_dir = os.path.join(checkpoint_dir, model_dir)# 配置的路径
    print('checkpoint_dir is', checkpoint_dir)

    # Require only one checkpoint in the directory.
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        print('Restoring from', os.path.join(checkpoint_dir, ckpt_name))
        saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))
        return True
    else:
        return False

建立模型

def conv2d(x, W):
    # return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='VALID')
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
# Define the computation graph of SRCNN.
def SRCNN(x):
    # Define weights and biases.
    # f1 = 9, f3 = 5, n1 = 64, n2 = 32.
    weights = {'w1' : tf.Variable(tf.random_normal([9, 9, 1, 64], stddev=1e-3)),
               'w2' : tf.Variable(tf.random_normal([1, 1, 64, 32], stddev=1e-3)),
               'w3' : tf.Variable(tf.random_normal([5, 5, 32, 1], stddev=1e-3))}

    biases = {'b1' : tf.Variable(tf.zeros([64])),
              'b2' : tf.Variable(tf.zeros([32])),
              'b3' : tf.Variable(tf.zeros([1]))}

    conv1 = tf.nn.relu(conv2d(x, weights['w1']) + biases['b1'])
    conv2 = tf.nn.relu(conv2d(conv1, weights['w2']) + biases['b2'])
    conv3 = conv2d(conv2, weights['w3']) + biases['b3']
    return conv3

训练模型

def train_SRCNN(x):
    # Initialization.
    model = SRCNN(x) 
    l2_loss = tf.reduce_mean(tf.square(labels - model))
    optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(l2_loss)
    train_data, train_label = load_data(train_path)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print('Training...')
        start_time = time.time()
        counter = 0
        saver = tf.train.Saver()

        if load_ckpt(sess, ckpt_dir, saver):
            print('Successfully loaded checkpoint.')
        else:
            print('Failed to load checkpoint.')

        # Training
        for epoch in range(num_epoch):#迭代的次数
            epoch_loss = 0
            idx_batch = len(train_data) 
            for i in range(idx_batch):
                epoch_images = train_data[i * batch_size : (i + 1) * batch_size]
                epoch_labels = train_label[i * batch_size : (i + 1) * batch_size]

                _, c = sess.run([optimizer, l2_loss], feed_dict = {images: epoch_images, labels: epoch_labels})
                epoch_loss += c
                counter += 1

                # Log the training process every 10 steps.
                if counter % 10 == 0:
                    print('Epoch:', epoch + 1, 'step:', counter, 'loss:', c, 'duration:', time.time() - start_time)

                # Save the checkpoint every 500 steps.
                if counter % 500 == 0:
                    save_ckpt(sess, counter, saver)