detect_objects.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. import os
  2. import cv2
  3. import imutils
  4. import time
  5. import datetime
  6. import ctypes
  7. import logging
  8. import multiprocessing as mp
  9. import threading
  10. import json
  11. from contextlib import closing
  12. import numpy as np
  13. import tensorflow as tf
  14. from object_detection.utils import label_map_util
  15. from object_detection.utils import visualization_utils as vis_util
  16. from flask import Flask, Response, make_response
  17. import paho.mqtt.client as mqtt
  18. RTSP_URL = os.getenv('RTSP_URL')
  19. # Path to frozen detection graph. This is the actual model that is used for the object detection.
  20. PATH_TO_CKPT = '/frozen_inference_graph.pb'
  21. # List of the strings that is used to add correct label for each box.
  22. PATH_TO_LABELS = '/label_map.pbtext'
  23. MQTT_HOST = os.getenv('MQTT_HOST')
  24. MQTT_TOPIC_PREFIX = os.getenv('MQTT_TOPIC_PREFIX')
  25. MQTT_OBJECT_CLASSES = os.getenv('MQTT_OBJECT_CLASSES')
  26. # TODO: make dynamic?
  27. NUM_CLASSES = 90
  28. # REGIONS = "350,0,300,50:400,350,250,50:400,750,250,50"
  29. # REGIONS = "400,350,250,50"
  30. REGIONS = os.getenv('REGIONS')
  31. DETECTED_OBJECTS = []
  32. # Loading label map
  33. label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
  34. categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,
  35. use_display_name=True)
  36. category_index = label_map_util.create_category_index(categories)
  37. def detect_objects(cropped_frame, sess, detection_graph, region_size, region_x_offset, region_y_offset, debug):
  38. # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
  39. image_np_expanded = np.expand_dims(cropped_frame, axis=0)
  40. image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
  41. # Each box represents a part of the image where a particular object was detected.
  42. boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
  43. # Each score represent how level of confidence for each of the objects.
  44. # Score is shown on the result image, together with the class label.
  45. scores = detection_graph.get_tensor_by_name('detection_scores:0')
  46. classes = detection_graph.get_tensor_by_name('detection_classes:0')
  47. num_detections = detection_graph.get_tensor_by_name('num_detections:0')
  48. # Actual detection.
  49. (boxes, scores, classes, num_detections) = sess.run(
  50. [boxes, scores, classes, num_detections],
  51. feed_dict={image_tensor: image_np_expanded})
  52. if debug:
  53. if len([category_index.get(value) for index,value in enumerate(classes[0]) if scores[0,index] > 0.5]) > 0:
  54. vis_util.visualize_boxes_and_labels_on_image_array(
  55. cropped_frame,
  56. np.squeeze(boxes),
  57. np.squeeze(classes).astype(np.int32),
  58. np.squeeze(scores),
  59. category_index,
  60. use_normalized_coordinates=True,
  61. line_thickness=4)
  62. cv2.imwrite("/lab/debug/obj-{}-{}-{}.jpg".format(region_x_offset, region_y_offset, datetime.datetime.now().timestamp()), cropped_frame)
  63. # build an array of detected objects
  64. objects = []
  65. for index, value in enumerate(classes[0]):
  66. score = scores[0, index]
  67. if score > 0.5:
  68. box = boxes[0, index].tolist()
  69. box[0] = (box[0] * region_size) + region_y_offset
  70. box[1] = (box[1] * region_size) + region_x_offset
  71. box[2] = (box[2] * region_size) + region_y_offset
  72. box[3] = (box[3] * region_size) + region_x_offset
  73. objects += [value, scores[0, index]] + box
  74. # only get the first 10 objects
  75. if len(objects) == 60:
  76. break
  77. return objects
  78. class ObjectParser(threading.Thread):
  79. def __init__(self, objects_changed, object_arrays):
  80. threading.Thread.__init__(self)
  81. self._objects_changed = objects_changed
  82. self._object_arrays = object_arrays
  83. def run(self):
  84. global DETECTED_OBJECTS
  85. while True:
  86. detected_objects = []
  87. # wait until object detection has run
  88. with self._objects_changed:
  89. self._objects_changed.wait()
  90. for object_array in self._object_arrays:
  91. object_index = 0
  92. while(object_index < 60 and object_array[object_index] > 0):
  93. object_class = object_array[object_index]
  94. detected_objects.append({
  95. 'name': str(category_index.get(object_class).get('name')),
  96. 'score': object_array[object_index+1],
  97. 'ymin': int(object_array[object_index+2]),
  98. 'xmin': int(object_array[object_index+3]),
  99. 'ymax': int(object_array[object_index+4]),
  100. 'xmax': int(object_array[object_index+5])
  101. })
  102. object_index += 6
  103. DETECTED_OBJECTS = detected_objects
  104. class MqttMotionPublisher(threading.Thread):
  105. def __init__(self, client, topic_prefix, motion_changed, motion_flags):
  106. threading.Thread.__init__(self)
  107. self.client = client
  108. self.topic_prefix = topic_prefix
  109. self.motion_changed = motion_changed
  110. self.motion_flags = motion_flags
  111. def run(self):
  112. last_sent_motion = ""
  113. while True:
  114. with self.motion_changed:
  115. self.motion_changed.wait()
  116. # send message for motion
  117. motion_status = 'OFF'
  118. if any(obj.is_set() for obj in self.motion_flags):
  119. motion_status = 'ON'
  120. if last_sent_motion != motion_status:
  121. last_sent_motion = motion_status
  122. self.client.publish(self.topic_prefix+'/motion', motion_status, retain=False)
  123. class MqttPublisher(threading.Thread):
  124. def __init__(self, client, topic_prefix, object_classes):
  125. threading.Thread.__init__(self)
  126. self.client = client
  127. self.topic_prefix = topic_prefix
  128. self.object_classes = object_classes
  129. def run(self):
  130. global DETECTED_OBJECTS
  131. last_sent_payload = ""
  132. while True:
  133. # initialize the payload
  134. payload = {}
  135. for obj in self.object_classes:
  136. payload[obj] = []
  137. # loop over detected objects and populate
  138. # the payload
  139. detected_objects = DETECTED_OBJECTS.copy()
  140. for obj in detected_objects:
  141. if obj['name'] in self.object_classes:
  142. payload[obj['name']].append(obj)
  143. # send message for objects
  144. new_payload = json.dumps(payload, sort_keys=True)
  145. if new_payload != last_sent_payload:
  146. last_sent_payload = new_payload
  147. self.client.publish(self.topic_prefix+'/objects', new_payload, retain=False)
  148. time.sleep(0.1)
  149. def main():
  150. # Parse selected regions
  151. regions = []
  152. for region_string in REGIONS.split(':'):
  153. region_parts = region_string.split(',')
  154. regions.append({
  155. 'size': int(region_parts[0]),
  156. 'x_offset': int(region_parts[1]),
  157. 'y_offset': int(region_parts[2]),
  158. 'min_object_size': int(region_parts[3]),
  159. # Event for motion detection signaling
  160. 'motion_detected': mp.Event(),
  161. # create shared array for storing 10 detected objects
  162. # note: this must be a double even though the value you are storing
  163. # is a float. otherwise it stops updating the value in shared
  164. # memory. probably something to do with the size of the memory block
  165. 'output_array': mp.Array(ctypes.c_double, 6*10)
  166. })
  167. # capture a single frame and check the frame shape so the correct array
  168. # size can be allocated in memory
  169. video = cv2.VideoCapture(RTSP_URL)
  170. ret, frame = video.read()
  171. if ret:
  172. frame_shape = frame.shape
  173. else:
  174. print("Unable to capture video stream")
  175. exit(1)
  176. video.release()
  177. # compute the flattened array length from the array shape
  178. flat_array_length = frame_shape[0] * frame_shape[1] * frame_shape[2]
  179. # create shared array for storing the full frame image data
  180. shared_arr = mp.Array(ctypes.c_uint16, flat_array_length)
  181. # create shared value for storing the frame_time
  182. shared_frame_time = mp.Value('d', 0.0)
  183. # Lock to control access to the frame while writing
  184. frame_lock = mp.Lock()
  185. # Condition for notifying that a new frame is ready
  186. frame_ready = mp.Condition()
  187. # Condition for notifying that motion status changed globally
  188. motion_changed = mp.Condition()
  189. # Condition for notifying that object detection ran
  190. objects_changed = mp.Condition()
  191. # shape current frame so it can be treated as an image
  192. frame_arr = tonumpyarray(shared_arr).reshape(frame_shape)
  193. capture_process = mp.Process(target=fetch_frames, args=(shared_arr,
  194. shared_frame_time, frame_lock, frame_ready, frame_shape))
  195. capture_process.daemon = True
  196. detection_processes = []
  197. motion_processes = []
  198. for region in regions:
  199. detection_process = mp.Process(target=process_frames, args=(shared_arr,
  200. region['output_array'],
  201. shared_frame_time,
  202. frame_lock, frame_ready,
  203. region['motion_detected'],
  204. objects_changed,
  205. frame_shape,
  206. region['size'], region['x_offset'], region['y_offset']))
  207. detection_process.daemon = True
  208. detection_processes.append(detection_process)
  209. motion_process = mp.Process(target=detect_motion, args=(shared_arr,
  210. shared_frame_time,
  211. frame_lock, frame_ready,
  212. region['motion_detected'],
  213. motion_changed,
  214. frame_shape,
  215. region['size'], region['x_offset'], region['y_offset'],
  216. region['min_object_size'],
  217. True))
  218. motion_process.daemon = True
  219. motion_processes.append(motion_process)
  220. object_parser = ObjectParser(objects_changed, [region['output_array'] for region in regions])
  221. object_parser.start()
  222. client = mqtt.Client()
  223. client.will_set(MQTT_TOPIC_PREFIX+'/available', payload='offline', qos=1, retain=True)
  224. client.connect(MQTT_HOST, 1883, 60)
  225. client.loop_start()
  226. client.publish(MQTT_TOPIC_PREFIX+'/available', 'online', retain=True)
  227. mqtt_publisher = MqttPublisher(client, MQTT_TOPIC_PREFIX,
  228. MQTT_OBJECT_CLASSES.split(','))
  229. mqtt_publisher.start()
  230. mqtt_motion_publisher = MqttMotionPublisher(client, MQTT_TOPIC_PREFIX, motion_changed,
  231. [region['motion_detected'] for region in regions])
  232. mqtt_motion_publisher.start()
  233. capture_process.start()
  234. print("capture_process pid ", capture_process.pid)
  235. for detection_process in detection_processes:
  236. detection_process.start()
  237. print("detection_process pid ", detection_process.pid)
  238. for motion_process in motion_processes:
  239. motion_process.start()
  240. print("motion_process pid ", motion_process.pid)
  241. app = Flask(__name__)
  242. @app.route('/')
  243. def index():
  244. # return a multipart response
  245. return Response(imagestream(),
  246. mimetype='multipart/x-mixed-replace; boundary=frame')
  247. def imagestream():
  248. global DETECTED_OBJECTS
  249. while True:
  250. # max out at 5 FPS
  251. time.sleep(0.2)
  252. # make a copy of the current detected objects
  253. detected_objects = DETECTED_OBJECTS.copy()
  254. # lock and make a copy of the current frame
  255. with frame_lock:
  256. frame = frame_arr.copy()
  257. # convert to RGB for drawing
  258. frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  259. # draw the bounding boxes on the screen
  260. for obj in detected_objects:
  261. vis_util.draw_bounding_box_on_image_array(frame,
  262. obj['ymin'],
  263. obj['xmin'],
  264. obj['ymax'],
  265. obj['xmax'],
  266. color='red',
  267. thickness=2,
  268. display_str_list=["{}: {}%".format(obj['name'],int(obj['score']*100))],
  269. use_normalized_coordinates=False)
  270. for region in regions:
  271. color = (255,255,255)
  272. if region['motion_detected'].is_set():
  273. color = (0,255,0)
  274. cv2.rectangle(frame, (region['x_offset'], region['y_offset']),
  275. (region['x_offset']+region['size'], region['y_offset']+region['size']),
  276. color, 2)
  277. cv2.putText(frame, datetime.datetime.now().strftime("%H:%M:%S"), (1125, 20),
  278. cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
  279. # convert back to BGR
  280. frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
  281. # encode the image into a jpg
  282. ret, jpg = cv2.imencode('.jpg', frame)
  283. yield (b'--frame\r\n'
  284. b'Content-Type: image/jpeg\r\n\r\n' + jpg.tobytes() + b'\r\n\r\n')
  285. app.run(host='0.0.0.0', debug=False)
  286. capture_process.join()
  287. for detection_process in detection_processes:
  288. detection_process.join()
  289. for motion_process in motion_processes:
  290. motion_process.join()
  291. object_parser.join()
  292. mqtt_publisher.join()
  293. # convert shared memory array into numpy array
  294. def tonumpyarray(mp_arr):
  295. return np.frombuffer(mp_arr.get_obj(), dtype=np.uint16)
  296. # fetch the frames as fast a possible, only decoding the frames when the
  297. # detection_process has consumed the current frame
  298. def fetch_frames(shared_arr, shared_frame_time, frame_lock, frame_ready, frame_shape):
  299. # convert shared memory array into numpy and shape into image array
  300. arr = tonumpyarray(shared_arr).reshape(frame_shape)
  301. # start the video capture
  302. video = cv2.VideoCapture(RTSP_URL)
  303. # keep the buffer small so we minimize old data
  304. video.set(cv2.CAP_PROP_BUFFERSIZE,1)
  305. while True:
  306. # grab the frame, but dont decode it yet
  307. ret = video.grab()
  308. # snapshot the time the frame was grabbed
  309. frame_time = datetime.datetime.now()
  310. if ret:
  311. # go ahead and decode the current frame
  312. ret, frame = video.retrieve()
  313. if ret:
  314. # Lock access and update frame
  315. with frame_lock:
  316. arr[:] = frame
  317. shared_frame_time.value = frame_time.timestamp()
  318. # Notify with the condition that a new frame is ready
  319. with frame_ready:
  320. frame_ready.notify_all()
  321. video.release()
  322. # do the actual object detection
  323. def process_frames(shared_arr, shared_output_arr, shared_frame_time, frame_lock, frame_ready,
  324. motion_detected, objects_changed, frame_shape, region_size, region_x_offset, region_y_offset):
  325. debug = True
  326. # shape shared input array into frame for processing
  327. arr = tonumpyarray(shared_arr).reshape(frame_shape)
  328. # Load a (frozen) Tensorflow model into memory before the processing loop
  329. detection_graph = tf.Graph()
  330. with detection_graph.as_default():
  331. od_graph_def = tf.GraphDef()
  332. with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
  333. serialized_graph = fid.read()
  334. od_graph_def.ParseFromString(serialized_graph)
  335. tf.import_graph_def(od_graph_def, name='')
  336. sess = tf.Session(graph=detection_graph)
  337. frame_time = 0.0
  338. while True:
  339. now = datetime.datetime.now().timestamp()
  340. # wait until motion is detected
  341. motion_detected.wait()
  342. with frame_ready:
  343. # if there isnt a frame ready for processing or it is old, wait for a signal
  344. if shared_frame_time.value == frame_time or (now - shared_frame_time.value) > 0.5:
  345. frame_ready.wait()
  346. # make a copy of the cropped frame
  347. with frame_lock:
  348. cropped_frame = arr[region_y_offset:region_y_offset+region_size, region_x_offset:region_x_offset+region_size].copy()
  349. frame_time = shared_frame_time.value
  350. # convert to RGB
  351. cropped_frame_rgb = cv2.cvtColor(cropped_frame, cv2.COLOR_BGR2RGB)
  352. # do the object detection
  353. objects = detect_objects(cropped_frame_rgb, sess, detection_graph, region_size, region_x_offset, region_y_offset, True)
  354. # copy the detected objects to the output array, filling the array when needed
  355. shared_output_arr[:] = objects + [0.0] * (60-len(objects))
  356. with objects_changed:
  357. objects_changed.notify_all()
  358. # do the actual motion detection
  359. def detect_motion(shared_arr, shared_frame_time, frame_lock, frame_ready, motion_detected, motion_changed,
  360. frame_shape, region_size, region_x_offset, region_y_offset, min_motion_area, debug):
  361. # shape shared input array into frame for processing
  362. arr = tonumpyarray(shared_arr).reshape(frame_shape)
  363. avg_frame = None
  364. last_motion = -1
  365. frame_time = 0.0
  366. motion_frames = 0
  367. while True:
  368. now = datetime.datetime.now().timestamp()
  369. # if it has been long enough since the last motion, clear the flag
  370. if last_motion > 0 and (now - last_motion) > 2:
  371. last_motion = -1
  372. motion_detected.clear()
  373. with motion_changed:
  374. motion_changed.notify_all()
  375. with frame_ready:
  376. # if there isnt a frame ready for processing or it is old, wait for a signal
  377. if shared_frame_time.value == frame_time or (now - shared_frame_time.value) > 0.5:
  378. frame_ready.wait()
  379. # lock and make a copy of the cropped frame
  380. with frame_lock:
  381. cropped_frame = arr[region_y_offset:region_y_offset+region_size, region_x_offset:region_x_offset+region_size].copy().astype('uint8')
  382. frame_time = shared_frame_time.value
  383. # convert to grayscale
  384. gray = cv2.cvtColor(cropped_frame, cv2.COLOR_BGR2GRAY)
  385. # apply gaussian blur
  386. gray = cv2.GaussianBlur(gray, (21, 21), 0)
  387. if avg_frame is None:
  388. avg_frame = gray.copy().astype("float")
  389. continue
  390. # look at the delta from the avg_frame
  391. cv2.accumulateWeighted(gray, avg_frame, 0.5)
  392. frameDelta = cv2.absdiff(gray, cv2.convertScaleAbs(avg_frame))
  393. thresh = cv2.threshold(frameDelta, 25, 255, cv2.THRESH_BINARY)[1]
  394. # dilate the thresholded image to fill in holes, then find contours
  395. # on thresholded image
  396. thresh = cv2.dilate(thresh, None, iterations=2)
  397. cnts = cv2.findContours(thresh.copy(), cv2.RETR_EXTERNAL,
  398. cv2.CHAIN_APPROX_SIMPLE)
  399. cnts = imutils.grab_contours(cnts)
  400. # if there are no contours, there is no motion
  401. if len(cnts) < 1:
  402. motion_frames = 0
  403. continue
  404. motion_found = False
  405. # loop over the contours
  406. for c in cnts:
  407. # if the contour is big enough, count it as motion
  408. contour_area = cv2.contourArea(c)
  409. if contour_area > min_motion_area:
  410. motion_found = True
  411. if debug:
  412. cv2.drawContours(cropped_frame, [c], -1, (0, 255, 0), 2)
  413. x, y, w, h = cv2.boundingRect(c)
  414. cv2.putText(cropped_frame, str(contour_area), (x, y),
  415. cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 100, 0), 2)
  416. else:
  417. break
  418. if motion_found:
  419. motion_frames += 1
  420. # if there have been enough consecutive motion frames, report motion
  421. if motion_frames >= 3:
  422. motion_detected.set()
  423. with motion_changed:
  424. motion_changed.notify_all()
  425. last_motion = now
  426. else:
  427. motion_frames = 0
  428. if debug and motion_frames > 0:
  429. cv2.imwrite("/lab/debug/motion-{}-{}-{}.jpg".format(region_x_offset, region_y_offset, datetime.datetime.now().timestamp()), cropped_frame)
  430. if __name__ == '__main__':
  431. mp.freeze_support()
  432. main()