object_processing.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. import json
  2. import hashlib
  3. import datetime
  4. import time
  5. import copy
  6. import cv2
  7. import threading
  8. import queue
  9. import numpy as np
  10. from collections import Counter, defaultdict
  11. import itertools
  12. import pyarrow.plasma as plasma
  13. import matplotlib.pyplot as plt
  14. from frigate.util import draw_box_with_label, PlasmaFrameManager
  15. from frigate.edgetpu import load_labels
  16. from typing import Callable, Dict
  17. from statistics import mean, median
  18. PATH_TO_LABELS = '/labelmap.txt'
  19. LABELS = load_labels(PATH_TO_LABELS)
  20. cmap = plt.cm.get_cmap('tab10', len(LABELS.keys()))
  21. COLOR_MAP = {}
  22. for key, val in LABELS.items():
  23. COLOR_MAP[val] = tuple(int(round(255 * c)) for c in cmap(key)[:3])
  24. def zone_filtered(obj, object_config):
  25. object_name = obj['label']
  26. object_filters = object_config.get('filters', {})
  27. if object_name in object_filters:
  28. obj_settings = object_filters[object_name]
  29. # if the min area is larger than the
  30. # detected object, don't add it to detected objects
  31. if obj_settings.get('min_area',-1) > obj['area']:
  32. return True
  33. # if the detected object is larger than the
  34. # max area, don't add it to detected objects
  35. if obj_settings.get('max_area', 24000000) < obj['area']:
  36. return True
  37. # if the score is lower than the threshold, skip
  38. if obj_settings.get('threshold', 0) > obj['computed_score']:
  39. return True
  40. return False
  41. # Maintains the state of a camera
  42. class CameraState():
  43. def __init__(self, name, config, frame_manager):
  44. self.name = name
  45. self.config = config
  46. self.frame_manager = frame_manager
  47. self.best_objects = {}
  48. self.object_status = defaultdict(lambda: 'OFF')
  49. self.tracked_objects = {}
  50. self.zone_objects = defaultdict(lambda: [])
  51. self.current_frame = np.zeros((720,1280,3), np.uint8)
  52. self.current_frame_time = 0.0
  53. self.previous_frame_id = None
  54. self.callbacks = defaultdict(lambda: [])
  55. def false_positive(self, obj):
  56. # once a true positive, always a true positive
  57. if not obj.get('false_positive', True):
  58. return False
  59. threshold = self.config['objects'].get('filters', {}).get(obj['label'], {}).get('threshold', 0.85)
  60. if obj['computed_score'] < threshold:
  61. return True
  62. return False
  63. def compute_score(self, obj):
  64. scores = obj['score_history'][:]
  65. # pad with zeros if you dont have at least 3 scores
  66. if len(scores) < 3:
  67. scores += [0.0]*(3 - len(scores))
  68. return median(scores)
  69. def on(self, event_type: str, callback: Callable[[Dict], None]):
  70. self.callbacks[event_type].append(callback)
  71. def update(self, frame_time, tracked_objects):
  72. self.current_frame_time = frame_time
  73. # get the new frame and delete the old frame
  74. frame_id = f"{self.name}{frame_time}"
  75. self.current_frame = self.frame_manager.get(frame_id)
  76. if not self.previous_frame_id is None:
  77. self.frame_manager.delete(self.previous_frame_id)
  78. self.previous_frame_id = frame_id
  79. current_ids = tracked_objects.keys()
  80. previous_ids = self.tracked_objects.keys()
  81. removed_ids = list(set(previous_ids).difference(current_ids))
  82. new_ids = list(set(current_ids).difference(previous_ids))
  83. updated_ids = list(set(current_ids).intersection(previous_ids))
  84. for id in new_ids:
  85. self.tracked_objects[id] = tracked_objects[id]
  86. self.tracked_objects[id]['zones'] = []
  87. # start the score history
  88. self.tracked_objects[id]['score_history'] = [self.tracked_objects[id]['score']]
  89. # calculate if this is a false positive
  90. self.tracked_objects[id]['computed_score'] = self.compute_score(self.tracked_objects[id])
  91. self.tracked_objects[id]['false_positive'] = self.false_positive(self.tracked_objects[id])
  92. # call event handlers
  93. for c in self.callbacks['start']:
  94. c(self.name, tracked_objects[id])
  95. for id in updated_ids:
  96. self.tracked_objects[id].update(tracked_objects[id])
  97. # if the object is not in the current frame, add a 0.0 to the score history
  98. if self.tracked_objects[id]['frame_time'] != self.current_frame_time:
  99. self.tracked_objects[id]['score_history'].append(0.0)
  100. else:
  101. self.tracked_objects[id]['score_history'].append(self.tracked_objects[id]['score'])
  102. # only keep the last 10 scores
  103. if len(self.tracked_objects[id]['score_history']) > 10:
  104. self.tracked_objects[id]['score_history'] = self.tracked_objects[id]['score_history'][-10:]
  105. # calculate if this is a false positive
  106. self.tracked_objects[id]['computed_score'] = self.compute_score(self.tracked_objects[id])
  107. self.tracked_objects[id]['false_positive'] = self.false_positive(self.tracked_objects[id])
  108. # call event handlers
  109. for c in self.callbacks['update']:
  110. c(self.name, self.tracked_objects[id])
  111. for id in removed_ids:
  112. # publish events to mqtt
  113. self.tracked_objects[id]['end_time'] = frame_time
  114. for c in self.callbacks['end']:
  115. c(self.name, self.tracked_objects[id])
  116. del self.tracked_objects[id]
  117. # check to see if the objects are in any zones
  118. for obj in self.tracked_objects.values():
  119. current_zones = []
  120. bottom_center = (obj['centroid'][0], obj['box'][3])
  121. # check each zone
  122. for name, zone in self.config['zones'].items():
  123. contour = zone['contour']
  124. # check if the object is in the zone and not filtered
  125. if (cv2.pointPolygonTest(contour, bottom_center, False) >= 0
  126. and not zone_filtered(obj, zone.get('filters', {}))):
  127. current_zones.append(name)
  128. obj['zones'] = current_zones
  129. # draw on the frame
  130. if not self.current_frame is None:
  131. # draw the bounding boxes on the frame
  132. for obj in self.tracked_objects.values():
  133. thickness = 2
  134. color = COLOR_MAP[obj['label']]
  135. if obj['frame_time'] != frame_time:
  136. thickness = 1
  137. color = (255,0,0)
  138. # draw the bounding boxes on the frame
  139. box = obj['box']
  140. draw_box_with_label(self.current_frame, box[0], box[1], box[2], box[3], obj['label'], f"{int(obj['score']*100)}% {int(obj['area'])}", thickness=thickness, color=color)
  141. # draw the regions on the frame
  142. region = obj['region']
  143. cv2.rectangle(self.current_frame, (region[0], region[1]), (region[2], region[3]), (0,255,0), 1)
  144. if self.config['snapshots']['show_timestamp']:
  145. time_to_show = datetime.datetime.fromtimestamp(frame_time).strftime("%m/%d/%Y %H:%M:%S")
  146. cv2.putText(self.current_frame, time_to_show, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, fontScale=.8, color=(255, 255, 255), thickness=2)
  147. if self.config['snapshots']['draw_zones']:
  148. for name, zone in self.config['zones'].items():
  149. thickness = 8 if any([name in obj['zones'] for obj in self.tracked_objects.values()]) else 2
  150. cv2.drawContours(self.current_frame, [zone['contour']], -1, zone['color'], thickness)
  151. # maintain best objects
  152. for obj in self.tracked_objects.values():
  153. object_type = obj['label']
  154. # if the object wasn't seen on the current frame, skip it
  155. if obj['frame_time'] != self.current_frame_time or obj['false_positive']:
  156. continue
  157. if object_type in self.best_objects:
  158. current_best = self.best_objects[object_type]
  159. now = datetime.datetime.now().timestamp()
  160. # if the object is a higher score than the current best score
  161. # or the current object is more than 1 minute old, use the new object
  162. if obj['score'] > current_best['score'] or (now - current_best['frame_time']) > 60:
  163. obj['frame'] = np.copy(self.current_frame)
  164. self.best_objects[object_type] = obj
  165. for c in self.callbacks['snapshot']:
  166. c(self.name, self.best_objects[object_type])
  167. else:
  168. obj['frame'] = np.copy(self.current_frame)
  169. self.best_objects[object_type] = obj
  170. for c in self.callbacks['snapshot']:
  171. c(self.name, self.best_objects[object_type])
  172. # update overall camera state for each object type
  173. obj_counter = Counter()
  174. for obj in self.tracked_objects.values():
  175. if not obj['false_positive']:
  176. obj_counter[obj['label']] += 1
  177. # report on detected objects
  178. for obj_name, count in obj_counter.items():
  179. new_status = 'ON' if count > 0 else 'OFF'
  180. if new_status != self.object_status[obj_name]:
  181. self.object_status[obj_name] = new_status
  182. for c in self.callbacks['object_status']:
  183. c(self.name, obj_name, new_status)
  184. # expire any objects that are ON and no longer detected
  185. expired_objects = [obj_name for obj_name, status in self.object_status.items() if status == 'ON' and not obj_name in obj_counter]
  186. for obj_name in expired_objects:
  187. self.object_status[obj_name] = 'OFF'
  188. for c in self.callbacks['object_status']:
  189. c(self.name, obj_name, 'OFF')
  190. for c in self.callbacks['snapshot']:
  191. c(self.name, self.best_objects[obj_name])
  192. class TrackedObjectProcessor(threading.Thread):
  193. def __init__(self, camera_config, zone_config, client, topic_prefix, tracked_objects_queue, event_queue, stop_event):
  194. threading.Thread.__init__(self)
  195. self.camera_config = camera_config
  196. self.zone_config = zone_config
  197. self.client = client
  198. self.topic_prefix = topic_prefix
  199. self.tracked_objects_queue = tracked_objects_queue
  200. self.event_queue = event_queue
  201. self.stop_event = stop_event
  202. self.camera_states: Dict[str, CameraState] = {}
  203. self.plasma_client = PlasmaFrameManager(self.stop_event)
  204. def start(camera, obj):
  205. # publish events to mqtt
  206. self.client.publish(f"{self.topic_prefix}/{camera}/events/start", json.dumps({x: obj[x] for x in obj if x not in ['frame']}), retain=False)
  207. self.event_queue.put(('start', camera, obj))
  208. def update(camera, obj):
  209. pass
  210. def end(camera, obj):
  211. self.client.publish(f"{self.topic_prefix}/{camera}/events/end", json.dumps({x: obj[x] for x in obj if x not in ['frame']}), retain=False)
  212. self.event_queue.put(('end', camera, obj))
  213. def snapshot(camera, obj):
  214. best_frame = cv2.cvtColor(obj['frame'], cv2.COLOR_RGB2BGR)
  215. ret, jpg = cv2.imencode('.jpg', best_frame)
  216. if ret:
  217. jpg_bytes = jpg.tobytes()
  218. self.client.publish(f"{self.topic_prefix}/{camera}/{obj['label']}/snapshot", jpg_bytes, retain=True)
  219. def object_status(camera, object_name, status):
  220. self.client.publish(f"{self.topic_prefix}/{camera}/{object_name}", status, retain=False)
  221. for camera in self.camera_config.keys():
  222. camera_state = CameraState(camera, self.camera_config[camera], self.plasma_client)
  223. camera_state.on('start', start)
  224. camera_state.on('update', update)
  225. camera_state.on('end', end)
  226. camera_state.on('snapshot', snapshot)
  227. camera_state.on('object_status', object_status)
  228. self.camera_states[camera] = camera_state
  229. self.camera_data = defaultdict(lambda: {
  230. 'best_objects': {},
  231. 'object_status': defaultdict(lambda: defaultdict(lambda: 'OFF')),
  232. 'tracked_objects': {},
  233. 'current_frame': np.zeros((720,1280,3), np.uint8),
  234. 'current_frame_time': 0.0,
  235. 'object_id': None
  236. })
  237. # {
  238. # 'zone_name': {
  239. # 'person': ['camera_1', 'camera_2']
  240. # }
  241. # }
  242. self.zone_data = defaultdict(lambda: defaultdict(lambda: set()))
  243. # set colors for zones
  244. zone_colors = {}
  245. colors = plt.cm.get_cmap('tab10', len(self.zone_config.keys()))
  246. for i, zone in enumerate(self.zone_config.keys()):
  247. zone_colors[zone] = tuple(int(round(255 * c)) for c in colors(i)[:3])
  248. # create zone contours
  249. for zone_name, config in zone_config.items():
  250. for camera, camera_zone_config in config.items():
  251. camera_zone = {}
  252. camera_zone['color'] = zone_colors[zone_name]
  253. coordinates = camera_zone_config['coordinates']
  254. if isinstance(coordinates, list):
  255. camera_zone['contour'] = np.array([[int(p.split(',')[0]), int(p.split(',')[1])] for p in coordinates])
  256. elif isinstance(coordinates, str):
  257. points = coordinates.split(',')
  258. camera_zone['contour'] = np.array([[int(points[i]), int(points[i+1])] for i in range(0, len(points), 2)])
  259. else:
  260. print(f"Unable to parse zone coordinates for {zone_name} - {camera}")
  261. self.camera_config[camera]['zones'][zone_name] = camera_zone
  262. def get_best(self, camera, label):
  263. best_objects = self.camera_states[camera].best_objects
  264. if label in best_objects:
  265. return best_objects[label]['frame']
  266. else:
  267. return None
  268. def get_current_frame(self, camera):
  269. return self.camera_states[camera].current_frame
  270. def run(self):
  271. while True:
  272. if self.stop_event.is_set():
  273. print(f"Exiting object processor...")
  274. break
  275. try:
  276. camera, frame_time, current_tracked_objects = self.tracked_objects_queue.get(True, 10)
  277. except queue.Empty:
  278. continue
  279. camera_state = self.camera_states[camera]
  280. camera_state.update(frame_time, current_tracked_objects)
  281. # update zone status for each label
  282. for zone in camera_state.config['zones'].keys():
  283. # get labels for current camera and all labels in current zone
  284. labels_for_camera = set([obj['label'] for obj in camera_state.tracked_objects.values() if zone in obj['zones'] and not obj['false_positive']])
  285. labels_to_check = labels_for_camera | set(self.zone_data[zone].keys())
  286. # for each label in zone
  287. for label in labels_to_check:
  288. camera_list = self.zone_data[zone][label]
  289. # remove or add the camera to the list for the current label
  290. previous_state = len(camera_list) > 0
  291. if label in labels_for_camera:
  292. camera_list.add(camera_state.name)
  293. elif camera_state.name in camera_list:
  294. camera_list.remove(camera_state.name)
  295. new_state = len(camera_list) > 0
  296. # if the value is changing, send over MQTT
  297. if previous_state == False and new_state == True:
  298. self.client.publish(f"{self.topic_prefix}/{zone}/{label}", 'ON', retain=False)
  299. elif previous_state == True and new_state == False:
  300. self.client.publish(f"{self.topic_prefix}/{zone}/{label}", 'OFF', retain=False)