在(一)中,是对图片的处理,在二中就开始看训练过程了
size_input = 33 size_label = 21
images = tf.placeholder(tf.float32, [None, size_input, size_input, 1], name='images') labels = tf.placeholder(tf.float32, [None, size_label, size_label, 1], name='labels') learning_rate = 1e-4 num_epoch = 15000 batch_size = 128 num_training = 21712 num_testing = 1113 train_path = os.path.join('./{}'.format('checkpoint'), "train.h5") test_path = os.path.join('./{}'.format('checkpoint'), "test.h5") ckpt_dir = './checkpoint/'
读取h5文件
def load_data(path): with h5py.File(path, 'r') as hf: data = np.array(hf.get('data')) label = np.array(hf.get('label')) return data, label
存储和加载配置文件
def save_ckpt(sess, step, saver): model_name = 'SRCNN.model' model_dir = "%s_%s" % ("srcnn", size_label) checkpoint_dir = os.path.join(ckpt_dir, model_dir) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) saver.save(sess, os.path.join(checkpoint_dir, model_name), global_step=step)# 存储
def load_ckpt(sess, checkpoint_dir, saver): print(" [*] Reading checkpoints...")# 从其中读取参数 model_dir = "%s_%s" % ("srcnn", size_label)# srcnn_21 checkpoint_dir = os.path.join(checkpoint_dir, model_dir)# 配置的路径 print('checkpoint_dir is', checkpoint_dir) # Require only one checkpoint in the directory. ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) print('Restoring from', os.path.join(checkpoint_dir, ckpt_name)) saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name)) return True else: return False
建立模型
def conv2d(x, W): # return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='VALID') return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
# Define the computation graph of SRCNN. def SRCNN(x): # Define weights and biases. # f1 = 9, f3 = 5, n1 = 64, n2 = 32. weights = {'w1' : tf.Variable(tf.random_normal([9, 9, 1, 64], stddev=1e-3)), 'w2' : tf.Variable(tf.random_normal([1, 1, 64, 32], stddev=1e-3)), 'w3' : tf.Variable(tf.random_normal([5, 5, 32, 1], stddev=1e-3))} biases = {'b1' : tf.Variable(tf.zeros([64])), 'b2' : tf.Variable(tf.zeros([32])), 'b3' : tf.Variable(tf.zeros([1]))} conv1 = tf.nn.relu(conv2d(x, weights['w1']) + biases['b1']) conv2 = tf.nn.relu(conv2d(conv1, weights['w2']) + biases['b2']) conv3 = conv2d(conv2, weights['w3']) + biases['b3'] return conv3
训练模型
def train_SRCNN(x): # Initialization. model = SRCNN(x)
l2_loss = tf.reduce_mean(tf.square(labels - model)) optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(l2_loss) train_data, train_label = load_data(train_path)
with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print('Training...') start_time = time.time() counter = 0 saver = tf.train.Saver() if load_ckpt(sess, ckpt_dir, saver): print('Successfully loaded checkpoint.') else: print('Failed to load checkpoint.') # Training for epoch in range(num_epoch):#迭代的次数 epoch_loss = 0 idx_batch = len(train_data) for i in range(idx_batch): epoch_images = train_data[i * batch_size : (i + 1) * batch_size] epoch_labels = train_label[i * batch_size : (i + 1) * batch_size] _, c = sess.run([optimizer, l2_loss], feed_dict = {images: epoch_images, labels: epoch_labels}) epoch_loss += c counter += 1 # Log the training process every 10 steps. if counter % 10 == 0: print('Epoch:', epoch + 1, 'step:', counter, 'loss:', c, 'duration:', time.time() - start_time) # Save the checkpoint every 500 steps. if counter % 500 == 0: save_ckpt(sess, counter, saver)