概况
从1月11日到23日这段时间,实现的内容主要有两个部分
part1
1,针对github的srcnn的代码,进一步的理解和掌握,在它的基础上,通过cv2等库实现了自己的srcnn的代码,实现的功能有一下几个部分。
1.1 实现图片的模糊和图像的重建,实现的结构都是单通道的,换句话说,就是灰色的图片
1.2 在1.1的基础上,得到的模糊图片和复原图片都是彩色的
1.3 实现图片的放大功能
part2
2,将1中的几个功能实现之后,就开始开发系统
2.1,选择flask和mysql进行开发,在这个基础上搭建项目
2.2,实现了图片的上传功能,图片的详情功能
2.3,实现图片的放大功能
注
part2中主要实现的是后台部分,前端页面的部分做的比较简单,仅仅可以展示结果
srcnn代码的实现
训练
def train(self, config):
prepare_for_train(self.sess)
data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "train.h5")
train_data, train_label = read_data(data_dir)
global_step = tf.Variable(0)
learning_rate_exp = tf.train.exponential_decay(config.learning_rate, global_step, 1480, 0.98,
staircase=True) # 每1个Epoch 学习率*0.98
self.train_op = tf.train.GradientDescentOptimizer(learning_rate_exp).minimize(self.loss,
global_step=global_step)
tf.global_variables_initializer().run()
start_time = time.time()
res = self.load(self.checkpoint_dir)
counter = 0
if res:
print(" [*] Load SUCCESS")
else:
print(" [!] Load failed...")
print("Training...")
for ep in range(config.epoch):
# 以batch为单元
batch_idxs = len(train_data) // config.batch_size
for idx in range(0, batch_idxs):
batch_images = train_data[idx * config.batch_size: (idx + 1) * config.batch_size]
batch_labels = train_label[idx * config.batch_size: (idx + 1) * config.batch_size]
counter += 1
_, err = self.sess.run([self.train_op, self.loss],
feed_dict={self.images: batch_images, self.labels: batch_labels})
if counter % 10 == 0: # 10的倍数step显示
print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \
% ((ep + 1), counter, time.time() - start_time, err))
if counter % 500 == 0: # 500的倍数step存储
self.save(config.checkpoint_dir, counter)以上是训练的代码,我们分步进行分析
训练集的处理
prepare_for_train(self.sess)
data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "train.h5")
train_data, train_label = read_data(data_dir)以上的代码就是为训练过程准备训练集,训练集是将91张图片先分割成许多的小图片,每一个小图片都分为两个,一个是清晰的,一个是模糊的,模糊的就是33x33的图片,而清晰的是21x21的图片,每一个33x33的图片经过模型的处理都会变成21x21的较为清晰的图片,然后这个时候和原本的清晰的21x21的图片进行比较就可以得到损失的大小,进而进行训练
def prepare_for_train(sess):
data_dir = os.path.join(os.getcwd(), "Train")
data = glob.glob(os.path.join(data_dir, "*.bmp"))
print(len(data)) #
sub_input_sequence = []
sub_label_sequence = []
padding = 6
for i in range(len(data)):
input_, label_ = preprocess_for_train(data[i], 3)
if len(input_.shape) == 3:
h, w, _ = input_.shape
else:
h, w = input_.shape
for x in range(0, h - 33 + 1, 14):
for y in range(0, w - 33 + 1, 14):
sub_input = input_[x:x + 33, y:y + 33]
sub_label = label_[x + 6:x + 6 + 21,
y + 6:y + 6 + 21]
sub_input = sub_input.reshape([33, 33, 1])
sub_label = sub_label.reshape([21, 21, 1])
sub_input_sequence.append(sub_input)
sub_label_sequence.append(sub_label)
# 上面的部分和训练是一样的
arrdata = np.asarray(sub_input_sequence) # [?, 33, 33, 1]
arrlabel = np.asarray(sub_label_sequence) # [?, 21, 21, 1]
savepath = os.path.join(os.getcwd(), 'checkpoint/train.h5')
with h5py.File(savepath, 'w') as hf:
hf.create_dataset('data', data=arrdata)
hf.create_dataset('label', data=arrlabel)以上就是得到训练集的方式,看最后几行的代码,我们就可以知道是33x33和21x21对应的集合.针对于91张每一张图片,我们是如何得到清晰和模糊的图片的呢,这里的处理是preprocess_for_train中实现的,如下所示
def preprocess_for_train(path, scale=3):
scale -= 1
image = imread(path, is_grayscale=True)
label_ = modcrop(image, scale)
# Must be normalized
image = image / 255.
label_ = label_ / 255.
input_ = scipy.ndimage.interpolation.zoom(label_, (1. / scale), prefilter=False)
input_ = scipy.ndimage.interpolation.zoom(input_, (scale / 1.), prefilter=False)
return input_, label_大概就是几个步骤,读入,剪切,得到模糊的,input_就是这图的模糊,label_就是这图的清晰。
训练的过程
for ep in range(config.epoch):
# 以batch为单元
batch_idxs = len(train_data) // config.batch_size
for idx in range(0, batch_idxs):
batch_images = train_data[idx * config.batch_size: (idx + 1) * config.batch_size]
batch_labels = train_label[idx * config.batch_size: (idx + 1) * config.batch_size]
counter += 1
_, err = self.sess.run([self.train_op, self.loss],
feed_dict={self.images: batch_images, self.labels: batch_labels})
if counter % 10 == 0: # 10的倍数step显示
print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \
% ((ep + 1), counter, time.time() - start_time, err))
if counter % 500 == 0: # 500的倍数step存储
self.save(config.checkpoint_dir, counter)训练不断的更新模型的参数,每10次就打印以此,每500次就记录一次参数,我们训练保存的文件在checkpoint的文件中
这个就是保存训练参数的文件,每次执行训练的时候,我们就可以读取文件,并且每次在这个基础上进行训练
显示读取成功,我们在这个的基础上,进行训练,发现一开始的损失函数就是已经达到的不错的数值
图片的复原
1月10日实现的图片复原,得到的仅仅是灰色的图片,现在可以得到彩色的图片
预处理
def preprocess_for_test(path, scale=3):
scale -= 1
im = cv.imread(path)
im = cv.cvtColor(im, cv.COLOR_BGR2YCR_CB)
img = im2double(im)
label_ = modcrop(img, scale=scale)
color_base = modcrop(im, scale=scale)
label_ = label_[:, :, 0]
input_ = scipy.ndimage.interpolation.zoom(label_, (1. / scale), prefilter=False)
input_ = scipy.ndimage.interpolation.zoom(input_, (scale / 1.), prefilter=False)
label_small = modcrop_small(label_) # 把原图裁剪成和输出一样的大小
input_small = modcrop_small(input_) # 把原图裁剪成和输出一样的大小
color_small = modcrop_small(color_base[:, :, 1:3])
imsave(input_small, '/home/chengcongyue/PycharmProjects/SRCNN13/sample/input_.bmp')
imsave(label_small, '/home/chengcongyue/PycharmProjects/SRCNN13/sample/label_.bmp')
return input_, label_, color_small这里的操作我们通过cv进行操作,首先是读入,转换图像空间,归一化,裁剪,这里我们要分为两个部分,然后仅针对单颜色空间进行模糊,然后剩下的两个通道,要传给结果。对于每一张的子图都是由33变成21,所以是变小的,因此,我们要通过方法,将这两个通道也变小
def parpare_for_test(sess, path):
sub_input_sequence = []
sub_label_sequence = []
padding = 6
# 测试
input_, label_, color = preprocess_for_test(path, 3) # 测试图片
if len(input_.shape) == 3:
h, w, _ = input_.shape
else:
h, w = input_.shape
nx = 0
ny = 0
for x in range(0, h - 33 + 1, 21):
nx += 1
ny = 0
for y in range(0, w - 33 + 1, 21):
ny += 1
sub_input = input_[x:x + 33, y:y + 33]
sub_label = label_[x + 6:x + 6 + 21,
y + 6:y + 6 + 21]
sub_input = sub_input.reshape([33, 33, 1])
sub_label = sub_label.reshape([21, 21, 1])
sub_input_sequence.append(sub_input)
sub_label_sequence.append(sub_label)
color = np.array(color)
data = np.asarray(sub_input_sequence)
label = np.asarray(sub_label_sequence)
return data, label, color, nx, ny以上也就是将这个转化成33x33和21x21的子图,这里的color就是上面剪切过的两个通道
复原
def test(self, path, config):
tf.global_variables_initializer().run()
if self.load(self.checkpoint_dir):
print(" [*] Load SUCCESS")
else:
print(" [!] Load failed...")以上是加载checkpoint的未见
data, label, color, nx, ny = parpare_for_test(self.sess, path)
data就是模糊处理过的33x33的子图
label就是的21x21的子图
color就是剪切过的两个通道,用来合成
nx,ny用来合成小图片
conv_out = self.pred.eval({self.images: data})
conv_out = merge(conv_out, [nx, ny])
conv_out = conv_out.squeeze()
result_bw = revert(conv_out)上面就是处理过的小图片再合成大的图片
image_path = os.path.join(os.getcwd(), config.sample_dir)
image_path = os.path.join(image_path, "MySRCNN_simple.bmp")
imsave(result_bw, image_path)上面是经过处理过的单通道的图片
result = np.zeros([result_bw.shape[0], result_bw.shape[1], 3], dtype=np.uint8)
result[:, :, 0] = result_bw
result[:, :, 1:3] = color
result = cv.cvtColor(result, cv.COLOR_YCrCb2RGB)
image_dir = os.path.join(os.getcwd(), config.sample_dir)
image_path = os.path.join(image_dir, "MySRCNN.bmp")
imsave(result, image_path)这里就是合成彩色的图片,得到的就是处理过的彩色图片
conv_out = merge(label, [nx, ny])
conv_out = conv_out.squeeze()
result_bw = revert(conv_out)
result = np.zeros([result_bw.shape[0], result_bw.shape[1], 3], dtype=np.uint8)
result[:, :, 0] = result_bw
result[:, :, 1:3] = color
bicubic = cv.cvtColor(result, cv.COLOR_YCrCb2RGB)
bicubic_path = os.path.join(image_dir, 'Orig_MySRCNN.bmp')
imsave(bicubic, bicubic_path)得到相同大小的原图,实际上就是在原图的基础上进行裁剪
图片的放大
图片的放大,就是在预处理的过程进行放大,然后在进行复原,因为通过双三次插值的方法放大的图片实际上就是模糊处理过的了
def preprocess_for__upscaling(path, scale=3):
im = cv.imread(path)
size = im.shape
im = scipy.misc.imresize(im, [size[0] * scale, size[1] * scale], interp='bicubic')
im = cv.cvtColor(im, cv.COLOR_BGR2YCR_CB)
img = im2double(im)
label_ = modcrop(img, scale=scale)
color_base = modcrop(im, scale=scale)
label_ = label_[:, :, 0]
label_small = modcrop_small(label_) # 把原图裁剪成和输出一样的大小
color_small = modcrop_small(color_base[:, :, 1:3])
imsave(label_small, '/home/chengcongyue/PycharmProjects/SRCNN5/sample/label_.bmp')
return label_, label_, color_small就是在放大的过程中,处理和原来不同
功能展示
图片的复原
从左到右依次是,模糊的,复原的,原图
分别是复原的图片和原图
图片的放大
使用srcnn进行图片放大的效果比较不错,再举几个例子
项目的开发
项目目录的展示
项目使用了flask+mysql进行的搭建,如下就是项目目录
从上到下
checkpoint就是srcnn训练得到的参数文件
migrations是执行下面的脚本生成的,和数据库相关
mytest这里就相当于图片服务器的目录,存储了这个项目的所有的图片,将来会进行改进
static存放js,css的静态文件
templates就是html页面
app.py相当于是项目的初始化类,初始化flask的app类,整个项目的配置等
from flask_sqlalchemy import SQLAlchemy import configs from flask import Flask # 操作数据库的对象 db = SQLAlchemy() # 操作flask的对象 app = Flask(__name__) # 加载配置 app.config.from_object(configs) # db绑定app db.init_app(app) from models import Picture from views import index
configs.py就是项目的配置,通过app.py进行加载
HOST = '127.0.0.1'
PORT = '3306'
DATABASE = 'srcnn'
USERNAME = 'root'
PASSWORD = 'mysql'
DB_URI = "mysql+pymysql://{username}:{password}@{host}:{port}/{db}?charset=utf8".format(username=USERNAME,
password=PASSWORD, host=HOST,
port=PORT, db=DATABASE)
SQLALCHEMY_DATABASE_URI = DB_URI
SQLALCHEMY_TRACK_MODIFICATIONS = False
SQLALCHEMY_ECHO = True
UPLOAD_FOLDER = "./sample/"main.py没有用,可以删除
manage.py这个是生成数据库的脚本,需要在命令框中运行命令
from flask_script import Manager
from flask_migrate import Migrate, MigrateCommand
from app import db, app
from models import Picture
manager = Manager(app)
Migrate(app=app, db=db)
manager.add_command('db', MigrateCommand) # 创建数据库映射命令
if __name__ == '__main__':
manager.run()model.py就是srcnn的模型
models.py就是数据库表对应的模型
runserver.py就是该项目的运行脚本
from app import app
if __name__ == '__main__':
app.run(debug=True)utils.py图片的预处理脚本
views.py就是和html进行交互的,也就是我们的表现层
数据库的实现
对于当前项目,我们只需要一个类,就是picutre
from app import db
import time
class Picture(db.Model):
__tablename__ = 'picture'
# 图片的id
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
# 图片的名称
name = db.Column(db.String(50))
# 后缀名
suffix = db.Column(db.String(50))
# 地址
url = db.Column(db.String(100))
# 上传或者重建的时间
changetime = db.Column(db.String(20))
# 行为:Bicubic,SRCNN,Origin,Upscale_X
# Bicubic:模糊处理,双三次插值
# SRCNN:卷积神经网络处理
# Origin:原图或者没有处理
# Upscale_X:放大多少倍数,如Upscale_3X表示放大3倍
action = db.Column(db.String(20))
# 原图id
# 如果是原图的话,表示为-1
orig_id = db.Column(db.Integer)
def __init__(self, name, url, action, orig_id):
self.name = name
self.suffix = name[name.find('.') + 1:]
self.url = url
self.changetime = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
self.action = action
self.orig_id = orig_id
def __repr__(self):
return '<Picture %s %s %s %s>' % (self.name, self.url, self.changetime, self.action)
class Result(object):
def __init__(self,
status,message,dict):
self.status=status
self.message=message
self.dict=dict
def __repr__(self):
return '<Result %s %s %s>' % (self.status, self.message, str(self.dict))然后我们在当前项目目录下执行如下命令,就可以生成对应的数据库表
>python manage.py db init >python manage.py db migrate >python manage.py db upgrade
如下是数据库的结构
具体功能的实现
图片上传
# 上传图片
@app.route('/uploader', methods=['GET', 'POST'])
def upload_file():
if request.method == 'POST':
f = request.files['file']
# 判断是否是图片
name = f.filename
suffix = name[name.find('.') + 1:]
if suffix != 'bmp' and suffix != 'jpg' and suffix != 'jpeg' and suffix != 'png':
return jsonify(code=500, message="the file is not pic")
# 存入数据库
# 重新生成名字
randomCode = get_code()
name = randomCode + '.' + suffix
# 设置url
url = request.url
url = url[0:url.rfind('/') + 1] + 'pic/' + randomCode
pic = Picture(name, url, 'Origin', -1)
# 添加
db.session.add(pic)
db.session.commit()
# 上传图片
basepath = os.path.dirname(__file__)
upload_path = os.path.join(basepath, 'myTest', secure_filename(name))
f.save(upload_path)
# 保存缩略图50x50
url = convertToThumbnail(name, url, suffix)
return jsonify(code=200, message="success upload", url=url, name=name)以上就是和上传图片相关的功能,第一个就是图片的上传,实现图片上传需要有如下的操作
设置上传图片的名称 存入数据库 存入图片服务器,这里就是myTest文件夹,将来打算实现图片服务器 创建这个图片的缩略图 缩略图存入数据库
我们来展示一下
我们选择这个图片进行上传
上传成功,这里展示的缩略图,我们来看一下图片服务器和数据库
我们上传了两张图片,一张是原图,一张是缩略图,那么数据库也应该有两个记录
我们来看最后两条记录,就是我们刚刚上传的图片以及缩略图的记录。
然后就是主页面的展示,除了上传的按钮,我还要展示曾经上传过的图片,这里展示的那些图片的缩略图
@app.route('/')
def index():
pics = Picture.query.filter_by(action='Thumbnail_50x50').order_by(Picture.changetime.desc()).limit(10).all()
return render_template('index.html', pics=pics)我们查找行为是缩略图的所有图片,以及按照时间倒叙排列,并展示其中的10张图片,对于图片的上传,我们还需要唯一定位到这个图片,不是通过文件的路径,而是通过url,使用如下的代码
# 通过url访问图片
@app.route('/pic/<picName>', methods=['GET', 'POST'])
def indexPic(picName):
picture = Picture.query.filter(Picture.name.like(picName + "%")).first()
suffix = picture.suffix
picName = os.path.join(os.getcwd(), 'myTest', picName + '.' + suffix)
print(picName)
image_data = open(picName, "rb").read()
response = make_response(image_data)
response.headers['Content-Type'] = 'image/png'
return response
如上图所示我们就可以通过url唯一定位到我们的图片。和以上功能相关的html如下
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>超分辨率图像处理系统</title>
</head>
<script type="text/javascript"
src="/static/js/jquery-1.8.0.min.js"></script>
<body>
<input type="file" name="file" id="file">
<input type="button" onclick="postData()" value="上传"/>
<div id="showpicture">
{% for pic in pics %}
<a target="_blank" href='/detail/{{pic.name[0:32]}}'><img src="{{pic.url}}"/></a>
{% endfor%}
</div>
<script type="text/javascript">
function postData() {
var formData = new FormData();
formData.append('file', $('#file')[0].files[0]);
$.ajax({
url: '/uploader',
type: 'post',
data: formData,
contentType: false,
processData: false,
success: function (res) {
var picIndex = "<a target=\"_blank\" href='/detail/" + res.name.substr(0, 32) + "'><img src='" + res.url + "'/></a>"
$('#showpicture').prepend(picIndex)
}
})
}
</script>
</form>
</body>
</html>图片详情页
然后就是图片详情页的操作,也就是在主页点击图片可以展示这个图片的详细信息,信息包括这个图片的名称,以及放大,模糊,复原的样图,这里操作比较容易,如下
# 点击图片访问该图片的详情
@app.route('/detail/<picname>', methods=['GET', 'POST'])
def detail(picname):
pictures = Picture.query.filter(Picture.name.like(picname + "%")).all()
return render_template('detail.html', pictures=pictures)展示一下
我们点击最后那个小女孩的图片
我们将所有的情况都展示了出来,这就是图片详情页的主要功能
图片放大
然后就是图片放大的功能
@app.route('/upscaling', methods=['GET', 'POST'])
def upscaling():
data = json.loads(request.get_data(as_text=True))
times = data['times']
picname = data['picname']
# 路径
path = os.path.join(os.getcwd(), FLAGS.sample_dir, picname)
print(path)
# 名字
picture = Picture.query.filter_by(name=picname).first()
picname = picname[0:picname.find('.')] + times + 'x_.' + picture.suffix
# url
url = picture.url
url = url[0:url.rfind('/')] + '/' + picname[0:picname.find('_') + 1]
# 放大图片
with tf.Session() as sess:
srcnn = SRCNN(sess,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir)
srcnn.upscaling(picname,
path,
FLAGS, int(times))
# 保存数据库
action='Upscale_' + times + 'X'
newpic = Picture(picname, url,action , picture.id)
db.session.add(newpic)
db.session.commit()
return jsonify(code=200, message="success upscaling", name=picname, url=url, action=action)思路还是十分简单的,我们来展示一下,我们使用刚刚上传的女人头像的图片
我们进入图片详情页之后,点击放大,就会弹出一个下拉框,我们可以选择放大的倍数,有2倍和3倍
我们可以随意的选择,比如我们可以选择2x
然后在选择3x
我们再来看一下文件夹和数据库
然后再来看一下数据库
相关的html页面如下
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>图片详情页</title>
</head>
<script type="text/javascript"
src="/static/js/jquery-1.8.0.min.js"></script>
<body>
<div id="content">
{% for pic in pictures%}
<p>图片名称:{{pic.name}}</p>
<p>图片如下所示:</p><img src="{{pic.url}}"/>
{% if pic.action=='Origin' %}
<h1> 上面的图片是原图 </h1>
<input id="picname" type="hidden" name="picname" value="{{pic.name}}"/>
<input type="button" name="upscaling" value="放大" onclick="createSelect()"/> <span
id="times"></span>
{% elif pic.action=='Thumbnail_50x50' %}
<h1>上面的图片是缩略图</h1>
{% elif pic.action.find('Upscale')!=-1 %}
<h1>上面的图片是{{pic.action}}的图片</h1>
{% endif%}
{% endfor%}
<div>
<script type="text/javascript">
function createSelect() {
if ($("#upscalingMenu").length > 0) {
return;
}
var times = "<select id=\"upscalingMenu\" onchange=\"upscaling()\"> \n" +
"<option value=\"1\">1x</option> \n" +
"<option value=\"2\">2x</option> \n" +
"<option value=\"3\">3x</option> \n" +
"</select> "
$('#times').append(times)
}
function upscaling() {
var times = $("#upscalingMenu option:selected").val();
if (parseInt(times) == 1) {
return;
}
var picname = $("#picname").val();
$.ajax({
url: '/upscaling',
type: 'post',
dataType: 'json',
data: JSON.stringify({
"times": times,
"picname": picname
}),
headers: {
"Content-Type": "application/json;charset=utf-8"
},
contentType: 'application/json; charset=utf-8',
success: function (res) {
if(res['code']=='200')
{
var picContent = "<p>图片名称:" + res['name'] + "</p>\n" +
"<p>图片如下所示:</p><img src='" + res['url'] + "'/>\n" +
"<h1>上面的图片是" + res['action'] + "的图片</h1>";
$("#content").append(picContent)
}else
{
}
}
})
}
</script>
</body>
</html>今后的计划
1,之后实现图片的模糊处理和复原处理,做一个图片比较的功能,展示srcnn的优越性,然后通过原图和复原图进行性能的比较,得到这个复原的效果,将复原的性能指标通过图表的方式展示出来。
2,现在图片和项目都是在同一目录下的,打算搭建一个fastDFS的图片服务器,将图片和项目分开管理
3,学习一些前端的框架,将页面展示的更加人性化。
4,然后开始进行实验来验证srcnn的优越性

京公网安备 11010502036488号