Load ssd coco once

This commit is contained in:
Kristofers Solo
2022-12-14 13:23:46 +02:00
parent 1b03cc4586
commit 35cbabc9b4
4 changed files with 41 additions and 49 deletions

View File

@@ -1,7 +1,5 @@
"""This program uses a trained neural network to detect the color of a traffic light in images.""" """This program uses a trained neural network to detect the color of a traffic light in images."""
from pathlib import Path
from detector.object_detection import load_ssd_coco, perform_object_detection from detector.object_detection import load_ssd_coco, perform_object_detection
from detector.paths import IMAGES_IN_PATH, MODEL_PATH from detector.paths import IMAGES_IN_PATH, MODEL_PATH
from loguru import logger from loguru import logger
@@ -11,8 +9,10 @@ from tensorflow import keras
@logger.catch @logger.catch
def detect_traffic_light_color_image() -> None: def detect_traffic_light_color_image() -> None:
model_traffic_lights_nn = keras.models.load_model(str(MODEL_PATH)) model_traffic_lights_nn = keras.models.load_model(str(MODEL_PATH))
# Load the SSD neural network that is trained on the COCO data set
model_ssd = load_ssd_coco()
# Go through all image files, and detect the traffic light color. # Go through all image files, and detect the traffic light color.
for file in IMAGES_IN_PATH.iterdir(): for file in IMAGES_IN_PATH.iterdir():
image, out, file_name = perform_object_detection(load_ssd_coco(), file, save_annotated=True, model_traffic_lights=model_traffic_lights_nn) image, out, file_name = perform_object_detection(model=model_ssd, file_name=file, save_annotated=True, model_traffic_lights=model_traffic_lights_nn)
logger.info(f"{file} {out}") logger.info(f"Performed object detection on {file}")

View File

@@ -1,6 +1,5 @@
"""This program extracts traffic lights from images.""" """This program extracts traffic lights from images."""
from pathlib import Path
import cv2 import cv2
from detector.object_detection import ( from detector.object_detection import (

View File

@@ -62,7 +62,7 @@ def load_model(model_name: str) -> tf.saved_model.LoadOptions:
# Download a file from a URL that is not already in the cache # Download a file from a URL that is not already in the cache
model_dir = tf.keras.utils.get_file(fname=model_name, untar=True, origin=url) model_dir = tf.keras.utils.get_file(fname=model_name, untar=True, origin=url)
logger.info(f"Model path: {model_dir}") logger.info(f"Loaded model: {model_dir}")
return tf.saved_model.load(f"{model_dir}/saved_model") return tf.saved_model.load(f"{model_dir}/saved_model")
@@ -85,61 +85,56 @@ def load_ssd_coco() -> tf.saved_model.LoadOptions:
@logger.catch @logger.catch
def save_image_annotated(image_rgb, file_name: Path, output, model_traffic_lights=None) -> None: def save_image_annotated(image_rgb, file_name: Path, output, model_traffic_lights) -> None:
"""Annotate the image with the object types, and generate cropped images of traffic lights.""" """Annotate the image with the object types, and generate cropped images of traffic lights."""
output_file = IMAGES_OUT_PATH.joinpath(file_name.name) output_file = IMAGES_OUT_PATH.joinpath(file_name.name)
# For each bounding box that was detected # For each bounding box that was detected
for idx, (box, object_class) in enumerate(zip(output["boxes"], output["detection_classes"])): for idx, (box, object_class) in enumerate(zip(output["boxes"], output["detection_classes"])):
color = LABELS.get(object_class, (255, 255, 255)) color = LABELS.get(object_class, None)
# How confident the object detection model is on the object's type # How confident the object detection model is on the object's type
score: int = object_class * 100 score: int = object_class * 100
label_text = f"{LABEL_TEXT.get(object_class)} {score}"
# Extract the bounding box
box = output["boxes"][idx]
label_text = f"{object_class} {score}"
if object_class == LABEL_TRAFFIC_LIGHT: if object_class == LABEL_TRAFFIC_LIGHT:
if model_traffic_lights is not None:
# Annotate the image and save it # Annotate the image and save it
image_traffic_light = image_rgb[box["y"]:box["y2"], box["x"]:box["x2"]] image_traffic_light = image_rgb[box.get("y"):box.get("y2"), box.get("x"):box.get("x2")]
image_inception = cv2.resize(image_traffic_light, (299, 299)) image_inception = cv2.resize(image_traffic_light, (299, 299))
# Uncomment this if you want to save a cropped image of the traffic light # Uncomment this if you want to save a cropped image of the traffic light
image_inception = np.array([preprocess_input(image_inception)]) image_inception = np.array([preprocess_input(image_inception)])
prediction = model_traffic_lights.predict(image_inception) prediction = model_traffic_lights.predict(image_inception)
label = np.argmax(prediction) label = np.argmax(prediction)
score_light = int(np.max(prediction) * 100) score_light = int(np.max(prediction) * 100)
if label == 0: if label == 0:
label_text = f"Green {score_light}" label_text = f"Green {score_light}"
elif label == 1: elif label == 1:
label_text = f"Yellow {score_light}" label_text = f"Yellow {score_light}"
elif label == 2: elif label == 2:
label_text = f"Red {score_light}" label_text = f"Red {score_light}"
else: else:
label_text = "NO-LIGHT" label_text = "NO-LIGHT"
# Draw the bounding box and object class label on the image, if the confidence score is above 50 and the box is not a duplicate # Draw the bounding box and object class label on the image, if the confidence score is above 50 and the box is not a duplicate
if color and label_text and accept_box(output["boxes"], idx, 5) and score > 50: if color and label_text and accept_box(output.get("boxes"), idx, 5) and score > 50:
cv2.rectangle(image_rgb, (box["x"], box["y"]), (box["x2"], box["y2"]), color, 2) cv2.rectangle(image_rgb, (box.get("x"), box.get("y")), (box.get("x2"), box.get("y2")), color, 2)
cv2.putText(image_rgb, label_text, (box["x"], box["y"]), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) cv2.putText(image_rgb, label_text, (box.get("x"), box.get("y")), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
cv2.imwrite(str(output_file), cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)) cv2.imwrite(str(output_file), cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR))
logger.info(output_file) logger.info(output_file)
@logger.catch @ logger.catch
def center(box: dict[str, float], coord_type: str) -> float: def center(box: dict[str, float], coord_type: str) -> float:
"""Get center of the bounding box.""" """Get center of the bounding box."""
return (box[coord_type] + box[coord_type + "2"]) / 2 return (box[coord_type] + box[coord_type + "2"]) / 2
@logger.catch @ logger.catch
def perform_object_detection(model, file_name, save_annotated=False, model_traffic_lights=None): def perform_object_detection(model, file_name: Path, save_annotated=False, model_traffic_lights=None):
"""Perform object detection on an image using the predefined neural network.""" """Perform object detection on an image using the predefined neural network."""
# Store the image # Store the image
image_bgr = cv2.imread(str(file_name)) image_bgr = cv2.imread(str(file_name))
@@ -150,21 +145,21 @@ def perform_object_detection(model, file_name, save_annotated=False, model_traff
# Run the model # Run the model
output = model(input_tensor) output = model(input_tensor)
logger.info(f"Number detections: {output['num_detections']} {int(output['num_detections'])}") logger.debug(f"Number detections: {output['num_detections']} {int(output['num_detections'])}")
# Convert the tensors to a NumPy array # Convert the tensors to a NumPy array
num_detections = int(output.pop("num_detections")) number_detections = int(output.pop("num_detections"))
output = {key: value[0, :num_detections].numpy() for key, value in output.items()} output = {key: value[0, :number_detections].numpy() for key, value in output.items()}
output["num_detections"] = num_detections output["num_detections"] = number_detections
logger.info(f"Detection classes: {output['detection_classes']}") logger.debug(f"Detection classes: {output['detection_classes']}")
logger.info(f"Detection Boxes: {output['detection_boxes']}") logger.debug(f"Detection Boxes: {output['detection_boxes']}")
# The detected classes need to be integers. # The detected classes need to be integers.
output["detection_classes"] = output["detection_classes"].astype(np.int64) output["detection_classes"] = output["detection_classes"].astype(np.int64)
output["boxes"] = [{"y": int(box[0] * image_rgb.shape[0]), output["boxes"] = [{"y": int(box[0] * image_rgb.shape[0]),
"x": int(box[1] * image_rgb.shape[1]), "x": int(box[1] * image_rgb.shape[1]),
"y2": int(box[2] * image_rgb.shape[0]), "y2": int(box[2] * image_rgb.shape[0]),
"x2": int(box[3] * image_rgb.shape[1])} "x2": int(box[3] * image_rgb.shape[1])}
for box in output["detection_boxes"]] for box in output["detection_boxes"]]
@@ -174,7 +169,7 @@ def perform_object_detection(model, file_name, save_annotated=False, model_traff
return image_rgb, output, file_name return image_rgb, output, file_name
@logger.catch @ logger.catch
def perform_object_detection_video(video_frame, model, model_traffic_lights): def perform_object_detection_video(video_frame, model, model_traffic_lights):
"""Perform object detection on a video using the predefined neural network.""" """Perform object detection on a video using the predefined neural network."""
@@ -195,7 +190,7 @@ def perform_object_detection_video(video_frame, model, model_traffic_lights):
output["detection_classes"] = output["detection_classes"].astype(np.int64) output["detection_classes"] = output["detection_classes"].astype(np.int64)
output["boxes"] = [{"y": int(box[0] * image_rgb.shape[0]), output["boxes"] = [{"y": int(box[0] * image_rgb.shape[0]),
"x": int(box[1] * image_rgb.shape[1]), "x": int(box[1] * image_rgb.shape[1]),
"y2": int(box[2] * image_rgb.shape[0]), "y2": int(box[2] * image_rgb.shape[0]),
"x2": int(box[3] * image_rgb.shape[1])} "x2": int(box[3] * image_rgb.shape[1])}
for box in output["detection_boxes"]] for box in output["detection_boxes"]]
@@ -236,7 +231,7 @@ def perform_object_detection_video(video_frame, model, model_traffic_lights):
return cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) return cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
@logger.catch @ logger.catch
def double_shuffle(images: list[str], labels: list[int]) -> tuple[list[str], list[int]]: def double_shuffle(images: list[str], labels: list[int]) -> tuple[list[str], list[int]]:
"""Shuffle the images to add some randomness.""" """Shuffle the images to add some randomness."""
indexes = np.random.permutation(len(images)) indexes = np.random.permutation(len(images))
@@ -244,7 +239,7 @@ def double_shuffle(images: list[str], labels: list[int]) -> tuple[list[str], lis
return [images[idx] for idx in indexes], [labels[idx] for idx in indexes] return [images[idx] for idx in indexes], [labels[idx] for idx in indexes]
@logger.catch @ logger.catch
def reverse_preprocess_inception(image_preprocessed): def reverse_preprocess_inception(image_preprocessed):
"""Reverse the preprocessing process for an image that has been input to the Inception V3 model.""" """Reverse the preprocessing process for an image that has been input to the Inception V3 model."""
image = image_preprocessed + 1 * 127.5 image = image_preprocessed + 1 * 127.5

View File

@@ -5,8 +5,6 @@ to a directory. Also, the best neural network model is saved as traffic.h5.
""" """
import collections import collections
from pathlib import Path
import cv2 import cv2
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np