栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

torch

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

torch

#include "torch/script.h" // One-stop header. 

#include 
#include  
#include 
#include 
#include 
#include 
#include 
#include  
#include  

using namespace cv;
using namespace std;

torch::jit::script::Module model_high;
torch::jit::script::Module model_low;
torch::DeviceType device_type = at::kCPU; // 定义设备类型
int h = 70;
int w = 70;
void getAllFiles(string path, vector& files)
{
	// 文件句柄
	intptr_t hFile = 0;
	// 文件信息
	struct _finddata_t fileinfo;

	string p;

	if ((hFile = _findfirst(p.assign(path).append("\*").c_str(), &fileinfo)) != -1) {
		do {
			if (strcmp(fileinfo.name, ".") == 0 || strcmp(fileinfo.name, "..") == 0) {
				continue;
			}
			// 保存文件的全路径
			files.push_back(fileinfo.name);

		} while (_findnext(hFile, &fileinfo) == 0); //寻找下一个,成功返回0,否则-1

		_findclose(hFile);
	}
}
Mat resizeMat(Mat image) {
	int xStartP = 0;
	int yStartP = 0;
	int width = 0;
	int high = 0;
	if (image.rows > h) {
		yStartP = (image.rows - h) / 2;
		high = h;
	}
	else {
		high = image.rows;
	}
	if (image.cols > w) {
		xStartP = (image.cols - w) / 2;
		width = w;
	}
	else {
		width = image.cols;
	}
	//起点x,y,宽度(列数),高度(行数)横向为x
	Mat img = image(Rect(xStartP, yStartP, width, high));
	Mat dst;
	int top_size = int((h - img.rows) / 2);
	int bottom_size = h - top_size - img.rows;
	int left_size = int((w - img.cols) / 2);
	int right_size = w - left_size - img.cols;
	copyMakeBorder(img, dst, top_size, bottom_size, left_size, right_size, BORDER_CONSTANT, Scalar(0, 0, 0));
	return dst;
}
void writeMatToFile(cv::Mat& m, const char* filename)
{
	std::ofstream fout(filename);

	if (!fout)
	{
		std::cout << "File Not Opened" << std::endl;
		return;
	}

	for (int i = 0; i < m.rows; i++)
	{
		for (int j = 0; j < m.cols; j++)
		{
			fout << m.at(i, j) << "t";
		}
		fout << std::endl;
	}

	fout.close();
}

float classify(Mat image) {
	//cout << image << endl;
	torch::Tensor tensor_image = torch::from_blob(image.data, { image.rows,image.cols,image.channels() }, torch::kByte).to(device_type);
	tensor_image = tensor_image.permute({ 2, 0, 1 });
	tensor_image = tensor_image.toType(torch::kFloat);
	tensor_image = tensor_image.div(255 * 255);
	tensor_image = tensor_image.unsqueeze(0);
	std::vector inputs;
	inputs.push_back(tensor_image);
	torch::Tensor output, output_max;
	output = model_high.forward(inputs).toTensor();
	cout << output << endl;
	output = torch::softmax(output, 1);
	output = output.to(at::kCPU);
	cout << output << endl;
	auto x_data = output.accessor();
	float x = x_data[0][1];
	return x;
}
void Rtexture(Mat* high_blank, Mat* low_blank, Mat* high, Mat* low, Mat* R) {
	Mat log_high(R->rows, R->cols, CV_64FC1);
	Mat log_low(R->rows, R->cols, CV_64FC1);
	//Mat log_high, log_low;
	log_high = (*high_blank + 1) / (*high + 1);
	//log_high.convertTo(log_high, CV_64F);
	cv::log(log_high, log_high);
	log_low = (*low_blank + 1) / (*low + 1);
	//log_low.convertTo(log_low, CV_64F);
	cv::log(log_low, log_low);
	*R = abs(log_high) + 1 / abs(log_low) + 1;
	log_high.release();
	log_low.release();
}
Mat img_joint(Mat img, int row_num, int col_num) {
	Mat out = Mat::zeros(img.rows * row_num, img.cols * col_num, img.type());
	for (int r = 0; r < row_num; r++) {
		for (int c = 0; c < col_num; c++) {
			CvRect rect = cvRect(img.cols * c, img.rows * r, img.cols, img.rows);
			Mat dstMat = out(rect);
			// 将原始图拷贝至目标区域
			img.colRange(0, img.cols).copyTo(dstMat);
		}
	}
	return out;
}
Mat readd(Mat img, int num, int r_c) {
	Mat dst;
	if (r_c == 0) {
		//行拼接
		dst = img_joint(img, num, 1);
		dst = dst(Rect(0, 0, dst.cols, h));
	}
	else {
		//列拼接
		dst = img_joint(img, 1, num);
		dst = dst(Rect(0, 0, w, dst.rows));
	}
	return dst;
}
Mat maxminscale(Mat img, int mode) {
	if (mode != 2) {
		float k = 255.0 / (45000 - 6000);
		img = (img - 6000) * k;
		//Mat dst = Mat(img.rows, img.cols, img.type());
		threshold(img, img, 255, 255, THRESH_TRUNC);
		threshold(img, img, 0, 255, THRESH_TOZERO);
	}
	else {
		float k = 255.0 / (17000 - 0);
		img = (img - 0) * k;
		//Mat dst = Mat(img.rows, img.cols, img.type());
		threshold(img, img, 255, 255, THRESH_TRUNC);
		threshold(img, img, 0, 255, THRESH_TOZERO);
	}
	return img;
}
Mat img_is_split(Mat high, Mat low) {
	Mat new_high, new_low;
	Mat R;
	high.copyTo(new_high);
	low.copyTo(new_low);
	if (new_high.rows > h) {
		Mat top_h = new_high(Rect(0, 0, new_high.cols, int(h / 2)));
		Mat bottom_h = new_high(Rect(0, new_high.rows - int(h / 2), new_high.cols, int(h / 2)));
		Mat top_l = new_low(Rect(0, 0, new_high.cols, int(h / 2)));
		Mat bottom_l = new_low(Rect(0, new_high.rows - int(h / 2), new_high.cols, int(h / 2)));
		vconcat(top_h, bottom_h, new_high);
		vconcat(top_l, bottom_l, new_low);
	}
	if (new_high.cols > w) {
		Mat left_h = new_high(Rect(0, 0, int(w / 2), new_high.rows));
		Mat right_h = new_high(Rect(new_high.cols - int(w / 2), 0, int(w / 2), new_high.rows));
		Mat left_l = new_low(Rect(0, 0, int(w / 2), new_high.rows));
		Mat right_l = new_low(Rect(new_high.cols - int(w / 2), 0, int(w / 2), new_high.rows));
		hconcat(left_h, right_h, new_high);
		hconcat(left_l, right_l, new_low);
	}
	if (new_high.rows < h) {
		int row_num = ceil(float(h) / new_high.rows);
		new_high = readd(new_high, row_num, 0);
		new_low = readd(new_low, row_num, 0);
	}
	if (new_high.cols < w) {
		int col_num = ceil(float(w) / new_high.cols);
		new_high = readd(new_high, col_num, 1);
		new_low = readd(new_low, col_num, 1);
	}
	vector idx;
	cv::findNonZero(new_high, idx);
	if (float(idx.size()) / (h * w) > 0.05) {
		//imghs.push_back(new_high);
		//imgls.push_back(new_low);
		Mat blank_high = Mat(Size(new_high.rows, new_high.cols), CV_64FC1, Scalar(48815));
		Mat blank_low = Mat(Size(new_high.rows, new_high.cols), CV_64FC1, Scalar(48312));
		new_high.convertTo(new_high, CV_64FC1);
		new_low.convertTo(new_low, CV_64FC1);

		Rtexture(&blank_high, &blank_low, &new_high, &new_low, &R);
		new_high.convertTo(new_high, CV_32FC1);
		new_low.convertTo(new_low, CV_32FC1);
		R.convertTo(R, CV_32FC1);
		Mat scale_high = maxminscale(new_high, 0);
		Mat scale_low = maxminscale(new_low, 1);
		Mat scale_R = maxminscale(R, 2);

		vector channels;
		channels.push_back(scale_high);
		channels.push_back(scale_low);
		channels.push_back(scale_R);
		Mat MultiImage;
		merge(channels, MultiImage);
		MultiImage.convertTo(MultiImage, CV_8UC3,1);
		return MultiImage;
	}
	Mat null = Mat(Size(1, 1), CV_8U, Scalar(0));
	return null;
}
void predict(Mat imgh,Mat imgl) {
	int Pen_flag = 1;
	int row = imgh.rows;
	int col = imgh.cols;
	if (max(row, col) > 130) {
		Pen_flag = 0;
	}
	else {
		Mat img = img_is_split(imgh, imgl);
		if (img.rows == 1) {
			return;
		}
		float vote = 0;
		vote += classify(img);
		if (vote > 0.4) {
			Pen_flag = 0;
		}
		cout << "模型判断完毕模型求和" << vote << "喷吹:" << Pen_flag << endl;
		cout << "model-" << clock() << endl;
	}
}
int main()
{
	std::cout << "CUDA:   " << torch::cuda::is_available() << std::endl;
	std::cout << "CUDNN:  " << torch::cuda::cudnn_is_available() << std::endl;
	std::cout << "GPU(s): " << torch::cuda::device_count() << std::endl;
	if (torch::cuda::is_available())
		device_type = at::kCUDA;
	model_high = torch::jit::load("model-12-6.pt");
	model_high.to(device_type);
	//for (int i = 0; i < 5; i++) {
	//	Mat image = Mat::eye(h, w, CV_16UC1);
	//	classify(image, 1);
	//	classify(image, 0);
	//}
	cout << "加载模型成功" << endl;
	//assert(module != nullptr);  

	//vector shifiles, meifiles;
	string shipath_h = "shi//High";
	string shipath_l = "shi//Low";
	string meipath = "mei";
	//getAllFiles(shipath_h, shifiles);
	
 	Mat high = cv::imread("74292_0_cluster1.png",-1);
	Mat low = cv::imread("74292_0_cluster2.png", -1);
	predict(high, low);
}

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/657460.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号