no classification
no tag
no datas
posted on 2023-06-06 17:57 read(513) comment(0) like(8) collect(2)
Table of contents
After training is complete, generate a confusion matrix! ! ! !
ImageNet data format, generate confusion matrix! ! ! !
Non-ImageNet data format, define the class name and method of importing data! ! ! !
Non-imageNet data format, complete the program code to generate confusion matrix! ! ! !
Confusion matrix : It is a common tool for evaluating the performance of classification models, and can be used to calculate indicators such as classification accuracy, precision, recall and F1-score. Generating a confusion matrix requires comparing the prediction results of the model on the test set with the real labels, and then counting the number of correct predictions and the number of wrong predictions for each category, and finally organizing these data into a matrix form.
Python implements the confusion matrix code:
Among them, data_path
is the data set path and model_path
the model path, which needs to be modified according to the actual situation . The code uses ImageFolder
a direct import of the dataset without redefining the class. When importing, just pass in the root directory of the dataset and the data enhancement method ImageFolder
. Finally, a confusion matrix is generated and saved as a CSV file.
- import torch
- import torchvision.datasets as datasets
- import torchvision.transforms as transforms
- from sklearn.metrics import confusion_matrix
- import pandas as pd
-
- # 设置设备
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-
- # 定义数据增强
- transform = transforms.Compose([
- transforms.Resize((224, 224)),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ])
-
- # 加载数据集
- data_path = "path/to/dataset"
- dataset = datasets.ImageFolder(root=data_path, transform=transform)
-
- # 加载模型
- model_path = "path/to/model.pth"
- model = torch.load(model_path)
- model.to(device)
- model.eval()
-
- # 获取预测结果和标签
- labels = []
- preds = []
- for inputs, targets in dataset:
- inputs = inputs.unsqueeze(0).to(device)
- targets = targets.to(device)
- outputs = model(inputs)
- _, predicted = torch.max(outputs.data, 1)
- labels.append(targets.item())
- preds.append(predicted.item())
-
- # 生成混淆矩阵
- cm = confusion_matrix(labels, preds)
- classes = dataset.classes
- cm_df = pd.DataFrame(cm, index=classes, columns=classes)
-
- # 保存为CSV文件
- cm_df.to_csv("confusion_matrix.csv")
- print("Confusion matrix saved as confusion_matrix.csv")
Class and function definition code for importing data:
- import os
- import numpy as np
- import torch
- from torch.utils.data import Dataset
- from torchvision import transforms
- from PIL import Image
-
- class CustomDataset(Dataset):
- def __init__(self, data_dir, transform=None):
- self.data_dir = data_dir
- self.transform = transform
- self.img_files = os.listdir(data_dir)
-
- def __len__(self):
- return len(self.img_files)
-
- def __getitem__(self, index):
- img_path = os.path.join(self.data_dir, self.img_files[index])
- img = Image.open(img_path).convert('RGB')
- label = self.get_label_from_filename(self.img_files[index])
-
- if self.transform:
- img = self.transform(img)
-
- return img, label
-
- def get_label_from_filename(self, filename):
- label = filename.split('.')[0] # 假设文件名为"label.image_id.jpg"格式
- label = label.split('_')[0] # 仅保留label信息
- return int(label)
-
- # 加载数据集并进行预处理
- data_dir = "your_data_dir"
- transform = transforms.Compose([
- transforms.Resize((224, 224)), # 图像大小调整为224x224
- transforms.ToTensor(), # 将图像转换为Tensor格式,并将像素值缩放到[0, 1]
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 图像标准化
- ])
- dataset = CustomDataset(data_dir, transform=transform)
be careful:
1. The custom data uses "_" to obtain the label value of the category of the picture, so whether your name contains the label value, if there is no label value, you still need to modify it yourself, of course, if there is, you also need to think about it , the position where the label value is placed is the position in the list after split!
2. By the way, the latter part needs to be indented, haha, the python code is concise, but the indentation problem is very abstract, from
model = torch.load('model.pth') start indenting directly! ! !
- import torch
- import torchvision.transforms as transforms
- from torch.utils.data import DataLoader
- from sklearn.metrics import confusion_matrix
- import pandas as pd
- import numpy as np
-
-
- # 自定义数据集类
- class MyDataset(torch.utils.data.Dataset):
- def __init__(self, root_dir, transform=None):
- self.root_dir = root_dir
- self.transform = transform
- self.img_list = os.listdir(root_dir)
-
- def __len__(self):
- return len(self.img_list)
-
- def __getitem__(self, idx):
- img_name = os.path.join(self.root_dir, self.img_list[idx])
- image = Image.open(img_name).convert('RGB')
- if self.transform:
- image = self.transform(image)
- label = int(self.img_list[idx].split('_')[0]) # 根据文件名获取标签
- return image, label
-
- # 加载模型
- model = torch.load('model.pth')
-
- # 定义数据集
- transform = transforms.Compose([
- transforms.Resize((224, 224)),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
- ])
- dataset = MyDataset(root='path/to/dataset', transform=transform)
- dataloader = DataLoader(dataset, batch_size=16, shuffle=False)
-
- # 预测结果和真实标签
- y_pred = []
- y_true = []
- with torch.no_grad():
- for images, labels in dataloader:
- outputs = model(images)
- _, predicted = torch.max(outputs.data, 1)
- y_pred.extend(predicted.cpu().numpy())
- y_true.extend(labels.cpu().numpy())
-
- # 生成混淆矩阵
- cm = confusion_matrix(y_true, y_pred)
-
- # 将混淆矩阵保存为CSV文件
- pd.DataFrame(cm).to_csv('confusion_matrix.csv', index=False, header=False)
-
- # 打印混淆矩阵
- print(cm)
Author:Ineverleft
link:http://www.pythonblackhole.com/blog/article/83281/08d1c666d155603ffbc1/
source:python black hole net
Please indicate the source for any form of reprinting. If any infringement is discovered, it will be held legally responsible.
name:
Comment content: (supports up to 255 characters)
Copyright © 2018-2021 python black hole network All Rights Reserved All rights reserved, and all rights reserved.京ICP备18063182号-7
For complaints and reports, and advertising cooperation, please contact vgs_info@163.com or QQ3083709327
Disclaimer: All articles on the website are uploaded by users and are only for readers' learning and communication use, and commercial use is prohibited. If the article involves pornography, reactionary, infringement and other illegal information, please report it to us and we will delete it immediately after verification!