131 lines
5.3 KiB
Python
131 lines
5.3 KiB
Python
import colorsys
|
|
import copy
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from PIL import Image
|
|
from torch import nn
|
|
|
|
from nets.unet import Unet as unet
|
|
from utils.utils import cvtColor, preprocess_input, resize_image
|
|
|
|
|
|
class Unet(object):
|
|
_defaults = {
|
|
"model_path": None,
|
|
"num_classes": 2,
|
|
"backbone": "vgg",
|
|
"input_shape": [1696, 864],
|
|
"mix_type": 1,
|
|
"cuda": True,
|
|
}
|
|
|
|
def __init__(self, _defaults,**kwargs):
|
|
self._defaults = _defaults
|
|
self.__dict__.update(self._defaults)
|
|
for name, value in kwargs.items():
|
|
setattr(self, name, value)
|
|
|
|
if self.num_classes <= 2:
|
|
self.colors = [(0, 0, 0), (255,255,255)]
|
|
else:
|
|
hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
|
|
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
|
|
self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
|
|
|
|
self.generate()
|
|
|
|
|
|
def generate(self, onnx=False):
|
|
self.net = unet(num_classes=self.num_classes, backbone=self.backbone)
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
self.net.load_state_dict(torch.load(self.model_path, map_location=device))
|
|
self.net = self.net.eval()
|
|
print('{} model, and classes loaded.'.format(self.model_path))
|
|
if not onnx:
|
|
if self.cuda:
|
|
self.net = nn.DataParallel(self.net)
|
|
self.net = self.net.cuda()
|
|
|
|
def detect_image(self, image, count=False, name_classes=None):
|
|
|
|
image = cvtColor(image)
|
|
|
|
old_img = copy.deepcopy(image)
|
|
orininal_h = np.array(image).shape[0]
|
|
orininal_w = np.array(image).shape[1]
|
|
|
|
image_data, nw, nh = resize_image(image, (self.input_shape[1], self.input_shape[0]))
|
|
|
|
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
|
|
|
|
with torch.no_grad():
|
|
images = torch.from_numpy(image_data)
|
|
if self.cuda:
|
|
images = images.cuda()
|
|
|
|
pr = self.net(images)[0]
|
|
|
|
pr = F.softmax(pr.permute(1, 2, 0), dim=-1).cpu().numpy()
|
|
|
|
pr = pr[int((self.input_shape[0] - nh) // 2): int((self.input_shape[0] - nh) // 2 + nh), \
|
|
int((self.input_shape[1] - nw) // 2): int((self.input_shape[1] - nw) // 2 + nw)]
|
|
|
|
pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation=cv2.INTER_LINEAR)
|
|
|
|
pr = pr.argmax(axis=-1)
|
|
|
|
if count:
|
|
classes_nums = np.zeros([self.num_classes])
|
|
total_points_num = orininal_h * orininal_w
|
|
print('-' * 63)
|
|
print("|%25s | %15s | %15s|" % ("Key", "Value", "Ratio"))
|
|
print('-' * 63)
|
|
for i in range(self.num_classes):
|
|
num = np.sum(pr == i)
|
|
ratio = num / total_points_num * 100
|
|
if num > 0:
|
|
print("|%25s | %15s | %14.2f%%|" % (str(name_classes[i]), str(num), ratio))
|
|
print('-' * 63)
|
|
classes_nums[i] = num
|
|
print("classes_nums:", classes_nums)
|
|
|
|
if self.mix_type == 0:
|
|
# seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
|
|
# for c in range(self.num_classes):
|
|
# seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8')
|
|
# seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8')
|
|
# seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8')
|
|
seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
|
|
# ------------------------------------------------#
|
|
# 将新图片转换成Image的形式
|
|
# ------------------------------------------------#
|
|
image = Image.fromarray(np.uint8(seg_img))
|
|
# ------------------------------------------------#
|
|
# 将新图与原图及进行混合
|
|
# ------------------------------------------------#
|
|
image = Image.blend(old_img, image, 0.7)
|
|
|
|
elif self.mix_type == 1:
|
|
# seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
|
|
# for c in range(self.num_classes):
|
|
# seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8')
|
|
# seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8')
|
|
# seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8')
|
|
seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
|
|
# ------------------------------------------------#
|
|
# 将新图片转换成Image的形式
|
|
# ------------------------------------------------#
|
|
image = Image.fromarray(np.uint8(seg_img))
|
|
|
|
elif self.mix_type == 2:
|
|
seg_img = (np.expand_dims(pr != 0, -1) * np.array(old_img, np.float32)).astype('uint8')
|
|
# ------------------------------------------------#
|
|
# 将新图片转换成Image的形式
|
|
# ------------------------------------------------#
|
|
image = Image.fromarray(np.uint8(seg_img))
|
|
|
|
return image |