object_processing.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. import json
  2. import hashlib
  3. import datetime
  4. import time
  5. import copy
  6. import cv2
  7. import threading
  8. import queue
  9. import copy
  10. import numpy as np
  11. from collections import Counter, defaultdict
  12. import itertools
  13. import matplotlib.pyplot as plt
  14. from frigate.util import draw_box_with_label, SharedMemoryFrameManager
  15. from frigate.edgetpu import load_labels
  16. from typing import Callable, Dict
  17. from statistics import mean, median
  18. PATH_TO_LABELS = '/labelmap.txt'
  19. LABELS = load_labels(PATH_TO_LABELS)
  20. cmap = plt.cm.get_cmap('tab10', len(LABELS.keys()))
  21. COLOR_MAP = {}
  22. for key, val in LABELS.items():
  23. COLOR_MAP[val] = tuple(int(round(255 * c)) for c in cmap(key)[:3])
  24. def zone_filtered(obj, object_config):
  25. object_name = obj['label']
  26. if object_name in object_config:
  27. obj_settings = object_config[object_name]
  28. # if the min area is larger than the
  29. # detected object, don't add it to detected objects
  30. if obj_settings.get('min_area',-1) > obj['area']:
  31. return True
  32. # if the detected object is larger than the
  33. # max area, don't add it to detected objects
  34. if obj_settings.get('max_area', 24000000) < obj['area']:
  35. return True
  36. # if the score is lower than the threshold, skip
  37. if obj_settings.get('threshold', 0) > obj['computed_score']:
  38. return True
  39. return False
  40. # Maintains the state of a camera
  41. class CameraState():
  42. def __init__(self, name, config, frame_manager):
  43. self.name = name
  44. self.config = config
  45. self.frame_manager = frame_manager
  46. self.best_objects = {}
  47. self.object_status = defaultdict(lambda: 'OFF')
  48. self.tracked_objects = {}
  49. self.zone_objects = defaultdict(lambda: [])
  50. self._current_frame = np.zeros(self.config['frame_shape'], np.uint8)
  51. self.current_frame_lock = threading.Lock()
  52. self.current_frame_time = 0.0
  53. self.previous_frame_id = None
  54. self.callbacks = defaultdict(lambda: [])
  55. def get_current_frame(self):
  56. with self.current_frame_lock:
  57. return np.copy(self._current_frame)
  58. def false_positive(self, obj):
  59. # once a true positive, always a true positive
  60. if not obj.get('false_positive', True):
  61. return False
  62. threshold = self.config['objects'].get('filters', {}).get(obj['label'], {}).get('threshold', 0.85)
  63. if obj['computed_score'] < threshold:
  64. return True
  65. return False
  66. def compute_score(self, obj):
  67. scores = obj['score_history'][:]
  68. # pad with zeros if you dont have at least 3 scores
  69. if len(scores) < 3:
  70. scores += [0.0]*(3 - len(scores))
  71. return median(scores)
  72. def on(self, event_type: str, callback: Callable[[Dict], None]):
  73. self.callbacks[event_type].append(callback)
  74. def update(self, frame_time, tracked_objects):
  75. self.current_frame_time = frame_time
  76. # get the new frame and delete the old frame
  77. frame_id = f"{self.name}{frame_time}"
  78. with self.current_frame_lock:
  79. self._current_frame = self.frame_manager.get(frame_id, (self.config['frame_shape'][0]*3//2, self.config['frame_shape'][1]))
  80. if not self.previous_frame_id is None:
  81. self.frame_manager.delete(self.previous_frame_id)
  82. self.previous_frame_id = frame_id
  83. current_ids = tracked_objects.keys()
  84. previous_ids = self.tracked_objects.keys()
  85. removed_ids = list(set(previous_ids).difference(current_ids))
  86. new_ids = list(set(current_ids).difference(previous_ids))
  87. updated_ids = list(set(current_ids).intersection(previous_ids))
  88. for id in new_ids:
  89. self.tracked_objects[id] = tracked_objects[id]
  90. self.tracked_objects[id]['zones'] = []
  91. # start the score history
  92. self.tracked_objects[id]['score_history'] = [self.tracked_objects[id]['score']]
  93. # calculate if this is a false positive
  94. self.tracked_objects[id]['computed_score'] = self.compute_score(self.tracked_objects[id])
  95. self.tracked_objects[id]['false_positive'] = self.false_positive(self.tracked_objects[id])
  96. # call event handlers
  97. for c in self.callbacks['start']:
  98. c(self.name, tracked_objects[id])
  99. for id in updated_ids:
  100. self.tracked_objects[id].update(tracked_objects[id])
  101. # if the object is not in the current frame, add a 0.0 to the score history
  102. if self.tracked_objects[id]['frame_time'] != self.current_frame_time:
  103. self.tracked_objects[id]['score_history'].append(0.0)
  104. else:
  105. self.tracked_objects[id]['score_history'].append(self.tracked_objects[id]['score'])
  106. # only keep the last 10 scores
  107. if len(self.tracked_objects[id]['score_history']) > 10:
  108. self.tracked_objects[id]['score_history'] = self.tracked_objects[id]['score_history'][-10:]
  109. # calculate if this is a false positive
  110. self.tracked_objects[id]['computed_score'] = self.compute_score(self.tracked_objects[id])
  111. self.tracked_objects[id]['false_positive'] = self.false_positive(self.tracked_objects[id])
  112. # call event handlers
  113. for c in self.callbacks['update']:
  114. c(self.name, self.tracked_objects[id])
  115. for id in removed_ids:
  116. # publish events to mqtt
  117. self.tracked_objects[id]['end_time'] = frame_time
  118. for c in self.callbacks['end']:
  119. c(self.name, self.tracked_objects[id])
  120. del self.tracked_objects[id]
  121. # check to see if the objects are in any zones
  122. for obj in self.tracked_objects.values():
  123. current_zones = []
  124. bottom_center = (obj['centroid'][0], obj['box'][3])
  125. # check each zone
  126. for name, zone in self.config['zones'].items():
  127. contour = zone['contour']
  128. # check if the object is in the zone and not filtered
  129. if (cv2.pointPolygonTest(contour, bottom_center, False) >= 0
  130. and not zone_filtered(obj, zone.get('filters', {}))):
  131. current_zones.append(name)
  132. obj['zones'] = current_zones
  133. # draw on the frame
  134. if not self._current_frame is None:
  135. # draw the bounding boxes on the frame
  136. for obj in self.tracked_objects.values():
  137. thickness = 2
  138. color = COLOR_MAP[obj['label']]
  139. if obj['frame_time'] != frame_time:
  140. thickness = 1
  141. color = (255,0,0)
  142. # draw the bounding boxes on the frame
  143. box = obj['box']
  144. draw_box_with_label(self._current_frame, box[0], box[1], box[2], box[3], obj['label'], f"{int(obj['score']*100)}% {int(obj['area'])}", thickness=thickness, color=color)
  145. # draw the regions on the frame
  146. region = obj['region']
  147. cv2.rectangle(self._current_frame, (region[0], region[1]), (region[2], region[3]), (0,255,0), 1)
  148. if self.config['snapshots']['show_timestamp']:
  149. time_to_show = datetime.datetime.fromtimestamp(frame_time).strftime("%m/%d/%Y %H:%M:%S")
  150. cv2.putText(self._current_frame, time_to_show, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, fontScale=.8, color=(255, 255, 255), thickness=2)
  151. if self.config['snapshots']['draw_zones']:
  152. for name, zone in self.config['zones'].items():
  153. thickness = 8 if any([name in obj['zones'] for obj in self.tracked_objects.values()]) else 2
  154. cv2.drawContours(self._current_frame, [zone['contour']], -1, zone['color'], thickness)
  155. # maintain best objects
  156. for obj in self.tracked_objects.values():
  157. object_type = obj['label']
  158. # if the object wasn't seen on the current frame, skip it
  159. if obj['frame_time'] != self.current_frame_time or obj['false_positive']:
  160. continue
  161. obj_copy = copy.deepcopy(obj)
  162. if object_type in self.best_objects:
  163. current_best = self.best_objects[object_type]
  164. now = datetime.datetime.now().timestamp()
  165. # if the object is a higher score than the current best score
  166. # or the current object is older than desired, use the new object
  167. if obj_copy['score'] > current_best['score'] or (now - current_best['frame_time']) > self.config.get('best_image_timeout', 60):
  168. obj_copy['frame'] = np.copy(self._current_frame)
  169. self.best_objects[object_type] = obj_copy
  170. for c in self.callbacks['snapshot']:
  171. c(self.name, self.best_objects[object_type])
  172. else:
  173. obj_copy['frame'] = np.copy(self._current_frame)
  174. self.best_objects[object_type] = obj_copy
  175. for c in self.callbacks['snapshot']:
  176. c(self.name, self.best_objects[object_type])
  177. # update overall camera state for each object type
  178. obj_counter = Counter()
  179. for obj in self.tracked_objects.values():
  180. if not obj['false_positive']:
  181. obj_counter[obj['label']] += 1
  182. # report on detected objects
  183. for obj_name, count in obj_counter.items():
  184. new_status = 'ON' if count > 0 else 'OFF'
  185. if new_status != self.object_status[obj_name]:
  186. self.object_status[obj_name] = new_status
  187. for c in self.callbacks['object_status']:
  188. c(self.name, obj_name, new_status)
  189. # expire any objects that are ON and no longer detected
  190. expired_objects = [obj_name for obj_name, status in self.object_status.items() if status == 'ON' and not obj_name in obj_counter]
  191. for obj_name in expired_objects:
  192. self.object_status[obj_name] = 'OFF'
  193. for c in self.callbacks['object_status']:
  194. c(self.name, obj_name, 'OFF')
  195. for c in self.callbacks['snapshot']:
  196. c(self.name, self.best_objects[obj_name])
  197. class TrackedObjectProcessor(threading.Thread):
  198. def __init__(self, camera_config, client, topic_prefix, tracked_objects_queue, event_queue, stop_event):
  199. threading.Thread.__init__(self)
  200. self.camera_config = camera_config
  201. self.client = client
  202. self.topic_prefix = topic_prefix
  203. self.tracked_objects_queue = tracked_objects_queue
  204. self.event_queue = event_queue
  205. self.stop_event = stop_event
  206. self.camera_states: Dict[str, CameraState] = {}
  207. self.frame_manager = SharedMemoryFrameManager()
  208. def start(camera, obj):
  209. # publish events to mqtt
  210. self.client.publish(f"{self.topic_prefix}/{camera}/events/start", json.dumps(obj), retain=False)
  211. self.event_queue.put(('start', camera, obj))
  212. def update(camera, obj):
  213. pass
  214. def end(camera, obj):
  215. self.client.publish(f"{self.topic_prefix}/{camera}/events/end", json.dumps(obj), retain=False)
  216. self.event_queue.put(('end', camera, obj))
  217. def snapshot(camera, obj):
  218. if not 'frame' in obj:
  219. return
  220. best_frame = cv2.cvtColor(obj['frame'], cv2.COLOR_RGB2BGR)
  221. mqtt_config = self.camera_config[camera].get('mqtt', {'crop_to_region': False})
  222. if mqtt_config.get('crop_to_region'):
  223. region = obj['region']
  224. best_frame = best_frame[region[1]:region[3], region[0]:region[2]]
  225. if 'snapshot_height' in mqtt_config:
  226. height = int(mqtt_config['snapshot_height'])
  227. width = int(height*best_frame.shape[1]/best_frame.shape[0])
  228. best_frame = cv2.resize(best_frame, dsize=(width, height), interpolation=cv2.INTER_AREA)
  229. ret, jpg = cv2.imencode('.jpg', best_frame)
  230. if ret:
  231. jpg_bytes = jpg.tobytes()
  232. self.client.publish(f"{self.topic_prefix}/{camera}/{obj['label']}/snapshot", jpg_bytes, retain=True)
  233. def object_status(camera, object_name, status):
  234. self.client.publish(f"{self.topic_prefix}/{camera}/{object_name}", status, retain=False)
  235. for camera in self.camera_config.keys():
  236. camera_state = CameraState(camera, self.camera_config[camera], self.frame_manager)
  237. camera_state.on('start', start)
  238. camera_state.on('update', update)
  239. camera_state.on('end', end)
  240. camera_state.on('snapshot', snapshot)
  241. camera_state.on('object_status', object_status)
  242. self.camera_states[camera] = camera_state
  243. self.camera_data = defaultdict(lambda: {
  244. 'best_objects': {},
  245. 'object_status': defaultdict(lambda: defaultdict(lambda: 'OFF')),
  246. 'tracked_objects': {},
  247. 'current_frame': np.zeros((720,1280,3), np.uint8),
  248. 'current_frame_time': 0.0,
  249. 'object_id': None
  250. })
  251. # {
  252. # 'zone_name': {
  253. # 'person': ['camera_1', 'camera_2']
  254. # }
  255. # }
  256. self.zone_data = defaultdict(lambda: defaultdict(lambda: set()))
  257. # set colors for zones
  258. all_zone_names = set([zone for config in self.camera_config.values() for zone in config['zones'].keys()])
  259. zone_colors = {}
  260. colors = plt.cm.get_cmap('tab10', len(all_zone_names))
  261. for i, zone in enumerate(all_zone_names):
  262. zone_colors[zone] = tuple(int(round(255 * c)) for c in colors(i)[:3])
  263. # create zone contours
  264. for camera_config in self.camera_config.values():
  265. for zone_name, zone_config in camera_config['zones'].items():
  266. zone_config['color'] = zone_colors[zone_name]
  267. coordinates = zone_config['coordinates']
  268. if isinstance(coordinates, list):
  269. zone_config['contour'] = np.array([[int(p.split(',')[0]), int(p.split(',')[1])] for p in coordinates])
  270. elif isinstance(coordinates, str):
  271. points = coordinates.split(',')
  272. zone_config['contour'] = np.array([[int(points[i]), int(points[i+1])] for i in range(0, len(points), 2)])
  273. else:
  274. print(f"Unable to parse zone coordinates for {zone_name} - {camera}")
  275. def get_best(self, camera, label):
  276. best_objects = self.camera_states[camera].best_objects
  277. if label in best_objects:
  278. return best_objects[label]
  279. else:
  280. return {}
  281. def get_current_frame(self, camera):
  282. return self.camera_states[camera].get_current_frame()
  283. def run(self):
  284. while True:
  285. if self.stop_event.is_set():
  286. print(f"Exiting object processor...")
  287. break
  288. try:
  289. camera, frame_time, current_tracked_objects = self.tracked_objects_queue.get(True, 10)
  290. except queue.Empty:
  291. continue
  292. camera_state = self.camera_states[camera]
  293. camera_state.update(frame_time, current_tracked_objects)
  294. # update zone status for each label
  295. for zone in camera_state.config['zones'].keys():
  296. # get labels for current camera and all labels in current zone
  297. labels_for_camera = set([obj['label'] for obj in camera_state.tracked_objects.values() if zone in obj['zones'] and not obj['false_positive']])
  298. labels_to_check = labels_for_camera | set(self.zone_data[zone].keys())
  299. # for each label in zone
  300. for label in labels_to_check:
  301. camera_list = self.zone_data[zone][label]
  302. # remove or add the camera to the list for the current label
  303. previous_state = len(camera_list) > 0
  304. if label in labels_for_camera:
  305. camera_list.add(camera_state.name)
  306. elif camera_state.name in camera_list:
  307. camera_list.remove(camera_state.name)
  308. new_state = len(camera_list) > 0
  309. # if the value is changing, send over MQTT
  310. if previous_state == False and new_state == True:
  311. self.client.publish(f"{self.topic_prefix}/{zone}/{label}", 'ON', retain=False)
  312. elif previous_state == True and new_state == False:
  313. self.client.publish(f"{self.topic_prefix}/{zone}/{label}", 'OFF', retain=False)