重庆分公司,新征程启航
为企业提供网站建设、域名注册、服务器等服务
这篇文章主要讲解了tensorflow pb to tflite精度下降的问题,内容清晰明了,对此有兴趣的小伙伴可以学习一下,相信大家阅读完之后会有帮助。
成都做网站、网站制作、成都外贸网站建设的关注点不是能为您做些什么网站,而是怎么做网站,有没有做好网站,给创新互联一个展示的机会来证明自己,这并不会花费您太多时间,或许会给您带来新的灵感和惊喜。面向用户友好,注重用户体验,一切以用户为中心。之前希望在手机端使用深度模型做OCR,于是尝试在手机端部署tensorflow模型,用于图像分类。
思路主要是想使用tflite部署到安卓端,但是在使用tflite的时候发现模型的精度大幅度下降,已经不能支持业务需求了,最后就把OCR模型调用写在服务端了,但是精度下降的原因目前也没有找到,现在这里记录一下。
工作思路:
1.训练图像分类模型;2.模型固化成pb;3.由pb转成tflite文件;
但是使用python 的tf interpreter 调用tflite文件就已经出现精度下降的问题,android端部署也是一样。
1.网络结构
from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf slim = tf.contrib.slim def ttnet(images, num_classes=10, is_training=False, dropout_keep_prob=0.5, prediction_fn=slim.softmax, scope='TtNet'): end_points = {} with tf.variable_scope(scope, 'TtNet', [images, num_classes]): net = slim.conv2d(images, 32, [3, 3], scope='conv1') # net = slim.conv2d(images, 64, [3, 3], scope='conv1_2') net = slim.max_pool2d(net, [2, 2], 2, scope='pool1') net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='bn1') # net = slim.conv2d(net, 128, [3, 3], scope='conv2_1') net = slim.conv2d(net, 64, [3, 3], scope='conv2') net = slim.max_pool2d(net, [2, 2], 2, scope='pool2') net = slim.conv2d(net, 128, [3, 3], scope='conv3') net = slim.max_pool2d(net, [2, 2], 2, scope='pool3') net = slim.conv2d(net, 256, [3, 3], scope='conv4') net = slim.max_pool2d(net, [2, 2], 2, scope='pool4') net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='bn2') # net = slim.conv2d(net, 512, [3, 3], scope='conv5') # net = slim.max_pool2d(net, [2, 2], 2, scope='pool5') net = slim.flatten(net) end_points['Flatten'] = net # net = slim.fully_connected(net, 1024, scope='fc3') net = slim.dropout(net, dropout_keep_prob, is_training=is_training, scope='dropout3') logits = slim.fully_connected(net, num_classes, activation_fn=None, scope='fc4') end_points['Logits'] = logits end_points['Predictions'] = prediction_fn(logits, scope='Predictions') return logits, end_points ttnet.default_image_size = 28 def ttnet_arg_scope(weight_decay=0.0): with slim.arg_scope( [slim.conv2d, slim.fully_connected], weights_regularizer=slim.l2_regularizer(weight_decay), weights_initializer=tf.truncated_normal_initializer(stddev=0.1), activation_fn=tf.nn.relu) as sc: return sc