import os import matplotlib import torch import torch.nn.functional as F matplotlib.use('Agg') from matplotlib import pyplot as plt import scipy.signal import cv2 import shutil import numpy as np from PIL import Image from tqdm import tqdm from torch.utils.tensorboard import SummaryWriter from .utils import cvtColor, preprocess_input, resize_image from .utils_metrics import compute_mIoU class LossHistory(): def __init__(self, log_dir, model, input_shape, val_loss_flag=True): self.log_dir = log_dir self.val_loss_flag = val_loss_flag self.losses = [] if self.val_loss_flag: self.val_loss = [] os.makedirs(self.log_dir) self.writer = SummaryWriter(self.log_dir) try: dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1]) self.writer.add_graph(model, dummy_input) except: pass def append_loss(self, epoch, loss, val_loss = None): if not os.path.exists(self.log_dir): os.makedirs(self.log_dir) self.losses.append(loss) if self.val_loss_flag: self.val_loss.append(val_loss) with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f: f.write(str(loss)) f.write("\n") if self.val_loss_flag: with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f: f.write(str(val_loss)) f.write("\n") self.writer.add_scalar('loss', loss, epoch) if self.val_loss_flag: self.writer.add_scalar('val_loss', val_loss, epoch) self.loss_plot() def loss_plot(self): iters = range(len(self.losses)) plt.figure() plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss') if self.val_loss_flag: plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss') try: if len(self.losses) < 25: num = 5 else: num = 15 plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss') if self.val_loss_flag: plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss') except: pass plt.grid(True) plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend(loc="upper right") plt.savefig(os.path.join(self.log_dir, "epoch_loss.png")) plt.cla() plt.close("all") class EvalCallback(): def __init__(self, net, input_shape, num_classes, image_ids, dataset_path, log_dir, cuda, \ miou_out_path=".temp_miou_out", eval_flag=True, period=1): super(EvalCallback, self).__init__() self.net = net self.input_shape = input_shape self.num_classes = num_classes self.image_ids = image_ids self.dataset_path = dataset_path self.log_dir = log_dir self.cuda = cuda self.miou_out_path = miou_out_path self.eval_flag = eval_flag self.period = period self.image_ids = [image_id.split()[0] for image_id in image_ids] self.mious = [0] self.epoches = [0] if self.eval_flag: with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f: f.write(str(0)) f.write("\n") def get_miou_png(self, image): #---------------------------------------------------------# # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB #---------------------------------------------------------# image = cvtColor(image) orininal_h = np.array(image).shape[0] orininal_w = np.array(image).shape[1] #---------------------------------------------------------# # 给图像增加灰条,实现不失真的resize # 也可以直接resize进行识别 #---------------------------------------------------------# image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0])) #---------------------------------------------------------# # 添加上batch_size维度 #---------------------------------------------------------# image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0) with torch.no_grad(): images = torch.from_numpy(image_data) if self.cuda: images = images.cuda() #---------------------------------------------------# # 图片传入网络进行预测 #---------------------------------------------------# pr = self.net(images)[0] #---------------------------------------------------# # 取出每一个像素点的种类 #---------------------------------------------------# pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy() #--------------------------------------# # 将灰条部分截取掉 #--------------------------------------# pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \ int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)] #---------------------------------------------------# # 进行图片的resize #---------------------------------------------------# pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR) #---------------------------------------------------# # 取出每一个像素点的种类 #---------------------------------------------------# pr = pr.argmax(axis=-1) image = Image.fromarray(np.uint8(pr)) return image def on_epoch_end(self, epoch, model_eval): if epoch % self.period == 0 and self.eval_flag: self.net = model_eval gt_dir = os.path.join(self.dataset_path, "VOC2007/SegmentationClass/") pred_dir = os.path.join(self.miou_out_path, 'detection-results') if not os.path.exists(self.miou_out_path): os.makedirs(self.miou_out_path) if not os.path.exists(pred_dir): os.makedirs(pred_dir) print("Get miou.") for image_id in tqdm(self.image_ids): #-------------------------------# # 从文件中读取图像 #-------------------------------# image_path = os.path.join(self.dataset_path, "VOC2007/JPEGImages/"+image_id+".jpg") image = Image.open(image_path) #------------------------------# # 获得预测txt #------------------------------# image = self.get_miou_png(image) image.save(os.path.join(pred_dir, image_id + ".png")) print("Calculate miou.") _, IoUs, _, _ = compute_mIoU(gt_dir, pred_dir, self.image_ids, self.num_classes, None) # 执行计算mIoU的函数 temp_miou = np.nanmean(IoUs) * 100 self.mious.append(temp_miou) self.epoches.append(epoch) with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f: f.write(str(temp_miou)) f.write("\n") plt.figure() plt.plot(self.epoches, self.mious, 'red', linewidth = 2, label='train miou') plt.grid(True) plt.xlabel('Epoch') plt.ylabel('Miou') plt.title('A Miou Curve') plt.legend(loc="upper right") plt.savefig(os.path.join(self.log_dir, "epoch_miou.png")) plt.cla() plt.close("all") print("Get miou done.") shutil.rmtree(self.miou_out_path)