diff --git a/src/detector/train_traffic_light_color.py b/src/detector/train_traffic_light_color.py index 4681e08..55350f8 100644 --- a/src/detector/train_traffic_light_color.py +++ b/src/detector/train_traffic_light_color.py @@ -151,8 +151,8 @@ def train_traffic_light_color() -> None: labels.extend([3] * len(img_3_not_traffic_light)) # Create NumPy array - labels_np = np.ndarray(shape=(len(labels), 4)) - images_np = np.ndarray(shape=(len(labels), shape[0], shape[1], 3)) + labels_np: np.ndarray[int, np.dtype[np.generic]] = np.ndarray(shape=(len(labels), 4)) + images_np: np.ndarray[int, np.dtype[np.generic]] = np.ndarray(shape=(len(labels), shape[0], shape[1], 3)) # Create a list of all the images in the traffic lights data set img_all = [] @@ -177,7 +177,7 @@ def train_traffic_light_color() -> None: logger.info(f"Labels: {len(labels)}") # Perform one-hot encoding - for idx in range(len(labels_np)): + for idx, _ in enumerate(labels_np): # We have four integer labels, representing the different colors of the # traffic lights. labels_np[idx] = np.array(to_categorical(labels[idx], 4))