Python+Android进行TensorFlow开发

软件发布|下载排行|最新软件

当前位置:首页IT学院IT技术

Python+Android进行TensorFlow开发

码农突围   2020-03-16 我要评论
Tensorflow是Google开源的一套机器学习框架,支持GPU、CPU、Android等多种计算平台。本文将介绍在Tensorflow在Android上的使用。 Android使用Tensorflow框架需要引入两个文件libtensorflow_inference.so、libandroid_tensorflow_inference_java.jar。这两个文件可以使用官方预编译的文件。如果预编译的so不满足要求(比如不支持训练模型中的某些操作符运算),也可以自己通过bazel编译生成这两个文件。 将libandroid_tensorflow_inference_java.jar放在app下的libs目录下,so文件命名为libtensorflow_jni.so放在src/main/jniLibs目录下对应的ABI文件夹下。目录结构如下: ![在这里插入图片描述](https://img-blog.csdnimg.cn/20200316213610942.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2hlamp1bmxpbg==,size_16,color_FFFFFF,t_70) Android目录结构 同时在app的build.gradle中的dependencies模块下添加如下配置: ``` dependencies { ... compile files('libs/libandroid_tensorflow_inference_java.jar') ... } ``` 使用tensorflow框架进行机器学习分为四个步骤: - 构造神经网络 - 训练神经网络模型 - 将训练好的模型输出为pb文件 - ndroid上加载pb模型进行计算 前三步是模型的构造,我们通过python实现,下面给出了一个二分类的简单模型的构造过程,首先是训练过程: ``` # -*-coding:utf-8 -*- from __future__ import print_function import os import tensorflow as tf from numpy.random import RandomState os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' """ 训练模型 """ def train(): # 定义训练数据集batch大小为8 batch_size = 8 # 定义神经网络参数,参数体现出神经网络结构,一个输入层,一个输出层,一个隐藏层 w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1), name="w1_val") w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1), name="w2_val") # 定义输入输出格式 x = tf.placeholder(tf.float32, shape=(None, 2), name='x_input') y_ = tf.placeholder(tf.float32, shape=(None, 1)) # 定义神经网络前向传播过程 a = tf.matmul(x, w1) y = tf.matmul(a, w2, name="cal_node") # 定义交叉熵和反向传播算法 cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) train_step = tf.train.AdadeltaOptimizer(0.001).minimize(cross_entropy) # 生成随机训练集 rdm = RandomState(1) dataset_size = 128 # 定义映射关系 X = rdm.rand(dataset_size, 2) Y = [[int(x1 + x2 < 1)] for (x1, x2) in X] with tf.Session() as sess: # 初始化所有参数 init_op = tf.global_variables_initializer() sess.run(init_op) # print sess.run(w1) # print sess.run(w2) STEPS = 500 for i in range(STEPS): start = (i * batch_size) % dataset_size end = min(start + batch_size, dataset_size) # 训练神经网络,更新神经网络参数 sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]}) if i % 100 == 0: total_cross_entropy = sess.run(cross_entropy, feed_dict={x: X, y_: Y}) print("After %d training step(s), cross entropy on all data is %g" % (i, total_cross_entropy)) print(sess.run(w1)) print(sess.run(w2)) # 保存check point saver = tf.train.Saver(tf.trainable_variables()) saver.save(sess, './model/checpt') ``` 上面的代码首先定义神经网络,初始化训练数据,进行500次训练过程,并将训练结果checkpoints保存到model文件夹下,checkpoints包含了训练模型得到的参数信息,共生成四个相关的文件,如下图: ![checkpoint相关文件](https://img-blog.csdnimg.cn/20200316214101163.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2hlamp1bmxpbg==,size_16,color_FFFFFF,t_70) 由于checkpoint文件众多,为了方便使用,我们通过下面的代码将它们生成一个pb文件,在android上只需要这个pb文件即可使用这个训练好的模型: ``` """ 存储pb模型 """ def dump_graph_to_pb(pb_path): with tf.Session() as sess: check_point = tf.train.get_checkpoint_state("./model/") if check_point: saver = tf.train.import_meta_graph(check_point.model_checkpoint_path + '.meta') saver.restore(sess, check_point.model_checkpoint_path) else: raise ValueError("Model load failed from {}".format(check_point.model_checkpoint_path)) graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), "cal_node".split(",")) with tf.gfile.GFile(pb_path, "wb") as f: f.write(graph_def.SerializeToString()) ``` 拿到生成的pb模型,我们可以在android上使用了。将pb文件在这main/assets下: ![在这里插入图片描述](https://img-blog.csdnimg.cn/20200316214148856.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2hlamp1bmxpbg==,size_16,color_FFFFFF,t_70) 接下来就可以载入pb,进行计算了: ``` public class MainActivity extends AppCompatActivity { private Graph graph_; private Session session_; private AssetManager assetManager; private static ExecutorService executorService; private static Handler handler; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); executorService = Executors.newFixedThreadPool(5); // 初始化tensorflow initTensorFlow("outmodel.pb"); // 使用tensorflow进行计算 runTensorFlow(); } ... } ``` 通过如下方式载入pb模型,初始化tensorflow: ``` private boolean initTensorFlow(String modelFile) { assetManager = getAssets(); // 新建Graph graph_ = new Graph(); InputStream is = null; try { // 读取Assets pb文件 is = assetManager.open(modelFile); } catch (IOException e) { e.printStackTrace(); return false; } try { // 加载pb到Graph TensorUtil.loadGraph(is, graph_); is.close(); } catch (IOException e) { e.printStackTrace(); return false; } // 初始化session session_ = new Session(graph_); if (session_ == null) { return false; } return true; } ``` 然后就可以使用tensorflow API进行运算了: ``` private void runTensorFlow() { executorService.execute(generatePredictRunnable(handler)); } private Runnable generatePredictRunnable(Handler handler) { return new Runnable() { @Override public void run() { float[][] input = new float[1][2]; input[0][0] = 1; input[0][1] = 2; // 定义输入tensor Tensor inputTensor = Tensor.create(input); // 指定输入,输出节点,运行并得到结果 Tensor resultTensor = session_.runner() .feed("x_input", inputTensor) .fetch("cal_node") .run() .get(0); float[][] dst = new float[1][1]; resultTensor.copyTo(dst); // 处理结果 ArrayList

Copyright 2022 版权所有 软件发布 访问手机版

声明:所有软件和文章来自软件开发商或者作者 如有异议 请与本站联系 联系我们