This commit is contained in:
恍兮惚兮 2024-11-05 00:03:53 +08:00
parent 7cd7c77848
commit b493017148

View File

@ -7,7 +7,8 @@ typedef std::vector<cv::Point> TextBox;
typedef std::string TextLine;
typedef std::pair<TextBox, TextLine> TextBlock;
struct ScaleParam {
struct ScaleParam
{
int srcWidth;
int srcHeight;
int dstWidth;
@ -15,21 +16,6 @@ struct ScaleParam {
float ratioWidth;
float ratioHeight;
};
#define getinputoutputNames(Func1, vec, Func2) \
do \
{ \
Ort::AllocatorWithDefaultOptions allocator; \
const size_t numInputNodes = session->Func1(); \
\
vec.reserve(numInputNodes); \
std::vector<int64_t> input_node_dims; \
\
for (size_t i = 0; i < numInputNodes; i++) \
{ \
auto inputName = session->Func2(i, allocator); \
vec.push_back(std::move(inputName)); \
} \
} while (0);
class CommonOnnxModel
{
@ -65,6 +51,21 @@ class CommonOnnxModel
sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
}
template <typename T, typename Func, typename Func2>
void getinputoutputNames(T &vec, Func func, Func2 func2)
{
Ort::AllocatorWithDefaultOptions allocator;
const size_t numInputNodes = ((*session.get()).*func)();
vec.reserve(numInputNodes);
std::vector<int64_t> input_node_dims;
for (size_t i = 0; i < numInputNodes; i++)
{
auto inputName = ((*session.get()).*func2)(i, allocator);
vec.push_back(std::move(inputName));
}
}
public:
std::pair<std::vector<float>, std::vector<int64_t>> RunSession(cv::Mat src)
{
@ -90,19 +91,18 @@ public:
{
setNumThread(numOfThread);
session = std::make_unique<Ort::Session>(env, path.c_str(), sessionOptions);
getinputoutputNames(GetInputCount, inputNamesPtr, GetInputNameAllocated);
getinputoutputNames(GetOutputCount, outputNamesPtr, GetOutputNameAllocated);
getinputoutputNames(inputNamesPtr, &Ort::Session::GetInputCount, &Ort::Session::GetInputNameAllocated);
getinputoutputNames(outputNamesPtr, &Ort::Session::GetOutputCount, &Ort::Session::GetOutputNameAllocated);
}
};
class CrnnNet:public CommonOnnxModel{
class CrnnNet : public CommonOnnxModel
{
public:
CrnnNet(const std::wstring &pathStr, const std::wstring &keysPath, int numOfThread);
std::vector<TextLine> getTextLines(std::vector<cv::Mat> &partImg);
private:
const float meanValues[3] = {127.5, 127.5, 127.5};
const float normValues[3] = {1.0 / 127.5, 1.0 / 127.5, 1.0 / 127.5};
const int dstHeight = 48;
std::vector<std::string> keys;
@ -112,28 +112,31 @@ private:
TextLine getTextLine(const cv::Mat &src);
};
class DbNet:public CommonOnnxModel{
class DbNet : public CommonOnnxModel
{
public:
DbNet(const std::wstring &pathStr, int numOfThread);
DbNet(const std::wstring &pathStr, int numOfThread): CommonOnnxModel(pathStr, {0.485 * 255, 0.456 * 255, 0.406 * 255}, {1.0 / 0.229 / 255.0, 1.0 / 0.224 / 255.0, 1.0 / 0.225 / 255.0}, numOfThread)
{
}
std::vector<TextBox> getTextBoxes(cv::Mat &src, ScaleParam &s, float boxScoreThresh,
float boxThresh, float unClipRatio);
private:
const float meanValues[3] = {0.485 * 255, 0.456 * 255, 0.406 * 255};
const float normValues[3] = {1.0 / 0.229 / 255.0, 1.0 / 0.224 / 255.0, 1.0 / 0.225 / 255.0};
};
// onnxruntime init windows
ScaleParam getScaleParam(cv::Mat &src, const float scale) {
ScaleParam getScaleParam(cv::Mat &src, const float scale)
{
int srcWidth = src.cols;
int srcHeight = src.rows;
int dstWidth = int((float)srcWidth * scale);
int dstHeight = int((float)srcHeight * scale);
if (dstWidth % 32 != 0) {
if (dstWidth % 32 != 0)
{
dstWidth = (dstWidth / 32 - 1) * 32;
dstWidth = (std::max)(dstWidth, 32);
}
if (dstHeight % 32 != 0) {
if (dstHeight % 32 != 0)
{
dstHeight = (dstHeight / 32 - 1) * 32;
dstHeight = (std::max)(dstHeight, 32);
}
@ -142,24 +145,30 @@ ScaleParam getScaleParam(cv::Mat &src, const float scale) {
return {srcWidth, srcHeight, dstWidth, dstHeight, scaleWidth, scaleHeight};
}
ScaleParam getScaleParam(cv::Mat &src, const int targetSize) {
ScaleParam getScaleParam(cv::Mat &src, const int targetSize)
{
int srcWidth, srcHeight, dstWidth, dstHeight;
srcWidth = dstWidth = src.cols;
srcHeight = dstHeight = src.rows;
float ratio = 1.f;
if (srcWidth > srcHeight) {
if (srcWidth > srcHeight)
{
ratio = float(targetSize) / float(srcWidth);
} else {
}
else
{
ratio = float(targetSize) / float(srcHeight);
}
dstWidth = int(float(srcWidth) * ratio);
dstHeight = int(float(srcHeight) * ratio);
if (dstWidth % 32 != 0) {
if (dstWidth % 32 != 0)
{
dstWidth = (dstWidth / 32) * 32;
dstWidth = (std::max)(dstWidth, 32);
}
if (dstHeight % 32 != 0) {
if (dstHeight % 32 != 0)
{
dstHeight = (dstHeight / 32) * 32;
dstHeight = (std::max)(dstHeight, 32);
}
@ -168,7 +177,8 @@ ScaleParam getScaleParam(cv::Mat &src, const int targetSize) {
return {srcWidth, srcHeight, dstWidth, dstHeight, ratioWidth, ratioHeight};
}
std::vector<cv::Point2f> getBox(const cv::RotatedRect &rect) {
std::vector<cv::Point2f> getBox(const cv::RotatedRect &rect)
{
cv::Point2f vertices[4];
rect.points(vertices);
// std::vector<cv::Point2f> ret(4);
@ -177,7 +187,8 @@ std::vector<cv::Point2f> getBox(const cv::RotatedRect &rect) {
return ret2;
}
cv::Mat getRotateCropImage(const cv::Mat &src, std::vector<cv::Point> box) {
cv::Mat getRotateCropImage(const cv::Mat &src, std::vector<cv::Point> box)
{
cv::Mat image;
src.copyTo(image);
std::vector<cv::Point> points = box;
@ -192,7 +203,8 @@ cv::Mat getRotateCropImage(const cv::Mat &src, std::vector<cv::Point> box) {
cv::Mat imgCrop;
image(cv::Rect(left, top, right - left, bottom - top)).copyTo(imgCrop);
for (auto &point: points) {
for (auto &point : points)
{
point.x -= left;
point.y -= top;
}
@ -233,26 +245,34 @@ cv::Mat getRotateCropImage(const cv::Mat &src, std::vector<cv::Point> box) {
return partImg;
}
bool cvPointCompare(const cv::Point &a, const cv::Point &b) {
bool cvPointCompare(const cv::Point &a, const cv::Point &b)
{
return a.x < b.x;
}
std::vector<cv::Point2f> getMinBoxes(const cv::RotatedRect &boxRect, float &maxSideLen) {
std::vector<cv::Point2f> getMinBoxes(const cv::RotatedRect &boxRect, float &maxSideLen)
{
maxSideLen = std::max(boxRect.size.width, boxRect.size.height);
std::vector<cv::Point2f> boxPoint = getBox(boxRect);
std::sort(boxPoint.begin(), boxPoint.end(), cvPointCompare);
int index1, index2, index3, index4;
if (boxPoint[1].y > boxPoint[0].y) {
if (boxPoint[1].y > boxPoint[0].y)
{
index1 = 0;
index4 = 1;
} else {
}
else
{
index1 = 1;
index4 = 0;
}
if (boxPoint[3].y > boxPoint[2].y) {
if (boxPoint[3].y > boxPoint[2].y)
{
index2 = 2;
index3 = 3;
} else {
}
else
{
index2 = 3;
index3 = 2;
}
@ -273,7 +293,8 @@ inline T clamp(T x, T min, T max)
return min;
return x;
}
float boxScoreFast(const std::vector<cv::Point2f> &boxes, const cv::Mat &pred) {
float boxScoreFast(const std::vector<cv::Point2f> &boxes, const cv::Mat &pred)
{
int width = pred.cols;
int height = pred.rows;
@ -304,11 +325,13 @@ float boxScoreFast(const std::vector<cv::Point2f> &boxes, const cv::Mat &pred) {
return score;
}
float getContourArea(const std::vector<cv::Point2f> &box, float unClipRatio) {
float getContourArea(const std::vector<cv::Point2f> &box, float unClipRatio)
{
size_t size = box.size();
float area = 0.0f;
float dist = 0.0f;
for (size_t i = 0; i < size; i++) {
for (size_t i = 0; i < size; i++)
{
area += box[i].x * box[(i + 1) % size].y -
box[i].y * box[(i + 1) % size].x;
dist += sqrtf((box[i].x - box[(i + 1) % size].x) *
@ -321,7 +344,8 @@ float getContourArea(const std::vector<cv::Point2f> &box, float unClipRatio) {
return area * unClipRatio / dist;
}
cv::RotatedRect unClip(std::vector<cv::Point2f> box, float unClipRatio) {
cv::RotatedRect unClip(std::vector<cv::Point2f> box, float unClipRatio)
{
float distance = getContourArea(box, unClipRatio);
Clipper2Lib::ClipperOffset offset;
@ -335,15 +359,20 @@ cv::RotatedRect unClip(std::vector<cv::Point2f> box, float unClipRatio) {
offset.Execute(distance, soln);
std::vector<cv::Point2f> points;
for (size_t j = 0; j < soln.size(); j++) {
for (size_t i = 0; i < soln[soln.size() - 1].size(); i++) {
for (size_t j = 0; j < soln.size(); j++)
{
for (size_t i = 0; i < soln[soln.size() - 1].size(); i++)
{
points.emplace_back(cv::Point2f{float(soln[j][i].x), float(soln[j][i].y)});
}
}
cv::RotatedRect res;
if (points.empty()) {
if (points.empty())
{
res = cv::RotatedRect(cv::Point2f(0, 0), cv::Size2f(1, 1), 0);
} else {
}
else
{
res = cv::minAreaRect(points);
}
return res;
@ -427,10 +456,6 @@ std::vector<TextLine> CrnnNet::getTextLines(std::vector<cv::Mat> &partImg)
return textLines;
}
DbNet::DbNet(const std::wstring &pathStr, int numOfThread) : CommonOnnxModel(pathStr, {0.485 * 255, 0.456 * 255, 0.406 * 255}, {1.0 / 0.229 / 255.0, 1.0 / 0.224 / 255.0, 1.0 / 0.225 / 255.0}, numOfThread)
{
}
std::vector<TextBox> findRsBoxes(const cv::Mat &predMat, const cv::Mat &dilateMat, ScaleParam &s,
const float boxScoreThresh, const float unClipRatio)
{
@ -601,13 +626,15 @@ std::vector<cv::Mat> OcrLite::getPartImages(cv::Mat &src, std::vector<TextBox> &
return partImages;
}
cv::Mat matRotateClockWise180(cv::Mat src) {
cv::Mat matRotateClockWise180(cv::Mat src)
{
flip(src, src, 0);
flip(src, src, 1);
return src;
}
cv::Mat matRotateClockWise90(cv::Mat src) {
cv::Mat matRotateClockWise90(cv::Mat src)
{
transpose(src, src);
flip(src, src, 1);
return src;