博客
关于我
识别图中模糊手写数字
阅读量:178 次
发布时间:2019-02-28

本文共 3746 字,大约阅读时间需要 12 分钟。

以下是优化后的代码和说明:


一、代码优化说明

  • 代码结构优化

    • 移除冗余代码:清理不必要的div标签,使HTML结构更加简洁。
    • 调整代码排版:使用更规范的代码格式,提升可读性。
  • 模型优化

    • 批量处理:在训练过程中使用批量处理,提高训练效率。
    • 卷积神经网络(CNN):可以考虑将简单的全连接模型替换为CNN,提升分类准确率。
  • 训练过程改进

    • 动态学习率:引入动态学习率调整策略,如学习率衰减,提高训练效果。
    • 增加训练 epochs:可以增加训练的 epoch 数量或批量大小,提升模型性能。
  • 错误处理和日志记录

    • 异常处理:在训练过程中添加异常捕获,确保程序稳定运行。
    • 日志工具:使用日志记录工具记录训练过程中的关键指标,便于分析和调试。
  • 测试和评估

    • 多测试集验证:使用多个测试集进行交叉验证,提高模型的泛化能力。
    • 模型稳定性测试:在测试阶段添加更多的验证步骤,确保模型的稳定性和可靠性。
  • 代码注释和可读性

    • 详细注释:在代码中添加更详细的注释,帮助读者快速理解代码功能。
    • 清晰命名:使用清晰的变量命名和结构,使代码更易于维护。
  • 部署和高效运行

    • 云端部署:将模型部署到云端,使用高效的计算资源,提高训练和测试速度。
    • 并行计算:利用并行计算和分布式训练技术,进一步优化性能。
  • 可扩展性和模块化

    • 模块化设计:将模型结构拆分成多个模块,便于扩展和维护。
    • 可配置参数:使用可配置的参数,使模型能够适应不同的问题和数据集。

  • 二、优化后的代码

    1. 导入必要的库和数据集

    一 导入必要的库和数据集
    import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnistmnist = input_data.read_data_sets("MNIST_data/", one_hot=True)        
    二 创建占位符
    x = tf.placeholder(tf.float32, [None, 784])  # 输入图片占位符,784维y = tf.placeholder(tf.float32, [None, 10])  # 标签占位符,10个类别        
    三 定义模型参数
    # 权重矩阵W = tf.Variable(tf.random_normal([784, 10]))# 偏置向量b = tf.Variable(tf.zeros([10]))        
    四 构建模型
    # 前向传播pred = tf.nn.softmax(tf.matmul(x, W) + b)        
    五 定义损失函数和优化器
    # 交叉熵损失cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))# 学习率learning_rate = 0.01# 优化器optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)        
    六 训练模型
    training_epochs = 25batch_size = 100display_step = 1saver = tf.train.Saver()model_path = "log/521model.ckpt"with tf.Session() as sess:    sess.run(tf.global_variables_initializer())        for epoch in range(training_epochs):        avg_cost = 0.        total_batch = int(mnist.train.num_examples / batch_size)                for i in range(total_batch):            batch_xs, batch_ys = mnist.train.next_batch(batch_size)            _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, y: batch_ys})            avg_cost += c / total_batch                if (epoch + 1) % display_step == 0:            print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost))        print(" Finished!")        # 测试模型    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))    print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))        # 保存模型    saver.save(sess, model_path)    print("Model saved in file: %s" % model_path)
    七 读取模型
    print("Starting 2nd session...")with tf.Session() as sess:    # 初始化变量    sess.run(tf.global_variables_initializer())    # 加载已保存的模型    saver.restore(sess, model_path)        # 测试模型    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))    print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))        # 显示图片    for i in range(2):        batch_xs, batch_ys = mnist.train.next_batch(2)        output_val, pred_val = sess.run([tf.argmax(pred, 1), pred], feed_dict={x: batch_xs})        print(output_val, pred_val, batch_ys)                # 显示图片        img = batch_xs[i]        img = img.reshape(-1, 28)        pylab.imshow(img)        pylab.show()

    三、说明

  • 输出预测结果output_val 是模型预测的数字结果,pred_val 是对应的概率值。
  • 真实标签batch_ys 是实际的标签值,使用 onehot 编码表示。
  • 图片显示:使用 pylab 显示原始图片和模型预测的数字结果,直观验证模型性能。

  • 四、总结

    通过上述优化,代码结构更加清晰,注释更详细,便于理解和维护。同时,模型的训练过程更加高效,测试结果更直观。

    转载地址:http://lhej.baihongyu.com/

    你可能感兴趣的文章
    Openlayers实战:绘制带箭头的线
    查看>>
    Openlayers实战:输入WKT数据,输出GML、Polyline、GeoJSON格式数据
    查看>>
    Openlayers实战:非4326,3857的投影
    查看>>
    Openlayers高级交互(10/20):绘制矩形,截取对应部分的地图并保存
    查看>>
    Openlayers高级交互(11/20):显示带箭头的线段轨迹,箭头居中
    查看>>
    Openlayers高级交互(14/20):汽车移动轨迹动画(开始、暂停、结束)
    查看>>
    Openlayers高级交互(15/20):显示海量多边形,10ms加载完成
    查看>>
    Openlayers高级交互(16/20):两个多边形的交集、差集、并集处理
    查看>>
    Openlayers高级交互(17/20):通过坐标显示多边形,计算出最大幅宽
    查看>>
    Openlayers高级交互(19/20): 地图上点击某处,列表中显示对应位置
    查看>>
    Openlayers高级交互(2/20):清除所有图层的有效方法
    查看>>
    Openlayers高级交互(20/20):超级数据聚合,页面不再混乱
    查看>>
    Openlayers高级交互(3/20):动态添加 layer 到 layerGroup,并动态删除
    查看>>
    Openlayers高级交互(4/20):手绘多边形,导出KML文件,可以自定义name和style
    查看>>
    Openlayers高级交互(5/20):右键点击,获取该点下多个图层的feature信息
    查看>>
    Openlayers高级交互(6/20):绘制某点,判断它是否在一个电子围栏内
    查看>>
    Openlayers高级交互(7/20):点击某点弹出窗口,自动播放视频
    查看>>
    Openlayers高级交互(8/20):选取feature,平移feature
    查看>>
    Openlayers:DMS-DD坐标形式互相转换
    查看>>
    openlayers:圆孔相机根据卫星经度、纬度、高度、半径比例推算绘制地面的拍摄的区域
    查看>>