edgetpu.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. import datetime
  2. import logging
  3. import multiprocessing as mp
  4. import os
  5. import queue
  6. import signal
  7. import threading
  8. from abc import ABC, abstractmethod
  9. from typing import Dict
  10. import numpy as np
  11. import tflite_runtime.interpreter as tflite
  12. from setproctitle import setproctitle
  13. from tflite_runtime.interpreter import load_delegate
  14. from frigate.util import EventsPerSecond, SharedMemoryFrameManager, listen
  15. logger = logging.getLogger(__name__)
  16. def load_labels(path, encoding="utf-8"):
  17. """Loads labels from file (with or without index numbers).
  18. Args:
  19. path: path to label file.
  20. encoding: label file encoding.
  21. Returns:
  22. Dictionary mapping indices to labels.
  23. """
  24. with open(path, "r", encoding=encoding) as f:
  25. lines = f.readlines()
  26. if not lines:
  27. return {}
  28. if lines[0].split(" ", maxsplit=1)[0].isdigit():
  29. pairs = [line.split(" ", maxsplit=1) for line in lines]
  30. return {int(index): label.strip() for index, label in pairs}
  31. else:
  32. return {index: line.strip() for index, line in enumerate(lines)}
  33. class ObjectDetector(ABC):
  34. @abstractmethod
  35. def detect(self, tensor_input, threshold=0.4):
  36. pass
  37. class LocalObjectDetector(ObjectDetector):
  38. def __init__(self, tf_device=None, num_threads=3, labels=None):
  39. self.fps = EventsPerSecond()
  40. if labels is None:
  41. self.labels = {}
  42. else:
  43. self.labels = load_labels(labels)
  44. device_config = {"device": "usb"}
  45. if not tf_device is None:
  46. device_config = {"device": tf_device}
  47. edge_tpu_delegate = None
  48. if tf_device != "cpu":
  49. try:
  50. logger.info(f"Attempting to load TPU as {device_config['device']}")
  51. edge_tpu_delegate = load_delegate("libedgetpu.so.1.0", device_config)
  52. logger.info("TPU found")
  53. self.interpreter = tflite.Interpreter(
  54. model_path="/edgetpu_model.tflite",
  55. experimental_delegates=[edge_tpu_delegate],
  56. )
  57. except ValueError:
  58. logger.info("No EdgeTPU detected.")
  59. raise
  60. else:
  61. self.interpreter = tflite.Interpreter(
  62. model_path="/cpu_model.tflite", num_threads=num_threads
  63. )
  64. self.interpreter.allocate_tensors()
  65. self.tensor_input_details = self.interpreter.get_input_details()
  66. self.tensor_output_details = self.interpreter.get_output_details()
  67. def detect(self, tensor_input, threshold=0.4):
  68. detections = []
  69. raw_detections = self.detect_raw(tensor_input)
  70. for d in raw_detections:
  71. if d[1] < threshold:
  72. break
  73. detections.append(
  74. (self.labels[int(d[0])], float(d[1]), (d[2], d[3], d[4], d[5]))
  75. )
  76. self.fps.update()
  77. return detections
  78. def detect_raw(self, tensor_input):
  79. self.interpreter.set_tensor(self.tensor_input_details[0]["index"], tensor_input)
  80. self.interpreter.invoke()
  81. boxes = np.squeeze(
  82. self.interpreter.get_tensor(self.tensor_output_details[0]["index"])
  83. )
  84. label_codes = np.squeeze(
  85. self.interpreter.get_tensor(self.tensor_output_details[1]["index"])
  86. )
  87. scores = np.squeeze(
  88. self.interpreter.get_tensor(self.tensor_output_details[2]["index"])
  89. )
  90. detections = np.zeros((20, 6), np.float32)
  91. for i, score in enumerate(scores):
  92. detections[i] = [
  93. label_codes[i],
  94. score,
  95. boxes[i][0],
  96. boxes[i][1],
  97. boxes[i][2],
  98. boxes[i][3],
  99. ]
  100. return detections
  101. def run_detector(
  102. name: str,
  103. detection_queue: mp.Queue,
  104. out_events: Dict[str, mp.Event],
  105. avg_speed,
  106. start,
  107. model_shape,
  108. tf_device,
  109. num_threads,
  110. ):
  111. threading.current_thread().name = f"detector:{name}"
  112. logger = logging.getLogger(f"detector.{name}")
  113. logger.info(f"Starting detection process: {os.getpid()}")
  114. setproctitle(f"frigate.detector.{name}")
  115. listen()
  116. stop_event = mp.Event()
  117. def receiveSignal(signalNumber, frame):
  118. stop_event.set()
  119. signal.signal(signal.SIGTERM, receiveSignal)
  120. signal.signal(signal.SIGINT, receiveSignal)
  121. frame_manager = SharedMemoryFrameManager()
  122. object_detector = LocalObjectDetector(tf_device=tf_device, num_threads=num_threads)
  123. outputs = {}
  124. for name in out_events.keys():
  125. out_shm = mp.shared_memory.SharedMemory(name=f"out-{name}", create=False)
  126. out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
  127. outputs[name] = {"shm": out_shm, "np": out_np}
  128. while not stop_event.is_set():
  129. try:
  130. connection_id = detection_queue.get(timeout=5)
  131. except queue.Empty:
  132. continue
  133. input_frame = frame_manager.get(
  134. connection_id, (1, model_shape[0], model_shape[1], 3)
  135. )
  136. if input_frame is None:
  137. continue
  138. # detect and send the output
  139. start.value = datetime.datetime.now().timestamp()
  140. detections = object_detector.detect_raw(input_frame)
  141. duration = datetime.datetime.now().timestamp() - start.value
  142. outputs[connection_id]["np"][:] = detections[:]
  143. out_events[connection_id].set()
  144. start.value = 0.0
  145. avg_speed.value = (avg_speed.value * 9 + duration) / 10
  146. class EdgeTPUProcess:
  147. def __init__(
  148. self,
  149. name,
  150. detection_queue,
  151. out_events,
  152. model_shape,
  153. tf_device=None,
  154. num_threads=3,
  155. ):
  156. self.name = name
  157. self.out_events = out_events
  158. self.detection_queue = detection_queue
  159. self.avg_inference_speed = mp.Value("d", 0.01)
  160. self.detection_start = mp.Value("d", 0.0)
  161. self.detect_process = None
  162. self.model_shape = model_shape
  163. self.tf_device = tf_device
  164. self.num_threads = num_threads
  165. self.start_or_restart()
  166. def stop(self):
  167. self.detect_process.terminate()
  168. logging.info("Waiting for detection process to exit gracefully...")
  169. self.detect_process.join(timeout=30)
  170. if self.detect_process.exitcode is None:
  171. logging.info("Detection process didnt exit. Force killing...")
  172. self.detect_process.kill()
  173. self.detect_process.join()
  174. def start_or_restart(self):
  175. self.detection_start.value = 0.0
  176. if (not self.detect_process is None) and self.detect_process.is_alive():
  177. self.stop()
  178. self.detect_process = mp.Process(
  179. target=run_detector,
  180. name=f"detector:{self.name}",
  181. args=(
  182. self.name,
  183. self.detection_queue,
  184. self.out_events,
  185. self.avg_inference_speed,
  186. self.detection_start,
  187. self.model_shape,
  188. self.tf_device,
  189. self.num_threads,
  190. ),
  191. )
  192. self.detect_process.daemon = True
  193. self.detect_process.start()
  194. class RemoteObjectDetector:
  195. def __init__(self, name, labels, detection_queue, event, model_shape):
  196. self.labels = load_labels(labels)
  197. self.name = name
  198. self.fps = EventsPerSecond()
  199. self.detection_queue = detection_queue
  200. self.event = event
  201. self.shm = mp.shared_memory.SharedMemory(name=self.name, create=False)
  202. self.np_shm = np.ndarray(
  203. (1, model_shape[0], model_shape[1], 3), dtype=np.uint8, buffer=self.shm.buf
  204. )
  205. self.out_shm = mp.shared_memory.SharedMemory(
  206. name=f"out-{self.name}", create=False
  207. )
  208. self.out_np_shm = np.ndarray((20, 6), dtype=np.float32, buffer=self.out_shm.buf)
  209. def detect(self, tensor_input, threshold=0.4):
  210. detections = []
  211. # copy input to shared memory
  212. self.np_shm[:] = tensor_input[:]
  213. self.event.clear()
  214. self.detection_queue.put(self.name)
  215. result = self.event.wait(timeout=10.0)
  216. # if it timed out
  217. if result is None:
  218. return detections
  219. for d in self.out_np_shm:
  220. if d[1] < threshold:
  221. break
  222. detections.append(
  223. (self.labels[int(d[0])], float(d[1]), (d[2], d[3], d[4], d[5]))
  224. )
  225. self.fps.update()
  226. return detections
  227. def cleanup(self):
  228. self.shm.unlink()
  229. self.out_shm.unlink()