UNet_UAE_for_Lane_Detection/VOCdevkit/VOC2007/transform.py

107 lines
3.9 KiB
Python
Raw Permalink Normal View History

2024-08-23 19:42:44 +08:00
import os
import random
import uuid
import cv2
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
from multiprocessing import cpu_count
from tqdm import tqdm
def trans_255_1(image):
if len(image.shape) == 2: # 单通道图像
mask = image == 255
image[mask] = 1
elif len(image.shape) == 3 and image.shape[2] == 3: # 三通道图像
white = np.array([255, 255, 255])
mask = np.all(image == white, axis=-1)
image[mask] = [1, 1, 1]
else:
raise ValueError("Unsupported image format!")
return image
def resize_image_and_mask(image, mask, target_size):
resized_image = cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)
resized_mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
return resized_image, resized_mask
def process_image(image_file, image_dir, mask_dir, output_mask_dir, output_image_dir, target_size):
image_path = os.path.join(image_dir, image_file)
mask_path = os.path.join(mask_dir, image_file.rsplit('.', 1)[0] + '_bin.png')
image = cv2.imread(image_path)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if image is None or mask is None:
raise ValueError(f"Image or mask not found for {image_file}")
resized_image, resized_mask = resize_image_and_mask(image, mask, target_size)
resized_mask = trans_255_1(resized_mask)
unique_id = str(uuid.uuid4())
mask_output_path = os.path.join(output_mask_dir, unique_id + '.png')
image_output_path = os.path.join(output_image_dir, unique_id + '.jpg')
cv2.imwrite(mask_output_path, resized_mask)
cv2.imwrite(image_output_path, resized_image)
return unique_id
def mask_to_unet(image_dir, mask_dir, output_mask_dir, output_image_dir, train_txt, val_txt, target_size,
num_images=2000):
image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
if len(image_files) < num_images:
raise ValueError(f"Not enough images in directory to sample {num_images} images.")
# 随机选择 num_images 个文件
image_files = random.sample(image_files, num_images)
args = [(image_file, image_dir, mask_dir, output_mask_dir, output_image_dir, target_size) for image_file in
image_files]
results = []
with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
futures = [executor.submit(process_image, *arg) for arg in args]
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing images"):
try:
unique_id = future.result()
results.append(unique_id)
except Exception as e:
print(f"Error processing image: {e}")
random.shuffle(results)
split_point = int(len(results) * 0.8)
train_ids, val_ids = results[:split_point], results[split_point:]
with open(train_txt, 'a') as train_file:
for uid in train_ids:
train_file.write(f"{uid}\n")
with open(val_txt, 'a') as val_file:
for uid in val_ids:
val_file.write(f"{uid}\n")
print(f"训练集文件名已写入 {train_txt}")
print(f"验证集文件名已写入 {val_txt}")
if __name__ == "__main__":
target_size = (1696, 864)
image_dir = r"E:\git\unet_seg\unet\VOCdevkit\VOC2007\original_data\dataset_A\train\img"
mask_dir = r"E:\git\unet_seg\unet\VOCdevkit\VOC2007\original_data\dataset_A\train\label"
output_mask_dir = "SegmentationClass"
output_image_dir = "JPEGImages"
output_txt_dir = './ImageSets/Segmentation'
train_txt = os.path.join(output_txt_dir, 'train.txt')
val_txt = os.path.join(output_txt_dir, 'val.txt')
os.makedirs(output_mask_dir, exist_ok=True)
os.makedirs(output_image_dir, exist_ok=True)
os.makedirs(output_txt_dir, exist_ok=True)
mask_to_unet(image_dir, mask_dir, output_mask_dir, output_image_dir, train_txt, val_txt, target_size,
num_images=len(os.listdir(image_dir)))