part1
初始化配置
def __init__(self, sess, image_size=33,#输入图片是33 label_size=21, #输出图片是21 batch_size=64,#一个batch是64 c_dim=1, checkpoint_dir=None, sample_dir=None): self.sess = sess self.is_grayscale = (c_dim == 1) self.image_size = image_size self.label_size = label_size self.batch_size = batch_size self.c_dim = c_dim self.checkpoint_dir = checkpoint_dir self.sample_dir = sample_dir self.build_model()
搭建网络
def build_model(self): #输入图片和输入图片 self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, self.c_dim], name='images') self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, self.c_dim], name='labels') #三层的过滤器 self.weights = { 'w1': tf.Variable(tf.random_normal([9, 9, 1, 64], stddev=1e-3), name='w1'), 'w2': tf.Variable(tf.random_normal([1, 1, 64, 32], stddev=1e-3), name='w2'), 'w3': tf.Variable(tf.random_normal([5, 5, 32, 1], stddev=1e-3), name='w3') } #偏置 self.biases = { 'b1': tf.Variable(tf.zeros([64]), name='b1'), 'b2': tf.Variable(tf.zeros([32]), name='b2'), 'b3': tf.Variable(tf.zeros([1]), name='b3') } self.pred = self.model() # 以MSE作为损耗函数 self.loss = tf.reduce_mean(tf.square(self.labels - self.pred)) self.saver = tf.train.Saver()
下面就是三层模型
def model(self): conv1 = tf.nn.relu(tf.nn.conv2d(self.images, self.weights['w1'], strides=[1,1,1,1], padding='VALID') + self.biases['b1']) conv2 = tf.nn.relu(tf.nn.conv2d(conv1, self.weights['w2'], strides=[1,1,1,1], padding='VALID') + self.biases['b2']) conv3 = tf.nn.conv2d(conv2, self.weights['w3'], strides=[1,1,1,1], padding='VALID') + self.biases['b3'] return conv3
保存和加载模型
def save(self, checkpoint_dir, step): model_name = "SRCNN.model" model_dir = "%s_%s" % ("srcnn", self.label_size) checkpoint_dir = os.path.join(checkpoint_dir, model_dir) #再一次确定路径为 checkpoint->srcnn_21下 if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) self.saver.save(self.sess, os.path.join(checkpoint_dir, model_name), #文件名为SRCNN.model-迭代次数 global_step=step)
def load(self, checkpoint_dir): print(" [*] Reading checkpoints...") model_dir = "%s_%s" % ("srcnn", self.label_size) checkpoint_dir = os.path.join(checkpoint_dir, model_dir) #路径为checkpoint->srcnn_labelsize(21) #加载路径下的模型(.meta文件保存当前图的结构; .index文件保存当前参数名; .data文件保存当前参数值) ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) #saver.restore()函数给出model.-n路径后会自动寻找参数名-值文件进行加载 return True else: return False
part2
def train(self, config): if config.is_train: #判断是否为训练(main传入) input_setup(self.sess, config) else: nx, ny = input_setup(self.sess, config) #训练为checkpoint下train.h5 #测试为checkpoint下test.h5 if config.is_train: data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "train.h5") else: data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "test.h5") train_data, train_label = read_data(data_dir)#读取.h5文件(由测试和训练决定) global_step=tf.Variable(0)#定义global_step 它会自动+1 #通过exponential_decay函数生成学习率 learning_rate_exp=tf.train.exponential_decay(config.learning_rate , global_step , 1480 , 0.98 , staircase=True) #每1个Epoch 学习率*0.98 #标准反向传播的随机梯度下降 #self.train_op = tf.train.GradientDescentOptimizer(config.learning_rate).minimize(self.loss)#学习率learning rate 使self.loss有最小值 self.train_op = tf.train.GradientDescentOptimizer(learning_rate_exp).minimize(self.loss , global_step=global_step) #Adam 替换上面的连续4行 #self.train_op = tf.train.AdamOptimizer(config.learning_rate).minimize(self.loss, global_step=global_step) #出现warning : initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02. #tf.initialize_all_variables().run() tf.global_variables_initializer().run() #替换掉上句 counter = 0 start_time = time.time() if self.load(self.checkpoint_dir): print(" [*] Load SUCCESS") else: print(" [!] Load failed...") #训练 if config.is_train: print("Training...") for ep in range(config.epoch): #迭代次数的循环 # 以batch为单元 batch_idxs = len(train_data) // config.batch_size for idx in range(0, batch_idxs): batch_images = train_data[idx*config.batch_size : (idx+1)*config.batch_size] batch_labels = train_label[idx*config.batch_size : (idx+1)*config.batch_size] counter += 1 _, err = self.sess.run([self.train_op, self.loss], feed_dict={self.images: batch_images, self.labels: batch_labels}) if counter % 10 == 0: #10的倍数step显示 print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \ % ((ep+1), counter, time.time()-start_time, err)) if counter % 500 == 0: #500的倍数step存储 self.save(config.checkpoint_dir, counter) #测试 else: print("Testing...") result = self.pred.eval({self.images: train_data, self.labels: train_label}) # 从test.h中来 result = merge(result, [nx, ny]) result = result.squeeze()#除去size为1的维度 #result= exposure.adjust_gamma(result, 1.07)#调暗一些 image_path = os.path.join(os.getcwd(), config.sample_dir) image_path = os.path.join(image_path, "MySRCNN.bmp") imsave(result, image_path)