News from this site

 Rental advertising space, please contact the webmaster if you need cooperation


+focus
focused

classification  

no classification

tag  

no tag

date  

no datas

Image classification model evaluation using python to draw confusion matrix confusion_matrix

posted on 2023-06-06 17:57     read(513)     comment(0)     like(8)     collect(2)


Table of contents

After training is complete, generate a confusion matrix! ! ! !

ImageNet data format, generate confusion matrix! ! ! !

Non-ImageNet data format, define the class name and method of importing data! ! ! !

Non-imageNet data format, complete the program code to generate confusion matrix! ! ! !


Confusion matrix : It is a common tool for evaluating the performance of classification models, and can be used to calculate indicators such as classification accuracy, precision, recall and F1-score. Generating a confusion matrix requires comparing the prediction results of the model on the test set with the real labels, and then counting the number of correct predictions and the number of wrong predictions for each category, and finally organizing these data into a matrix form.

Python implements the confusion matrix code:

After training is complete, generate a confusion matrix! ! ! !

ImageNet data format, generate confusion matrix! ! ! !

Among them, data_pathis the data set path and model_paththe model path, which needs to be modified according to the actual situation . The code uses ImageFoldera direct import of the dataset without redefining the class. When importing, just pass in the root directory of the dataset and the data enhancement method ImageFolder. Finally, a confusion matrix is ​​generated and saved as a CSV file.

  1. import torch
  2. import torchvision.datasets as datasets
  3. import torchvision.transforms as transforms
  4. from sklearn.metrics import confusion_matrix
  5. import pandas as pd
  6. # 设置设备
  7. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  8. # 定义数据增强
  9. transform = transforms.Compose([
  10. transforms.Resize((224, 224)),
  11. transforms.ToTensor(),
  12. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  13. ])
  14. # 加载数据集
  15. data_path = "path/to/dataset"
  16. dataset = datasets.ImageFolder(root=data_path, transform=transform)
  17. # 加载模型
  18. model_path = "path/to/model.pth"
  19. model = torch.load(model_path)
  20. model.to(device)
  21. model.eval()
  22. # 获取预测结果和标签
  23. labels = []
  24. preds = []
  25. for inputs, targets in dataset:
  26. inputs = inputs.unsqueeze(0).to(device)
  27. targets = targets.to(device)
  28. outputs = model(inputs)
  29. _, predicted = torch.max(outputs.data, 1)
  30. labels.append(targets.item())
  31. preds.append(predicted.item())
  32. # 生成混淆矩阵
  33. cm = confusion_matrix(labels, preds)
  34. classes = dataset.classes
  35. cm_df = pd.DataFrame(cm, index=classes, columns=classes)
  36. # 保存为CSV文件
  37. cm_df.to_csv("confusion_matrix.csv")
  38. print("Confusion matrix saved as confusion_matrix.csv")

Non-ImageNet data format, define the class name and method of importing data! ! ! !

Class and function definition code for importing data:

  1. import os
  2. import numpy as np
  3. import torch
  4. from torch.utils.data import Dataset
  5. from torchvision import transforms
  6. from PIL import Image
  7. class CustomDataset(Dataset):
  8. def __init__(self, data_dir, transform=None):
  9. self.data_dir = data_dir
  10. self.transform = transform
  11. self.img_files = os.listdir(data_dir)
  12. def __len__(self):
  13. return len(self.img_files)
  14. def __getitem__(self, index):
  15. img_path = os.path.join(self.data_dir, self.img_files[index])
  16. img = Image.open(img_path).convert('RGB')
  17. label = self.get_label_from_filename(self.img_files[index])
  18. if self.transform:
  19. img = self.transform(img)
  20. return img, label
  21. def get_label_from_filename(self, filename):
  22. label = filename.split('.')[0] # 假设文件名为"label.image_id.jpg"格式
  23. label = label.split('_')[0] # 仅保留label信息
  24. return int(label)
  25. # 加载数据集并进行预处理
  26. data_dir = "your_data_dir"
  27. transform = transforms.Compose([
  28. transforms.Resize((224, 224)), # 图像大小调整为224x224
  29. transforms.ToTensor(), # 将图像转换为Tensor格式,并将像素值缩放到[0, 1]
  30. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 图像标准化
  31. ])
  32. dataset = CustomDataset(data_dir, transform=transform)

Non-imageNet data format, complete the program code to generate confusion matrix! ! ! !

be careful:

1. The custom data uses "_" to obtain the label value of the category of the picture, so whether your name contains the label value, if there is no label value, you still need to modify it yourself, of course, if there is, you also need to think about it , the position where the label value is placed is the position in the list after split!

2. By the way, the latter part needs to be indented, haha, the python code is concise, but the indentation problem is very abstract, from

model = torch.load('model.pth')  start indenting directly! ! !

  1. import torch
  2. import torchvision.transforms as transforms
  3. from torch.utils.data import DataLoader
  4. from sklearn.metrics import confusion_matrix
  5. import pandas as pd
  6. import numpy as np
  7. # 自定义数据集类
  8. class MyDataset(torch.utils.data.Dataset):
  9. def __init__(self, root_dir, transform=None):
  10. self.root_dir = root_dir
  11. self.transform = transform
  12. self.img_list = os.listdir(root_dir)
  13. def __len__(self):
  14. return len(self.img_list)
  15. def __getitem__(self, idx):
  16. img_name = os.path.join(self.root_dir, self.img_list[idx])
  17. image = Image.open(img_name).convert('RGB')
  18. if self.transform:
  19. image = self.transform(image)
  20. label = int(self.img_list[idx].split('_')[0]) # 根据文件名获取标签
  21. return image, label
  22. # 加载模型
  23. model = torch.load('model.pth')
  24. # 定义数据集
  25. transform = transforms.Compose([
  26. transforms.Resize((224, 224)),
  27. transforms.ToTensor(),
  28. transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
  29. ])
  30. dataset = MyDataset(root='path/to/dataset', transform=transform)
  31. dataloader = DataLoader(dataset, batch_size=16, shuffle=False)
  32. # 预测结果和真实标签
  33. y_pred = []
  34. y_true = []
  35. with torch.no_grad():
  36. for images, labels in dataloader:
  37. outputs = model(images)
  38. _, predicted = torch.max(outputs.data, 1)
  39. y_pred.extend(predicted.cpu().numpy())
  40. y_true.extend(labels.cpu().numpy())
  41. # 生成混淆矩阵
  42. cm = confusion_matrix(y_true, y_pred)
  43. # 将混淆矩阵保存为CSV文件
  44. pd.DataFrame(cm).to_csv('confusion_matrix.csv', index=False, header=False)
  45. # 打印混淆矩阵
  46. print(cm)


Category of website: technical article > Blog

Author:Ineverleft

link:http://www.pythonblackhole.com/blog/article/83281/08d1c666d155603ffbc1/

source:python black hole net

Please indicate the source for any form of reprinting. If any infringement is discovered, it will be held legally responsible.

8 0
collect article
collected

Comment content: (supports up to 255 characters)