mirror of
https://github.com/kristoferssolo/Traffic-Light-Detector.git
synced 2026-03-22 00:36:22 +00:00
Load ssd coco once
This commit is contained in:
@@ -1,7 +1,5 @@
|
|||||||
"""This program uses a trained neural network to detect the color of a traffic light in images."""
|
"""This program uses a trained neural network to detect the color of a traffic light in images."""
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from detector.object_detection import load_ssd_coco, perform_object_detection
|
from detector.object_detection import load_ssd_coco, perform_object_detection
|
||||||
from detector.paths import IMAGES_IN_PATH, MODEL_PATH
|
from detector.paths import IMAGES_IN_PATH, MODEL_PATH
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -11,8 +9,10 @@ from tensorflow import keras
|
|||||||
@logger.catch
|
@logger.catch
|
||||||
def detect_traffic_light_color_image() -> None:
|
def detect_traffic_light_color_image() -> None:
|
||||||
model_traffic_lights_nn = keras.models.load_model(str(MODEL_PATH))
|
model_traffic_lights_nn = keras.models.load_model(str(MODEL_PATH))
|
||||||
|
# Load the SSD neural network that is trained on the COCO data set
|
||||||
|
model_ssd = load_ssd_coco()
|
||||||
|
|
||||||
# Go through all image files, and detect the traffic light color.
|
# Go through all image files, and detect the traffic light color.
|
||||||
for file in IMAGES_IN_PATH.iterdir():
|
for file in IMAGES_IN_PATH.iterdir():
|
||||||
image, out, file_name = perform_object_detection(load_ssd_coco(), file, save_annotated=True, model_traffic_lights=model_traffic_lights_nn)
|
image, out, file_name = perform_object_detection(model=model_ssd, file_name=file, save_annotated=True, model_traffic_lights=model_traffic_lights_nn)
|
||||||
logger.info(f"{file} {out}")
|
logger.info(f"Performed object detection on {file}")
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""This program extracts traffic lights from images."""
|
"""This program extracts traffic lights from images."""
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
from detector.object_detection import (
|
from detector.object_detection import (
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ def load_model(model_name: str) -> tf.saved_model.LoadOptions:
|
|||||||
# Download a file from a URL that is not already in the cache
|
# Download a file from a URL that is not already in the cache
|
||||||
model_dir = tf.keras.utils.get_file(fname=model_name, untar=True, origin=url)
|
model_dir = tf.keras.utils.get_file(fname=model_name, untar=True, origin=url)
|
||||||
|
|
||||||
logger.info(f"Model path: {model_dir}")
|
logger.info(f"Loaded model: {model_dir}")
|
||||||
|
|
||||||
return tf.saved_model.load(f"{model_dir}/saved_model")
|
return tf.saved_model.load(f"{model_dir}/saved_model")
|
||||||
|
|
||||||
@@ -85,61 +85,56 @@ def load_ssd_coco() -> tf.saved_model.LoadOptions:
|
|||||||
|
|
||||||
|
|
||||||
@logger.catch
|
@logger.catch
|
||||||
def save_image_annotated(image_rgb, file_name: Path, output, model_traffic_lights=None) -> None:
|
def save_image_annotated(image_rgb, file_name: Path, output, model_traffic_lights) -> None:
|
||||||
"""Annotate the image with the object types, and generate cropped images of traffic lights."""
|
"""Annotate the image with the object types, and generate cropped images of traffic lights."""
|
||||||
output_file = IMAGES_OUT_PATH.joinpath(file_name.name)
|
output_file = IMAGES_OUT_PATH.joinpath(file_name.name)
|
||||||
|
|
||||||
# For each bounding box that was detected
|
# For each bounding box that was detected
|
||||||
for idx, (box, object_class) in enumerate(zip(output["boxes"], output["detection_classes"])):
|
for idx, (box, object_class) in enumerate(zip(output["boxes"], output["detection_classes"])):
|
||||||
|
|
||||||
color = LABELS.get(object_class, (255, 255, 255))
|
color = LABELS.get(object_class, None)
|
||||||
# How confident the object detection model is on the object's type
|
# How confident the object detection model is on the object's type
|
||||||
score: int = object_class * 100
|
score: int = object_class * 100
|
||||||
|
label_text = f"{LABEL_TEXT.get(object_class)} {score}"
|
||||||
# Extract the bounding box
|
|
||||||
box = output["boxes"][idx]
|
|
||||||
|
|
||||||
label_text = f"{object_class} {score}"
|
|
||||||
if object_class == LABEL_TRAFFIC_LIGHT:
|
if object_class == LABEL_TRAFFIC_LIGHT:
|
||||||
if model_traffic_lights is not None:
|
|
||||||
|
|
||||||
# Annotate the image and save it
|
# Annotate the image and save it
|
||||||
image_traffic_light = image_rgb[box["y"]:box["y2"], box["x"]:box["x2"]]
|
image_traffic_light = image_rgb[box.get("y"):box.get("y2"), box.get("x"):box.get("x2")]
|
||||||
image_inception = cv2.resize(image_traffic_light, (299, 299))
|
image_inception = cv2.resize(image_traffic_light, (299, 299))
|
||||||
|
|
||||||
# Uncomment this if you want to save a cropped image of the traffic light
|
# Uncomment this if you want to save a cropped image of the traffic light
|
||||||
image_inception = np.array([preprocess_input(image_inception)])
|
image_inception = np.array([preprocess_input(image_inception)])
|
||||||
|
|
||||||
prediction = model_traffic_lights.predict(image_inception)
|
prediction = model_traffic_lights.predict(image_inception)
|
||||||
label = np.argmax(prediction)
|
label = np.argmax(prediction)
|
||||||
score_light = int(np.max(prediction) * 100)
|
score_light = int(np.max(prediction) * 100)
|
||||||
|
|
||||||
if label == 0:
|
if label == 0:
|
||||||
label_text = f"Green {score_light}"
|
label_text = f"Green {score_light}"
|
||||||
elif label == 1:
|
elif label == 1:
|
||||||
label_text = f"Yellow {score_light}"
|
label_text = f"Yellow {score_light}"
|
||||||
elif label == 2:
|
elif label == 2:
|
||||||
label_text = f"Red {score_light}"
|
label_text = f"Red {score_light}"
|
||||||
else:
|
else:
|
||||||
label_text = "NO-LIGHT"
|
label_text = "NO-LIGHT"
|
||||||
|
|
||||||
# Draw the bounding box and object class label on the image, if the confidence score is above 50 and the box is not a duplicate
|
# Draw the bounding box and object class label on the image, if the confidence score is above 50 and the box is not a duplicate
|
||||||
if color and label_text and accept_box(output["boxes"], idx, 5) and score > 50:
|
if color and label_text and accept_box(output.get("boxes"), idx, 5) and score > 50:
|
||||||
cv2.rectangle(image_rgb, (box["x"], box["y"]), (box["x2"], box["y2"]), color, 2)
|
cv2.rectangle(image_rgb, (box.get("x"), box.get("y")), (box.get("x2"), box.get("y2")), color, 2)
|
||||||
cv2.putText(image_rgb, label_text, (box["x"], box["y"]), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
|
cv2.putText(image_rgb, label_text, (box.get("x"), box.get("y")), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
|
||||||
|
|
||||||
cv2.imwrite(str(output_file), cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR))
|
cv2.imwrite(str(output_file), cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR))
|
||||||
logger.info(output_file)
|
logger.info(output_file)
|
||||||
|
|
||||||
|
|
||||||
@logger.catch
|
@ logger.catch
|
||||||
def center(box: dict[str, float], coord_type: str) -> float:
|
def center(box: dict[str, float], coord_type: str) -> float:
|
||||||
"""Get center of the bounding box."""
|
"""Get center of the bounding box."""
|
||||||
return (box[coord_type] + box[coord_type + "2"]) / 2
|
return (box[coord_type] + box[coord_type + "2"]) / 2
|
||||||
|
|
||||||
|
|
||||||
@logger.catch
|
@ logger.catch
|
||||||
def perform_object_detection(model, file_name, save_annotated=False, model_traffic_lights=None):
|
def perform_object_detection(model, file_name: Path, save_annotated=False, model_traffic_lights=None):
|
||||||
"""Perform object detection on an image using the predefined neural network."""
|
"""Perform object detection on an image using the predefined neural network."""
|
||||||
# Store the image
|
# Store the image
|
||||||
image_bgr = cv2.imread(str(file_name))
|
image_bgr = cv2.imread(str(file_name))
|
||||||
@@ -150,21 +145,21 @@ def perform_object_detection(model, file_name, save_annotated=False, model_traff
|
|||||||
# Run the model
|
# Run the model
|
||||||
output = model(input_tensor)
|
output = model(input_tensor)
|
||||||
|
|
||||||
logger.info(f"Number detections: {output['num_detections']} {int(output['num_detections'])}")
|
logger.debug(f"Number detections: {output['num_detections']} {int(output['num_detections'])}")
|
||||||
|
|
||||||
# Convert the tensors to a NumPy array
|
# Convert the tensors to a NumPy array
|
||||||
num_detections = int(output.pop("num_detections"))
|
number_detections = int(output.pop("num_detections"))
|
||||||
output = {key: value[0, :num_detections].numpy() for key, value in output.items()}
|
output = {key: value[0, :number_detections].numpy() for key, value in output.items()}
|
||||||
output["num_detections"] = num_detections
|
output["num_detections"] = number_detections
|
||||||
|
|
||||||
logger.info(f"Detection classes: {output['detection_classes']}")
|
logger.debug(f"Detection classes: {output['detection_classes']}")
|
||||||
logger.info(f"Detection Boxes: {output['detection_boxes']}")
|
logger.debug(f"Detection Boxes: {output['detection_boxes']}")
|
||||||
|
|
||||||
# The detected classes need to be integers.
|
# The detected classes need to be integers.
|
||||||
output["detection_classes"] = output["detection_classes"].astype(np.int64)
|
output["detection_classes"] = output["detection_classes"].astype(np.int64)
|
||||||
output["boxes"] = [{"y": int(box[0] * image_rgb.shape[0]),
|
output["boxes"] = [{"y": int(box[0] * image_rgb.shape[0]),
|
||||||
"x": int(box[1] * image_rgb.shape[1]),
|
"x": int(box[1] * image_rgb.shape[1]),
|
||||||
"y2": int(box[2] * image_rgb.shape[0]),
|
"y2": int(box[2] * image_rgb.shape[0]),
|
||||||
"x2": int(box[3] * image_rgb.shape[1])}
|
"x2": int(box[3] * image_rgb.shape[1])}
|
||||||
for box in output["detection_boxes"]]
|
for box in output["detection_boxes"]]
|
||||||
|
|
||||||
@@ -174,7 +169,7 @@ def perform_object_detection(model, file_name, save_annotated=False, model_traff
|
|||||||
return image_rgb, output, file_name
|
return image_rgb, output, file_name
|
||||||
|
|
||||||
|
|
||||||
@logger.catch
|
@ logger.catch
|
||||||
def perform_object_detection_video(video_frame, model, model_traffic_lights):
|
def perform_object_detection_video(video_frame, model, model_traffic_lights):
|
||||||
"""Perform object detection on a video using the predefined neural network."""
|
"""Perform object detection on a video using the predefined neural network."""
|
||||||
|
|
||||||
@@ -195,7 +190,7 @@ def perform_object_detection_video(video_frame, model, model_traffic_lights):
|
|||||||
output["detection_classes"] = output["detection_classes"].astype(np.int64)
|
output["detection_classes"] = output["detection_classes"].astype(np.int64)
|
||||||
output["boxes"] = [{"y": int(box[0] * image_rgb.shape[0]),
|
output["boxes"] = [{"y": int(box[0] * image_rgb.shape[0]),
|
||||||
"x": int(box[1] * image_rgb.shape[1]),
|
"x": int(box[1] * image_rgb.shape[1]),
|
||||||
"y2": int(box[2] * image_rgb.shape[0]),
|
"y2": int(box[2] * image_rgb.shape[0]),
|
||||||
"x2": int(box[3] * image_rgb.shape[1])}
|
"x2": int(box[3] * image_rgb.shape[1])}
|
||||||
for box in output["detection_boxes"]]
|
for box in output["detection_boxes"]]
|
||||||
|
|
||||||
@@ -236,7 +231,7 @@ def perform_object_detection_video(video_frame, model, model_traffic_lights):
|
|||||||
return cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
|
return cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
|
||||||
@logger.catch
|
@ logger.catch
|
||||||
def double_shuffle(images: list[str], labels: list[int]) -> tuple[list[str], list[int]]:
|
def double_shuffle(images: list[str], labels: list[int]) -> tuple[list[str], list[int]]:
|
||||||
"""Shuffle the images to add some randomness."""
|
"""Shuffle the images to add some randomness."""
|
||||||
indexes = np.random.permutation(len(images))
|
indexes = np.random.permutation(len(images))
|
||||||
@@ -244,7 +239,7 @@ def double_shuffle(images: list[str], labels: list[int]) -> tuple[list[str], lis
|
|||||||
return [images[idx] for idx in indexes], [labels[idx] for idx in indexes]
|
return [images[idx] for idx in indexes], [labels[idx] for idx in indexes]
|
||||||
|
|
||||||
|
|
||||||
@logger.catch
|
@ logger.catch
|
||||||
def reverse_preprocess_inception(image_preprocessed):
|
def reverse_preprocess_inception(image_preprocessed):
|
||||||
"""Reverse the preprocessing process for an image that has been input to the Inception V3 model."""
|
"""Reverse the preprocessing process for an image that has been input to the Inception V3 model."""
|
||||||
image = image_preprocessed + 1 * 127.5
|
image = image_preprocessed + 1 * 127.5
|
||||||
|
|||||||
@@ -5,8 +5,6 @@ to a directory. Also, the best neural network model is saved as traffic.h5.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
Reference in New Issue
Block a user