概况
从1月11日到23日这段时间,实现的内容主要有两个部分
part1
1,针对github的srcnn的代码,进一步的理解和掌握,在它的基础上,通过cv2等库实现了自己的srcnn的代码,实现的功能有一下几个部分。
1.1 实现图片的模糊和图像的重建,实现的结构都是单通道的,换句话说,就是灰色的图片
1.2 在1.1的基础上,得到的模糊图片和复原图片都是彩色的
1.3 实现图片的放大功能
part2
2,将1中的几个功能实现之后,就开始开发系统
2.1,选择flask和mysql进行开发,在这个基础上搭建项目
2.2,实现了图片的上传功能,图片的详情功能
2.3,实现图片的放大功能
注
part2中主要实现的是后台部分,前端页面的部分做的比较简单,仅仅可以展示结果
srcnn代码的实现
训练
def train(self, config): prepare_for_train(self.sess) data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "train.h5") train_data, train_label = read_data(data_dir) global_step = tf.Variable(0) learning_rate_exp = tf.train.exponential_decay(config.learning_rate, global_step, 1480, 0.98, staircase=True) # 每1个Epoch 学习率*0.98 self.train_op = tf.train.GradientDescentOptimizer(learning_rate_exp).minimize(self.loss, global_step=global_step) tf.global_variables_initializer().run() start_time = time.time() res = self.load(self.checkpoint_dir) counter = 0 if res: print(" [*] Load SUCCESS") else: print(" [!] Load failed...") print("Training...") for ep in range(config.epoch): # 以batch为单元 batch_idxs = len(train_data) // config.batch_size for idx in range(0, batch_idxs): batch_images = train_data[idx * config.batch_size: (idx + 1) * config.batch_size] batch_labels = train_label[idx * config.batch_size: (idx + 1) * config.batch_size] counter += 1 _, err = self.sess.run([self.train_op, self.loss], feed_dict={self.images: batch_images, self.labels: batch_labels}) if counter % 10 == 0: # 10的倍数step显示 print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \ % ((ep + 1), counter, time.time() - start_time, err)) if counter % 500 == 0: # 500的倍数step存储 self.save(config.checkpoint_dir, counter)
以上是训练的代码,我们分步进行分析
训练集的处理
prepare_for_train(self.sess) data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "train.h5") train_data, train_label = read_data(data_dir)
以上的代码就是为训练过程准备训练集,训练集是将91张图片先分割成许多的小图片,每一个小图片都分为两个,一个是清晰的,一个是模糊的,模糊的就是33x33的图片,而清晰的是21x21的图片,每一个33x33的图片经过模型的处理都会变成21x21的较为清晰的图片,然后这个时候和原本的清晰的21x21的图片进行比较就可以得到损失的大小,进而进行训练
def prepare_for_train(sess): data_dir = os.path.join(os.getcwd(), "Train") data = glob.glob(os.path.join(data_dir, "*.bmp")) print(len(data)) # sub_input_sequence = [] sub_label_sequence = [] padding = 6 for i in range(len(data)): input_, label_ = preprocess_for_train(data[i], 3) if len(input_.shape) == 3: h, w, _ = input_.shape else: h, w = input_.shape for x in range(0, h - 33 + 1, 14): for y in range(0, w - 33 + 1, 14): sub_input = input_[x:x + 33, y:y + 33] sub_label = label_[x + 6:x + 6 + 21, y + 6:y + 6 + 21] sub_input = sub_input.reshape([33, 33, 1]) sub_label = sub_label.reshape([21, 21, 1]) sub_input_sequence.append(sub_input) sub_label_sequence.append(sub_label) # 上面的部分和训练是一样的 arrdata = np.asarray(sub_input_sequence) # [?, 33, 33, 1] arrlabel = np.asarray(sub_label_sequence) # [?, 21, 21, 1] savepath = os.path.join(os.getcwd(), 'checkpoint/train.h5') with h5py.File(savepath, 'w') as hf: hf.create_dataset('data', data=arrdata) hf.create_dataset('label', data=arrlabel)
以上就是得到训练集的方式,看最后几行的代码,我们就可以知道是33x33和21x21对应的集合.针对于91张每一张图片,我们是如何得到清晰和模糊的图片的呢,这里的处理是preprocess_for_train中实现的,如下所示
def preprocess_for_train(path, scale=3): scale -= 1 image = imread(path, is_grayscale=True) label_ = modcrop(image, scale) # Must be normalized image = image / 255. label_ = label_ / 255. input_ = scipy.ndimage.interpolation.zoom(label_, (1. / scale), prefilter=False) input_ = scipy.ndimage.interpolation.zoom(input_, (scale / 1.), prefilter=False) return input_, label_
大概就是几个步骤,读入,剪切,得到模糊的,input_就是这图的模糊,label_就是这图的清晰。
训练的过程
for ep in range(config.epoch): # 以batch为单元 batch_idxs = len(train_data) // config.batch_size for idx in range(0, batch_idxs): batch_images = train_data[idx * config.batch_size: (idx + 1) * config.batch_size] batch_labels = train_label[idx * config.batch_size: (idx + 1) * config.batch_size] counter += 1 _, err = self.sess.run([self.train_op, self.loss], feed_dict={self.images: batch_images, self.labels: batch_labels}) if counter % 10 == 0: # 10的倍数step显示 print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \ % ((ep + 1), counter, time.time() - start_time, err)) if counter % 500 == 0: # 500的倍数step存储 self.save(config.checkpoint_dir, counter)
训练不断的更新模型的参数,每10次就打印以此,每500次就记录一次参数,我们训练保存的文件在checkpoint的文件中
这个就是保存训练参数的文件,每次执行训练的时候,我们就可以读取文件,并且每次在这个基础上进行训练
显示读取成功,我们在这个的基础上,进行训练,发现一开始的损失函数就是已经达到的不错的数值
图片的复原
1月10日实现的图片复原,得到的仅仅是灰色的图片,现在可以得到彩色的图片
预处理
def preprocess_for_test(path, scale=3): scale -= 1 im = cv.imread(path) im = cv.cvtColor(im, cv.COLOR_BGR2YCR_CB) img = im2double(im) label_ = modcrop(img, scale=scale) color_base = modcrop(im, scale=scale) label_ = label_[:, :, 0] input_ = scipy.ndimage.interpolation.zoom(label_, (1. / scale), prefilter=False) input_ = scipy.ndimage.interpolation.zoom(input_, (scale / 1.), prefilter=False) label_small = modcrop_small(label_) # 把原图裁剪成和输出一样的大小 input_small = modcrop_small(input_) # 把原图裁剪成和输出一样的大小 color_small = modcrop_small(color_base[:, :, 1:3]) imsave(input_small, '/home/chengcongyue/PycharmProjects/SRCNN13/sample/input_.bmp') imsave(label_small, '/home/chengcongyue/PycharmProjects/SRCNN13/sample/label_.bmp') return input_, label_, color_small
这里的操作我们通过cv进行操作,首先是读入,转换图像空间,归一化,裁剪,这里我们要分为两个部分,然后仅针对单颜色空间进行模糊,然后剩下的两个通道,要传给结果。对于每一张的子图都是由33变成21,所以是变小的,因此,我们要通过方法,将这两个通道也变小
def parpare_for_test(sess, path): sub_input_sequence = [] sub_label_sequence = [] padding = 6 # 测试 input_, label_, color = preprocess_for_test(path, 3) # 测试图片 if len(input_.shape) == 3: h, w, _ = input_.shape else: h, w = input_.shape nx = 0 ny = 0 for x in range(0, h - 33 + 1, 21): nx += 1 ny = 0 for y in range(0, w - 33 + 1, 21): ny += 1 sub_input = input_[x:x + 33, y:y + 33] sub_label = label_[x + 6:x + 6 + 21, y + 6:y + 6 + 21] sub_input = sub_input.reshape([33, 33, 1]) sub_label = sub_label.reshape([21, 21, 1]) sub_input_sequence.append(sub_input) sub_label_sequence.append(sub_label) color = np.array(color) data = np.asarray(sub_input_sequence) label = np.asarray(sub_label_sequence) return data, label, color, nx, ny
以上也就是将这个转化成33x33和21x21的子图,这里的color就是上面剪切过的两个通道
复原
def test(self, path, config): tf.global_variables_initializer().run() if self.load(self.checkpoint_dir): print(" [*] Load SUCCESS") else: print(" [!] Load failed...")
以上是加载checkpoint的未见
data, label, color, nx, ny = parpare_for_test(self.sess, path)
data就是模糊处理过的33x33的子图
label就是的21x21的子图
color就是剪切过的两个通道,用来合成
nx,ny用来合成小图片
conv_out = self.pred.eval({self.images: data}) conv_out = merge(conv_out, [nx, ny]) conv_out = conv_out.squeeze() result_bw = revert(conv_out)
上面就是处理过的小图片再合成大的图片
image_path = os.path.join(os.getcwd(), config.sample_dir) image_path = os.path.join(image_path, "MySRCNN_simple.bmp") imsave(result_bw, image_path)
上面是经过处理过的单通道的图片
result = np.zeros([result_bw.shape[0], result_bw.shape[1], 3], dtype=np.uint8) result[:, :, 0] = result_bw result[:, :, 1:3] = color result = cv.cvtColor(result, cv.COLOR_YCrCb2RGB) image_dir = os.path.join(os.getcwd(), config.sample_dir) image_path = os.path.join(image_dir, "MySRCNN.bmp") imsave(result, image_path)
这里就是合成彩色的图片,得到的就是处理过的彩色图片
conv_out = merge(label, [nx, ny]) conv_out = conv_out.squeeze() result_bw = revert(conv_out) result = np.zeros([result_bw.shape[0], result_bw.shape[1], 3], dtype=np.uint8) result[:, :, 0] = result_bw result[:, :, 1:3] = color bicubic = cv.cvtColor(result, cv.COLOR_YCrCb2RGB) bicubic_path = os.path.join(image_dir, 'Orig_MySRCNN.bmp') imsave(bicubic, bicubic_path)
得到相同大小的原图,实际上就是在原图的基础上进行裁剪
图片的放大
图片的放大,就是在预处理的过程进行放大,然后在进行复原,因为通过双三次插值的方法放大的图片实际上就是模糊处理过的了
def preprocess_for__upscaling(path, scale=3): im = cv.imread(path) size = im.shape im = scipy.misc.imresize(im, [size[0] * scale, size[1] * scale], interp='bicubic') im = cv.cvtColor(im, cv.COLOR_BGR2YCR_CB) img = im2double(im) label_ = modcrop(img, scale=scale) color_base = modcrop(im, scale=scale) label_ = label_[:, :, 0] label_small = modcrop_small(label_) # 把原图裁剪成和输出一样的大小 color_small = modcrop_small(color_base[:, :, 1:3]) imsave(label_small, '/home/chengcongyue/PycharmProjects/SRCNN5/sample/label_.bmp') return label_, label_, color_small
就是在放大的过程中,处理和原来不同
功能展示
图片的复原
从左到右依次是,模糊的,复原的,原图
分别是复原的图片和原图
图片的放大
使用srcnn进行图片放大的效果比较不错,再举几个例子
项目的开发
项目目录的展示
项目使用了flask+mysql进行的搭建,如下就是项目目录
从上到下
checkpoint就是srcnn训练得到的参数文件
migrations是执行下面的脚本生成的,和数据库相关
mytest这里就相当于图片服务器的目录,存储了这个项目的所有的图片,将来会进行改进
static存放js,css的静态文件
templates就是html页面
app.py相当于是项目的初始化类,初始化flask的app类,整个项目的配置等
from flask_sqlalchemy import SQLAlchemy import configs from flask import Flask # 操作数据库的对象 db = SQLAlchemy() # 操作flask的对象 app = Flask(__name__) # 加载配置 app.config.from_object(configs) # db绑定app db.init_app(app) from models import Picture from views import index
configs.py就是项目的配置,通过app.py进行加载
HOST = '127.0.0.1' PORT = '3306' DATABASE = 'srcnn' USERNAME = 'root' PASSWORD = 'mysql' DB_URI = "mysql+pymysql://{username}:{password}@{host}:{port}/{db}?charset=utf8".format(username=USERNAME, password=PASSWORD, host=HOST, port=PORT, db=DATABASE) SQLALCHEMY_DATABASE_URI = DB_URI SQLALCHEMY_TRACK_MODIFICATIONS = False SQLALCHEMY_ECHO = True UPLOAD_FOLDER = "./sample/"
main.py没有用,可以删除
manage.py这个是生成数据库的脚本,需要在命令框中运行命令
from flask_script import Manager from flask_migrate import Migrate, MigrateCommand from app import db, app from models import Picture manager = Manager(app) Migrate(app=app, db=db) manager.add_command('db', MigrateCommand) # 创建数据库映射命令 if __name__ == '__main__': manager.run()
model.py就是srcnn的模型
models.py就是数据库表对应的模型
runserver.py就是该项目的运行脚本
from app import app if __name__ == '__main__': app.run(debug=True)
utils.py图片的预处理脚本
views.py就是和html进行交互的,也就是我们的表现层
数据库的实现
对于当前项目,我们只需要一个类,就是picutre
from app import db import time class Picture(db.Model): __tablename__ = 'picture' # 图片的id id = db.Column(db.Integer, primary_key=True, autoincrement=True) # 图片的名称 name = db.Column(db.String(50)) # 后缀名 suffix = db.Column(db.String(50)) # 地址 url = db.Column(db.String(100)) # 上传或者重建的时间 changetime = db.Column(db.String(20)) # 行为:Bicubic,SRCNN,Origin,Upscale_X # Bicubic:模糊处理,双三次插值 # SRCNN:卷积神经网络处理 # Origin:原图或者没有处理 # Upscale_X:放大多少倍数,如Upscale_3X表示放大3倍 action = db.Column(db.String(20)) # 原图id # 如果是原图的话,表示为-1 orig_id = db.Column(db.Integer) def __init__(self, name, url, action, orig_id): self.name = name self.suffix = name[name.find('.') + 1:] self.url = url self.changetime = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) self.action = action self.orig_id = orig_id def __repr__(self): return '<Picture %s %s %s %s>' % (self.name, self.url, self.changetime, self.action) class Result(object): def __init__(self, status,message,dict): self.status=status self.message=message self.dict=dict def __repr__(self): return '<Result %s %s %s>' % (self.status, self.message, str(self.dict))
然后我们在当前项目目录下执行如下命令,就可以生成对应的数据库表
>python manage.py db init >python manage.py db migrate >python manage.py db upgrade
如下是数据库的结构
具体功能的实现
图片上传
# 上传图片 @app.route('/uploader', methods=['GET', 'POST']) def upload_file(): if request.method == 'POST': f = request.files['file'] # 判断是否是图片 name = f.filename suffix = name[name.find('.') + 1:] if suffix != 'bmp' and suffix != 'jpg' and suffix != 'jpeg' and suffix != 'png': return jsonify(code=500, message="the file is not pic") # 存入数据库 # 重新生成名字 randomCode = get_code() name = randomCode + '.' + suffix # 设置url url = request.url url = url[0:url.rfind('/') + 1] + 'pic/' + randomCode pic = Picture(name, url, 'Origin', -1) # 添加 db.session.add(pic) db.session.commit() # 上传图片 basepath = os.path.dirname(__file__) upload_path = os.path.join(basepath, 'myTest', secure_filename(name)) f.save(upload_path) # 保存缩略图50x50 url = convertToThumbnail(name, url, suffix) return jsonify(code=200, message="success upload", url=url, name=name)
以上就是和上传图片相关的功能,第一个就是图片的上传,实现图片上传需要有如下的操作
设置上传图片的名称 存入数据库 存入图片服务器,这里就是myTest文件夹,将来打算实现图片服务器 创建这个图片的缩略图 缩略图存入数据库
我们来展示一下
我们选择这个图片进行上传
上传成功,这里展示的缩略图,我们来看一下图片服务器和数据库
我们上传了两张图片,一张是原图,一张是缩略图,那么数据库也应该有两个记录
我们来看最后两条记录,就是我们刚刚上传的图片以及缩略图的记录。
然后就是主页面的展示,除了上传的按钮,我还要展示曾经上传过的图片,这里展示的那些图片的缩略图
@app.route('/') def index(): pics = Picture.query.filter_by(action='Thumbnail_50x50').order_by(Picture.changetime.desc()).limit(10).all() return render_template('index.html', pics=pics)
我们查找行为是缩略图的所有图片,以及按照时间倒叙排列,并展示其中的10张图片,对于图片的上传,我们还需要唯一定位到这个图片,不是通过文件的路径,而是通过url,使用如下的代码
# 通过url访问图片 @app.route('/pic/<picName>', methods=['GET', 'POST']) def indexPic(picName): picture = Picture.query.filter(Picture.name.like(picName + "%")).first() suffix = picture.suffix picName = os.path.join(os.getcwd(), 'myTest', picName + '.' + suffix) print(picName) image_data = open(picName, "rb").read() response = make_response(image_data) response.headers['Content-Type'] = 'image/png' return response
如上图所示我们就可以通过url唯一定位到我们的图片。和以上功能相关的html如下
<!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8"> <title>超分辨率图像处理系统</title> </head> <script type="text/javascript" src="/static/js/jquery-1.8.0.min.js"></script> <body> <input type="file" name="file" id="file"> <input type="button" onclick="postData()" value="上传"/> <div id="showpicture"> {% for pic in pics %} <a target="_blank" href='/detail/{{pic.name[0:32]}}'><img src="{{pic.url}}"/></a> {% endfor%} </div> <script type="text/javascript"> function postData() { var formData = new FormData(); formData.append('file', $('#file')[0].files[0]); $.ajax({ url: '/uploader', type: 'post', data: formData, contentType: false, processData: false, success: function (res) { var picIndex = "<a target=\"_blank\" href='/detail/" + res.name.substr(0, 32) + "'><img src='" + res.url + "'/></a>" $('#showpicture').prepend(picIndex) } }) } </script> </form> </body> </html>
图片详情页
然后就是图片详情页的操作,也就是在主页点击图片可以展示这个图片的详细信息,信息包括这个图片的名称,以及放大,模糊,复原的样图,这里操作比较容易,如下
# 点击图片访问该图片的详情 @app.route('/detail/<picname>', methods=['GET', 'POST']) def detail(picname): pictures = Picture.query.filter(Picture.name.like(picname + "%")).all() return render_template('detail.html', pictures=pictures)
展示一下
我们点击最后那个小女孩的图片
我们将所有的情况都展示了出来,这就是图片详情页的主要功能
图片放大
然后就是图片放大的功能
@app.route('/upscaling', methods=['GET', 'POST']) def upscaling(): data = json.loads(request.get_data(as_text=True)) times = data['times'] picname = data['picname'] # 路径 path = os.path.join(os.getcwd(), FLAGS.sample_dir, picname) print(path) # 名字 picture = Picture.query.filter_by(name=picname).first() picname = picname[0:picname.find('.')] + times + 'x_.' + picture.suffix # url url = picture.url url = url[0:url.rfind('/')] + '/' + picname[0:picname.find('_') + 1] # 放大图片 with tf.Session() as sess: srcnn = SRCNN(sess, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir) srcnn.upscaling(picname, path, FLAGS, int(times)) # 保存数据库 action='Upscale_' + times + 'X' newpic = Picture(picname, url,action , picture.id) db.session.add(newpic) db.session.commit() return jsonify(code=200, message="success upscaling", name=picname, url=url, action=action)
思路还是十分简单的,我们来展示一下,我们使用刚刚上传的女人头像的图片
我们进入图片详情页之后,点击放大,就会弹出一个下拉框,我们可以选择放大的倍数,有2倍和3倍
我们可以随意的选择,比如我们可以选择2x
然后在选择3x
我们再来看一下文件夹和数据库
然后再来看一下数据库
相关的html页面如下
<!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8"> <title>图片详情页</title> </head> <script type="text/javascript" src="/static/js/jquery-1.8.0.min.js"></script> <body> <div id="content"> {% for pic in pictures%} <p>图片名称:{{pic.name}}</p> <p>图片如下所示:</p><img src="{{pic.url}}"/> {% if pic.action=='Origin' %} <h1> 上面的图片是原图 </h1> <input id="picname" type="hidden" name="picname" value="{{pic.name}}"/> <input type="button" name="upscaling" value="放大" onclick="createSelect()"/> <span id="times"></span> {% elif pic.action=='Thumbnail_50x50' %} <h1>上面的图片是缩略图</h1> {% elif pic.action.find('Upscale')!=-1 %} <h1>上面的图片是{{pic.action}}的图片</h1> {% endif%} {% endfor%} <div> <script type="text/javascript"> function createSelect() { if ($("#upscalingMenu").length > 0) { return; } var times = "<select id=\"upscalingMenu\" onchange=\"upscaling()\"> \n" + "<option value=\"1\">1x</option> \n" + "<option value=\"2\">2x</option> \n" + "<option value=\"3\">3x</option> \n" + "</select> " $('#times').append(times) } function upscaling() { var times = $("#upscalingMenu option:selected").val(); if (parseInt(times) == 1) { return; } var picname = $("#picname").val(); $.ajax({ url: '/upscaling', type: 'post', dataType: 'json', data: JSON.stringify({ "times": times, "picname": picname }), headers: { "Content-Type": "application/json;charset=utf-8" }, contentType: 'application/json; charset=utf-8', success: function (res) { if(res['code']=='200') { var picContent = "<p>图片名称:" + res['name'] + "</p>\n" + "<p>图片如下所示:</p><img src='" + res['url'] + "'/>\n" + "<h1>上面的图片是" + res['action'] + "的图片</h1>"; $("#content").append(picContent) }else { } } }) } </script> </body> </html>
今后的计划
1,之后实现图片的模糊处理和复原处理,做一个图片比较的功能,展示srcnn的优越性,然后通过原图和复原图进行性能的比较,得到这个复原的效果,将复原的性能指标通过图表的方式展示出来。
2,现在图片和项目都是在同一目录下的,打算搭建一个fastDFS的图片服务器,将图片和项目分开管理
3,学习一些前端的框架,将页面展示的更加人性化。
4,然后开始进行实验来验证srcnn的优越性