edgetpu.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import datetime
  2. import hashlib
  3. import logging
  4. import multiprocessing as mp
  5. import os
  6. import queue
  7. import threading
  8. import signal
  9. from abc import ABC, abstractmethod
  10. from multiprocessing.connection import Connection
  11. from typing import Dict
  12. import numpy as np
  13. import tflite_runtime.interpreter as tflite
  14. from tflite_runtime.interpreter import load_delegate
  15. from frigate.util import EventsPerSecond, SharedMemoryFrameManager, listen
  16. logger = logging.getLogger(__name__)
  17. def load_labels(path, encoding='utf-8'):
  18. """Loads labels from file (with or without index numbers).
  19. Args:
  20. path: path to label file.
  21. encoding: label file encoding.
  22. Returns:
  23. Dictionary mapping indices to labels.
  24. """
  25. with open(path, 'r', encoding=encoding) as f:
  26. lines = f.readlines()
  27. if not lines:
  28. return {}
  29. if lines[0].split(' ', maxsplit=1)[0].isdigit():
  30. pairs = [line.split(' ', maxsplit=1) for line in lines]
  31. return {int(index): label.strip() for index, label in pairs}
  32. else:
  33. return {index: line.strip() for index, line in enumerate(lines)}
  34. class ObjectDetector(ABC):
  35. @abstractmethod
  36. def detect(self, tensor_input, threshold = .4):
  37. pass
  38. class LocalObjectDetector(ObjectDetector):
  39. def __init__(self, tf_device=None, labels=None):
  40. self.fps = EventsPerSecond()
  41. if labels is None:
  42. self.labels = {}
  43. else:
  44. self.labels = load_labels(labels)
  45. device_config = {"device": "usb"}
  46. if not tf_device is None:
  47. device_config = {"device": tf_device}
  48. edge_tpu_delegate = None
  49. if tf_device != 'cpu':
  50. try:
  51. logger.info(f"Attempting to load TPU as {device_config['device']}")
  52. edge_tpu_delegate = load_delegate('libedgetpu.so.1.0', device_config)
  53. logger.info("TPU found")
  54. except ValueError:
  55. logger.info("No EdgeTPU detected. Falling back to CPU.")
  56. if edge_tpu_delegate is None:
  57. self.interpreter = tflite.Interpreter(
  58. model_path='/cpu_model.tflite')
  59. else:
  60. self.interpreter = tflite.Interpreter(
  61. model_path='/edgetpu_model.tflite',
  62. experimental_delegates=[edge_tpu_delegate])
  63. self.interpreter.allocate_tensors()
  64. self.tensor_input_details = self.interpreter.get_input_details()
  65. self.tensor_output_details = self.interpreter.get_output_details()
  66. def detect(self, tensor_input, threshold=.4):
  67. detections = []
  68. raw_detections = self.detect_raw(tensor_input)
  69. for d in raw_detections:
  70. if d[1] < threshold:
  71. break
  72. detections.append((
  73. self.labels[int(d[0])],
  74. float(d[1]),
  75. (d[2], d[3], d[4], d[5])
  76. ))
  77. self.fps.update()
  78. return detections
  79. def detect_raw(self, tensor_input):
  80. self.interpreter.set_tensor(self.tensor_input_details[0]['index'], tensor_input)
  81. self.interpreter.invoke()
  82. boxes = np.squeeze(self.interpreter.get_tensor(self.tensor_output_details[0]['index']))
  83. label_codes = np.squeeze(self.interpreter.get_tensor(self.tensor_output_details[1]['index']))
  84. scores = np.squeeze(self.interpreter.get_tensor(self.tensor_output_details[2]['index']))
  85. detections = np.zeros((20,6), np.float32)
  86. for i, score in enumerate(scores):
  87. detections[i] = [label_codes[i], score, boxes[i][0], boxes[i][1], boxes[i][2], boxes[i][3]]
  88. return detections
  89. def run_detector(name: str, detection_queue: mp.Queue, out_events: Dict[str, mp.Event], avg_speed, start, tf_device):
  90. threading.current_thread().name = f"detector:{name}"
  91. logger = logging.getLogger(f"detector.{name}")
  92. logger.info(f"Starting detection process: {os.getpid()}")
  93. listen()
  94. stop_event = mp.Event()
  95. def receiveSignal(signalNumber, frame):
  96. stop_event.set()
  97. signal.signal(signal.SIGTERM, receiveSignal)
  98. signal.signal(signal.SIGINT, receiveSignal)
  99. frame_manager = SharedMemoryFrameManager()
  100. object_detector = LocalObjectDetector(tf_device=tf_device)
  101. outputs = {}
  102. for name in out_events.keys():
  103. out_shm = mp.shared_memory.SharedMemory(name=f"out-{name}", create=False)
  104. out_np = np.ndarray((20,6), dtype=np.float32, buffer=out_shm.buf)
  105. outputs[name] = {
  106. 'shm': out_shm,
  107. 'np': out_np
  108. }
  109. while True:
  110. if stop_event.is_set():
  111. break
  112. try:
  113. connection_id = detection_queue.get(timeout=5)
  114. except queue.Empty:
  115. continue
  116. input_frame = frame_manager.get(connection_id, (1,300,300,3))
  117. if input_frame is None:
  118. continue
  119. # detect and send the output
  120. start.value = datetime.datetime.now().timestamp()
  121. detections = object_detector.detect_raw(input_frame)
  122. duration = datetime.datetime.now().timestamp()-start.value
  123. outputs[connection_id]['np'][:] = detections[:]
  124. out_events[connection_id].set()
  125. start.value = 0.0
  126. avg_speed.value = (avg_speed.value*9 + duration)/10
  127. class EdgeTPUProcess():
  128. def __init__(self, name, detection_queue, out_events, tf_device=None):
  129. self.name = name
  130. self.out_events = out_events
  131. self.detection_queue = detection_queue
  132. self.avg_inference_speed = mp.Value('d', 0.01)
  133. self.detection_start = mp.Value('d', 0.0)
  134. self.detect_process = None
  135. self.tf_device = tf_device
  136. self.start_or_restart()
  137. def stop(self):
  138. self.detect_process.terminate()
  139. logging.info("Waiting for detection process to exit gracefully...")
  140. self.detect_process.join(timeout=30)
  141. if self.detect_process.exitcode is None:
  142. logging.info("Detection process didnt exit. Force killing...")
  143. self.detect_process.kill()
  144. self.detect_process.join()
  145. def start_or_restart(self):
  146. self.detection_start.value = 0.0
  147. if (not self.detect_process is None) and self.detect_process.is_alive():
  148. self.stop()
  149. self.detect_process = mp.Process(target=run_detector, name=f"detector:{self.name}", args=(self.name, self.detection_queue, self.out_events, self.avg_inference_speed, self.detection_start, self.tf_device))
  150. self.detect_process.daemon = True
  151. self.detect_process.start()
  152. class RemoteObjectDetector():
  153. def __init__(self, name, labels, detection_queue, event):
  154. self.labels = load_labels(labels)
  155. self.name = name
  156. self.fps = EventsPerSecond()
  157. self.detection_queue = detection_queue
  158. self.event = event
  159. self.shm = mp.shared_memory.SharedMemory(name=self.name, create=False)
  160. self.np_shm = np.ndarray((1,300,300,3), dtype=np.uint8, buffer=self.shm.buf)
  161. self.out_shm = mp.shared_memory.SharedMemory(name=f"out-{self.name}", create=False)
  162. self.out_np_shm = np.ndarray((20,6), dtype=np.float32, buffer=self.out_shm.buf)
  163. def detect(self, tensor_input, threshold=.4):
  164. detections = []
  165. # copy input to shared memory
  166. self.np_shm[:] = tensor_input[:]
  167. self.event.clear()
  168. self.detection_queue.put(self.name)
  169. result = self.event.wait(timeout=10.0)
  170. # if it timed out
  171. if result is None:
  172. return detections
  173. for d in self.out_np_shm:
  174. if d[1] < threshold:
  175. break
  176. detections.append((
  177. self.labels[int(d[0])],
  178. float(d[1]),
  179. (d[2], d[3], d[4], d[5])
  180. ))
  181. self.fps.update()
  182. return detections
  183. def cleanup(self):
  184. self.shm.unlink()
  185. self.out_shm.unlink()