From 6327799556b9761674128014df2a49a1895126c1 Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Sat, 10 Dec 2022 16:20:06 +0200 Subject: [PATCH] Created traffic light extraction program --- src/detector/extract_traffic_lights.py | 83 ++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 src/detector/extract_traffic_lights.py diff --git a/src/detector/extract_traffic_lights.py b/src/detector/extract_traffic_lights.py new file mode 100644 index 0000000..862b4ff --- /dev/null +++ b/src/detector/extract_traffic_lights.py @@ -0,0 +1,83 @@ +"""This program extracts traffic lights from images.""" + +import logging +from pathlib import Path + +import cv2 +from detector.object_detection import ( + LABEL_TRAFFIC_LIGHT, + load_ssd_coco, + perform_object_detection, +) +from detector.paths import CROPPED_IMAGES_PATH, INPUT_PATH, LOGS_PATH + +# Set up logging +logger = logging.getLogger(__name__) +handler = logging.FileHandler(str(Path.joinpath(LOGS_PATH, f"{__name__}.log"))) +formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +def extract_traffic_lights() -> None: + files = Path.iterdir(INPUT_PATH) + + # Load the object detection model + this_model = load_ssd_coco() + + # Keep track of the number of traffic lights found + traffic_light_count = 0 + + # Keep track of the number of image files that were processed + file_count = 0 + + # Display a count of the number of images we need to process + # logger.info(f"Number of Images: {len(files)}") + + # Go through each image file, one at a time + for file in files: + + # Detect objects in the image + # img_rgb is the original image in RGB format + # out is a dictionary containing the results of object detection + # file_name is the name of the file + img_rgb, out, file_name = perform_object_detection(model=this_model, file_name=file, save_annotated=None, model_traffic_lights=None) + + # Every 10 files that are processed + if file_count % 10 == 0: + + # Display a count of the number of files that have been processed + logger.info(f"Images processed: {file_count}") + + # Display the total number of traffic lights that have been identified so far + logger.info(f"Number of Traffic lights identified: {traffic_light_count}") + + # Increment the number of files by 1 + file_count += 1 + + # For each traffic light (i.e. bounding box) that was detected + for idx, _ in enumerate(out["boxes"]): + + # Extract the type of object that was detected + obj_class = out["detection_classes"][idx] + + # If the object that was detected is a traffic light + if obj_class == LABEL_TRAFFIC_LIGHT: + + # Extract the coordinates of the bounding box + box = out["boxes"][idx] + + # Extract (i.e. crop) the traffic light from the image + traffic_light = img_rgb[box["y"]:box["y2"], box["x"]:box["x2"]] + + # Convert the traffic light from RGB format into BGR format + traffic_light = cv2.cvtColor(traffic_light, cv2.COLOR_RGB2BGR) + + # Store the cropped image in a folder named 'traffic_light_cropped' + cv2.imwrite(str(Path.joinpath(CROPPED_IMAGES_PATH, f"{traffic_light_count}.jpg")), traffic_light) + + # Increment the number of traffic lights by 1 + traffic_light_count += 1 + + # Display the total number of traffic lights identified + logger.info(f"Number of Traffic lights identified: {traffic_light_count}")