This commit is contained in:
恍兮惚兮 2024-11-05 00:03:53 +08:00
parent 7cd7c77848
commit b493017148

View File

@ -7,7 +7,8 @@ typedef std::vector<cv::Point> TextBox;
typedef std::string TextLine; typedef std::string TextLine;
typedef std::pair<TextBox, TextLine> TextBlock; typedef std::pair<TextBox, TextLine> TextBlock;
struct ScaleParam { struct ScaleParam
{
int srcWidth; int srcWidth;
int srcHeight; int srcHeight;
int dstWidth; int dstWidth;
@ -15,21 +16,6 @@ struct ScaleParam {
float ratioWidth; float ratioWidth;
float ratioHeight; float ratioHeight;
}; };
#define getinputoutputNames(Func1, vec, Func2) \
do \
{ \
Ort::AllocatorWithDefaultOptions allocator; \
const size_t numInputNodes = session->Func1(); \
\
vec.reserve(numInputNodes); \
std::vector<int64_t> input_node_dims; \
\
for (size_t i = 0; i < numInputNodes; i++) \
{ \
auto inputName = session->Func2(i, allocator); \
vec.push_back(std::move(inputName)); \
} \
} while (0);
class CommonOnnxModel class CommonOnnxModel
{ {
@ -65,6 +51,21 @@ class CommonOnnxModel
sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
} }
template <typename T, typename Func, typename Func2>
void getinputoutputNames(T &vec, Func func, Func2 func2)
{
Ort::AllocatorWithDefaultOptions allocator;
const size_t numInputNodes = ((*session.get()).*func)();
vec.reserve(numInputNodes);
std::vector<int64_t> input_node_dims;
for (size_t i = 0; i < numInputNodes; i++)
{
auto inputName = ((*session.get()).*func2)(i, allocator);
vec.push_back(std::move(inputName));
}
}
public: public:
std::pair<std::vector<float>, std::vector<int64_t>> RunSession(cv::Mat src) std::pair<std::vector<float>, std::vector<int64_t>> RunSession(cv::Mat src)
{ {
@ -90,19 +91,18 @@ public:
{ {
setNumThread(numOfThread); setNumThread(numOfThread);
session = std::make_unique<Ort::Session>(env, path.c_str(), sessionOptions); session = std::make_unique<Ort::Session>(env, path.c_str(), sessionOptions);
getinputoutputNames(GetInputCount, inputNamesPtr, GetInputNameAllocated); getinputoutputNames(inputNamesPtr, &Ort::Session::GetInputCount, &Ort::Session::GetInputNameAllocated);
getinputoutputNames(GetOutputCount, outputNamesPtr, GetOutputNameAllocated); getinputoutputNames(outputNamesPtr, &Ort::Session::GetOutputCount, &Ort::Session::GetOutputNameAllocated);
} }
}; };
class CrnnNet:public CommonOnnxModel{ class CrnnNet : public CommonOnnxModel
{
public: public:
CrnnNet(const std::wstring &pathStr, const std::wstring &keysPath, int numOfThread); CrnnNet(const std::wstring &pathStr, const std::wstring &keysPath, int numOfThread);
std::vector<TextLine> getTextLines(std::vector<cv::Mat> &partImg); std::vector<TextLine> getTextLines(std::vector<cv::Mat> &partImg);
private: private:
const float meanValues[3] = {127.5, 127.5, 127.5};
const float normValues[3] = {1.0 / 127.5, 1.0 / 127.5, 1.0 / 127.5};
const int dstHeight = 48; const int dstHeight = 48;
std::vector<std::string> keys; std::vector<std::string> keys;
@ -112,72 +112,83 @@ private:
TextLine getTextLine(const cv::Mat &src); TextLine getTextLine(const cv::Mat &src);
}; };
class DbNet:public CommonOnnxModel{ class DbNet : public CommonOnnxModel
public: {
DbNet(const std::wstring &pathStr, int numOfThread); public:
DbNet(const std::wstring &pathStr, int numOfThread): CommonOnnxModel(pathStr, {0.485 * 255, 0.456 * 255, 0.406 * 255}, {1.0 / 0.229 / 255.0, 1.0 / 0.224 / 255.0, 1.0 / 0.225 / 255.0}, numOfThread)
{
}
std::vector<TextBox> getTextBoxes(cv::Mat &src, ScaleParam &s, float boxScoreThresh, std::vector<TextBox> getTextBoxes(cv::Mat &src, ScaleParam &s, float boxScoreThresh,
float boxThresh, float unClipRatio); float boxThresh, float unClipRatio);
private:
const float meanValues[3] = {0.485 * 255, 0.456 * 255, 0.406 * 255};
const float normValues[3] = {1.0 / 0.229 / 255.0, 1.0 / 0.224 / 255.0, 1.0 / 0.225 / 255.0};
}; };
//onnxruntime init windows // onnxruntime init windows
ScaleParam getScaleParam(cv::Mat &src, const float scale) { ScaleParam getScaleParam(cv::Mat &src, const float scale)
{
int srcWidth = src.cols; int srcWidth = src.cols;
int srcHeight = src.rows; int srcHeight = src.rows;
int dstWidth = int((float) srcWidth * scale); int dstWidth = int((float)srcWidth * scale);
int dstHeight = int((float) srcHeight * scale); int dstHeight = int((float)srcHeight * scale);
if (dstWidth % 32 != 0) { if (dstWidth % 32 != 0)
{
dstWidth = (dstWidth / 32 - 1) * 32; dstWidth = (dstWidth / 32 - 1) * 32;
dstWidth = (std::max)(dstWidth, 32); dstWidth = (std::max)(dstWidth, 32);
} }
if (dstHeight % 32 != 0) { if (dstHeight % 32 != 0)
{
dstHeight = (dstHeight / 32 - 1) * 32; dstHeight = (dstHeight / 32 - 1) * 32;
dstHeight = (std::max)(dstHeight, 32); dstHeight = (std::max)(dstHeight, 32);
} }
float scaleWidth = (float) dstWidth / (float) srcWidth; float scaleWidth = (float)dstWidth / (float)srcWidth;
float scaleHeight = (float) dstHeight / (float) srcHeight; float scaleHeight = (float)dstHeight / (float)srcHeight;
return {srcWidth, srcHeight, dstWidth, dstHeight, scaleWidth, scaleHeight}; return {srcWidth, srcHeight, dstWidth, dstHeight, scaleWidth, scaleHeight};
} }
ScaleParam getScaleParam(cv::Mat &src, const int targetSize) { ScaleParam getScaleParam(cv::Mat &src, const int targetSize)
{
int srcWidth, srcHeight, dstWidth, dstHeight; int srcWidth, srcHeight, dstWidth, dstHeight;
srcWidth = dstWidth = src.cols; srcWidth = dstWidth = src.cols;
srcHeight = dstHeight = src.rows; srcHeight = dstHeight = src.rows;
float ratio = 1.f; float ratio = 1.f;
if (srcWidth > srcHeight) { if (srcWidth > srcHeight)
{
ratio = float(targetSize) / float(srcWidth); ratio = float(targetSize) / float(srcWidth);
} else { }
else
{
ratio = float(targetSize) / float(srcHeight); ratio = float(targetSize) / float(srcHeight);
} }
dstWidth = int(float(srcWidth) * ratio); dstWidth = int(float(srcWidth) * ratio);
dstHeight = int(float(srcHeight) * ratio); dstHeight = int(float(srcHeight) * ratio);
if (dstWidth % 32 != 0) { if (dstWidth % 32 != 0)
{
dstWidth = (dstWidth / 32) * 32; dstWidth = (dstWidth / 32) * 32;
dstWidth = (std::max)(dstWidth, 32); dstWidth = (std::max)(dstWidth, 32);
} }
if (dstHeight % 32 != 0) { if (dstHeight % 32 != 0)
{
dstHeight = (dstHeight / 32) * 32; dstHeight = (dstHeight / 32) * 32;
dstHeight = (std::max)(dstHeight, 32); dstHeight = (std::max)(dstHeight, 32);
} }
float ratioWidth = (float) dstWidth / (float) srcWidth; float ratioWidth = (float)dstWidth / (float)srcWidth;
float ratioHeight = (float) dstHeight / (float) srcHeight; float ratioHeight = (float)dstHeight / (float)srcHeight;
return {srcWidth, srcHeight, dstWidth, dstHeight, ratioWidth, ratioHeight}; return {srcWidth, srcHeight, dstWidth, dstHeight, ratioWidth, ratioHeight};
} }
std::vector<cv::Point2f> getBox(const cv::RotatedRect &rect) { std::vector<cv::Point2f> getBox(const cv::RotatedRect &rect)
{
cv::Point2f vertices[4]; cv::Point2f vertices[4];
rect.points(vertices); rect.points(vertices);
//std::vector<cv::Point2f> ret(4); // std::vector<cv::Point2f> ret(4);
std::vector<cv::Point2f> ret2(vertices, vertices + sizeof(vertices) / sizeof(vertices[0])); std::vector<cv::Point2f> ret2(vertices, vertices + sizeof(vertices) / sizeof(vertices[0]));
//memcpy(vertices, &ret[0], ret.size() * sizeof(ret[0])); // memcpy(vertices, &ret[0], ret.size() * sizeof(ret[0]));
return ret2; return ret2;
} }
cv::Mat getRotateCropImage(const cv::Mat &src, std::vector<cv::Point> box) { cv::Mat getRotateCropImage(const cv::Mat &src, std::vector<cv::Point> box)
{
cv::Mat image; cv::Mat image;
src.copyTo(image); src.copyTo(image);
std::vector<cv::Point> points = box; std::vector<cv::Point> points = box;
@ -192,7 +203,8 @@ cv::Mat getRotateCropImage(const cv::Mat &src, std::vector<cv::Point> box) {
cv::Mat imgCrop; cv::Mat imgCrop;
image(cv::Rect(left, top, right - left, bottom - top)).copyTo(imgCrop); image(cv::Rect(left, top, right - left, bottom - top)).copyTo(imgCrop);
for (auto &point: points) { for (auto &point : points)
{
point.x -= left; point.x -= left;
point.y -= top; point.y -= top;
} }
@ -233,26 +245,34 @@ cv::Mat getRotateCropImage(const cv::Mat &src, std::vector<cv::Point> box) {
return partImg; return partImg;
} }
bool cvPointCompare(const cv::Point &a, const cv::Point &b) { bool cvPointCompare(const cv::Point &a, const cv::Point &b)
{
return a.x < b.x; return a.x < b.x;
} }
std::vector<cv::Point2f> getMinBoxes(const cv::RotatedRect &boxRect, float &maxSideLen) { std::vector<cv::Point2f> getMinBoxes(const cv::RotatedRect &boxRect, float &maxSideLen)
{
maxSideLen = std::max(boxRect.size.width, boxRect.size.height); maxSideLen = std::max(boxRect.size.width, boxRect.size.height);
std::vector<cv::Point2f> boxPoint = getBox(boxRect); std::vector<cv::Point2f> boxPoint = getBox(boxRect);
std::sort(boxPoint.begin(), boxPoint.end(), cvPointCompare); std::sort(boxPoint.begin(), boxPoint.end(), cvPointCompare);
int index1, index2, index3, index4; int index1, index2, index3, index4;
if (boxPoint[1].y > boxPoint[0].y) { if (boxPoint[1].y > boxPoint[0].y)
{
index1 = 0; index1 = 0;
index4 = 1; index4 = 1;
} else { }
else
{
index1 = 1; index1 = 1;
index4 = 0; index4 = 0;
} }
if (boxPoint[3].y > boxPoint[2].y) { if (boxPoint[3].y > boxPoint[2].y)
{
index2 = 2; index2 = 2;
index3 = 3; index3 = 3;
} else { }
else
{
index2 = 3; index2 = 3;
index3 = 2; index3 = 2;
} }
@ -273,7 +293,8 @@ inline T clamp(T x, T min, T max)
return min; return min;
return x; return x;
} }
float boxScoreFast(const std::vector<cv::Point2f> &boxes, const cv::Mat &pred) { float boxScoreFast(const std::vector<cv::Point2f> &boxes, const cv::Mat &pred)
{
int width = pred.cols; int width = pred.cols;
int height = pred.rows; int height = pred.rows;
@ -298,30 +319,33 @@ float boxScoreFast(const std::vector<cv::Point2f> &boxes, const cv::Mat &pred) {
cv::Mat croppedImg; cv::Mat croppedImg;
pred(cv::Rect(minX, minY, maxX - minX + 1, maxY - minY + 1)) pred(cv::Rect(minX, minY, maxX - minX + 1, maxY - minY + 1))
.copyTo(croppedImg); .copyTo(croppedImg);
auto score = (float) cv::mean(croppedImg, mask)[0]; auto score = (float)cv::mean(croppedImg, mask)[0];
return score; return score;
} }
float getContourArea(const std::vector<cv::Point2f> &box, float unClipRatio) { float getContourArea(const std::vector<cv::Point2f> &box, float unClipRatio)
{
size_t size = box.size(); size_t size = box.size();
float area = 0.0f; float area = 0.0f;
float dist = 0.0f; float dist = 0.0f;
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++)
{
area += box[i].x * box[(i + 1) % size].y - area += box[i].x * box[(i + 1) % size].y -
box[i].y * box[(i + 1) % size].x; box[i].y * box[(i + 1) % size].x;
dist += sqrtf((box[i].x - box[(i + 1) % size].x) * dist += sqrtf((box[i].x - box[(i + 1) % size].x) *
(box[i].x - box[(i + 1) % size].x) + (box[i].x - box[(i + 1) % size].x) +
(box[i].y - box[(i + 1) % size].y) * (box[i].y - box[(i + 1) % size].y) *
(box[i].y - box[(i + 1) % size].y)); (box[i].y - box[(i + 1) % size].y));
} }
area = fabs(float(area / 2.0)); area = fabs(float(area / 2.0));
return area * unClipRatio / dist; return area * unClipRatio / dist;
} }
cv::RotatedRect unClip(std::vector<cv::Point2f> box, float unClipRatio) { cv::RotatedRect unClip(std::vector<cv::Point2f> box, float unClipRatio)
{
float distance = getContourArea(box, unClipRatio); float distance = getContourArea(box, unClipRatio);
Clipper2Lib::ClipperOffset offset; Clipper2Lib::ClipperOffset offset;
@ -335,15 +359,20 @@ cv::RotatedRect unClip(std::vector<cv::Point2f> box, float unClipRatio) {
offset.Execute(distance, soln); offset.Execute(distance, soln);
std::vector<cv::Point2f> points; std::vector<cv::Point2f> points;
for (size_t j = 0; j < soln.size(); j++) { for (size_t j = 0; j < soln.size(); j++)
for (size_t i = 0; i < soln[soln.size() - 1].size(); i++) { {
for (size_t i = 0; i < soln[soln.size() - 1].size(); i++)
{
points.emplace_back(cv::Point2f{float(soln[j][i].x), float(soln[j][i].y)}); points.emplace_back(cv::Point2f{float(soln[j][i].x), float(soln[j][i].y)});
} }
} }
cv::RotatedRect res; cv::RotatedRect res;
if (points.empty()) { if (points.empty())
{
res = cv::RotatedRect(cv::Point2f(0, 0), cv::Size2f(1, 1), 0); res = cv::RotatedRect(cv::Point2f(0, 0), cv::Size2f(1, 1), 0);
} else { }
else
{
res = cv::minAreaRect(points); res = cv::minAreaRect(points);
} }
return res; return res;
@ -427,10 +456,6 @@ std::vector<TextLine> CrnnNet::getTextLines(std::vector<cv::Mat> &partImg)
return textLines; return textLines;
} }
DbNet::DbNet(const std::wstring &pathStr, int numOfThread) : CommonOnnxModel(pathStr, {0.485 * 255, 0.456 * 255, 0.406 * 255}, {1.0 / 0.229 / 255.0, 1.0 / 0.224 / 255.0, 1.0 / 0.225 / 255.0}, numOfThread)
{
}
std::vector<TextBox> findRsBoxes(const cv::Mat &predMat, const cv::Mat &dilateMat, ScaleParam &s, std::vector<TextBox> findRsBoxes(const cv::Mat &predMat, const cv::Mat &dilateMat, ScaleParam &s,
const float boxScoreThresh, const float unClipRatio) const float boxScoreThresh, const float unClipRatio)
{ {
@ -601,13 +626,15 @@ std::vector<cv::Mat> OcrLite::getPartImages(cv::Mat &src, std::vector<TextBox> &
return partImages; return partImages;
} }
cv::Mat matRotateClockWise180(cv::Mat src) { cv::Mat matRotateClockWise180(cv::Mat src)
{
flip(src, src, 0); flip(src, src, 0);
flip(src, src, 1); flip(src, src, 1);
return src; return src;
} }
cv::Mat matRotateClockWise90(cv::Mat src) { cv::Mat matRotateClockWise90(cv::Mat src)
{
transpose(src, src); transpose(src, src);
flip(src, src, 1); flip(src, src, 1);
return src; return src;