栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch预测结果的精度、召回率、F1值、准确率计算(简单实现)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch预测结果的精度、召回率、F1值、准确率计算(简单实现)

1. 导入必要的库函数

import os
import json
import torch
from PIL import Image
from torchvision import transforms
# 导入自己的模型
from model_v3 import mobilenet_v3_small

2. 文件存储格式

imgs文件下下是需要预测的图像

label.txt为每幅图对应的标签

### 这个无所谓,能正常读取就行

 3. 标签读取辅助函数

targets_path = "test/label.txt"
    with open(targets_path, 'r') as file:
        targets = file.readlines()

 4. 精度、召回率、F1值、准确率计算函数

# 获取标签
target = int(targets[i])
# 输出结果
output = torch.squeeze(model(img.to(device))).cpu()
# 取出最大值的索引
predict = torch.softmax(output, dim=0)
# 得到数值
predict_cla = torch.argmax(predict).item()
# 0为正类, 1为负类
if predict_cla == 0 and target == 0:
    TP += 1
if predict_cla == 1 and target == 1:
    TN += 1
if predict_cla == 0 and target == 1:
    FP += 1
    print(str(imgs_path) + " " + str(file) + " is predicted wrong")
if predict_cla == 1 and target == 0:
    FN += 1
    print(str(imgs_path) + " " + str(file) + " is predicted wrong")

# 根据公式写的,如果有问题欢迎大家支持,一起学习
P = TP / (TP + FP + esp)
R = TP / (TP + FN + esp)
F1 = 2 * P * R / (P + R + esp)
acc = (TP + TN) / (TP + TN + FP + FN + esp)

5. 完整实现,修改路径可以用于自己的模型

import os
import json
import torch
from PIL import Image
from torchvision import transforms
from model_v3 import mobilenet_v3_small


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])


    imgs_path = "I:/ZTC950V763_211118/CV/imgs/"
    targets_path = "I:/ZTC950V763_211118/CV/y_CV.txt"
    with open(targets_path, 'r') as file:
        targets = file.readlines()
    TP, TN, FP, FN = 0, 0, 0, 0
    esp = 1e-6
    i = 0
    for _, __, files in os.walk(imgs_path):
        for file in files:
            img = Image.open(imgs_path + str(file))
            target = int(targets[i])
            # plt.imshow(img)
            # [N, C, H, W]
            img = data_transform(img)
            # expand batch dimension
            img = torch.unsqueeze(img, dim=0)
            # read class_indict
            json_path = 'class_indices.json'
            assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

            json_file = open(json_path, "r")
            class_indict = json.load(json_file)

            # create model
            model = mobilenet_v3_small(num_classes=2).to(device)
            # load model weights
            model_weight_path = "C:/Users/00769111/PycharmProjects/mobilenet_juanyang/weights/No_freeze_MobileNetV3.pth"
            model.load_state_dict(torch.load(model_weight_path, map_location=device))
            model.eval()
            with torch.no_grad():
                # predict class
                output = torch.squeeze(model(img.to(device))).cpu()
                predict = torch.softmax(output, dim=0)
                predict_cla = torch.argmax(predict).item()
                if predict_cla == 0 and target == 0:
                    TP += 1
                if predict_cla == 1 and target == 1:
                    TN += 1
                if predict_cla == 0 and target == 1:
                    FP += 1
                    print(str(imgs_path) + " " + str(file) + " is predicted wrong")
                if predict_cla == 1 and target == 0:
                    FN += 1
                    print(str(imgs_path) + " " + str(file) + " is predicted wrong")
            i += 1
            if i % 200 == 0:
                P = TP / (TP + FP + esp)
                R = TP / (TP + FN + esp)
                F1 = 2 * P * R / (P + R + esp)
                acc = (TP + TN) / (TP + TN + FP + FN + esp)
                print(f"精度为: {P}n")
                print(f"召回率为: {R}n")
                print(f"F1值为: {F1}n")
                print(f"准确率为: {acc}")
        P = TP / (TP + FP + esp)
        R = TP / (TP + FN + esp)
        F1 = 2 * P * R / (P + R + esp)
        acc = (TP + TN) / (TP + TN + FP + FN + esp)
        print("结果汇总n")
        print(f"精度为: {P}n")
        print(f"召回率为: {R}n")
        print(f"F1值为: {F1}n")
        print(f"准确率为: {acc}")


if __name__ == '__main__':
    main()

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/581112.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号