基于【OpenCV3.4】GoogleNet Caffe进行图片分类

时间:2022-08-17 03:33:29
#ifndef _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_WARNINGS
#endif
#include <iostream>
#include <fstream>
#include <opencv2/dnn/dnn.hpp>
#include <highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>

using namespace std;
using namespace cv;
using namespace cv::dnn;

string model_txt = "bvlc_googlenet.prototxt";
String model_bin = "bvlc_googlenet.caffemodel";
string label_file = "synset_words.txt"; // 类别标签表

vector<String> readLabels();
int main(int argc, char* argv[])
{
    // 1.加载图片
    Mat src = imread("test1.jpg");
    if (src.empty())
    {
        cout << "The image is empty, please check it." << endl;
        return -1;

    }
    imshow("test1", src);

    // 2.加载caffe模型
    Net net = readNetFromCaffe(model_txt, model_bin);
    if (net.empty())
    {
        cout << "load net model data failed..." << endl;
        return -1;
    }

    // 3.读入分类标签
    vector<String> labels = readLabels();

    // 4.将输入图像转换成GoogleNet可识别的blob格式
    Mat inputblob = blobFromImage(src, 1.0, Size(224, 224), Scalar(255, 0, 0));

    // 5.预测
    Mat prob_result;
    for (int i = 0; i < 10; i++) { // 进行10次预测,取可能性最大的类别
        net.setInput(inputblob, "data");
        prob_result = net.forward("prob");
    }
    Mat probMat = prob_result.reshape(1, 1); // 1-channel,1-rows, 变成1行10列
    Point class_position; 
    double class_probability; 
    minMaxLoc(probMat, NULL, &class_probability, NULL, &class_position); // 找出最大的可能性及其位置

    // 打印最大可能性的值
    int classidx = class_position.x;
    printf("\n current image classification : %s, possible : %.2f", labels.at(classidx).c_str(), class_probability);

    // 在图上打印类别
    putText(src, labels.at(classidx), Point(20, 20), FONT_HERSHEY_SIMPLEX, 1.0, Scalar(0, 0, 255), 2, 8);
    imshow("Image Classification", src);

    waitKey();
    return 0;
}
vector<String> readLabels()
{
    vector<String> classNames;
    ifstream in(label_file);

    if (!in.is_open()) 
    { 
        cout << "标签文件不能打开" << endl; 
        exit(-1); 
    }
    string name;
    while (!in.eof())// 直至到达文件尾
    {
        getline(in, name); // 读取一行
        if (!name.empty())
        {
            // 将描述分类前的数字去掉
            classNames.push_back(name.substr(name.find(' ') + 1));// 复制制定位置、长度的子字符串
        }
    }
    in.close();
    return classNames;
}