博客
关于我
识别图中模糊手写数字
阅读量:178 次
发布时间:2019-02-28

本文共 3746 字,大约阅读时间需要 12 分钟。

以下是优化后的代码和说明:


一、代码优化说明

  • 代码结构优化

    • 移除冗余代码:清理不必要的div标签,使HTML结构更加简洁。
    • 调整代码排版:使用更规范的代码格式,提升可读性。
  • 模型优化

    • 批量处理:在训练过程中使用批量处理,提高训练效率。
    • 卷积神经网络(CNN):可以考虑将简单的全连接模型替换为CNN,提升分类准确率。
  • 训练过程改进

    • 动态学习率:引入动态学习率调整策略,如学习率衰减,提高训练效果。
    • 增加训练 epochs:可以增加训练的 epoch 数量或批量大小,提升模型性能。
  • 错误处理和日志记录

    • 异常处理:在训练过程中添加异常捕获,确保程序稳定运行。
    • 日志工具:使用日志记录工具记录训练过程中的关键指标,便于分析和调试。
  • 测试和评估

    • 多测试集验证:使用多个测试集进行交叉验证,提高模型的泛化能力。
    • 模型稳定性测试:在测试阶段添加更多的验证步骤,确保模型的稳定性和可靠性。
  • 代码注释和可读性

    • 详细注释:在代码中添加更详细的注释,帮助读者快速理解代码功能。
    • 清晰命名:使用清晰的变量命名和结构,使代码更易于维护。
  • 部署和高效运行

    • 云端部署:将模型部署到云端,使用高效的计算资源,提高训练和测试速度。
    • 并行计算:利用并行计算和分布式训练技术,进一步优化性能。
  • 可扩展性和模块化

    • 模块化设计:将模型结构拆分成多个模块,便于扩展和维护。
    • 可配置参数:使用可配置的参数,使模型能够适应不同的问题和数据集。

  • 二、优化后的代码

    1. 导入必要的库和数据集

    一 导入必要的库和数据集
    import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnistmnist = input_data.read_data_sets("MNIST_data/", one_hot=True)        
    二 创建占位符
    x = tf.placeholder(tf.float32, [None, 784])  # 输入图片占位符,784维y = tf.placeholder(tf.float32, [None, 10])  # 标签占位符,10个类别        
    三 定义模型参数
    # 权重矩阵W = tf.Variable(tf.random_normal([784, 10]))# 偏置向量b = tf.Variable(tf.zeros([10]))        
    四 构建模型
    # 前向传播pred = tf.nn.softmax(tf.matmul(x, W) + b)        
    五 定义损失函数和优化器
    # 交叉熵损失cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))# 学习率learning_rate = 0.01# 优化器optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)        
    六 训练模型
    training_epochs = 25batch_size = 100display_step = 1saver = tf.train.Saver()model_path = "log/521model.ckpt"with tf.Session() as sess:    sess.run(tf.global_variables_initializer())        for epoch in range(training_epochs):        avg_cost = 0.        total_batch = int(mnist.train.num_examples / batch_size)                for i in range(total_batch):            batch_xs, batch_ys = mnist.train.next_batch(batch_size)            _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, y: batch_ys})            avg_cost += c / total_batch                if (epoch + 1) % display_step == 0:            print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost))        print(" Finished!")        # 测试模型    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))    print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))        # 保存模型    saver.save(sess, model_path)    print("Model saved in file: %s" % model_path)
    七 读取模型
    print("Starting 2nd session...")with tf.Session() as sess:    # 初始化变量    sess.run(tf.global_variables_initializer())    # 加载已保存的模型    saver.restore(sess, model_path)        # 测试模型    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))    print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))        # 显示图片    for i in range(2):        batch_xs, batch_ys = mnist.train.next_batch(2)        output_val, pred_val = sess.run([tf.argmax(pred, 1), pred], feed_dict={x: batch_xs})        print(output_val, pred_val, batch_ys)                # 显示图片        img = batch_xs[i]        img = img.reshape(-1, 28)        pylab.imshow(img)        pylab.show()

    三、说明

  • 输出预测结果output_val 是模型预测的数字结果,pred_val 是对应的概率值。
  • 真实标签batch_ys 是实际的标签值,使用 onehot 编码表示。
  • 图片显示:使用 pylab 显示原始图片和模型预测的数字结果,直观验证模型性能。

  • 四、总结

    通过上述优化,代码结构更加清晰,注释更详细,便于理解和维护。同时,模型的训练过程更加高效,测试结果更直观。

    转载地址:http://lhej.baihongyu.com/

    你可能感兴趣的文章
    nginx 配置~~~本身就是一个静态资源的服务器
    查看>>
    Nginx下配置codeigniter框架方法
    查看>>
    nginx添加模块与https支持
    查看>>
    Nginx的Rewrite正则表达式,匹配非某单词
    查看>>
    Nginx的使用总结(一)
    查看>>
    Nginx的是什么?干什么用的?
    查看>>
    Nginx访问控制_登陆权限的控制(http_auth_basic_module)
    查看>>
    nginx负载均衡的五种算法
    查看>>
    Nginx配置ssl实现https
    查看>>
    Nginx配置TCP代理指南
    查看>>
    Nginx配置代理解决本地html进行ajax请求接口跨域问题
    查看>>
    Nginx配置参数中文说明
    查看>>
    Nginx配置好ssl,但$_SERVER[‘HTTPS‘]取不到值
    查看>>
    Nginx配置实例-负载均衡实例:平均访问多台服务器
    查看>>
    NIFI大数据进阶_连接与关系_设置数据流负载均衡_设置背压_设置展现弯曲_介绍以及实际操作---大数据之Nifi工作笔记0027
    查看>>
    Nio ByteBuffer组件读写指针切换原理与常用方法
    查看>>
    NIO Selector实现原理
    查看>>
    nio 中channel和buffer的基本使用
    查看>>
    NISP一级,NISP二级报考说明,零基础入门到精通,收藏这篇就够了
    查看>>
    NI笔试——大数加法
    查看>>