Skip to content

TensorFlow读取.pb模型文件并预测图片

引言

TensorFlow通过Python接口训练好模型,然后使用C++ API加载.pb模型文件,再对图片进行预测判断,这需要解决几个问题:

  • 如何将训练好的模型保存为.pb格式文件?
  • 如何在C++ API中加载训练好的模型文件?
  • 如何在C++ API中读取图片并变成需要的格式?
  • 如果在C++ API中根据模型对图片进行预测判断?

以上几个问题通过这个链接基本找到了答案: https://www.cnblogs.com/buyizhiyou/p/10412967.html

Python接口保存graph为.pb格式

import tensorflow as tf
import numpy as np

with tf.compat.v1.Session() as sess:

    a=tf.compat.v1.placeholder(tf.float32,shape=None, name='a')
    b=tf.compat.v1.placeholder(tf.float32,shape=None, name='b')
    c = tf.multiply(a, b, name='c')

    sess.run(tf.compat.v1.global_variables_initializer())

    tf.compat.v1.train.write_graph(sess.graph_def, 'models/', 'train.pb', as_text=False)

    res = sess.run(c, feed_dict={'a:0': 2.0, 'b:0': 3.0})
    print("res = ", res)

C++ API读取.pb格式graph

#include <tensorflow/core/public/session.h>
#include <tensorflow/core/platform/env.h>

using namespace tensorflow;

int main(int argc, char* argv[]) {
  // Initialize a tensorflow session
  Session* session;
  Status status = NewSession(SessionOptions(), &session);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }


  // Read in the protobuf graph we exported
  // (The path seems to be relative to the cwd. Keep this in mind
  // when using `bazel run` since the cwd isn't where you call
  // `bazel run` but from inside a temp folder.)
  GraphDef graph_def;
  status = ReadBinaryProto(Env::Default(), "./models/train.pb", &graph_def);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }

  // Add the graph to the session
  status = session->Create(graph_def);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }

  // Setup inputs and outputs:

  // Our graph doesn't require any inputs, since it specifies default values,
  // but we'll change an input to demonstrate.
  Tensor a(DT_FLOAT, TensorShape());
  a.scalar<float>()() = 3.0;

  Tensor b(DT_FLOAT, TensorShape());
  b.scalar<float>()() = 2.0;

  std::vector<std::pair<string, tensorflow::Tensor>> inputs = {
    { "a", a },
    { "b", b },
  };

  // The session will initialize the outputs
  std::vector<tensorflow::Tensor> outputs;

  // Run the session, evaluating our "c" operation from the graph
  status = session->Run(inputs, {"c"}, {}, &outputs);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }

  // Grab the first output (we only evaluated one graph node: "c")
  // and convert the node to a scalar representation.
  auto output_c = outputs[0].scalar<float>();

  // (There are similar methods for vectors and matrices here:
  // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/public/tensor.h)

  // Print the results
  std::cout << outputs[0].DebugString() << "\n"; // Tensor<type: float shape: [] values: 30>
  std::cout << output_c() << "\n"; // 30

  // Free any resources used by the session
  session->Close();
  return 0;
}

读取图片变成tensor

  • 使用OpenCV库:
#include "opencv2/core/core.hpp"

// V1
tensorflow::Tensor readTensor(string filename){
    Mat src = imread(filename,0);
    Mat dst;
    resize(src,dst,Size(240,40));//resize
    Mat dst_transpose=dst.t();//transpose

    tensorflow::Tensor input_tensor(DT_FLOAT,TensorShape({1,240,40,1}));
    auto tmap = input_tensor.tensor<float,4>();

    for(int i=0;i<240;i++){//Mat复制到Tensor
        for(int j=0;j<40;j++){
            tmap(0,i,j,0)=dst_transpose.at<uchar>(i,j);
        }
    }

    return input_tensor;
}

// V2
tensorflow::Tensor readTensor(string filename){
    tensorflow::Tensor input_tensor(DT_FLOAT,TensorShape({1,240,40,1}));
    float *tensor_data_ptr = input_tensor.flat<float>().data();              

    Mat src = imread(filename,0);
    Mat dst;
    resize(src,dst,Size(240,40));//resize

    cv::Mat fake_mat(dst.rows, dst.cols, CV_32FC(src.channels()), tensor_data_ptr); 
    dst.convertTo(fake_mat, CV_32FC3);

    return input_tensor;
}
  • 使用TensorFlow API:
//从文件名中读取数据
Status ReadTensorFromImageFile(string file_name, const int input_height,
                               const int input_width,
                               vector<Tensor>* out_tensors) {
    auto root = Scope::NewRootScope();
    using namespace ops;

    auto file_reader = ops::ReadFile(root.WithOpName("file_reader"),file_name);
    const int wanted_channels = 1;
    Output image_reader;
    std::size_t found = file_name.find(".png");
    //判断文件格式

    if (found!=std::string::npos) {
       image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,DecodePng::Channels(wanted_channels));
    } 
    else {
       image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,DecodeJpeg::Channels(wanted_channels));
    }

    // 下面几步是读取图片并处理
    auto float_caster =Cast(root.WithOpName("float_caster"), image_reader, DT_FLOAT);
    auto dims_expander = ExpandDims(root, float_caster, 0);
    auto resized = ResizeBilinear(root, dims_expander,Const(root.WithOpName("resize"), {input_height, input_width}));
    // Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),{input_std});
    Transpose(root.WithOpName("transpose"),resized,{0,2,1,3});

    GraphDef graph;
    root.ToGraphDef(&graph);

    unique_ptr<Session> session(NewSession(SessionOptions()));
    session->Create(graph);
    session->Run({}, {"transpose"}, {}, out_tensors);//Run,获取图片数据保存到Tensor中

    return Status::OK();
}

//... 
int input_height = 40;
int input_width = 240;

vector<Tensor> inputs;

string image_path("test.jpg");
if (!ReadTensorFromImageFile(image_path, input_height, input_width,&inputs).ok()) {
    cout<<"Read image file failed"<<endl;
    return -1;
}

一个完整例子

#include <iostream>
#include <map>

#include "tensorflow/cc/ops/image_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"

using namespace std ;
using namespace tensorflow;
using tensorflow::Tensor;
using tensorflow::Status;
using tensorflow::string;
using tensorflow::int32;


//从文件名中读取数据
Status ReadTensorFromImageFile(string file_name, const int input_height,
                               const int input_width,
                               vector<Tensor>* out_tensors) {
    auto root = Scope::NewRootScope();
    using namespace ops;

    auto file_reader = ops::ReadFile(root.WithOpName("file_reader"),file_name);
    const int wanted_channels = 1;
    Output image_reader;
    std::size_t found = file_name.find(".png");
    //判断文件格式
    if (found!=std::string::npos) {
        image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,DecodePng::Channels(wanted_channels));
    }
    else {
        image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,DecodeJpeg::Channels(wanted_channels));
    }
    // 下面几步是读取图片并处理
    auto float_caster =Cast(root.WithOpName("float_caster"), image_reader, DT_FLOAT);
    auto dims_expander = ExpandDims(root, float_caster, 0);
    auto resized = ResizeBilinear(root, dims_expander,Const(root.WithOpName("resize"), {input_height, input_width}));
    // Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),{input_std});
    Transpose(root.WithOpName("transpose"),resized,{0,2,1,3});

    GraphDef graph;
    root.ToGraphDef(&graph);

    unique_ptr<Session> session(NewSession(SessionOptions()));
    session->Create(graph);
    session->Run({}, {"transpose"}, {}, out_tensors);//Run,获取图片数据保存到Tensor中

    return Status::OK();
}

int main(int argc, char* argv[]) {

    string graph_path = "aov_crnn.pb";
    GraphDef graph_def;
    //读取模型文件
    if (!ReadBinaryProto(Env::Default(), graph_path, &graph_def).ok()) {
        cout << "Read model .pb failed"<<endl;
        return -1;
    }

    //新建session
    unique_ptr<Session> session;
    SessionOptions sess_opt;
    sess_opt.config.mutable_gpu_options()->set_allow_growth(true);
    (&session)->reset(NewSession(sess_opt));
    if (!session->Create(graph_def).ok()) {
        cout<<"Create graph failed"<<endl;
        return -1;
    }

    //读取图像到inputs中
    int input_height = 40;
    int input_width = 240;
    vector<Tensor> inputs;
    // string image_path(argv[1]);
    string image_path("test.jpg");
    if (!ReadTensorFromImageFile(image_path, input_height, input_width,&inputs).ok()) {
        cout<<"Read image file failed"<<endl;
        return -1;
    }

    vector<Tensor> outputs;
    string input = "inputs_sq";
    string output = "results_sq";//graph中的输入节点和输出节点,需要预先知道

    pair<string,Tensor>img(input,inputs[0]);
    Status status = session->Run({img},{output}, {}, &outputs);//Run,得到运行结果,存到outputs中
    if (!status.ok()) {
        cout<<"Running model failed"<<endl;
        cout<<status.ToString()<<endl;
        return -1;
    }


    //得到模型运行结果
    Tensor t = outputs[0];
    auto tmap = t.tensor<int64, 2>();
    int output_dim = t.shape().dim_size(1);


    return 0;
}

资源