Added Traffic Light recognition to class

This commit is contained in:
Kristofers Solo
2022-12-17 16:51:13 +02:00
parent 0ee644320d
commit dced61a204
2 changed files with 29 additions and 33 deletions

View File

@@ -1,12 +1,14 @@
import cv2 import cv2
from loguru import logger from loguru import logger
from paths import HAAR_PATH
from TrafficLightDetector.color import Color from TrafficLightDetector.color import Color
class TrafficLightDetector: class TrafficLightDetector:
CASCADE = cv2.CascadeClassifier(str(HAAR_PATH))
FONT = cv2.FONT_HERSHEY_SIMPLEX FONT = cv2.FONT_HERSHEY_SIMPLEX
RADIUS = 5 RADIUS = 5
BOUNDARY = 4 / 10 BOUNDARY = 2
# HSV values # HSV values
RED_LOWER = ((160, 100, 100), (0, 100, 100)) RED_LOWER = ((160, 100, 100), (0, 100, 100))
RED_UPPER = ((180, 255, 255), (10, 255, 255)) RED_UPPER = ((180, 255, 255), (10, 255, 255))
@@ -20,7 +22,7 @@ class TrafficLightDetector:
YELLOW = (0, 175, 225) YELLOW = (0, 175, 225)
GREEN = (0, 150, 0) GREEN = (0, 150, 0)
def _set_image(self, image) -> None: def _set_image(self, image=None) -> None:
self.image = image self.image = image
self.image_copy = self.image self.image_copy = self.image
self.size = self.image.shape self.size = self.image.shape
@@ -30,12 +32,15 @@ class TrafficLightDetector:
self.green = Color("GREEN", self.GREEN, self.GREEN_LOWER, self.GREEN_UPPER, hsv, minDist=30, param2=5) self.green = Color("GREEN", self.GREEN, self.GREEN_LOWER, self.GREEN_UPPER, hsv, minDist=30, param2=5)
self.colors = [self.red, self.yellow, self.green] self.colors = [self.red, self.yellow, self.green]
def _find_traffic_lights(self) -> None:
gray = cv2.cvtColor(self.image, cv2.COLOR_BGR2GRAY)
# draw rectangle around traffic lights
for x, y, width, height in self.CASCADE.detectMultiScale(gray, 1.2, 5):
cv2.rectangle(self.image, (x, y), (x + width, y + height), (255, 0, 0), self.BOUNDARY)
def _draw_circle(self) -> None: def _draw_circle(self) -> None:
try:
for color in self.colors: for color in self.colors:
if color.circle is not None: if color.circle is not None:
logger.debug(f"{color.circle = }")
for values in color.circle[0, :]: for values in color.circle[0, :]:
if values[0] > self.size[1] or values[1] > self.size[0] or values[1] > self.size[0] * self.BOUNDARY: if values[0] > self.size[1] or values[1] > self.size[0] or values[1] > self.size[0] * self.BOUNDARY:
continue continue
@@ -52,8 +57,6 @@ class TrafficLightDetector:
cv2.circle(color.mask, (values[0], values[1]), values[2] + 30, (255, 255, 255), 2) cv2.circle(color.mask, (values[0], values[1]), values[2] + 30, (255, 255, 255), 2)
cv2.putText(self.image_copy, color.name, (values[0], values[1]), self.FONT, 1, color.color, 2, cv2.LINE_AA) cv2.putText(self.image_copy, color.name, (values[0], values[1]), self.FONT, 1, color.color, 2, cv2.LINE_AA)
self.signal = color.name self.signal = color.name
except AttributeError:
logger.warning("Image/frame was not specified")
def get_signal(self) -> str: def get_signal(self) -> str:
return self.signal return self.signal

View File

@@ -1,6 +1,4 @@
import cv2 import cv2
from loguru import logger
from paths import HAAR_PATH
from TrafficLightDetector.traffic_light_detector import TrafficLightDetector from TrafficLightDetector.traffic_light_detector import TrafficLightDetector
@@ -8,20 +6,15 @@ class TrafficLightDetectorWebcam(TrafficLightDetector):
def __init__(self) -> None: def __init__(self) -> None:
self.video_capture = cv2.VideoCapture(0) # Change number if webcam didn't detect self.video_capture = cv2.VideoCapture(0) # Change number if webcam didn't detect
self.lights_cascade = cv2.CascadeClassifier(str(HAAR_PATH))
def enable(self) -> None: def enable(self) -> None:
while True: while True:
_, frame = self.video_capture.read() _, frame = self.video_capture.read()
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
lights = self.lights_cascade.detectMultiScale(gray, 1.2, 5)
for x, y, w, h in lights:
cv2.rectangle(frame, (x, y), (x+w, y+h), (255, 0, 0), 5)
# self._set_image(frame) # self._set_image(frame)
# self._draw_circle() # self._draw_circle()
# cv2.imshow("Video", self.image_copy) self._find_traffic_lights()
cv2.imshow("Video", self.image_copy)
cv2.imshow("Video", frame) cv2.imshow("Video", frame)
if cv2.waitKey(1) & 0xFF == ord("q"): if cv2.waitKey(1) & 0xFF == ord("q"):
break break