问题概述
之前从tensorflow的C/C++接口中成功调用了saved_model,回味过程的时候看到opencv中的dnn,可以直接调用训练好的模型,但是这个不是saved_model,而是h5模型转换得到的另一种pb模型。
转换代码
另外会打印输入输出层的名字,使用tensorflow调用转换好的pb模型时会用到,但cv2的dnn就用不到了。
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
def h5_to_pb(h5_save_path):
model = tf.keras.models.load_model(h5_save_path, compile=False)
model.summary()
full_model = tf.function(lambda Input: model(Input))
full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
print(layer)
print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir="./pb",
name="model.pb",
as_text=False)
h5_to_pb('./model.hdf5')
Python调用
import cv2
import time
import pickle
import os
with open('labels.dat','rb') as f:
lb = pickle.load(f)
f.close()
region = [(0, 0, 16, 25), (14, 0, 31, 25), (30, 0, 46, 25), (44, 0, 60, 25)]
true_count,total = 0,100
captcha_image_files = os.listdir('./img')
model = cv2.dnn.readNetFromTensorflow('model.pb')
start_time = time.time()
for image_file in captcha_image_files:
image = cv2.imread('./img/' + image_file)
result = image_file.split('.')[0]
image = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
predictions = []
for reg in region:
letter_image = image[:, reg[0]:reg[2]]
letter_image = cv2.resize(letter_image,(15,25))
model.setInput(cv2.dnn.blobFromImage(image = letter_image,
scalefactor = 1.0,
size=(15, 25),
mean=(0,0,0),
swapRB=True,
crop=False))
output = model.forward()
letter = lb.inverse_transform(output)[0]
predictions.append(letter)
captcha_text = ''.join(predictions)
if captcha_text == result:
print("RESULT is: {}, CAPTCHA text is: {}, True".format(result, captcha_text))
true_count += 1
else:
print("RESULT is: {}, CAPTCHA text is: {}, False".format(result, captcha_text))
end_time = time.time()
print("predict rate: ", true_count / total)
print('time: {:.6f}s'.format(end_time - start_time))
与调用h5、saved_model结果完全一致。
C++调用
需要预先安装(Windows)或编译(Linux)C++版的OpenCV。Windows去官网下载安装,然后将build\x64\vc15\bin
目录配置到环境变量。Linux可参考以下文章:
ubuntu安装opencv的正确方法
注意:使用VS 2019时需要配置包含目录和库目录,并链接opencv_worldxxx.lib(xxx对应版本,如4.5.2就是452,具体看lib目录下的文件)
#include <iostream>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/imgproc/types_c.h>
#include <opencv2/dnn.hpp>
#include <io.h>
using namespace std;
using namespace cv;
using namespace dnn;
String modelWeights = "model.pb";
Net model = readNetFromTensorflow(modelWeights);
clock_t start_time, end_time;
char letter[36] = {
'0','1','2','3','4','5','6','7','8','9','a','b',
'c','d','e','f','g','h','i','j','k','l','m','n',
'o','p','q','r','s','t','u','v','w','x','y','z'
};
int region[4][2] = { {0,16},{14,31},{30,46},{44,60} };
string cv_predict(const char* filepath) {
Mat img = imread(filepath);
Mat img_split, blob, out;
string result = "";
int max_location = 0;
float max_value = 0.0;
cvtColor(img, img, CV_BGR2GRAY);
for (int r = 0; r < 4; r++) {
max_location = 0;
max_value = 0.0;
img_split = img(Rect(region[r][0], 0, region[r][1] - region[r][0], 25));
resize(img_split, img_split, Size(15, 25));
blob = blobFromImage(img_split, 1.0 , Size(15, 25), Scalar(0,0,0), true, false);
model.setInput(blob);
out = model.forward();
for (int i = 0; i < 36; i++) {
if (out.at<float>(0, i) > max_value) {
max_value = out.at<float>(0, i);
max_location = i;
}
}
result += letter[max_location];
}
return result;
}
int main(int nargv, const char* argvs[])
{
model.setPreferableBackend(DNN_BACKEND_OPENCV);
model.setPreferableTarget(DNN_TARGET_CPU);
// 单张预测
if (nargv > 1) {
string result = "";
result = cv_predict(argvs[1]);
const char* test = result.c_str();
printf("%s\n", test);
}
// 验证识别率
else {
start_time = clock();
int true_count = 0;
std::string inPath = "img\\*.png";
intptr_t handle;
struct _finddata_t fileinfo;
handle = _findfirst(inPath.c_str(), &fileinfo);
if (handle == -1)
return -1;
do
{
string filepath = "img/" + (string)fileinfo.name;
string result = cv_predict(filepath.c_str());
printf("RESULT is: %s, CAPTCHA text is: %s, ", result.c_str(), fileinfo.name);
if (filepath.find(result) != filepath.npos) {
true_count += 1;
printf("True\n");
}
else
printf("False\n");
} while (!_findnext(handle, &fileinfo));
_findclose(handle);
end_time = clock();
printf("predict rate: %f\n", (float)true_count / 100.0);
printf("time: %fs\n", (float)(end_time - start_time) / CLOCKS_PER_SEC);
}
}
资源整理
相关资源可在另一篇文章中得到。