192 lines
6.5 KiB
Python
192 lines
6.5 KiB
Python
|
import os
|
||
|
import streamlit as st
|
||
|
import cv2
|
||
|
import tempfile
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
from PIL.Image import Image
|
||
|
from torchvision import transforms
|
||
|
from PIL import Image
|
||
|
from unet import Unet
|
||
|
from nets.U_ConvAutoencoder import U_ConvAutoencoder
|
||
|
from typing import Tuple, List
|
||
|
|
||
|
# Constants and configuration
|
||
|
DEFAULTS = {
|
||
|
"model_path": 'model_data/8414_8376.pth',
|
||
|
"num_classes": 2,
|
||
|
"backbone": "vgg",
|
||
|
"input_shape": [1696, 864],
|
||
|
"mix_type": 1,
|
||
|
"cuda": torch.cuda.is_available(),
|
||
|
}
|
||
|
|
||
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
TRANSFORM = transforms.Compose([
|
||
|
transforms.Resize((1728, 3392)),
|
||
|
transforms.ToTensor()
|
||
|
])
|
||
|
|
||
|
|
||
|
class PreCA:
|
||
|
model: U_ConvAutoencoder = None
|
||
|
|
||
|
@classmethod
|
||
|
def initialize_model(cls, u_ca_path: str) -> None:
|
||
|
cls.model = U_ConvAutoencoder().to(DEVICE)
|
||
|
cls.model.load_state_dict(torch.load(u_ca_path, map_location=DEVICE))
|
||
|
cls.model.eval()
|
||
|
|
||
|
@classmethod
|
||
|
def unload_model(cls) -> None:
|
||
|
cls.model = None
|
||
|
torch.cuda.empty_cache()
|
||
|
|
||
|
@classmethod
|
||
|
def load_image(cls, image: Image.Image) -> torch.Tensor:
|
||
|
image = image.convert("L")
|
||
|
image = TRANSFORM(image).unsqueeze(0)
|
||
|
return image.to(DEVICE)
|
||
|
|
||
|
@staticmethod
|
||
|
def ca_smooth(image: Image.Image) -> Image.Image:
|
||
|
image_cv2 = np.array(image)
|
||
|
closed_image = cv2.morphologyEx(image_cv2, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)))
|
||
|
blurred = cv2.GaussianBlur(closed_image, (1, 1), 0)
|
||
|
th = cv2.threshold(blurred, 126, 255, cv2.THRESH_BINARY)[1]
|
||
|
return Image.fromarray(th)
|
||
|
|
||
|
@classmethod
|
||
|
def infer(cls, image: Image.Image) -> Image.Image:
|
||
|
image_tensor = cls.load_image(image)
|
||
|
with torch.no_grad():
|
||
|
output = cls.model(image_tensor)
|
||
|
output = output.squeeze(0).cpu()
|
||
|
output_image = transforms.ToPILImage()(output)
|
||
|
return output_image.resize((3384, 1710), Image.NEAREST)
|
||
|
|
||
|
|
||
|
class PreUnet:
|
||
|
|
||
|
@staticmethod
|
||
|
def calculate_metrics(pred_image: Image.Image, true_image: Image.Image, threshold: int = 1) -> Tuple[int, int, int]:
|
||
|
pred_binary = pred_image.convert('L').point(lambda x: 0 if x < threshold else 255)
|
||
|
true_binary = true_image.convert('L').point(lambda x: 0 if x < threshold else 255)
|
||
|
|
||
|
pred_array = np.array(pred_binary)
|
||
|
true_array = np.array(true_binary)
|
||
|
|
||
|
TP = np.sum((pred_array == 255) & (true_array == 255))
|
||
|
FP = np.sum((pred_array == 255) & (true_array == 0))
|
||
|
FN = np.sum((pred_array == 0) & (true_array == 255))
|
||
|
|
||
|
return TP, FP, FN
|
||
|
|
||
|
@staticmethod
|
||
|
def apply_mask(original_image: Image.Image, mask_image: Image.Image) -> Image.Image:
|
||
|
original_image = original_image.convert("RGB").resize((3384, 1710), Image.NEAREST)
|
||
|
mask_image = mask_image.convert("RGB").resize((3384, 1710), Image.NEAREST)
|
||
|
|
||
|
original_array = np.array(original_image)
|
||
|
mask_array = np.array(mask_image)
|
||
|
|
||
|
mask = np.all(mask_array == [255, 255, 255], axis=-1)
|
||
|
original_array[mask] = [0, 255, 0]
|
||
|
|
||
|
return Image.fromarray(original_array)
|
||
|
|
||
|
@classmethod
|
||
|
def process_image(cls, image: Image.Image, unet):
|
||
|
detected_image = unet.detect_image(image)
|
||
|
inferred_image = PreCA.infer(detected_image)
|
||
|
smoothed_image = PreCA.ca_smooth(inferred_image)
|
||
|
return cls.apply_mask(image, smoothed_image),smoothed_image
|
||
|
|
||
|
|
||
|
def main_page():
|
||
|
st.title('自动驾驶车道线自动检测与增强')
|
||
|
stframe = st.empty()
|
||
|
st.sidebar.subheader("参数设置")
|
||
|
|
||
|
is_pre = st.sidebar.checkbox('开启预测')
|
||
|
unet = Unet(DEFAULTS) if is_pre else None
|
||
|
|
||
|
if is_pre:
|
||
|
u_ca_path = 'weights/best_conv_autoencoder1.pth'
|
||
|
PreCA.initialize_model(u_ca_path)
|
||
|
else:
|
||
|
PreCA.unload_model()
|
||
|
|
||
|
st.sidebar.subheader("图像检测")
|
||
|
image_dir_path = st.sidebar.text_input('请输入图像文件夹路径:')
|
||
|
is_get_iou = st.sidebar.checkbox('开启计算IOU')
|
||
|
label_dir_path = st.sidebar.text_input('请输入标签文件夹路径:') if is_get_iou else None
|
||
|
btn_click = st.sidebar.button("开始预测")
|
||
|
|
||
|
if btn_click:
|
||
|
process_images(image_dir_path, label_dir_path, unet, is_pre, is_get_iou, stframe)
|
||
|
|
||
|
st.sidebar.subheader("视频检测")
|
||
|
uploaded_video = st.sidebar.file_uploader("上传视频:", type=['mp4', 'mov', 'avi', 'mkv', 'flv', 'wmv', 'm4v'])
|
||
|
|
||
|
if uploaded_video is not None:
|
||
|
process_video(uploaded_video, unet, is_pre, stframe)
|
||
|
|
||
|
|
||
|
def process_images(image_dir_path, label_dir_path, unet, is_pre, is_get_iou, stframe):
|
||
|
ious = []
|
||
|
img_names = os.listdir(image_dir_path)
|
||
|
iou_text = st.empty()
|
||
|
|
||
|
for img_name in img_names:
|
||
|
if img_name.lower().endswith(
|
||
|
('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
|
||
|
image_path = os.path.join(image_dir_path, img_name)
|
||
|
image = Image.open(image_path)
|
||
|
if is_pre:
|
||
|
result_image,smoothed_image = PreUnet.process_image(image, unet)
|
||
|
stframe.image([image, result_image], width=640)
|
||
|
if is_get_iou and label_dir_path:
|
||
|
label_path = os.path.join(label_dir_path, f"{os.path.splitext(img_name)[0]}_bin.png")
|
||
|
label = Image.open(label_path)
|
||
|
TP, FP, FN = PreUnet.calculate_metrics(smoothed_image, label)
|
||
|
iou = TP / (TP + FP + FN)
|
||
|
# ious.append(iou)
|
||
|
iou_text.text(f'当前IOU: {iou}')
|
||
|
else:
|
||
|
stframe.image(image, width=1024)
|
||
|
|
||
|
|
||
|
def process_video(uploaded_video, unet, is_pre, stframe):
|
||
|
tfile = tempfile.NamedTemporaryFile(delete=False)
|
||
|
tfile.write(uploaded_video.read())
|
||
|
tfile.close()
|
||
|
|
||
|
cap = cv2.VideoCapture(tfile.name)
|
||
|
|
||
|
if 'frame_pos' not in st.session_state:
|
||
|
st.session_state.frame_pos = 0
|
||
|
|
||
|
cap.set(cv2.CAP_PROP_POS_FRAMES, st.session_state.frame_pos)
|
||
|
|
||
|
while cap.isOpened():
|
||
|
ret, frame = cap.read()
|
||
|
if not ret:
|
||
|
break
|
||
|
|
||
|
st.session_state.frame_pos = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
|
||
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
|
frame = Image.fromarray(frame)
|
||
|
|
||
|
if is_pre:
|
||
|
processed_frame,smoothed_image = PreUnet.process_image(frame, unet)
|
||
|
stframe.image(processed_frame, width=1024,use_column_width=False)
|
||
|
else:
|
||
|
stframe.image(frame, width=1024,use_column_width=False)
|
||
|
|
||
|
cap.release()
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main_page()
|