edgetpu.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. import datetime
  2. import hashlib
  3. import logging
  4. import multiprocessing as mp
  5. import os
  6. import queue
  7. import threading
  8. import signal
  9. from abc import ABC, abstractmethod
  10. from multiprocessing.connection import Connection
  11. from setproctitle import setproctitle
  12. from typing import Dict
  13. import numpy as np
  14. import tflite_runtime.interpreter as tflite
  15. from tflite_runtime.interpreter import load_delegate
  16. from frigate.util import EventsPerSecond, SharedMemoryFrameManager, listen
  17. logger = logging.getLogger(__name__)
  18. def load_labels(path, encoding='utf-8'):
  19. """Loads labels from file (with or without index numbers).
  20. Args:
  21. path: path to label file.
  22. encoding: label file encoding.
  23. Returns:
  24. Dictionary mapping indices to labels.
  25. """
  26. with open(path, 'r', encoding=encoding) as f:
  27. lines = f.readlines()
  28. if not lines:
  29. return {}
  30. if lines[0].split(' ', maxsplit=1)[0].isdigit():
  31. pairs = [line.split(' ', maxsplit=1) for line in lines]
  32. return {int(index): label.strip() for index, label in pairs}
  33. else:
  34. return {index: line.strip() for index, line in enumerate(lines)}
  35. class ObjectDetector(ABC):
  36. @abstractmethod
  37. def detect(self, tensor_input, threshold = .4):
  38. pass
  39. class LocalObjectDetector(ObjectDetector):
  40. def __init__(self, tf_device=None, num_threads=3, labels=None):
  41. self.fps = EventsPerSecond()
  42. if labels is None:
  43. self.labels = {}
  44. else:
  45. self.labels = load_labels(labels)
  46. device_config = {"device": "usb"}
  47. if not tf_device is None:
  48. device_config = {"device": tf_device}
  49. edge_tpu_delegate = None
  50. if tf_device != 'cpu':
  51. try:
  52. logger.info(f"Attempting to load TPU as {device_config['device']}")
  53. edge_tpu_delegate = load_delegate('libedgetpu.so.1.0', device_config)
  54. logger.info("TPU found")
  55. self.interpreter = tflite.Interpreter(
  56. model_path='/edgetpu_model.tflite',
  57. experimental_delegates=[edge_tpu_delegate])
  58. except ValueError:
  59. logger.info("No EdgeTPU detected.")
  60. raise
  61. else:
  62. self.interpreter = tflite.Interpreter(
  63. model_path='/cpu_model.tflite', num_threads=num_threads)
  64. self.interpreter.allocate_tensors()
  65. self.tensor_input_details = self.interpreter.get_input_details()
  66. self.tensor_output_details = self.interpreter.get_output_details()
  67. def detect(self, tensor_input, threshold=.4):
  68. detections = []
  69. raw_detections = self.detect_raw(tensor_input)
  70. for d in raw_detections:
  71. if d[1] < threshold:
  72. break
  73. detections.append((
  74. self.labels[int(d[0])],
  75. float(d[1]),
  76. (d[2], d[3], d[4], d[5])
  77. ))
  78. self.fps.update()
  79. return detections
  80. def detect_raw(self, tensor_input):
  81. self.interpreter.set_tensor(self.tensor_input_details[0]['index'], tensor_input)
  82. self.interpreter.invoke()
  83. boxes = np.squeeze(self.interpreter.get_tensor(self.tensor_output_details[0]['index']))
  84. label_codes = np.squeeze(self.interpreter.get_tensor(self.tensor_output_details[1]['index']))
  85. scores = np.squeeze(self.interpreter.get_tensor(self.tensor_output_details[2]['index']))
  86. detections = np.zeros((20,6), np.float32)
  87. for i, score in enumerate(scores):
  88. detections[i] = [label_codes[i], score, boxes[i][0], boxes[i][1], boxes[i][2], boxes[i][3]]
  89. return detections
  90. def run_detector(name: str, detection_queue: mp.Queue, out_events: Dict[str, mp.Event], avg_speed, start, model_shape, tf_device, num_threads):
  91. threading.current_thread().name = f"detector:{name}"
  92. logger = logging.getLogger(f"detector.{name}")
  93. logger.info(f"Starting detection process: {os.getpid()}")
  94. setproctitle(f"frigate.detector.{name}")
  95. listen()
  96. stop_event = mp.Event()
  97. def receiveSignal(signalNumber, frame):
  98. stop_event.set()
  99. signal.signal(signal.SIGTERM, receiveSignal)
  100. signal.signal(signal.SIGINT, receiveSignal)
  101. frame_manager = SharedMemoryFrameManager()
  102. object_detector = LocalObjectDetector(tf_device=tf_device, num_threads=num_threads)
  103. outputs = {}
  104. for name in out_events.keys():
  105. out_shm = mp.shared_memory.SharedMemory(name=f"out-{name}", create=False)
  106. out_np = np.ndarray((20,6), dtype=np.float32, buffer=out_shm.buf)
  107. outputs[name] = {
  108. 'shm': out_shm,
  109. 'np': out_np
  110. }
  111. while True:
  112. if stop_event.is_set():
  113. break
  114. try:
  115. connection_id = detection_queue.get(timeout=5)
  116. except queue.Empty:
  117. continue
  118. input_frame = frame_manager.get(connection_id, (1,model_shape[0],model_shape[1],3))
  119. if input_frame is None:
  120. continue
  121. # detect and send the output
  122. start.value = datetime.datetime.now().timestamp()
  123. detections = object_detector.detect_raw(input_frame)
  124. duration = datetime.datetime.now().timestamp()-start.value
  125. outputs[connection_id]['np'][:] = detections[:]
  126. out_events[connection_id].set()
  127. start.value = 0.0
  128. avg_speed.value = (avg_speed.value*9 + duration)/10
  129. class EdgeTPUProcess():
  130. def __init__(self, name, detection_queue, out_events, model_shape, tf_device=None, num_threads=3):
  131. self.name = name
  132. self.out_events = out_events
  133. self.detection_queue = detection_queue
  134. self.avg_inference_speed = mp.Value('d', 0.01)
  135. self.detection_start = mp.Value('d', 0.0)
  136. self.detect_process = None
  137. self.model_shape = model_shape
  138. self.tf_device = tf_device
  139. self.num_threads = num_threads
  140. self.start_or_restart()
  141. def stop(self):
  142. self.detect_process.terminate()
  143. logging.info("Waiting for detection process to exit gracefully...")
  144. self.detect_process.join(timeout=30)
  145. if self.detect_process.exitcode is None:
  146. logging.info("Detection process didnt exit. Force killing...")
  147. self.detect_process.kill()
  148. self.detect_process.join()
  149. def start_or_restart(self):
  150. self.detection_start.value = 0.0
  151. if (not self.detect_process is None) and self.detect_process.is_alive():
  152. self.stop()
  153. self.detect_process = mp.Process(target=run_detector, name=f"detector:{self.name}", args=(self.name, self.detection_queue, self.out_events, self.avg_inference_speed, self.detection_start, self.model_shape, self.tf_device, self.num_threads))
  154. self.detect_process.daemon = True
  155. self.detect_process.start()
  156. class RemoteObjectDetector():
  157. def __init__(self, name, labels, detection_queue, event, model_shape):
  158. self.labels = load_labels(labels)
  159. self.name = name
  160. self.fps = EventsPerSecond()
  161. self.detection_queue = detection_queue
  162. self.event = event
  163. self.shm = mp.shared_memory.SharedMemory(name=self.name, create=False)
  164. self.np_shm = np.ndarray((1,model_shape[0],model_shape[1],3), dtype=np.uint8, buffer=self.shm.buf)
  165. self.out_shm = mp.shared_memory.SharedMemory(name=f"out-{self.name}", create=False)
  166. self.out_np_shm = np.ndarray((20,6), dtype=np.float32, buffer=self.out_shm.buf)
  167. def detect(self, tensor_input, threshold=.4):
  168. detections = []
  169. # copy input to shared memory
  170. self.np_shm[:] = tensor_input[:]
  171. self.event.clear()
  172. self.detection_queue.put(self.name)
  173. result = self.event.wait(timeout=10.0)
  174. # if it timed out
  175. if result is None:
  176. return detections
  177. for d in self.out_np_shm:
  178. if d[1] < threshold:
  179. break
  180. detections.append((
  181. self.labels[int(d[0])],
  182. float(d[1]),
  183. (d[2], d[3], d[4], d[5])
  184. ))
  185. self.fps.update()
  186. return detections
  187. def cleanup(self):
  188. self.shm.unlink()
  189. self.out_shm.unlink()