detect_objects.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. import os
  2. import cv2
  3. import imutils
  4. import time
  5. import datetime
  6. import ctypes
  7. import logging
  8. import multiprocessing as mp
  9. import threading
  10. import json
  11. from contextlib import closing
  12. import numpy as np
  13. import tensorflow as tf
  14. from object_detection.utils import label_map_util
  15. from object_detection.utils import visualization_utils as vis_util
  16. from flask import Flask, Response, make_response
  17. import paho.mqtt.client as mqtt
  18. RTSP_URL = os.getenv('RTSP_URL')
  19. # Path to frozen detection graph. This is the actual model that is used for the object detection.
  20. PATH_TO_CKPT = '/frozen_inference_graph.pb'
  21. # List of the strings that is used to add correct label for each box.
  22. PATH_TO_LABELS = '/label_map.pbtext'
  23. MQTT_HOST = os.getenv('MQTT_HOST')
  24. MQTT_TOPIC_PREFIX = os.getenv('MQTT_TOPIC_PREFIX')
  25. MQTT_OBJECT_CLASSES = os.getenv('MQTT_OBJECT_CLASSES')
  26. # TODO: make dynamic?
  27. NUM_CLASSES = 90
  28. # REGIONS = "350,0,300,50:400,350,250,50:400,750,250,50"
  29. # REGIONS = "400,350,250,50"
  30. REGIONS = os.getenv('REGIONS')
  31. DETECTED_OBJECTS = []
  32. # Loading label map
  33. label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
  34. categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,
  35. use_display_name=True)
  36. category_index = label_map_util.create_category_index(categories)
  37. def detect_objects(cropped_frame, sess, detection_graph, region_size, region_x_offset, region_y_offset, debug):
  38. # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
  39. image_np_expanded = np.expand_dims(cropped_frame, axis=0)
  40. image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
  41. # Each box represents a part of the image where a particular object was detected.
  42. boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
  43. # Each score represent how level of confidence for each of the objects.
  44. # Score is shown on the result image, together with the class label.
  45. scores = detection_graph.get_tensor_by_name('detection_scores:0')
  46. classes = detection_graph.get_tensor_by_name('detection_classes:0')
  47. num_detections = detection_graph.get_tensor_by_name('num_detections:0')
  48. # Actual detection.
  49. (boxes, scores, classes, num_detections) = sess.run(
  50. [boxes, scores, classes, num_detections],
  51. feed_dict={image_tensor: image_np_expanded})
  52. if debug:
  53. if len([category_index.get(value) for index,value in enumerate(classes[0]) if scores[0,index] > 0.5]) > 0:
  54. vis_util.visualize_boxes_and_labels_on_image_array(
  55. cropped_frame,
  56. np.squeeze(boxes),
  57. np.squeeze(classes).astype(np.int32),
  58. np.squeeze(scores),
  59. category_index,
  60. use_normalized_coordinates=True,
  61. line_thickness=4)
  62. cv2.imwrite("/lab/debug/obj-{}-{}-{}.jpg".format(region_x_offset, region_y_offset, datetime.datetime.now().timestamp()), cropped_frame)
  63. # build an array of detected objects
  64. objects = []
  65. for index, value in enumerate(classes[0]):
  66. score = scores[0, index]
  67. if score > 0.5:
  68. box = boxes[0, index].tolist()
  69. box[0] = (box[0] * region_size) + region_y_offset
  70. box[1] = (box[1] * region_size) + region_x_offset
  71. box[2] = (box[2] * region_size) + region_y_offset
  72. box[3] = (box[3] * region_size) + region_x_offset
  73. objects += [value, scores[0, index]] + box
  74. # only get the first 10 objects
  75. if len(objects) == 60:
  76. break
  77. return objects
  78. class ObjectParser(threading.Thread):
  79. def __init__(self, objects_changed, objects_parsed, object_arrays):
  80. threading.Thread.__init__(self)
  81. self._objects_changed = objects_changed
  82. self._objects_parsed = objects_parsed
  83. self._object_arrays = object_arrays
  84. def run(self):
  85. global DETECTED_OBJECTS
  86. while True:
  87. detected_objects = []
  88. # wait until object detection has run
  89. # TODO: what if something else changed while I was processing???
  90. with self._objects_changed:
  91. self._objects_changed.wait()
  92. for object_array in self._object_arrays:
  93. object_index = 0
  94. while(object_index < 60 and object_array[object_index] > 0):
  95. object_class = object_array[object_index]
  96. detected_objects.append({
  97. 'name': str(category_index.get(object_class).get('name')),
  98. 'score': object_array[object_index+1],
  99. 'ymin': int(object_array[object_index+2]),
  100. 'xmin': int(object_array[object_index+3]),
  101. 'ymax': int(object_array[object_index+4]),
  102. 'xmax': int(object_array[object_index+5])
  103. })
  104. object_index += 6
  105. DETECTED_OBJECTS = detected_objects
  106. # notify that objects were parsed
  107. with self._objects_parsed:
  108. self._objects_parsed.notify_all()
  109. class MqttMotionPublisher(threading.Thread):
  110. def __init__(self, client, topic_prefix, motion_changed, motion_flags):
  111. threading.Thread.__init__(self)
  112. self.client = client
  113. self.topic_prefix = topic_prefix
  114. self.motion_changed = motion_changed
  115. self.motion_flags = motion_flags
  116. def run(self):
  117. last_sent_motion = ""
  118. while True:
  119. with self.motion_changed:
  120. self.motion_changed.wait()
  121. # send message for motion
  122. motion_status = 'OFF'
  123. if any(obj.is_set() for obj in self.motion_flags):
  124. motion_status = 'ON'
  125. if last_sent_motion != motion_status:
  126. last_sent_motion = motion_status
  127. self.client.publish(self.topic_prefix+'/motion', motion_status, retain=False)
  128. class MqttObjectPublisher(threading.Thread):
  129. def __init__(self, client, topic_prefix, objects_parsed, object_classes):
  130. threading.Thread.__init__(self)
  131. self.client = client
  132. self.topic_prefix = topic_prefix
  133. self.objects_parsed = objects_parsed
  134. self.object_classes = object_classes
  135. def run(self):
  136. global DETECTED_OBJECTS
  137. last_sent_payload = ""
  138. while True:
  139. # initialize the payload
  140. payload = {}
  141. for obj in self.object_classes:
  142. payload[obj] = []
  143. # wait until objects have been parsed
  144. with self.objects_parsed:
  145. self.objects_parsed.wait()
  146. # loop over detected objects and populate
  147. # the payload
  148. detected_objects = DETECTED_OBJECTS.copy()
  149. for obj in detected_objects:
  150. if obj['name'] in self.object_classes:
  151. payload[obj['name']].append(obj)
  152. # send message for objects if different
  153. new_payload = json.dumps(payload, sort_keys=True)
  154. if new_payload != last_sent_payload:
  155. last_sent_payload = new_payload
  156. self.client.publish(self.topic_prefix+'/objects', new_payload, retain=False)
  157. def main():
  158. # Parse selected regions
  159. regions = []
  160. for region_string in REGIONS.split(':'):
  161. region_parts = region_string.split(',')
  162. region_mask_image = cv2.imread("/config/{}".format(region_parts[4]), cv2.IMREAD_GRAYSCALE)
  163. region_mask = np.where(region_mask_image==[0])
  164. regions.append({
  165. 'size': int(region_parts[0]),
  166. 'x_offset': int(region_parts[1]),
  167. 'y_offset': int(region_parts[2]),
  168. 'min_object_size': int(region_parts[3]),
  169. 'mask': region_mask,
  170. # Event for motion detection signaling
  171. 'motion_detected': mp.Event(),
  172. # create shared array for storing 10 detected objects
  173. # note: this must be a double even though the value you are storing
  174. # is a float. otherwise it stops updating the value in shared
  175. # memory. probably something to do with the size of the memory block
  176. 'output_array': mp.Array(ctypes.c_double, 6*10)
  177. })
  178. # capture a single frame and check the frame shape so the correct array
  179. # size can be allocated in memory
  180. video = cv2.VideoCapture(RTSP_URL)
  181. ret, frame = video.read()
  182. if ret:
  183. frame_shape = frame.shape
  184. else:
  185. print("Unable to capture video stream")
  186. exit(1)
  187. video.release()
  188. # compute the flattened array length from the array shape
  189. flat_array_length = frame_shape[0] * frame_shape[1] * frame_shape[2]
  190. # create shared array for storing the full frame image data
  191. shared_arr = mp.Array(ctypes.c_uint16, flat_array_length)
  192. # create shared value for storing the frame_time
  193. shared_frame_time = mp.Value('d', 0.0)
  194. # Lock to control access to the frame while writing
  195. frame_lock = mp.Lock()
  196. # Condition for notifying that a new frame is ready
  197. frame_ready = mp.Condition()
  198. # Condition for notifying that motion status changed globally
  199. motion_changed = mp.Condition()
  200. # Condition for notifying that object detection ran
  201. objects_changed = mp.Condition()
  202. # Condition for notifying that objects were parsed
  203. objects_parsed = mp.Condition()
  204. # shape current frame so it can be treated as an image
  205. frame_arr = tonumpyarray(shared_arr).reshape(frame_shape)
  206. capture_process = mp.Process(target=fetch_frames, args=(shared_arr,
  207. shared_frame_time, frame_lock, frame_ready, frame_shape))
  208. capture_process.daemon = True
  209. detection_processes = []
  210. motion_processes = []
  211. for region in regions:
  212. detection_process = mp.Process(target=process_frames, args=(shared_arr,
  213. region['output_array'],
  214. shared_frame_time,
  215. frame_lock, frame_ready,
  216. region['motion_detected'],
  217. objects_changed,
  218. frame_shape,
  219. region['size'], region['x_offset'], region['y_offset'],
  220. False))
  221. detection_process.daemon = True
  222. detection_processes.append(detection_process)
  223. motion_process = mp.Process(target=detect_motion, args=(shared_arr,
  224. shared_frame_time,
  225. frame_lock, frame_ready,
  226. region['motion_detected'],
  227. motion_changed,
  228. frame_shape,
  229. region['size'], region['x_offset'], region['y_offset'],
  230. region['min_object_size'], region['mask'],
  231. True))
  232. motion_process.daemon = True
  233. motion_processes.append(motion_process)
  234. object_parser = ObjectParser(objects_changed, objects_parsed, [region['output_array'] for region in regions])
  235. object_parser.start()
  236. client = mqtt.Client()
  237. client.will_set(MQTT_TOPIC_PREFIX+'/available', payload='offline', qos=1, retain=True)
  238. client.connect(MQTT_HOST, 1883, 60)
  239. client.loop_start()
  240. client.publish(MQTT_TOPIC_PREFIX+'/available', 'online', retain=True)
  241. mqtt_publisher = MqttObjectPublisher(client, MQTT_TOPIC_PREFIX, objects_parsed,
  242. MQTT_OBJECT_CLASSES.split(','))
  243. mqtt_publisher.start()
  244. mqtt_motion_publisher = MqttMotionPublisher(client, MQTT_TOPIC_PREFIX, motion_changed,
  245. [region['motion_detected'] for region in regions])
  246. mqtt_motion_publisher.start()
  247. capture_process.start()
  248. print("capture_process pid ", capture_process.pid)
  249. for detection_process in detection_processes:
  250. detection_process.start()
  251. print("detection_process pid ", detection_process.pid)
  252. for motion_process in motion_processes:
  253. motion_process.start()
  254. print("motion_process pid ", motion_process.pid)
  255. app = Flask(__name__)
  256. @app.route('/')
  257. def index():
  258. # return a multipart response
  259. return Response(imagestream(),
  260. mimetype='multipart/x-mixed-replace; boundary=frame')
  261. def imagestream():
  262. global DETECTED_OBJECTS
  263. while True:
  264. # max out at 5 FPS
  265. time.sleep(0.2)
  266. # make a copy of the current detected objects
  267. detected_objects = DETECTED_OBJECTS.copy()
  268. # lock and make a copy of the current frame
  269. with frame_lock:
  270. frame = frame_arr.copy()
  271. # convert to RGB for drawing
  272. frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  273. # draw the bounding boxes on the screen
  274. for obj in detected_objects:
  275. vis_util.draw_bounding_box_on_image_array(frame,
  276. obj['ymin'],
  277. obj['xmin'],
  278. obj['ymax'],
  279. obj['xmax'],
  280. color='red',
  281. thickness=2,
  282. display_str_list=["{}: {}%".format(obj['name'],int(obj['score']*100))],
  283. use_normalized_coordinates=False)
  284. for region in regions:
  285. color = (255,255,255)
  286. if region['motion_detected'].is_set():
  287. color = (0,255,0)
  288. cv2.rectangle(frame, (region['x_offset'], region['y_offset']),
  289. (region['x_offset']+region['size'], region['y_offset']+region['size']),
  290. color, 2)
  291. # convert back to BGR
  292. frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
  293. # encode the image into a jpg
  294. ret, jpg = cv2.imencode('.jpg', frame)
  295. yield (b'--frame\r\n'
  296. b'Content-Type: image/jpeg\r\n\r\n' + jpg.tobytes() + b'\r\n\r\n')
  297. app.run(host='0.0.0.0', debug=False)
  298. capture_process.join()
  299. for detection_process in detection_processes:
  300. detection_process.join()
  301. for motion_process in motion_processes:
  302. motion_process.join()
  303. object_parser.join()
  304. mqtt_publisher.join()
  305. # convert shared memory array into numpy array
  306. def tonumpyarray(mp_arr):
  307. return np.frombuffer(mp_arr.get_obj(), dtype=np.uint16)
  308. # fetch the frames as fast a possible, only decoding the frames when the
  309. # detection_process has consumed the current frame
  310. def fetch_frames(shared_arr, shared_frame_time, frame_lock, frame_ready, frame_shape):
  311. # convert shared memory array into numpy and shape into image array
  312. arr = tonumpyarray(shared_arr).reshape(frame_shape)
  313. # start the video capture
  314. video = cv2.VideoCapture(RTSP_URL)
  315. # keep the buffer small so we minimize old data
  316. video.set(cv2.CAP_PROP_BUFFERSIZE,1)
  317. while True:
  318. # grab the frame, but dont decode it yet
  319. ret = video.grab()
  320. # snapshot the time the frame was grabbed
  321. frame_time = datetime.datetime.now()
  322. if ret:
  323. # go ahead and decode the current frame
  324. ret, frame = video.retrieve()
  325. if ret:
  326. # Lock access and update frame
  327. with frame_lock:
  328. arr[:] = frame
  329. shared_frame_time.value = frame_time.timestamp()
  330. # Notify with the condition that a new frame is ready
  331. with frame_ready:
  332. frame_ready.notify_all()
  333. video.release()
  334. # do the actual object detection
  335. def process_frames(shared_arr, shared_output_arr, shared_frame_time, frame_lock, frame_ready,
  336. motion_detected, objects_changed, frame_shape, region_size, region_x_offset, region_y_offset,
  337. debug):
  338. debug = True
  339. # shape shared input array into frame for processing
  340. arr = tonumpyarray(shared_arr).reshape(frame_shape)
  341. # Load a (frozen) Tensorflow model into memory before the processing loop
  342. detection_graph = tf.Graph()
  343. with detection_graph.as_default():
  344. od_graph_def = tf.GraphDef()
  345. with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
  346. serialized_graph = fid.read()
  347. od_graph_def.ParseFromString(serialized_graph)
  348. tf.import_graph_def(od_graph_def, name='')
  349. sess = tf.Session(graph=detection_graph)
  350. frame_time = 0.0
  351. while True:
  352. now = datetime.datetime.now().timestamp()
  353. # wait until motion is detected
  354. motion_detected.wait()
  355. with frame_ready:
  356. # if there isnt a frame ready for processing or it is old, wait for a signal
  357. if shared_frame_time.value == frame_time or (now - shared_frame_time.value) > 0.5:
  358. frame_ready.wait()
  359. # make a copy of the cropped frame
  360. with frame_lock:
  361. cropped_frame = arr[region_y_offset:region_y_offset+region_size, region_x_offset:region_x_offset+region_size].copy()
  362. frame_time = shared_frame_time.value
  363. # convert to RGB
  364. cropped_frame_rgb = cv2.cvtColor(cropped_frame, cv2.COLOR_BGR2RGB)
  365. # do the object detection
  366. objects = detect_objects(cropped_frame_rgb, sess, detection_graph, region_size, region_x_offset, region_y_offset, debug)
  367. # copy the detected objects to the output array, filling the array when needed
  368. shared_output_arr[:] = objects + [0.0] * (60-len(objects))
  369. with objects_changed:
  370. objects_changed.notify_all()
  371. # do the actual motion detection
  372. def detect_motion(shared_arr, shared_frame_time, frame_lock, frame_ready, motion_detected, motion_changed,
  373. frame_shape, region_size, region_x_offset, region_y_offset, min_motion_area, mask, debug):
  374. # shape shared input array into frame for processing
  375. arr = tonumpyarray(shared_arr).reshape(frame_shape)
  376. avg_frame = None
  377. avg_delta = None
  378. frame_time = 0.0
  379. motion_frames = 0
  380. while True:
  381. now = datetime.datetime.now().timestamp()
  382. with frame_ready:
  383. # if there isnt a frame ready for processing or it is old, wait for a signal
  384. if shared_frame_time.value == frame_time or (now - shared_frame_time.value) > 0.5:
  385. frame_ready.wait()
  386. # lock and make a copy of the cropped frame
  387. with frame_lock:
  388. cropped_frame = arr[region_y_offset:region_y_offset+region_size, region_x_offset:region_x_offset+region_size].copy().astype('uint8')
  389. frame_time = shared_frame_time.value
  390. # convert to grayscale
  391. gray = cv2.cvtColor(cropped_frame, cv2.COLOR_BGR2GRAY)
  392. # apply image mask to remove areas from motion detection
  393. gray[mask] = [255]
  394. # apply gaussian blur
  395. gray = cv2.GaussianBlur(gray, (21, 21), 0)
  396. if avg_frame is None:
  397. avg_frame = gray.copy().astype("float")
  398. continue
  399. # look at the delta from the avg_frame
  400. frameDelta = cv2.absdiff(gray, cv2.convertScaleAbs(avg_frame))
  401. if avg_delta is None:
  402. avg_delta = frameDelta.copy().astype("float")
  403. # compute the average delta over the past few frames
  404. # the alpha value can be modified to configure how sensitive the motion detection is
  405. # higher values mean the current frame impacts the delta a lot, and a single raindrop may
  406. # put it over the edge, too low and a fast moving person wont be detected as motion
  407. # this also assumes that a person is in the same location across more than a single frame
  408. cv2.accumulateWeighted(frameDelta, avg_delta, 0.2)
  409. # compute the threshold image for the current frame
  410. current_thresh = cv2.threshold(frameDelta, 25, 255, cv2.THRESH_BINARY)[1]
  411. # black out everything in the avg_delta where there isnt motion in the current frame
  412. avg_delta_image = cv2.convertScaleAbs(avg_delta)
  413. avg_delta_image[np.where(current_thresh==[0])] = [0]
  414. # then look for deltas above the threshold, but only in areas where there is a delta
  415. # in the current frame. this prevents deltas from previous frames from being included
  416. thresh = cv2.threshold(avg_delta_image, 25, 255, cv2.THRESH_BINARY)[1]
  417. # dilate the thresholded image to fill in holes, then find contours
  418. # on thresholded image
  419. thresh = cv2.dilate(thresh, None, iterations=2)
  420. cnts = cv2.findContours(thresh.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  421. cnts = imutils.grab_contours(cnts)
  422. # if there are no contours, there is no motion
  423. if len(cnts) < 1:
  424. motion_frames = 0
  425. continue
  426. motion_found = False
  427. # loop over the contours
  428. for c in cnts:
  429. # if the contour is big enough, count it as motion
  430. contour_area = cv2.contourArea(c)
  431. if contour_area > min_motion_area:
  432. motion_found = True
  433. if debug:
  434. cv2.drawContours(cropped_frame, [c], -1, (0, 255, 0), 2)
  435. x, y, w, h = cv2.boundingRect(c)
  436. cv2.putText(cropped_frame, str(contour_area), (x, y),
  437. cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 100, 0), 2)
  438. else:
  439. break
  440. if motion_found:
  441. motion_frames += 1
  442. # if there have been enough consecutive motion frames, report motion
  443. if motion_frames >= 3:
  444. # only average in the current frame if the difference persists for at least 3 frames
  445. cv2.accumulateWeighted(gray, avg_frame, 0.01)
  446. motion_detected.set()
  447. with motion_changed:
  448. motion_changed.notify_all()
  449. else:
  450. # when no motion, just keep averaging the frames together
  451. cv2.accumulateWeighted(gray, avg_frame, 0.01)
  452. motion_frames = 0
  453. motion_detected.clear()
  454. with motion_changed:
  455. motion_changed.notify_all()
  456. if debug and motion_frames >= 3:
  457. cv2.imwrite("/lab/debug/motion-{}-{}-{}.jpg".format(region_x_offset, region_y_offset, datetime.datetime.now().timestamp()), cropped_frame)
  458. cv2.imwrite("/lab/debug/avg_delta-{}-{}-{}.jpg".format(region_x_offset, region_y_offset, datetime.datetime.now().timestamp()), avg_delta_image)
  459. if __name__ == '__main__':
  460. mp.freeze_support()
  461. main()