UNet_UAE_for_Lane_Detection/utils/callbacks.py

211 lines
8.3 KiB
Python
Raw Permalink Normal View History

2024-08-23 19:42:44 +08:00
import os
import matplotlib
import torch
import torch.nn.functional as F
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import scipy.signal
import cv2
import shutil
import numpy as np
from PIL import Image
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from .utils import cvtColor, preprocess_input, resize_image
from .utils_metrics import compute_mIoU
class LossHistory():
def __init__(self, log_dir, model, input_shape, val_loss_flag=True):
self.log_dir = log_dir
self.val_loss_flag = val_loss_flag
self.losses = []
if self.val_loss_flag:
self.val_loss = []
os.makedirs(self.log_dir)
self.writer = SummaryWriter(self.log_dir)
try:
dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
self.writer.add_graph(model, dummy_input)
except:
pass
def append_loss(self, epoch, loss, val_loss = None):
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
self.losses.append(loss)
if self.val_loss_flag:
self.val_loss.append(val_loss)
with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
f.write(str(loss))
f.write("\n")
if self.val_loss_flag:
with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
f.write(str(val_loss))
f.write("\n")
self.writer.add_scalar('loss', loss, epoch)
if self.val_loss_flag:
self.writer.add_scalar('val_loss', val_loss, epoch)
self.loss_plot()
def loss_plot(self):
iters = range(len(self.losses))
plt.figure()
plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
if self.val_loss_flag:
plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
try:
if len(self.losses) < 25:
num = 5
else:
num = 15
plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
if self.val_loss_flag:
plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
except:
pass
plt.grid(True)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc="upper right")
plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
plt.cla()
plt.close("all")
class EvalCallback():
def __init__(self, net, input_shape, num_classes, image_ids, dataset_path, log_dir, cuda, \
miou_out_path=".temp_miou_out", eval_flag=True, period=1):
super(EvalCallback, self).__init__()
self.net = net
self.input_shape = input_shape
self.num_classes = num_classes
self.image_ids = image_ids
self.dataset_path = dataset_path
self.log_dir = log_dir
self.cuda = cuda
self.miou_out_path = miou_out_path
self.eval_flag = eval_flag
self.period = period
self.image_ids = [image_id.split()[0] for image_id in image_ids]
self.mious = [0]
self.epoches = [0]
if self.eval_flag:
with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f:
f.write(str(0))
f.write("\n")
def get_miou_png(self, image):
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
orininal_h = np.array(image).shape[0]
orininal_w = np.array(image).shape[1]
#---------------------------------------------------------#
# 给图像增加灰条实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]))
#---------------------------------------------------------#
# 添加上batch_size维度
#---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
if self.cuda:
images = images.cuda()
#---------------------------------------------------#
# 图片传入网络进行预测
#---------------------------------------------------#
pr = self.net(images)[0]
#---------------------------------------------------#
# 取出每一个像素点的种类
#---------------------------------------------------#
pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
#--------------------------------------#
# 将灰条部分截取掉
#--------------------------------------#
pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
#---------------------------------------------------#
# 进行图片的resize
#---------------------------------------------------#
pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
#---------------------------------------------------#
# 取出每一个像素点的种类
#---------------------------------------------------#
pr = pr.argmax(axis=-1)
image = Image.fromarray(np.uint8(pr))
return image
def on_epoch_end(self, epoch, model_eval):
if epoch % self.period == 0 and self.eval_flag:
self.net = model_eval
gt_dir = os.path.join(self.dataset_path, "VOC2007/SegmentationClass/")
pred_dir = os.path.join(self.miou_out_path, 'detection-results')
if not os.path.exists(self.miou_out_path):
os.makedirs(self.miou_out_path)
if not os.path.exists(pred_dir):
os.makedirs(pred_dir)
print("Get miou.")
for image_id in tqdm(self.image_ids):
#-------------------------------#
# 从文件中读取图像
#-------------------------------#
image_path = os.path.join(self.dataset_path, "VOC2007/JPEGImages/"+image_id+".jpg")
image = Image.open(image_path)
#------------------------------#
# 获得预测txt
#------------------------------#
image = self.get_miou_png(image)
image.save(os.path.join(pred_dir, image_id + ".png"))
print("Calculate miou.")
_, IoUs, _, _ = compute_mIoU(gt_dir, pred_dir, self.image_ids, self.num_classes, None) # 执行计算mIoU的函数
temp_miou = np.nanmean(IoUs) * 100
self.mious.append(temp_miou)
self.epoches.append(epoch)
with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f:
f.write(str(temp_miou))
f.write("\n")
plt.figure()
plt.plot(self.epoches, self.mious, 'red', linewidth = 2, label='train miou')
plt.grid(True)
plt.xlabel('Epoch')
plt.ylabel('Miou')
plt.title('A Miou Curve')
plt.legend(loc="upper right")
plt.savefig(os.path.join(self.log_dir, "epoch_miou.png"))
plt.cla()
plt.close("all")
print("Get miou done.")
shutil.rmtree(self.miou_out_path)