UNet_UAE_for_Lane_Detection/utils/utils_metrics.py

182 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import csv
import os
from os.path import join
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
def f_score(inputs, target, beta=1, smooth = 1e-5, threhold = 0.5):
n, c, h, w = inputs.size()
nt, ht, wt, ct = target.size()
if h != ht and w != wt:
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
temp_target = target.view(n, -1, ct)
#--------------------------------------------#
# 计算dice系数
#--------------------------------------------#
temp_inputs = torch.gt(temp_inputs, threhold).float()
tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
fp = torch.sum(temp_inputs , axis=[0,1]) - tp
fn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tp
score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
score = torch.mean(score)
return score
# 设标签宽W长H
def fast_hist(a, b, n):
#--------------------------------------------------------------------------------#
# a是转化成一维数组的标签形状(H×W,)b是转化成一维数组的预测结果形状(H×W,)
#--------------------------------------------------------------------------------#
k = (a >= 0) & (a < n)
#--------------------------------------------------------------------------------#
# np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数返回值形状(n, n)
# 返回中,写对角线上的为分类正确的像素点
#--------------------------------------------------------------------------------#
return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)
def per_class_iu(hist):
return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1)
def per_class_PA_Recall(hist):
return np.diag(hist) / np.maximum(hist.sum(1), 1)
def per_class_Precision(hist):
return np.diag(hist) / np.maximum(hist.sum(0), 1)
def per_Accuracy(hist):
return np.sum(np.diag(hist)) / np.maximum(np.sum(hist), 1)
def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes=None):
print('Num classes', num_classes)
#-----------------------------------------#
# 创建一个全是0的矩阵是一个混淆矩阵
#-----------------------------------------#
hist = np.zeros((num_classes, num_classes))
#------------------------------------------------#
# 获得验证集标签路径列表,方便直接读取
# 获得验证集图像分割结果路径列表,方便直接读取
#------------------------------------------------#
gt_imgs = [join(gt_dir, x + ".png") for x in png_name_list]
pred_imgs = [join(pred_dir, x + ".png") for x in png_name_list]
#------------------------------------------------#
# 读取每一个(图片-标签)对
#------------------------------------------------#
for ind in range(len(gt_imgs)):
#------------------------------------------------#
# 读取一张图像分割结果转化成numpy数组
#------------------------------------------------#
pred = np.array(Image.open(pred_imgs[ind]))
#------------------------------------------------#
# 读取一张对应的标签转化成numpy数组
#------------------------------------------------#
label = np.array(Image.open(gt_imgs[ind]))
# 如果图像分割结果与标签的大小不一样,这张图片就不计算
if len(label.flatten()) != len(pred.flatten()):
print(
'Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format(
len(label.flatten()), len(pred.flatten()), gt_imgs[ind],
pred_imgs[ind]))
continue
#------------------------------------------------#
# 对一张图片计算21×21的hist矩阵并累加
#------------------------------------------------#
hist += fast_hist(label.flatten(), pred.flatten(), num_classes)
# 每计算10张就输出一下目前已计算的图片中所有类别平均的mIoU值
if name_classes is not None and ind > 0 and ind % 10 == 0:
print('{:d} / {:d}: mIou-{:0.2f}%; mPA-{:0.2f}%; Accuracy-{:0.2f}%'.format(
ind,
len(gt_imgs),
100 * np.nanmean(per_class_iu(hist)),
100 * np.nanmean(per_class_PA_Recall(hist)),
100 * per_Accuracy(hist)
)
)
#------------------------------------------------#
# 计算所有验证集图片的逐类别mIoU值
#------------------------------------------------#
IoUs = per_class_iu(hist)
PA_Recall = per_class_PA_Recall(hist)
Precision = per_class_Precision(hist)
#------------------------------------------------#
# 逐类别输出一下mIoU值
#------------------------------------------------#
if name_classes is not None:
for ind_class in range(num_classes):
print('===>' + name_classes[ind_class] + ':\tIou-' + str(round(IoUs[ind_class] * 100, 2)) \
+ '; Recall (equal to the PA)-' + str(round(PA_Recall[ind_class] * 100, 2))+ '; Precision-' + str(round(Precision[ind_class] * 100, 2)))
#-----------------------------------------------------------------#
# 在所有验证集图像上求所有类别平均的mIoU值计算时忽略NaN值
#-----------------------------------------------------------------#
print('===> mIoU: ' + str(round(np.nanmean(IoUs) * 100, 2)) + '; mPA: ' + str(round(np.nanmean(PA_Recall) * 100, 2)) + '; Accuracy: ' + str(round(per_Accuracy(hist) * 100, 2)))
return np.array(hist, np.int), IoUs, PA_Recall, Precision
def adjust_axes(r, t, fig, axes):
bb = t.get_window_extent(renderer=r)
text_width_inches = bb.width / fig.dpi
current_fig_width = fig.get_figwidth()
new_fig_width = current_fig_width + text_width_inches
propotion = new_fig_width / current_fig_width
x_lim = axes.get_xlim()
axes.set_xlim([x_lim[0], x_lim[1] * propotion])
def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size = 12, plt_show = True):
fig = plt.gcf()
axes = plt.gca()
plt.barh(range(len(values)), values, color='royalblue')
plt.title(plot_title, fontsize=tick_font_size + 2)
plt.xlabel(x_label, fontsize=tick_font_size)
plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size)
r = fig.canvas.get_renderer()
for i, val in enumerate(values):
str_val = " " + str(val)
if val < 1.0:
str_val = " {0:.2f}".format(val)
t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold')
if i == (len(values)-1):
adjust_axes(r, t, fig, axes)
fig.tight_layout()
fig.savefig(output_path)
if plt_show:
plt.show()
plt.close()
def show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes, tick_font_size = 12):
draw_plot_func(IoUs, name_classes, "mIoU = {0:.2f}%".format(np.nanmean(IoUs)*100), "Intersection over Union", \
os.path.join(miou_out_path, "mIoU.png"), tick_font_size = tick_font_size, plt_show = True)
print("Save mIoU out to " + os.path.join(miou_out_path, "mIoU.png"))
draw_plot_func(PA_Recall, name_classes, "mPA = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Pixel Accuracy", \
os.path.join(miou_out_path, "mPA.png"), tick_font_size = tick_font_size, plt_show = False)
print("Save mPA out to " + os.path.join(miou_out_path, "mPA.png"))
draw_plot_func(PA_Recall, name_classes, "mRecall = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Recall", \
os.path.join(miou_out_path, "Recall.png"), tick_font_size = tick_font_size, plt_show = False)
print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png"))
draw_plot_func(Precision, name_classes, "mPrecision = {0:.2f}%".format(np.nanmean(Precision)*100), "Precision", \
os.path.join(miou_out_path, "Precision.png"), tick_font_size = tick_font_size, plt_show = False)
print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png"))
with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f:
writer = csv.writer(f)
writer_list = []
writer_list.append([' '] + [str(c) for c in name_classes])
for i in range(len(hist)):
writer_list.append([name_classes[i]] + [str(x) for x in hist[i]])
writer.writerows(writer_list)
print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv"))