cv2调用tensorflow模型


问题概述

之前从tensorflow的C/C++接口中成功调用了saved_model,回味过程的时候看到opencv中的dnn,可以直接调用训练好的模型,但是这个不是saved_model,而是h5模型转换得到的另一种pb模型。

转换代码

另外会打印输入输出层的名字,使用tensorflow调用转换好的pb模型时会用到,但cv2的dnn就用不到了。

import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

def h5_to_pb(h5_save_path):
    model = tf.keras.models.load_model(h5_save_path, compile=False)
    model.summary()
    full_model = tf.function(lambda Input: model(Input))
    full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

    # Get frozen ConcreteFunction
    frozen_func = convert_variables_to_constants_v2(full_model)
    frozen_func.graph.as_graph_def()

    layers = [op.name for op in frozen_func.graph.get_operations()]
    print("-" * 50)
    print("Frozen model layers: ")
    for layer in layers:
        print(layer)

    print("-" * 50)
    print("Frozen model inputs: ")
    print(frozen_func.inputs)
    print("Frozen model outputs: ")
    print(frozen_func.outputs)

    tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir="./pb",
                      name="model.pb",
                      as_text=False)

h5_to_pb('./model.hdf5')

Python调用

import cv2
import time
import pickle
import os

with open('labels.dat','rb') as f:
    lb = pickle.load(f)
f.close()

region = [(0, 0, 16, 25), (14, 0, 31, 25), (30, 0, 46, 25), (44, 0, 60, 25)]

true_count,total = 0,100
captcha_image_files = os.listdir('./img')
model = cv2.dnn.readNetFromTensorflow('model.pb')

start_time = time.time()
for image_file in captcha_image_files:

    image = cv2.imread('./img/' + image_file)
    result = image_file.split('.')[0]

    image = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
    predictions = []
    for reg in region:
        letter_image = image[:, reg[0]:reg[2]]
        letter_image = cv2.resize(letter_image,(15,25))

        model.setInput(cv2.dnn.blobFromImage(image = letter_image, 
                                             scalefactor = 1.0,
                                             size=(15, 25),
                                             mean=(0,0,0),
                                             swapRB=True,
                                             crop=False))

        output = model.forward()
        letter = lb.inverse_transform(output)[0]
        predictions.append(letter)
    captcha_text = ''.join(predictions)
    if captcha_text == result:
        print("RESULT is: {}, CAPTCHA text is: {}, True".format(result, captcha_text))
        true_count += 1
    else:
        print("RESULT is: {}, CAPTCHA text is: {}, False".format(result, captcha_text))

end_time = time.time()
print("predict rate: ", true_count / total)
print('time: {:.6f}s'.format(end_time - start_time))

与调用h5、saved_model结果完全一致。

C++调用

需要预先安装(Windows)或编译(Linux)C++版的OpenCV。Windows去官网下载安装,然后将build\x64\vc15\bin目录配置到环境变量。Linux可参考以下文章:
ubuntu安装opencv的正确方法
注意:使用VS 2019时需要配置包含目录和库目录,并链接opencv_worldxxx.lib(xxx对应版本,如4.5.2就是452,具体看lib目录下的文件)

#include <iostream>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/imgproc/types_c.h>
#include <opencv2/dnn.hpp>
#include <io.h>

using namespace std;
using namespace cv;
using namespace dnn;

String modelWeights = "model.pb";
Net model = readNetFromTensorflow(modelWeights);
clock_t start_time, end_time;

char letter[36] = {
    '0','1','2','3','4','5','6','7','8','9','a','b',
    'c','d','e','f','g','h','i','j','k','l','m','n',
    'o','p','q','r','s','t','u','v','w','x','y','z'
};

int region[4][2] = { {0,16},{14,31},{30,46},{44,60} };

string cv_predict(const char* filepath) {
    Mat img = imread(filepath);
    Mat img_split, blob, out;
    string result = "";
    int max_location = 0;
    float max_value = 0.0;
    cvtColor(img, img, CV_BGR2GRAY);
    for (int r = 0; r < 4; r++) {
        max_location = 0;
        max_value = 0.0;
        img_split = img(Rect(region[r][0], 0, region[r][1] - region[r][0], 25));
        resize(img_split, img_split, Size(15, 25));
        blob = blobFromImage(img_split, 1.0 , Size(15, 25), Scalar(0,0,0), true, false);
        model.setInput(blob);
        out = model.forward();
        for (int i = 0; i < 36; i++) {
            if (out.at<float>(0, i) > max_value) {
                max_value = out.at<float>(0, i);
                max_location = i;
            }
        }
        result += letter[max_location];
    }
    return result;
}

int main(int nargv, const char* argvs[])
{
    model.setPreferableBackend(DNN_BACKEND_OPENCV);
    model.setPreferableTarget(DNN_TARGET_CPU);
    // 单张预测
    if (nargv > 1) {
        string result = "";
        result = cv_predict(argvs[1]);
        const char* test = result.c_str();
        printf("%s\n", test);
    }
    // 验证识别率
    else {
        start_time = clock();
        int true_count = 0;
        std::string inPath = "img\\*.png";
        intptr_t handle;
        struct _finddata_t fileinfo;
        handle = _findfirst(inPath.c_str(), &fileinfo);
        if (handle == -1)
            return -1;
        do
        {
            string filepath = "img/" + (string)fileinfo.name;
            string result = cv_predict(filepath.c_str());
            printf("RESULT is: %s, CAPTCHA text is: %s, ", result.c_str(), fileinfo.name);
            if (filepath.find(result) != filepath.npos) {
                true_count += 1;
                printf("True\n");
            }
            else
                printf("False\n");
        } while (!_findnext(handle, &fileinfo));

        _findclose(handle);
        end_time = clock();
        printf("predict rate: %f\n", (float)true_count / 100.0);
        printf("time: %fs\n", (float)(end_time - start_time) / CLOCKS_PER_SEC);
    }
}

资源整理

相关资源可在另一篇文章中得到。