瀏覽代碼

allow for custom object detection model via configuration

Jason Hunter 3 年之前
父節點
當前提交
a7b7a45b23
共有 3 個文件被更改,包括 16 次插入5 次删除
  1. 3 0
      frigate/app.py
  2. 3 1
      frigate/config.py
  3. 10 4
      frigate/edgetpu.py

+ 3 - 0
frigate/app.py

@@ -170,6 +170,7 @@ class FrigateApp:
         self.mqtt_relay.start()
         self.mqtt_relay.start()
 
 
     def start_detectors(self):
     def start_detectors(self):
+        model_path = self.config.model.path
         model_shape = (self.config.model.height, self.config.model.width)
         model_shape = (self.config.model.height, self.config.model.width)
         for name in self.config.cameras.keys():
         for name in self.config.cameras.keys():
             self.detection_out_events[name] = mp.Event()
             self.detection_out_events[name] = mp.Event()
@@ -199,6 +200,7 @@ class FrigateApp:
                     name,
                     name,
                     self.detection_queue,
                     self.detection_queue,
                     self.detection_out_events,
                     self.detection_out_events,
+                    model_path,
                     model_shape,
                     model_shape,
                     "cpu",
                     "cpu",
                     detector.num_threads,
                     detector.num_threads,
@@ -208,6 +210,7 @@ class FrigateApp:
                     name,
                     name,
                     self.detection_queue,
                     self.detection_queue,
                     self.detection_out_events,
                     self.detection_out_events,
+                    model_path,
                     model_shape,
                     model_shape,
                     detector.device,
                     detector.device,
                     detector.num_threads,
                     detector.num_threads,

+ 3 - 1
frigate/config.py

@@ -603,6 +603,8 @@ class DatabaseConfig(FrigateBaseModel):
 
 
 
 
 class ModelConfig(FrigateBaseModel):
 class ModelConfig(FrigateBaseModel):
+    path: Optional[str] = Field(title="Custom Object detection model path.")
+    labelmap_path: Optional[str] = Field(title="Label map for custom object detector.")
     width: int = Field(default=320, title="Object detection model input width.")
     width: int = Field(default=320, title="Object detection model input width.")
     height: int = Field(default=320, title="Object detection model input height.")
     height: int = Field(default=320, title="Object detection model input height.")
     labelmap: Dict[int, str] = Field(
     labelmap: Dict[int, str] = Field(
@@ -623,7 +625,7 @@ class ModelConfig(FrigateBaseModel):
         super().__init__(**config)
         super().__init__(**config)
 
 
         self._merged_labelmap = {
         self._merged_labelmap = {
-            **load_labels("/labelmap.txt"),
+            **load_labels(config.get("labelmap_path", "/labelmap.txt")),
             **config.get("labelmap", {}),
             **config.get("labelmap", {}),
         }
         }
 
 

+ 10 - 4
frigate/edgetpu.py

@@ -45,7 +45,7 @@ class ObjectDetector(ABC):
 
 
 
 
 class LocalObjectDetector(ObjectDetector):
 class LocalObjectDetector(ObjectDetector):
-    def __init__(self, tf_device=None, num_threads=3, labels=None):
+    def __init__(self, tf_device=None, model_path=None, num_threads=3, labels=None):
         self.fps = EventsPerSecond()
         self.fps = EventsPerSecond()
         if labels is None:
         if labels is None:
             self.labels = {}
             self.labels = {}
@@ -64,7 +64,7 @@ class LocalObjectDetector(ObjectDetector):
                 edge_tpu_delegate = load_delegate("libedgetpu.so.1.0", device_config)
                 edge_tpu_delegate = load_delegate("libedgetpu.so.1.0", device_config)
                 logger.info("TPU found")
                 logger.info("TPU found")
                 self.interpreter = tflite.Interpreter(
                 self.interpreter = tflite.Interpreter(
-                    model_path="/edgetpu_model.tflite",
+                    model_path=model_path or "/edgetpu_model.tflite",
                     experimental_delegates=[edge_tpu_delegate],
                     experimental_delegates=[edge_tpu_delegate],
                 )
                 )
             except ValueError:
             except ValueError:
@@ -77,7 +77,7 @@ class LocalObjectDetector(ObjectDetector):
                 "CPU detectors are not recommended and should only be used for testing or for trial purposes."
                 "CPU detectors are not recommended and should only be used for testing or for trial purposes."
             )
             )
             self.interpreter = tflite.Interpreter(
             self.interpreter = tflite.Interpreter(
-                model_path="/cpu_model.tflite", num_threads=num_threads
+                model_path=model_path or "/cpu_model.tflite", num_threads=num_threads
             )
             )
 
 
         self.interpreter.allocate_tensors()
         self.interpreter.allocate_tensors()
@@ -133,6 +133,7 @@ def run_detector(
     out_events: Dict[str, mp.Event],
     out_events: Dict[str, mp.Event],
     avg_speed,
     avg_speed,
     start,
     start,
+    model_path,
     model_shape,
     model_shape,
     tf_device,
     tf_device,
     num_threads,
     num_threads,
@@ -152,7 +153,9 @@ def run_detector(
     signal.signal(signal.SIGINT, receiveSignal)
     signal.signal(signal.SIGINT, receiveSignal)
 
 
     frame_manager = SharedMemoryFrameManager()
     frame_manager = SharedMemoryFrameManager()
-    object_detector = LocalObjectDetector(tf_device=tf_device, num_threads=num_threads)
+    object_detector = LocalObjectDetector(
+        tf_device=tf_device, model_path=model_path, num_threads=num_threads
+    )
 
 
     outputs = {}
     outputs = {}
     for name in out_events.keys():
     for name in out_events.keys():
@@ -189,6 +192,7 @@ class EdgeTPUProcess:
         name,
         name,
         detection_queue,
         detection_queue,
         out_events,
         out_events,
+        model_path,
         model_shape,
         model_shape,
         tf_device=None,
         tf_device=None,
         num_threads=3,
         num_threads=3,
@@ -199,6 +203,7 @@ class EdgeTPUProcess:
         self.avg_inference_speed = mp.Value("d", 0.01)
         self.avg_inference_speed = mp.Value("d", 0.01)
         self.detection_start = mp.Value("d", 0.0)
         self.detection_start = mp.Value("d", 0.0)
         self.detect_process = None
         self.detect_process = None
+        self.model_path = model_path
         self.model_shape = model_shape
         self.model_shape = model_shape
         self.tf_device = tf_device
         self.tf_device = tf_device
         self.num_threads = num_threads
         self.num_threads = num_threads
@@ -226,6 +231,7 @@ class EdgeTPUProcess:
                 self.out_events,
                 self.out_events,
                 self.avg_inference_speed,
                 self.avg_inference_speed,
                 self.detection_start,
                 self.detection_start,
+                self.model_path,
                 self.model_shape,
                 self.model_shape,
                 self.tf_device,
                 self.tf_device,
                 self.num_threads,
                 self.num_threads,