object_processing.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. import json
  2. import hashlib
  3. import datetime
  4. import time
  5. import copy
  6. import cv2
  7. import threading
  8. import queue
  9. import copy
  10. import numpy as np
  11. from collections import Counter, defaultdict
  12. import itertools
  13. import matplotlib.pyplot as plt
  14. from frigate.util import draw_box_with_label, SharedMemoryFrameManager
  15. from frigate.edgetpu import load_labels
  16. from typing import Callable, Dict
  17. from statistics import mean, median
  18. PATH_TO_LABELS = '/labelmap.txt'
  19. LABELS = load_labels(PATH_TO_LABELS)
  20. cmap = plt.cm.get_cmap('tab10', len(LABELS.keys()))
  21. COLOR_MAP = {}
  22. for key, val in LABELS.items():
  23. COLOR_MAP[val] = tuple(int(round(255 * c)) for c in cmap(key)[:3])
  24. def zone_filtered(obj, object_config):
  25. object_name = obj['label']
  26. if object_name in object_config:
  27. obj_settings = object_config[object_name]
  28. # if the min area is larger than the
  29. # detected object, don't add it to detected objects
  30. if obj_settings.get('min_area',-1) > obj['area']:
  31. return True
  32. # if the detected object is larger than the
  33. # max area, don't add it to detected objects
  34. if obj_settings.get('max_area', 24000000) < obj['area']:
  35. return True
  36. # if the score is lower than the threshold, skip
  37. if obj_settings.get('threshold', 0) > obj['computed_score']:
  38. return True
  39. return False
  40. # Maintains the state of a camera
  41. class CameraState():
  42. def __init__(self, name, config, frame_manager):
  43. self.name = name
  44. self.config = config
  45. self.frame_manager = frame_manager
  46. self.best_objects = {}
  47. self.object_status = defaultdict(lambda: 'OFF')
  48. self.tracked_objects = {}
  49. self.zone_objects = defaultdict(lambda: [])
  50. self.current_frame = np.zeros(self.config['frame_shape'], np.uint8)
  51. self.current_frame_time = 0.0
  52. self.previous_frame_id = None
  53. self.callbacks = defaultdict(lambda: [])
  54. def false_positive(self, obj):
  55. # once a true positive, always a true positive
  56. if not obj.get('false_positive', True):
  57. return False
  58. threshold = self.config['objects'].get('filters', {}).get(obj['label'], {}).get('threshold', 0.85)
  59. if obj['computed_score'] < threshold:
  60. return True
  61. return False
  62. def compute_score(self, obj):
  63. scores = obj['score_history'][:]
  64. # pad with zeros if you dont have at least 3 scores
  65. if len(scores) < 3:
  66. scores += [0.0]*(3 - len(scores))
  67. return median(scores)
  68. def on(self, event_type: str, callback: Callable[[Dict], None]):
  69. self.callbacks[event_type].append(callback)
  70. def update(self, frame_time, tracked_objects):
  71. self.current_frame_time = frame_time
  72. # get the new frame and delete the old frame
  73. frame_id = f"{self.name}{frame_time}"
  74. self.current_frame = self.frame_manager.get(frame_id, self.config['frame_shape'])
  75. if not self.previous_frame_id is None:
  76. self.frame_manager.delete(self.previous_frame_id)
  77. self.previous_frame_id = frame_id
  78. current_ids = tracked_objects.keys()
  79. previous_ids = self.tracked_objects.keys()
  80. removed_ids = list(set(previous_ids).difference(current_ids))
  81. new_ids = list(set(current_ids).difference(previous_ids))
  82. updated_ids = list(set(current_ids).intersection(previous_ids))
  83. for id in new_ids:
  84. self.tracked_objects[id] = tracked_objects[id]
  85. self.tracked_objects[id]['zones'] = []
  86. # start the score history
  87. self.tracked_objects[id]['score_history'] = [self.tracked_objects[id]['score']]
  88. # calculate if this is a false positive
  89. self.tracked_objects[id]['computed_score'] = self.compute_score(self.tracked_objects[id])
  90. self.tracked_objects[id]['false_positive'] = self.false_positive(self.tracked_objects[id])
  91. # call event handlers
  92. for c in self.callbacks['start']:
  93. c(self.name, tracked_objects[id])
  94. for id in updated_ids:
  95. self.tracked_objects[id].update(tracked_objects[id])
  96. # if the object is not in the current frame, add a 0.0 to the score history
  97. if self.tracked_objects[id]['frame_time'] != self.current_frame_time:
  98. self.tracked_objects[id]['score_history'].append(0.0)
  99. else:
  100. self.tracked_objects[id]['score_history'].append(self.tracked_objects[id]['score'])
  101. # only keep the last 10 scores
  102. if len(self.tracked_objects[id]['score_history']) > 10:
  103. self.tracked_objects[id]['score_history'] = self.tracked_objects[id]['score_history'][-10:]
  104. # calculate if this is a false positive
  105. self.tracked_objects[id]['computed_score'] = self.compute_score(self.tracked_objects[id])
  106. self.tracked_objects[id]['false_positive'] = self.false_positive(self.tracked_objects[id])
  107. # call event handlers
  108. for c in self.callbacks['update']:
  109. c(self.name, self.tracked_objects[id])
  110. for id in removed_ids:
  111. # publish events to mqtt
  112. self.tracked_objects[id]['end_time'] = frame_time
  113. for c in self.callbacks['end']:
  114. c(self.name, self.tracked_objects[id])
  115. del self.tracked_objects[id]
  116. # check to see if the objects are in any zones
  117. for obj in self.tracked_objects.values():
  118. current_zones = []
  119. bottom_center = (obj['centroid'][0], obj['box'][3])
  120. # check each zone
  121. for name, zone in self.config['zones'].items():
  122. contour = zone['contour']
  123. # check if the object is in the zone and not filtered
  124. if (cv2.pointPolygonTest(contour, bottom_center, False) >= 0
  125. and not zone_filtered(obj, zone.get('filters', {}))):
  126. current_zones.append(name)
  127. obj['zones'] = current_zones
  128. # draw on the frame
  129. if not self.current_frame is None:
  130. # draw the bounding boxes on the frame
  131. for obj in self.tracked_objects.values():
  132. thickness = 2
  133. color = COLOR_MAP[obj['label']]
  134. if obj['frame_time'] != frame_time:
  135. thickness = 1
  136. color = (255,0,0)
  137. # draw the bounding boxes on the frame
  138. box = obj['box']
  139. draw_box_with_label(self.current_frame, box[0], box[1], box[2], box[3], obj['label'], f"{int(obj['score']*100)}% {int(obj['area'])}", thickness=thickness, color=color)
  140. # draw the regions on the frame
  141. region = obj['region']
  142. cv2.rectangle(self.current_frame, (region[0], region[1]), (region[2], region[3]), (0,255,0), 1)
  143. if self.config['snapshots']['show_timestamp']:
  144. time_to_show = datetime.datetime.fromtimestamp(frame_time).strftime("%m/%d/%Y %H:%M:%S")
  145. cv2.putText(self.current_frame, time_to_show, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, fontScale=.8, color=(255, 255, 255), thickness=2)
  146. if self.config['snapshots']['draw_zones']:
  147. for name, zone in self.config['zones'].items():
  148. thickness = 8 if any([name in obj['zones'] for obj in self.tracked_objects.values()]) else 2
  149. cv2.drawContours(self.current_frame, [zone['contour']], -1, zone['color'], thickness)
  150. # maintain best objects
  151. for obj in self.tracked_objects.values():
  152. object_type = obj['label']
  153. # if the object wasn't seen on the current frame, skip it
  154. if obj['frame_time'] != self.current_frame_time or obj['false_positive']:
  155. continue
  156. obj_copy = copy.deepcopy(obj)
  157. if object_type in self.best_objects:
  158. current_best = self.best_objects[object_type]
  159. now = datetime.datetime.now().timestamp()
  160. # if the object is a higher score than the current best score
  161. # or the current object is older than desired, use the new object
  162. if obj_copy['score'] > current_best['score'] or (now - current_best['frame_time']) > self.config.get('best_image_timeout', 60):
  163. obj_copy['frame'] = np.copy(self.current_frame)
  164. self.best_objects[object_type] = obj_copy
  165. for c in self.callbacks['snapshot']:
  166. c(self.name, self.best_objects[object_type])
  167. else:
  168. obj_copy['frame'] = np.copy(self.current_frame)
  169. self.best_objects[object_type] = obj_copy
  170. for c in self.callbacks['snapshot']:
  171. c(self.name, self.best_objects[object_type])
  172. # update overall camera state for each object type
  173. obj_counter = Counter()
  174. for obj in self.tracked_objects.values():
  175. if not obj['false_positive']:
  176. obj_counter[obj['label']] += 1
  177. # report on detected objects
  178. for obj_name, count in obj_counter.items():
  179. new_status = 'ON' if count > 0 else 'OFF'
  180. if new_status != self.object_status[obj_name]:
  181. self.object_status[obj_name] = new_status
  182. for c in self.callbacks['object_status']:
  183. c(self.name, obj_name, new_status)
  184. # expire any objects that are ON and no longer detected
  185. expired_objects = [obj_name for obj_name, status in self.object_status.items() if status == 'ON' and not obj_name in obj_counter]
  186. for obj_name in expired_objects:
  187. self.object_status[obj_name] = 'OFF'
  188. for c in self.callbacks['object_status']:
  189. c(self.name, obj_name, 'OFF')
  190. for c in self.callbacks['snapshot']:
  191. c(self.name, self.best_objects[obj_name])
  192. class TrackedObjectProcessor(threading.Thread):
  193. def __init__(self, camera_config, client, topic_prefix, tracked_objects_queue, event_queue, stop_event):
  194. threading.Thread.__init__(self)
  195. self.camera_config = camera_config
  196. self.client = client
  197. self.topic_prefix = topic_prefix
  198. self.tracked_objects_queue = tracked_objects_queue
  199. self.event_queue = event_queue
  200. self.stop_event = stop_event
  201. self.camera_states: Dict[str, CameraState] = {}
  202. self.frame_manager = SharedMemoryFrameManager()
  203. def start(camera, obj):
  204. # publish events to mqtt
  205. self.client.publish(f"{self.topic_prefix}/{camera}/events/start", json.dumps(obj), retain=False)
  206. self.event_queue.put(('start', camera, obj))
  207. def update(camera, obj):
  208. pass
  209. def end(camera, obj):
  210. self.client.publish(f"{self.topic_prefix}/{camera}/events/end", json.dumps(obj), retain=False)
  211. self.event_queue.put(('end', camera, obj))
  212. def snapshot(camera, obj):
  213. if not 'frame' in obj:
  214. return
  215. best_frame = cv2.cvtColor(obj['frame'], cv2.COLOR_RGB2BGR)
  216. mqtt_config = self.camera_config[camera].get('mqtt', {'crop_to_region': False})
  217. if mqtt_config.get('crop_to_region'):
  218. region = obj['region']
  219. best_frame = best_frame[region[1]:region[3], region[0]:region[2]]
  220. if 'snapshot_height' in mqtt_config:
  221. height = int(mqtt_config['snapshot_height'])
  222. width = int(height*best_frame.shape[1]/best_frame.shape[0])
  223. best_frame = cv2.resize(best_frame, dsize=(width, height), interpolation=cv2.INTER_AREA)
  224. ret, jpg = cv2.imencode('.jpg', best_frame)
  225. if ret:
  226. jpg_bytes = jpg.tobytes()
  227. self.client.publish(f"{self.topic_prefix}/{camera}/{obj['label']}/snapshot", jpg_bytes, retain=True)
  228. def object_status(camera, object_name, status):
  229. self.client.publish(f"{self.topic_prefix}/{camera}/{object_name}", status, retain=False)
  230. for camera in self.camera_config.keys():
  231. camera_state = CameraState(camera, self.camera_config[camera], self.frame_manager)
  232. camera_state.on('start', start)
  233. camera_state.on('update', update)
  234. camera_state.on('end', end)
  235. camera_state.on('snapshot', snapshot)
  236. camera_state.on('object_status', object_status)
  237. self.camera_states[camera] = camera_state
  238. self.camera_data = defaultdict(lambda: {
  239. 'best_objects': {},
  240. 'object_status': defaultdict(lambda: defaultdict(lambda: 'OFF')),
  241. 'tracked_objects': {},
  242. 'current_frame': np.zeros((720,1280,3), np.uint8),
  243. 'current_frame_time': 0.0,
  244. 'object_id': None
  245. })
  246. # {
  247. # 'zone_name': {
  248. # 'person': ['camera_1', 'camera_2']
  249. # }
  250. # }
  251. self.zone_data = defaultdict(lambda: defaultdict(lambda: set()))
  252. # set colors for zones
  253. all_zone_names = set([zone for config in self.camera_config.values() for zone in config['zones'].keys()])
  254. zone_colors = {}
  255. colors = plt.cm.get_cmap('tab10', len(all_zone_names))
  256. for i, zone in enumerate(all_zone_names):
  257. zone_colors[zone] = tuple(int(round(255 * c)) for c in colors(i)[:3])
  258. # create zone contours
  259. for camera_config in self.camera_config.values():
  260. for zone_name, zone_config in camera_config['zones'].items():
  261. zone_config['color'] = zone_colors[zone_name]
  262. coordinates = zone_config['coordinates']
  263. if isinstance(coordinates, list):
  264. zone_config['contour'] = np.array([[int(p.split(',')[0]), int(p.split(',')[1])] for p in coordinates])
  265. elif isinstance(coordinates, str):
  266. points = coordinates.split(',')
  267. zone_config['contour'] = np.array([[int(points[i]), int(points[i+1])] for i in range(0, len(points), 2)])
  268. else:
  269. print(f"Unable to parse zone coordinates for {zone_name} - {camera}")
  270. def get_best(self, camera, label):
  271. best_objects = self.camera_states[camera].best_objects
  272. if label in best_objects:
  273. return best_objects[label]
  274. else:
  275. return {}
  276. def get_current_frame(self, camera):
  277. return self.camera_states[camera].current_frame
  278. def run(self):
  279. while True:
  280. if self.stop_event.is_set():
  281. print(f"Exiting object processor...")
  282. break
  283. try:
  284. camera, frame_time, current_tracked_objects = self.tracked_objects_queue.get(True, 10)
  285. except queue.Empty:
  286. continue
  287. camera_state = self.camera_states[camera]
  288. camera_state.update(frame_time, current_tracked_objects)
  289. # update zone status for each label
  290. for zone in camera_state.config['zones'].keys():
  291. # get labels for current camera and all labels in current zone
  292. labels_for_camera = set([obj['label'] for obj in camera_state.tracked_objects.values() if zone in obj['zones'] and not obj['false_positive']])
  293. labels_to_check = labels_for_camera | set(self.zone_data[zone].keys())
  294. # for each label in zone
  295. for label in labels_to_check:
  296. camera_list = self.zone_data[zone][label]
  297. # remove or add the camera to the list for the current label
  298. previous_state = len(camera_list) > 0
  299. if label in labels_for_camera:
  300. camera_list.add(camera_state.name)
  301. elif camera_state.name in camera_list:
  302. camera_list.remove(camera_state.name)
  303. new_state = len(camera_list) > 0
  304. # if the value is changing, send over MQTT
  305. if previous_state == False and new_state == True:
  306. self.client.publish(f"{self.topic_prefix}/{zone}/{label}", 'ON', retain=False)
  307. elif previous_state == True and new_state == False:
  308. self.client.publish(f"{self.topic_prefix}/{zone}/{label}", 'OFF', retain=False)