Browse Source

allow for custom object detection model via configuration

Jason Hunter 3 years ago
parent
commit
a7b7a45b23
3 changed files with 16 additions and 5 deletions
  1. 3 0
      frigate/app.py
  2. 3 1
      frigate/config.py
  3. 10 4
      frigate/edgetpu.py

+ 3 - 0
frigate/app.py

@@ -170,6 +170,7 @@ class FrigateApp:
         self.mqtt_relay.start()
 
     def start_detectors(self):
+        model_path = self.config.model.path
         model_shape = (self.config.model.height, self.config.model.width)
         for name in self.config.cameras.keys():
             self.detection_out_events[name] = mp.Event()
@@ -199,6 +200,7 @@ class FrigateApp:
                     name,
                     self.detection_queue,
                     self.detection_out_events,
+                    model_path,
                     model_shape,
                     "cpu",
                     detector.num_threads,
@@ -208,6 +210,7 @@ class FrigateApp:
                     name,
                     self.detection_queue,
                     self.detection_out_events,
+                    model_path,
                     model_shape,
                     detector.device,
                     detector.num_threads,

+ 3 - 1
frigate/config.py

@@ -603,6 +603,8 @@ class DatabaseConfig(FrigateBaseModel):
 
 
 class ModelConfig(FrigateBaseModel):
+    path: Optional[str] = Field(title="Custom Object detection model path.")
+    labelmap_path: Optional[str] = Field(title="Label map for custom object detector.")
     width: int = Field(default=320, title="Object detection model input width.")
     height: int = Field(default=320, title="Object detection model input height.")
     labelmap: Dict[int, str] = Field(
@@ -623,7 +625,7 @@ class ModelConfig(FrigateBaseModel):
         super().__init__(**config)
 
         self._merged_labelmap = {
-            **load_labels("/labelmap.txt"),
+            **load_labels(config.get("labelmap_path", "/labelmap.txt")),
             **config.get("labelmap", {}),
         }
 

+ 10 - 4
frigate/edgetpu.py

@@ -45,7 +45,7 @@ class ObjectDetector(ABC):
 
 
 class LocalObjectDetector(ObjectDetector):
-    def __init__(self, tf_device=None, num_threads=3, labels=None):
+    def __init__(self, tf_device=None, model_path=None, num_threads=3, labels=None):
         self.fps = EventsPerSecond()
         if labels is None:
             self.labels = {}
@@ -64,7 +64,7 @@ class LocalObjectDetector(ObjectDetector):
                 edge_tpu_delegate = load_delegate("libedgetpu.so.1.0", device_config)
                 logger.info("TPU found")
                 self.interpreter = tflite.Interpreter(
-                    model_path="/edgetpu_model.tflite",
+                    model_path=model_path or "/edgetpu_model.tflite",
                     experimental_delegates=[edge_tpu_delegate],
                 )
             except ValueError:
@@ -77,7 +77,7 @@ class LocalObjectDetector(ObjectDetector):
                 "CPU detectors are not recommended and should only be used for testing or for trial purposes."
             )
             self.interpreter = tflite.Interpreter(
-                model_path="/cpu_model.tflite", num_threads=num_threads
+                model_path=model_path or "/cpu_model.tflite", num_threads=num_threads
             )
 
         self.interpreter.allocate_tensors()
@@ -133,6 +133,7 @@ def run_detector(
     out_events: Dict[str, mp.Event],
     avg_speed,
     start,
+    model_path,
     model_shape,
     tf_device,
     num_threads,
@@ -152,7 +153,9 @@ def run_detector(
     signal.signal(signal.SIGINT, receiveSignal)
 
     frame_manager = SharedMemoryFrameManager()
-    object_detector = LocalObjectDetector(tf_device=tf_device, num_threads=num_threads)
+    object_detector = LocalObjectDetector(
+        tf_device=tf_device, model_path=model_path, num_threads=num_threads
+    )
 
     outputs = {}
     for name in out_events.keys():
@@ -189,6 +192,7 @@ class EdgeTPUProcess:
         name,
         detection_queue,
         out_events,
+        model_path,
         model_shape,
         tf_device=None,
         num_threads=3,
@@ -199,6 +203,7 @@ class EdgeTPUProcess:
         self.avg_inference_speed = mp.Value("d", 0.01)
         self.detection_start = mp.Value("d", 0.0)
         self.detect_process = None
+        self.model_path = model_path
         self.model_shape = model_shape
         self.tf_device = tf_device
         self.num_threads = num_threads
@@ -226,6 +231,7 @@ class EdgeTPUProcess:
                 self.out_events,
                 self.avg_inference_speed,
                 self.detection_start,
+                self.model_path,
                 self.model_shape,
                 self.tf_device,
                 self.num_threads,