概况

从1月11日到23日这段时间,实现的内容主要有两个部分

part1

1,针对github的srcnn的代码,进一步的理解和掌握,在它的基础上,通过cv2等库实现了自己的srcnn的代码,实现的功能有一下几个部分。
1.1 实现图片的模糊和图像的重建,实现的结构都是单通道的,换句话说,就是灰色的图片
1.2 在1.1的基础上,得到的模糊图片和复原图片都是彩色的
1.3 实现图片的放大功能

part2

2,将1中的几个功能实现之后,就开始开发系统
2.1,选择flask和mysql进行开发,在这个基础上搭建项目
2.2,实现了图片的上传功能,图片的详情功能
2.3,实现图片的放大功能

part2中主要实现的是后台部分,前端页面的部分做的比较简单,仅仅可以展示结果

srcnn代码的实现

训练

    def train(self, config):
                prepare_for_train(self.sess)
        data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "train.h5")
        train_data, train_label = read_data(data_dir)
        global_step = tf.Variable(0)

        learning_rate_exp = tf.train.exponential_decay(config.learning_rate, global_step, 1480, 0.98,
                                                       staircase=True)  # 每1个Epoch 学习率*0.98
        self.train_op = tf.train.GradientDescentOptimizer(learning_rate_exp).minimize(self.loss,
                                                                                      global_step=global_step)
        tf.global_variables_initializer().run()
        start_time = time.time()
        res = self.load(self.checkpoint_dir)
        counter = 0
        if res:
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")
        print("Training...")
        for ep in range(config.epoch):
            # 以batch为单元
            batch_idxs = len(train_data) // config.batch_size
            for idx in range(0, batch_idxs):
                batch_images = train_data[idx * config.batch_size: (idx + 1) * config.batch_size]
                batch_labels = train_label[idx * config.batch_size: (idx + 1) * config.batch_size]
                counter += 1
                _, err = self.sess.run([self.train_op, self.loss],
                                       feed_dict={self.images: batch_images, self.labels: batch_labels})
                if counter % 10 == 0:  # 10的倍数step显示
                    print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \
                          % ((ep + 1), counter, time.time() - start_time, err))
                if counter % 500 == 0:  # 500的倍数step存储
                    self.save(config.checkpoint_dir, counter)

以上是训练的代码,我们分步进行分析

训练集的处理

                prepare_for_train(self.sess)
        data_dir = os.path.join('./{}'.format(config.checkpoint_dir), "train.h5")
        train_data, train_label = read_data(data_dir)

以上的代码就是为训练过程准备训练集,训练集是将91张图片先分割成许多的小图片,每一个小图片都分为两个,一个是清晰的,一个是模糊的,模糊的就是33x33的图片,而清晰的是21x21的图片,每一个33x33的图片经过模型的处理都会变成21x21的较为清晰的图片,然后这个时候和原本的清晰的21x21的图片进行比较就可以得到损失的大小,进而进行训练

def prepare_for_train(sess):
    data_dir = os.path.join(os.getcwd(), "Train")
    data = glob.glob(os.path.join(data_dir, "*.bmp"))
    print(len(data))  #
    sub_input_sequence = []
    sub_label_sequence = []
    padding = 6
    for i in range(len(data)):
        input_, label_ = preprocess_for_train(data[i], 3)
        if len(input_.shape) == 3:
            h, w, _ = input_.shape
        else:
            h, w = input_.shape
        for x in range(0, h - 33 + 1, 14):
            for y in range(0, w - 33 + 1, 14):
                sub_input = input_[x:x + 33, y:y + 33]
                sub_label = label_[x + 6:x + 6 + 21,
                            y + 6:y + 6 + 21]
                sub_input = sub_input.reshape([33, 33, 1])
                sub_label = sub_label.reshape([21, 21, 1])
                sub_input_sequence.append(sub_input)
                sub_label_sequence.append(sub_label)
    # 上面的部分和训练是一样的
    arrdata = np.asarray(sub_input_sequence)  # [?, 33, 33, 1]
    arrlabel = np.asarray(sub_label_sequence)  # [?, 21, 21, 1]
    savepath = os.path.join(os.getcwd(), 'checkpoint/train.h5')
    with h5py.File(savepath, 'w') as hf:
        hf.create_dataset('data', data=arrdata)
        hf.create_dataset('label', data=arrlabel)

以上就是得到训练集的方式,看最后几行的代码,我们就可以知道是33x33和21x21对应的集合.针对于91张每一张图片,我们是如何得到清晰和模糊的图片的呢,这里的处理是preprocess_for_train中实现的,如下所示

def preprocess_for_train(path, scale=3):
    scale -= 1
    image = imread(path, is_grayscale=True)
    label_ = modcrop(image, scale)

    # Must be normalized
    image = image / 255.
    label_ = label_ / 255.

    input_ = scipy.ndimage.interpolation.zoom(label_, (1. / scale), prefilter=False)
    input_ = scipy.ndimage.interpolation.zoom(input_, (scale / 1.), prefilter=False)

    return input_, label_

大概就是几个步骤,读入,剪切,得到模糊的,input_就是这图的模糊,label_就是这图的清晰。

训练的过程

    for ep in range(config.epoch):
            # 以batch为单元
            batch_idxs = len(train_data) // config.batch_size
            for idx in range(0, batch_idxs):
                batch_images = train_data[idx * config.batch_size: (idx + 1) * config.batch_size]
                batch_labels = train_label[idx * config.batch_size: (idx + 1) * config.batch_size]
                counter += 1
                _, err = self.sess.run([self.train_op, self.loss],
                                       feed_dict={self.images: batch_images, self.labels: batch_labels})
                if counter % 10 == 0:  # 10的倍数step显示
                    print("Epoch: [%2d], step: [%2d], time: [%4.4f], loss: [%.8f]" \
                          % ((ep + 1), counter, time.time() - start_time, err))
                if counter % 500 == 0:  # 500的倍数step存储
                    self.save(config.checkpoint_dir, counter)

训练不断的更新模型的参数,每10次就打印以此,每500次就记录一次参数,我们训练保存的文件在checkpoint的文件中
图片说明
这个就是保存训练参数的文件,每次执行训练的时候,我们就可以读取文件,并且每次在这个基础上进行训练
图片说明
显示读取成功,我们在这个的基础上,进行训练,发现一开始的损失函数就是已经达到的不错的数值

图片的复原

1月10日实现的图片复原,得到的仅仅是灰色的图片,现在可以得到彩色的图片

预处理

def preprocess_for_test(path, scale=3):
    scale -= 1
    im = cv.imread(path)
    im = cv.cvtColor(im, cv.COLOR_BGR2YCR_CB)
    img = im2double(im)

    label_ = modcrop(img, scale=scale)
    color_base = modcrop(im, scale=scale)
    label_ = label_[:, :, 0]
    input_ = scipy.ndimage.interpolation.zoom(label_, (1. / scale), prefilter=False)
    input_ = scipy.ndimage.interpolation.zoom(input_, (scale / 1.), prefilter=False)

    label_small = modcrop_small(label_)  # 把原图裁剪成和输出一样的大小
    input_small = modcrop_small(input_)  # 把原图裁剪成和输出一样的大小
    color_small = modcrop_small(color_base[:, :, 1:3])
    imsave(input_small, '/home/chengcongyue/PycharmProjects/SRCNN13/sample/input_.bmp')
    imsave(label_small, '/home/chengcongyue/PycharmProjects/SRCNN13/sample/label_.bmp')

    return input_, label_, color_small

这里的操作我们通过cv进行操作,首先是读入,转换图像空间,归一化,裁剪,这里我们要分为两个部分,然后仅针对单颜色空间进行模糊,然后剩下的两个通道,要传给结果。对于每一张的子图都是由33变成21,所以是变小的,因此,我们要通过方法,将这两个通道也变小

def parpare_for_test(sess, path):
    sub_input_sequence = []
    sub_label_sequence = []
    padding = 6
    # 测试
    input_, label_, color = preprocess_for_test(path, 3)  # 测试图片
    if len(input_.shape) == 3:
        h, w, _ = input_.shape
    else:
        h, w = input_.shape
    nx = 0
    ny = 0
    for x in range(0, h - 33 + 1, 21):
        nx += 1
        ny = 0
        for y in range(0, w - 33 + 1, 21):
            ny += 1
            sub_input = input_[x:x + 33, y:y + 33]
            sub_label = label_[x + 6:x + 6 + 21,
                        y + 6:y + 6 + 21]
            sub_input = sub_input.reshape([33, 33, 1])
            sub_label = sub_label.reshape([21, 21, 1])
            sub_input_sequence.append(sub_input)
            sub_label_sequence.append(sub_label)
    color = np.array(color)
    data = np.asarray(sub_input_sequence)
    label = np.asarray(sub_label_sequence)
    return data, label, color, nx, ny

以上也就是将这个转化成33x33和21x21的子图,这里的color就是上面剪切过的两个通道

复原

    def test(self, path, config):
        tf.global_variables_initializer().run()
        if self.load(self.checkpoint_dir):
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

以上是加载checkpoint的未见

        data, label, color, nx, ny = parpare_for_test(self.sess, path)

data就是模糊处理过的33x33的子图
label就是的21x21的子图
color就是剪切过的两个通道,用来合成
nx,ny用来合成小图片

        conv_out = self.pred.eval({self.images: data})
        conv_out = merge(conv_out, [nx, ny])
        conv_out = conv_out.squeeze()
        result_bw = revert(conv_out)

上面就是处理过的小图片再合成大的图片

        image_path = os.path.join(os.getcwd(), config.sample_dir)
        image_path = os.path.join(image_path, "MySRCNN_simple.bmp")
        imsave(result_bw, image_path)

上面是经过处理过的单通道的图片

        result = np.zeros([result_bw.shape[0], result_bw.shape[1], 3], dtype=np.uint8)
        result[:, :, 0] = result_bw
        result[:, :, 1:3] = color
        result = cv.cvtColor(result, cv.COLOR_YCrCb2RGB)
        image_dir = os.path.join(os.getcwd(), config.sample_dir)
        image_path = os.path.join(image_dir, "MySRCNN.bmp")
        imsave(result, image_path)

这里就是合成彩色的图片,得到的就是处理过的彩色图片

        conv_out = merge(label, [nx, ny])
        conv_out = conv_out.squeeze()
        result_bw = revert(conv_out)
        result = np.zeros([result_bw.shape[0], result_bw.shape[1], 3], dtype=np.uint8)
        result[:, :, 0] = result_bw
        result[:, :, 1:3] = color
        bicubic = cv.cvtColor(result, cv.COLOR_YCrCb2RGB)
        bicubic_path = os.path.join(image_dir, 'Orig_MySRCNN.bmp')
        imsave(bicubic, bicubic_path)

得到相同大小的原图,实际上就是在原图的基础上进行裁剪

图片的放大

图片的放大,就是在预处理的过程进行放大,然后在进行复原,因为通过双三次插值的方法放大的图片实际上就是模糊处理过的了

def preprocess_for__upscaling(path, scale=3):
    im = cv.imread(path)
    size = im.shape
    im = scipy.misc.imresize(im, [size[0] * scale, size[1] * scale], interp='bicubic')
    im = cv.cvtColor(im, cv.COLOR_BGR2YCR_CB)
    img = im2double(im)
    label_ = modcrop(img, scale=scale)
    color_base = modcrop(im, scale=scale)
    label_ = label_[:, :, 0]

    label_small = modcrop_small(label_)  # 把原图裁剪成和输出一样的大小
    color_small = modcrop_small(color_base[:, :, 1:3])
    imsave(label_small, '/home/chengcongyue/PycharmProjects/SRCNN5/sample/label_.bmp')

    return label_, label_, color_small

就是在放大的过程中,处理和原来不同

功能展示

图片的复原

图片说明
从左到右依次是,模糊的,复原的,原图
图片说明
分别是复原的图片和原图
图片说明
图片说明

图片的放大

图片说明
使用srcnn进行图片放大的效果比较不错,再举几个例子
图片说明
图片说明

项目的开发

项目目录的展示

项目使用了flask+mysql进行的搭建,如下就是项目目录
图片说明
从上到下
checkpoint就是srcnn训练得到的参数文件
migrations是执行下面的脚本生成的,和数据库相关
mytest这里就相当于图片服务器的目录,存储了这个项目的所有的图片,将来会进行改进
static存放js,css的静态文件
templates就是html页面
app.py相当于是项目的初始化类,初始化flask的app类,整个项目的配置等

from flask_sqlalchemy import SQLAlchemy
import configs
from flask import Flask

# 操作数据库的对象
db = SQLAlchemy()
# 操作flask的对象
app = Flask(__name__)
# 加载配置
app.config.from_object(configs)
# db绑定app
db.init_app(app)

from models import Picture
from views import index

configs.py就是项目的配置,通过app.py进行加载

HOST = '127.0.0.1'
PORT = '3306'
DATABASE = 'srcnn'
USERNAME = 'root'
PASSWORD = 'mysql'

DB_URI = "mysql+pymysql://{username}:{password}@{host}:{port}/{db}?charset=utf8".format(username=USERNAME,
                                                                                        password=PASSWORD, host=HOST,
                                                                                        port=PORT, db=DATABASE)

SQLALCHEMY_DATABASE_URI = DB_URI
SQLALCHEMY_TRACK_MODIFICATIONS = False
SQLALCHEMY_ECHO = True
UPLOAD_FOLDER = "./sample/"

main.py没有用,可以删除
manage.py这个是生成数据库的脚本,需要在命令框中运行命令

from flask_script import Manager
from flask_migrate import Migrate, MigrateCommand
from app import db, app
from models import Picture

manager = Manager(app)
Migrate(app=app, db=db)
manager.add_command('db', MigrateCommand)  # 创建数据库映射命令


if __name__ == '__main__':
    manager.run()

model.py就是srcnn的模型
models.py就是数据库表对应的模型
runserver.py就是该项目的运行脚本

from app import app
if __name__ == '__main__':
    app.run(debug=True)

utils.py图片的预处理脚本
views.py就是和html进行交互的,也就是我们的表现层

数据库的实现

对于当前项目,我们只需要一个类,就是picutre

from app import db
import time


class Picture(db.Model):
    __tablename__ = 'picture'
    # 图片的id
    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    # 图片的名称
    name = db.Column(db.String(50))
    # 后缀名
    suffix = db.Column(db.String(50))
    # 地址
    url = db.Column(db.String(100))
    # 上传或者重建的时间
    changetime = db.Column(db.String(20))
    # 行为:Bicubic,SRCNN,Origin,Upscale_X
    # Bicubic:模糊处理,双三次插值
    # SRCNN:卷积神经网络处理
    # Origin:原图或者没有处理
    # Upscale_X:放大多少倍数,如Upscale_3X表示放大3倍
    action = db.Column(db.String(20))
    # 原图id
    # 如果是原图的话,表示为-1
    orig_id = db.Column(db.Integer)

    def __init__(self, name, url, action, orig_id):
        self.name = name
        self.suffix = name[name.find('.') + 1:]
        self.url = url
        self.changetime = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
        self.action = action
        self.orig_id = orig_id
    def __repr__(self):
        return '<Picture %s %s %s %s>' % (self.name, self.url, self.changetime, self.action)
class Result(object):
    def __init__(self,
                status,message,dict):
        self.status=status
        self.message=message
        self.dict=dict
    def __repr__(self):
        return '<Result %s %s %s>' % (self.status, self.message, str(self.dict))

然后我们在当前项目目录下执行如下命令,就可以生成对应的数据库表

>python manage.py db init
>python manage.py db migrate
>python manage.py db upgrade

如下是数据库的结构
图片说明

具体功能的实现

图片上传

# 上传图片
@app.route('/uploader', methods=['GET', 'POST'])
def upload_file():
    if request.method == 'POST':
        f = request.files['file']
        # 判断是否是图片
        name = f.filename
        suffix = name[name.find('.') + 1:]
        if suffix != 'bmp' and suffix != 'jpg' and suffix != 'jpeg' and suffix != 'png':
            return jsonify(code=500, message="the file is not pic")
        # 存入数据库
        # 重新生成名字
        randomCode = get_code()
        name = randomCode + '.' + suffix
        # 设置url
        url = request.url
        url = url[0:url.rfind('/') + 1] + 'pic/' + randomCode
        pic = Picture(name, url, 'Origin', -1)
        # 添加
        db.session.add(pic)
        db.session.commit()
        # 上传图片
        basepath = os.path.dirname(__file__)
        upload_path = os.path.join(basepath, 'myTest', secure_filename(name))
        f.save(upload_path)
        # 保存缩略图50x50
        url = convertToThumbnail(name, url, suffix)
        return jsonify(code=200, message="success upload", url=url, name=name)

以上就是和上传图片相关的功能,第一个就是图片的上传,实现图片上传需要有如下的操作

设置上传图片的名称
存入数据库
存入图片服务器,这里就是myTest文件夹,将来打算实现图片服务器
创建这个图片的缩略图
缩略图存入数据库

我们来展示一下
图片说明
我们选择这个图片进行上传
图片说明
上传成功,这里展示的缩略图,我们来看一下图片服务器和数据库
图片说明
我们上传了两张图片,一张是原图,一张是缩略图,那么数据库也应该有两个记录
图片说明
我们来看最后两条记录,就是我们刚刚上传的图片以及缩略图的记录。
然后就是主页面的展示,除了上传的按钮,我还要展示曾经上传过的图片,这里展示的那些图片的缩略图

@app.route('/')
def index():
    pics = Picture.query.filter_by(action='Thumbnail_50x50').order_by(Picture.changetime.desc()).limit(10).all()
    return render_template('index.html', pics=pics)

我们查找行为是缩略图的所有图片,以及按照时间倒叙排列,并展示其中的10张图片,对于图片的上传,我们还需要唯一定位到这个图片,不是通过文件的路径,而是通过url,使用如下的代码

# 通过url访问图片
@app.route('/pic/<picName>', methods=['GET', 'POST'])
def indexPic(picName):
    picture = Picture.query.filter(Picture.name.like(picName + "%")).first()
    suffix = picture.suffix
    picName = os.path.join(os.getcwd(), 'myTest', picName + '.' + suffix)
    print(picName)
    image_data = open(picName, "rb").read()
    response = make_response(image_data)
    response.headers['Content-Type'] = 'image/png'
    return response

图片说明
如上图所示我们就可以通过url唯一定位到我们的图片。和以上功能相关的html如下

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>超分辨率图像处理系统</title>
</head>
<script type="text/javascript"
        src="/static/js/jquery-1.8.0.min.js"></script>
<body>
<input type="file" name="file" id="file">
<input type="button" onclick="postData()" value="上传"/>
<div id="showpicture">
    {% for pic in pics %}
    <a target="_blank" href='/detail/{{pic.name[0:32]}}'><img src="{{pic.url}}"/></a>
    {% endfor%}
</div>
<script type="text/javascript">
    function postData() {
        var formData = new FormData();
        formData.append('file', $('#file')[0].files[0]);
        $.ajax({
            url: '/uploader',
            type: 'post',
            data: formData,
            contentType: false,
            processData: false,
            success: function (res) {
                var picIndex = "<a target=\"_blank\" href='/detail/" + res.name.substr(0, 32) + "'><img src='" + res.url + "'/></a>"
                $('#showpicture').prepend(picIndex)
            }
        })
    }
</script>
</form>
</body>
</html>

图片详情页

然后就是图片详情页的操作,也就是在主页点击图片可以展示这个图片的详细信息,信息包括这个图片的名称,以及放大,模糊,复原的样图,这里操作比较容易,如下

# 点击图片访问该图片的详情
@app.route('/detail/<picname>', methods=['GET', 'POST'])
def detail(picname):
    pictures = Picture.query.filter(Picture.name.like(picname + "%")).all()
    return render_template('detail.html', pictures=pictures)

展示一下
图片说明
我们点击最后那个小女孩的图片
图片说明
我们将所有的情况都展示了出来,这就是图片详情页的主要功能

图片放大

然后就是图片放大的功能

@app.route('/upscaling', methods=['GET', 'POST'])
def upscaling():
    data = json.loads(request.get_data(as_text=True))
    times = data['times']
    picname = data['picname']
    # 路径
    path = os.path.join(os.getcwd(), FLAGS.sample_dir, picname)
    print(path)
    # 名字
    picture = Picture.query.filter_by(name=picname).first()
    picname = picname[0:picname.find('.')] + times + 'x_.' + picture.suffix
    # url
    url = picture.url
    url = url[0:url.rfind('/')] + '/' + picname[0:picname.find('_') + 1]
    # 放大图片
    with tf.Session() as sess:
        srcnn = SRCNN(sess,
                      checkpoint_dir=FLAGS.checkpoint_dir,
                      sample_dir=FLAGS.sample_dir)
        srcnn.upscaling(picname,
                        path,
                        FLAGS, int(times))
    # 保存数据库
    action='Upscale_' + times + 'X'
    newpic = Picture(picname, url,action , picture.id)
    db.session.add(newpic)
    db.session.commit()
    return jsonify(code=200, message="success upscaling", name=picname, url=url, action=action)

思路还是十分简单的,我们来展示一下,我们使用刚刚上传的女人头像的图片
图片说明
我们进入图片详情页之后,点击放大,就会弹出一个下拉框,我们可以选择放大的倍数,有2倍和3倍
图片说明
我们可以随意的选择,比如我们可以选择2x
图片说明
然后在选择3x
图片说明
我们再来看一下文件夹和数据库
图片说明
然后再来看一下数据库
图片说明
相关的html页面如下

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>图片详情页</title>
</head>
<script type="text/javascript"
        src="/static/js/jquery-1.8.0.min.js"></script>
<body>
<div id="content">
    {% for pic in pictures%}

    <p>图片名称:{{pic.name}}</p>
    <p>图片如下所示:</p><img src="{{pic.url}}"/>

    {% if pic.action=='Origin' %}
    <h1> 上面的图片是原图 </h1>
    <input id="picname" type="hidden" name="picname" value="{{pic.name}}"/>
    <input type="button" name="upscaling" value="放大" onclick="createSelect()"/>&nbsp;&nbsp;&nbsp;<span
        id="times"></span>
    {% elif pic.action=='Thumbnail_50x50' %}
    <h1>上面的图片是缩略图</h1>
    {% elif pic.action.find('Upscale')!=-1 %}
    <h1>上面的图片是{{pic.action}}的图片</h1>
    {% endif%}
    {% endfor%}
    <div>
        <script type="text/javascript">
            function createSelect() {
                if ($("#upscalingMenu").length > 0) {
                    return;
                }
                var times = "<select id=\"upscalingMenu\" onchange=\"upscaling()\"> \n" +
                    "<option value=\"1\">1x</option> \n" +
                    "<option value=\"2\">2x</option> \n" +
                    "<option value=\"3\">3x</option> \n" +
                    "</select> "
                $('#times').append(times)
            }

            function upscaling() {
                var times = $("#upscalingMenu option:selected").val();
                if (parseInt(times) == 1) {
                    return;
                }
                var picname = $("#picname").val();
                $.ajax({
                    url: '/upscaling',
                    type: 'post',
                    dataType: 'json',
                    data: JSON.stringify({
                        "times": times,
                        "picname": picname
                    }),
                    headers: {
                        "Content-Type": "application/json;charset=utf-8"
                    },
                    contentType: 'application/json; charset=utf-8',
                    success: function (res) {
                        if(res['code']=='200')
                        {
                            var picContent = "<p>图片名称:" + res['name'] + "</p>\n" +
                            "<p>图片如下所示:</p><img src='" + res['url'] + "'/>\n" +
                            "<h1>上面的图片是" + res['action'] + "的图片</h1>";
                            $("#content").append(picContent)
                        }else
                        {

                        }
                    }
                })
            }
        </script>
</body>
</html>

今后的计划 

1,之后实现图片的模糊处理和复原处理,做一个图片比较的功能,展示srcnn的优越性,然后通过原图和复原图进行性能的比较,得到这个复原的效果,将复原的性能指标通过图表的方式展示出来。
2,现在图片和项目都是在同一目录下的,打算搭建一个fastDFS的图片服务器,将图片和项目分开管理
3,学习一些前端的框架,将页面展示的更加人性化。
4,然后开始进行实验来验证srcnn的优越性