技术人生,  机器学习,  网络编程

基于web平台的手写体数字识别

本文使用numpy搭建基础的神经网络,用于手写体数字识别。用户可以在网页上进行书写,前端将图像信息发送给后端识别并接收识别结果。

大致流程如下:

  • 针对MNIST数据集的神经网络模型的训练,完成后保存训练模型
  • 后端:编写识别接口,使用训练得到的模型识别手写体数字图片
  • 前端:编写web页面,完成笔迹绘制功能,截图处理后传到后端进行识别
  • 后端:编写Tornado服务器,接收图片,进一步处理后识别图片并返回结果

第一部分:模型训练

数据集采用MNIST手写体数字数据集,使用numpy搭建三层人工神经网络,使用Sigmoid函数作为激活函数,误差梯度下降法作为权重更新方法,经过测试,识别准确率高达约97%。完成训练后保存模型权重参数。

该部分具体内容参考文章:使用numpy搭建基础的神经网络

这里给出保存权重参数的方法,numpy.save()方法。

    def save_array(self):
        np.savez('weight',wih=self.wih,who=self.who)

第二部分:识别接口

模型训练完成后,需要编写一个用于识别的接口使用训练得到的模型权重参数识别输入图片数据,用于后端服务器的直接调用识别。该接口输入图片对应的numpy数组,输出识别结果。

参考第一部分测试部分的代码,还原神经网络信息正向传播的模型。

import numpy as np
import scipy.special

def query(inputs):
    '''输入参数为784维(28*28延展后)灰度图像的numpy数组,输出识别的数字'''

    #归一化至0.01~0.99
    inputs = inputs / 255.0 * 0.99 + 0.01

    #读取权重参数
    npzfile = np.load('weight.npz')
    wih = npzfile['wih']
    who = npzfile['who']

    #输入层
    inputs = np.array(inputs, ndmin=2).T

    #计算隐藏层
    hidden_inputs = wih @ inputs
    hidden_outputs = scipy.special.expit(hidden_inputs)

    #计算输出层
    final_inputs = who @ hidden_outputs
    final_outputs = scipy.special.expit(final_inputs)

    #返回概率最大的标签
    result = np.argmax(final_outputs)

    return int(result)

第三部分:前端页面

首先编写html页面,我使用canvas提供用户的书写区域,两个按钮分别用于回传图片数据和清空画布。

<!DOCTYPE html>
<html>
    <head>
        <meta charset="utf-8"></meta>
        <meta name="viewport" content="width=device-width, initial-scale=1">
        <title>手写体数字识别</title>
        <script src="/static/process.js"></script>
    </head>
    <body style="background-image: linear-gradient(to right , orange, mediumpurple, #3399FF);">
        <div style="text-align:center;">
            <h1 id='result'>手写体数字识别</h1>
            <canvas id="cv" width="300" height="300" style="border:1px solid #000000;"></canvas>
            <script src="/static/draw.js"></script>
            <br>
            <button onclick="image_process()">提交</button>
            <button onclick="clear_canvas()">清空</button>
        </div>
    </body>
</html>
首页

绘图部分采用javascript实现在canvas上书写,这里我们允许用户可以使用鼠标和手指来进行书写操作,由于二者所绑定的事件不同,所以需要分别编写代码。

首先定义一个获得笔触坐标的函数,由于鼠标或手指的默认坐标是相对于页面而言的,所以我需要通过计算得到它们基于画布的坐标。

我将画布使用白色填充,使用变量isAllowDraw来定义当前是否属于可书写状态,只有按下鼠标或手指时可以进行书写。

从鼠标事件开始,当用户按下鼠标时绘制起始点,按下并移动鼠标时使用stroke函数连续绘制线段以达成书写的目的,松开鼠标时不再允许书写。

手指事件与鼠标类似,细小区别在于当滑动手指时需要禁止页面滚动以确保用户可以正常在画布上书写,由于页面滚动属于手指滑动的默认事件,因此使用preventDefault函数即可。

//获取笔触在canvas上的坐标
function get_pencil (canvas, x, y) {
    var rect = canvas.getBoundingClientRect()
    //x和y参数分别传入的是鼠标距离窗口的坐标,然后减去canvas距离窗口左边和顶部的距离
    return {x: x - rect.left * (canvas.width / rect.width), y: y - rect.top * (canvas.height / rect.height)}
}

var ctx = cv.getContext('2d');
ctx.fillStyle="#FFFFFF";
ctx.fillRect(0,0,cv.width,cv.height);
var isAllowDraw = false;

//鼠标按下事件
onmousedown = function (e) {
    isAllowDraw = true;
    //获得鼠标按下的点相对canvas的坐标
    var {x, y} = get_pencil(cv, e.clientX, e.clientY);
    //绘制起点
    ctx.moveTo(x, y);
}

//鼠标移动事件
onmousemove = function (e) {
    //移动时获取新的坐标位置,用lineTo记录当前的坐标,然后stroke绘制上一个点到当前点的路径
    if (isAllowDraw) {
        var {x, y} = get_pencil(cv, e.clientX, e.clientY);
        ctx.lineTo(x, y);
        ctx.strokeStyle = 'black';
        ctx.lineWidth = 15;
        ctx.stroke();
    }
}

//鼠标抬起事件
onmouseup = function () {
    //鼠标抬起停止作画
    isAllowDraw = false;
}

//手指按下事件
window.addEventListener('touchstart', function (e) {
    isAllowDraw = true;
    //获得手指按下的点相对canvas的坐标
    var {x, y} = get_pencil(cv, e.touches[0].clientX, e.touches[0].clientY);
    //绘制起点
    ctx.moveTo(x, y);
})

//手指移动事件
window.addEventListener('touchmove', function (e) {
    //禁止页面滚动
    e.preventDefault();
    //移动时获取新的坐标位置,用lineTo记录当前的坐标,然后stroke绘制上一个点到当前点的路径
    if (isAllowDraw) {
        var {x, y} = get_pencil(cv, e.touches[0].clientX, e.touches[0].clientY);
        ctx.lineTo(x, y);
        ctx.strokeStyle = 'black';
        ctx.lineWidth = 15;
        ctx.stroke();
    }
}, {passive: false})

//手指抬起事件
window.addEventListener('touchend', function () {
    //手指抬起停止作画
    isAllowDraw = false;
})

书写方式优化与改进:

使用线段连接的方式进行书写很容易造成边缘的锯齿和不连贯,解决的方法是使用二次贝塞尔曲线进行绘制,二次贝塞尔曲线通过计算三个控制点来绘制出平滑的曲线。js中使用quadraticCurveTo方法可以绘制二次贝塞尔曲线,函数需要传入两个控制点,第一个控制点为上一次鼠标到达的点,第二个控制点为上一次鼠标到达的点与此时鼠标所在点的中点。因此需要额外使用变量保存上一次到达的点。另外,将绘制时的线帽和线条连接设置为圆形可以使曲线看上去更平滑。

改进后的采用鼠标绘制部分的代码如下,使用手指也是类似的。

var ctx = cv.getContext('2d');
ctx.fillStyle="#FFFFFF";
ctx.fillRect(0,0,cv.width,cv.height);
var isAllowDraw = false;
var x_c, y_c;

//鼠标按下事件
onmousedown = function (e) {
    isAllowDraw = true;
    //获得鼠标按下的点相对canvas的坐标
    var {x, y} = get_pencil(cv, e.clientX, e.clientY);
    x_c = x;
    y_c = y;
    //绘制起点
    ctx.beginPath();
    ctx.lineCap = "round";                 //圆形线帽
    ctx.lineJoin = "round";                //圆形线条连接
    ctx.moveTo(x, y);
}

//鼠标移动事件
onmousemove = function (e) {
    if (isAllowDraw){
        var {x, y} = get_pencil(cv, e.clientX, e.clientY);
        ctx.quadraticCurveTo(x_c, y_c, (x+x_c)/2, (y+y_c)/2);
        ctx.strokeStyle = 'black';
        ctx.lineWidth = 15;
        ctx.stroke();
        x_c = x;
        y_c = y;
    }
}

//鼠标抬起事件
onmouseup = function () {
    //鼠标抬起停止作画
    isAllowDraw = false;
}

当点击提交按钮后,需要对canvas区域进行截图并进行一系列处理后将数据返回给后端,处理内容包括截取用户书写区域的最小正方形内的内容,将截取内容等比压缩至28*28(适应识别模型的输入,其要求为28*28的矩阵),将图片进行base64编码后通过websocket将数据发送至后端。

先将用户绘画板中的图像转移到一块新的两倍长宽大小的绘画板中并居中放置,这样可以确保之后可以取到最小正方形区域,通过getImageData函数获取canvas内图像的imageData对象,遍历所有像素,通过找到所有非白色像素来确定书写内容,接着找到这些像素中x,y坐标的最小和最大值确定书写区域框,通过判断宽和高的大小将这一区域截取为最小正方形。最终,将所需图像放置到一块新的28*28的画布中实现图像的等比压缩,这里放置时还要注意居中放置四周略微留空来模拟MNIST训练数据,使识别达到最佳效果。

最后,使用toDataURL方法将canvas图像转换为base64编码并建立WebSocket连接发送给后端服务器,当收到后端回复时,更新相应页面显示结果。

当点击清空按钮后需要清空画布重新书写,这里我通过重置画布的高度实现画布的重新创建以达成清空重置的目的,创建后依然使用白色填充。

function clear_canvas() {
    cv.height = cv.height;                             //通过重置高度重新生成画布
    ctx.fillStyle="#FFFFFF";
    ctx.fillRect(0,0,cv.width,cv.height);
}

function image_process() {
    var c = document.getElementById('cv');             //用户绘画板

    //将图像转移到一个两倍长宽的画布以保证之后能取到最小正方形区域
    var c2 = document.createElement('canvas');
    var ctx2 = c2.getContext("2d");
    c2.width = c.width * 2;
    c2.height = c.height * 2;
    ctx2.fillStyle="#FFFFFF";
    ctx2.fillRect(0,0,c2.width,c2.height);
    ctx2.drawImage(c, parseInt(c2.width/4), parseInt(c2.height/4));

    var img = ctx2.getImageData(0,0,c2.width,c2.height);  //获得imageData对象

    //处理img,仅截取书写部分
    var pixs = img.data;
    var valid_x = [], valid_y = [];
    for (var i = 0; i < pixs.length; i += 4){
        var r = pixs[i],
            g = pixs[i + 1],
            b = pixs[i + 2],
            a = pixs[i + 3];
        if (r != 255 || g != 255 || b != 255){       //书写内容区域条件
            var x = ((i / 4) % c2.width);            //一个像素点由rgba4部分组成  当前像素点相对于画布的X轴坐标
            var y = Math.floor((i / 4) / c2.height); //当前像素点相对于画布的Y轴坐标
            valid_x.push(x)
            valid_y.push(y)
        }
    }
    var minX = Math.min(...valid_x),  //取出X轴最小值
        maxX = Math.max(...valid_x),  //取出X轴最大值
        valid_W = maxX - minX,        //计算实际宽
        minY = Math.min(...valid_y),  //取出Y轴最小值
        maxY = Math.max(...valid_y),  //取出Y轴最大值
        valid_H = maxY - minY;        //计算实际高
    
    //取最小正方形区域
    if (valid_W < valid_H){
        minX -= parseInt((valid_H - valid_W) / 2);
        maxX += parseInt((valid_H - valid_W) / 2);
        valid_W = maxX - minX;
    }
    else if (valid_W > valid_H){
        minY -= parseInt((valid_W - valid_H) / 2);
        maxY += parseInt((valid_W - valid_H) / 2);
        valid_H = maxY - minY;
    }

    var c3 = document.createElement('canvas');                  //最终输出画板
    var ctx3 = c3.getContext("2d");
    c3.width = 28;
    c3.height = 28;
    ctx3.fillStyle="#FFFFFF";
    ctx3.fillRect(0,0,c3.width,c3.height);
    ctx3.drawImage(c2,minX,minY,valid_W,valid_H,5,5,18,18);    //裁剪图像并缩放至28*28,居中处理四周留空
    
    var b64Image = c3.toDataURL("image/png", 1.0);            //canvas图像base64编码

    var ws = new WebSocket("ws://"+window.location.host+"/ws");
    ws.onopen = function () {
        ws.send(b64Image.substring(22));
    }
    ws.onmessage = function (e) {
        document.getElementById('result').innerHTML = e.data;
    }
}

第四部分:后端服务器

和之前一样创建基础的Tornado服务器框架,这里不再过多赘述。由于前端数据使用WebSocket发送,因此后端WebSocket在接收到base64数据后还需进行一部分处理,再识别返回结果。

首先将base64数据解码得到二进制图片,使用PIL库读入图片。将rgba格式图像转换为灰度图像,再进行黑白反转以适应MNIST数据集图像的格式(灰度值与正常的图像是相反的,0为白,255为黑)。将转换好的图像读入numpy矩阵中,拉成一维,并将数据格式由uint8转换为float64。至此,已得到满足条件的输入数据,可以进行识别了。

调用写好的识别接口进行识别,将识别结果通过WebSocket发送至前端即可,最后关闭WebSocket连接。

from tornado import ioloop
from tornado.web import Application,RequestHandler
from tornado.httpserver import HTTPServer
from tornado.websocket import WebSocketHandler
from base64 import b64decode
from PIL import Image,ImageOps
from io import BytesIO
import numpy as np
from ocr_number import query

class index_handler(RequestHandler):
    def get(self):
        self.render('index.html')

class ws_Handler(WebSocketHandler):
    def on_message(self, message):
        b64_img_str = message                      #获取base64编码的图片
        b_img = b64decode(b64_img_str)             #base64解码得到二进制图片
        img = Image.open(BytesIO(b_img))           #PIL读入二进制图像数据
        img = img.convert('L')                     #转换为灰度图像
        img = ImageOps.invert(img)                 #黑白反转,适应MNIST数据集
        img.save('preview.png')                    #保存预览图片

        matrix = np.array(img)                     #转化为numpy矩阵
        matrix = matrix.flatten()                  #拉伸为一维
        matrix = matrix.astype(np.float)           #转换数据类型为float64
        
        ocr_res = query(matrix)                    #OCR识别

        print(ocr_res,end='\r')                    #打印识别结果
        self.write_message(str(ocr_res))
        self.close()

app = Application([(r"/", index_handler),
                    (r"/ws", ws_Handler)],
                static_path="statics",
                template_path="templates")

http_server = HTTPServer(app)
http_server.listen(8888)
ioloop.IOLoop.current().start()
识别结果

A WindRunner. VoyagingOne

留言

您的电子邮箱地址不会被公开。 必填项已用*标注