detect_objects.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. import os
  2. import cv2
  3. import imutils
  4. import time
  5. import datetime
  6. import ctypes
  7. import logging
  8. import multiprocessing as mp
  9. import threading
  10. import json
  11. from contextlib import closing
  12. import numpy as np
  13. import tensorflow as tf
  14. from object_detection.utils import label_map_util
  15. from object_detection.utils import visualization_utils as vis_util
  16. from flask import Flask, Response, make_response
  17. import paho.mqtt.client as mqtt
  18. RTSP_URL = os.getenv('RTSP_URL')
  19. # Path to frozen detection graph. This is the actual model that is used for the object detection.
  20. PATH_TO_CKPT = '/frozen_inference_graph.pb'
  21. # List of the strings that is used to add correct label for each box.
  22. PATH_TO_LABELS = '/label_map.pbtext'
  23. MQTT_HOST = os.getenv('MQTT_HOST')
  24. MQTT_TOPIC_PREFIX = os.getenv('MQTT_TOPIC_PREFIX')
  25. MQTT_OBJECT_CLASSES = os.getenv('MQTT_OBJECT_CLASSES')
  26. # TODO: make dynamic?
  27. NUM_CLASSES = 90
  28. # REGIONS = "350,0,300,50:400,350,250,50:400,750,250,50"
  29. # REGIONS = "400,350,250,50"
  30. REGIONS = os.getenv('REGIONS')
  31. DETECTED_OBJECTS = []
  32. # Loading label map
  33. label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
  34. categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,
  35. use_display_name=True)
  36. category_index = label_map_util.create_category_index(categories)
  37. def detect_objects(cropped_frame, sess, detection_graph, region_size, region_x_offset, region_y_offset, debug):
  38. # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
  39. image_np_expanded = np.expand_dims(cropped_frame, axis=0)
  40. image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
  41. # Each box represents a part of the image where a particular object was detected.
  42. boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
  43. # Each score represent how level of confidence for each of the objects.
  44. # Score is shown on the result image, together with the class label.
  45. scores = detection_graph.get_tensor_by_name('detection_scores:0')
  46. classes = detection_graph.get_tensor_by_name('detection_classes:0')
  47. num_detections = detection_graph.get_tensor_by_name('num_detections:0')
  48. # Actual detection.
  49. (boxes, scores, classes, num_detections) = sess.run(
  50. [boxes, scores, classes, num_detections],
  51. feed_dict={image_tensor: image_np_expanded})
  52. if debug:
  53. if len([category_index.get(value) for index,value in enumerate(classes[0]) if scores[0,index] > 0.5]) > 0:
  54. vis_util.visualize_boxes_and_labels_on_image_array(
  55. cropped_frame,
  56. np.squeeze(boxes),
  57. np.squeeze(classes).astype(np.int32),
  58. np.squeeze(scores),
  59. category_index,
  60. use_normalized_coordinates=True,
  61. line_thickness=4)
  62. cv2.imwrite("/lab/debug/obj-{}-{}-{}.jpg".format(region_x_offset, region_y_offset, datetime.datetime.now().timestamp()), cropped_frame)
  63. # build an array of detected objects
  64. objects = []
  65. for index, value in enumerate(classes[0]):
  66. score = scores[0, index]
  67. if score > 0.5:
  68. box = boxes[0, index].tolist()
  69. box[0] = (box[0] * region_size) + region_y_offset
  70. box[1] = (box[1] * region_size) + region_x_offset
  71. box[2] = (box[2] * region_size) + region_y_offset
  72. box[3] = (box[3] * region_size) + region_x_offset
  73. objects += [value, scores[0, index]] + box
  74. # only get the first 10 objects
  75. if len(objects) == 60:
  76. break
  77. return objects
  78. class ObjectParser(threading.Thread):
  79. def __init__(self, objects_changed, object_arrays):
  80. threading.Thread.__init__(self)
  81. self._objects_changed = objects_changed
  82. self._object_arrays = object_arrays
  83. def run(self):
  84. global DETECTED_OBJECTS
  85. while True:
  86. detected_objects = []
  87. # wait until object detection has run
  88. with self._objects_changed:
  89. self._objects_changed.wait()
  90. for object_array in self._object_arrays:
  91. object_index = 0
  92. while(object_index < 60 and object_array[object_index] > 0):
  93. object_class = object_array[object_index]
  94. detected_objects.append({
  95. 'name': str(category_index.get(object_class).get('name')),
  96. 'score': object_array[object_index+1],
  97. 'ymin': int(object_array[object_index+2]),
  98. 'xmin': int(object_array[object_index+3]),
  99. 'ymax': int(object_array[object_index+4]),
  100. 'xmax': int(object_array[object_index+5])
  101. })
  102. object_index += 6
  103. DETECTED_OBJECTS = detected_objects
  104. class MqttPublisher(threading.Thread):
  105. def __init__(self, host, topic_prefix, object_classes, motion_flags):
  106. threading.Thread.__init__(self)
  107. self.client = mqtt.Client()
  108. self.client.will_set(topic_prefix+'/available', payload='offline', qos=1, retain=True)
  109. self.client.connect(host, 1883, 60)
  110. self.client.loop_start()
  111. self.client.publish(topic_prefix+'/available', 'online', retain=True)
  112. self.topic_prefix = topic_prefix
  113. self.object_classes = object_classes
  114. self.motion_flags = motion_flags
  115. def run(self):
  116. global DETECTED_OBJECTS
  117. last_sent_payload = ""
  118. last_motion = ""
  119. while True:
  120. # initialize the payload
  121. payload = {}
  122. for obj in self.object_classes:
  123. payload[obj] = []
  124. # loop over detected objects and populate
  125. # the payload
  126. detected_objects = DETECTED_OBJECTS.copy()
  127. for obj in detected_objects:
  128. if obj['name'] in self.object_classes:
  129. payload[obj['name']].append(obj)
  130. # send message for objects
  131. new_payload = json.dumps(payload, sort_keys=True)
  132. if new_payload != last_sent_payload:
  133. last_sent_payload = new_payload
  134. self.client.publish(self.topic_prefix+'/objects', new_payload, retain=False)
  135. # send message for motion
  136. motion_status = 'OFF'
  137. if any(obj.is_set() for obj in self.motion_flags):
  138. motion_status = 'ON'
  139. if motion_status != last_motion:
  140. last_motion = motion_status
  141. self.client.publish(self.topic_prefix+'/motion', motion_status, retain=False)
  142. time.sleep(0.1)
  143. def main():
  144. # Parse selected regions
  145. regions = []
  146. for region_string in REGIONS.split(':'):
  147. region_parts = region_string.split(',')
  148. regions.append({
  149. 'size': int(region_parts[0]),
  150. 'x_offset': int(region_parts[1]),
  151. 'y_offset': int(region_parts[2]),
  152. 'min_object_size': int(region_parts[3]),
  153. # Event for motion detection signaling
  154. 'motion_detected': mp.Event(),
  155. # create shared array for storing 10 detected objects
  156. # note: this must be a double even though the value you are storing
  157. # is a float. otherwise it stops updating the value in shared
  158. # memory. probably something to do with the size of the memory block
  159. 'output_array': mp.Array(ctypes.c_double, 6*10)
  160. })
  161. # capture a single frame and check the frame shape so the correct array
  162. # size can be allocated in memory
  163. video = cv2.VideoCapture(RTSP_URL)
  164. ret, frame = video.read()
  165. if ret:
  166. frame_shape = frame.shape
  167. else:
  168. print("Unable to capture video stream")
  169. exit(1)
  170. video.release()
  171. # compute the flattened array length from the array shape
  172. flat_array_length = frame_shape[0] * frame_shape[1] * frame_shape[2]
  173. # create shared array for storing the full frame image data
  174. shared_arr = mp.Array(ctypes.c_uint16, flat_array_length)
  175. # create shared value for storing the frame_time
  176. shared_frame_time = mp.Value('d', 0.0)
  177. # Lock to control access to the frame while writing
  178. frame_lock = mp.Lock()
  179. # Condition for notifying that a new frame is ready
  180. frame_ready = mp.Condition()
  181. # Condition for notifying that object detection ran
  182. objects_changed = mp.Condition()
  183. # shape current frame so it can be treated as an image
  184. frame_arr = tonumpyarray(shared_arr).reshape(frame_shape)
  185. capture_process = mp.Process(target=fetch_frames, args=(shared_arr,
  186. shared_frame_time, frame_lock, frame_ready, frame_shape))
  187. capture_process.daemon = True
  188. detection_processes = []
  189. motion_processes = []
  190. for region in regions:
  191. detection_process = mp.Process(target=process_frames, args=(shared_arr,
  192. region['output_array'],
  193. shared_frame_time,
  194. frame_lock, frame_ready,
  195. region['motion_detected'],
  196. objects_changed,
  197. frame_shape,
  198. region['size'], region['x_offset'], region['y_offset']))
  199. detection_process.daemon = True
  200. detection_processes.append(detection_process)
  201. motion_process = mp.Process(target=detect_motion, args=(shared_arr,
  202. shared_frame_time,
  203. frame_lock, frame_ready,
  204. region['motion_detected'],
  205. frame_shape,
  206. region['size'], region['x_offset'], region['y_offset'],
  207. region['min_object_size'],
  208. True))
  209. motion_process.daemon = True
  210. motion_processes.append(motion_process)
  211. object_parser = ObjectParser(objects_changed, [region['output_array'] for region in regions])
  212. object_parser.start()
  213. mqtt_publisher = MqttPublisher(MQTT_HOST, MQTT_TOPIC_PREFIX,
  214. MQTT_OBJECT_CLASSES.split(','),
  215. [region['motion_detected'] for region in regions])
  216. mqtt_publisher.start()
  217. capture_process.start()
  218. print("capture_process pid ", capture_process.pid)
  219. for detection_process in detection_processes:
  220. detection_process.start()
  221. print("detection_process pid ", detection_process.pid)
  222. for motion_process in motion_processes:
  223. motion_process.start()
  224. print("motion_process pid ", motion_process.pid)
  225. app = Flask(__name__)
  226. @app.route('/')
  227. def index():
  228. # return a multipart response
  229. return Response(imagestream(),
  230. mimetype='multipart/x-mixed-replace; boundary=frame')
  231. def imagestream():
  232. global DETECTED_OBJECTS
  233. while True:
  234. # max out at 5 FPS
  235. time.sleep(0.2)
  236. # make a copy of the current detected objects
  237. detected_objects = DETECTED_OBJECTS.copy()
  238. # lock and make a copy of the current frame
  239. with frame_lock:
  240. frame = frame_arr.copy()
  241. # convert to RGB for drawing
  242. frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  243. # draw the bounding boxes on the screen
  244. for obj in detected_objects:
  245. vis_util.draw_bounding_box_on_image_array(frame,
  246. obj['ymin'],
  247. obj['xmin'],
  248. obj['ymax'],
  249. obj['xmax'],
  250. color='red',
  251. thickness=2,
  252. display_str_list=["{}: {}%".format(obj['name'],int(obj['score']*100))],
  253. use_normalized_coordinates=False)
  254. for region in regions:
  255. color = (255,255,255)
  256. if region['motion_detected'].is_set():
  257. color = (0,255,0)
  258. cv2.rectangle(frame, (region['x_offset'], region['y_offset']),
  259. (region['x_offset']+region['size'], region['y_offset']+region['size']),
  260. color, 2)
  261. cv2.putText(frame, datetime.datetime.now().strftime("%H:%M:%S"), (1125, 20),
  262. cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
  263. # convert back to BGR
  264. frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
  265. # encode the image into a jpg
  266. ret, jpg = cv2.imencode('.jpg', frame)
  267. yield (b'--frame\r\n'
  268. b'Content-Type: image/jpeg\r\n\r\n' + jpg.tobytes() + b'\r\n\r\n')
  269. app.run(host='0.0.0.0', debug=False)
  270. capture_process.join()
  271. for detection_process in detection_processes:
  272. detection_process.join()
  273. for motion_process in motion_processes:
  274. motion_process.join()
  275. object_parser.join()
  276. mqtt_publisher.join()
  277. # convert shared memory array into numpy array
  278. def tonumpyarray(mp_arr):
  279. return np.frombuffer(mp_arr.get_obj(), dtype=np.uint16)
  280. # fetch the frames as fast a possible, only decoding the frames when the
  281. # detection_process has consumed the current frame
  282. def fetch_frames(shared_arr, shared_frame_time, frame_lock, frame_ready, frame_shape):
  283. # convert shared memory array into numpy and shape into image array
  284. arr = tonumpyarray(shared_arr).reshape(frame_shape)
  285. # start the video capture
  286. video = cv2.VideoCapture(RTSP_URL)
  287. # keep the buffer small so we minimize old data
  288. video.set(cv2.CAP_PROP_BUFFERSIZE,1)
  289. while True:
  290. # grab the frame, but dont decode it yet
  291. ret = video.grab()
  292. # snapshot the time the frame was grabbed
  293. frame_time = datetime.datetime.now()
  294. if ret:
  295. # go ahead and decode the current frame
  296. ret, frame = video.retrieve()
  297. if ret:
  298. # Lock access and update frame
  299. with frame_lock:
  300. arr[:] = frame
  301. shared_frame_time.value = frame_time.timestamp()
  302. # Notify with the condition that a new frame is ready
  303. with frame_ready:
  304. frame_ready.notify_all()
  305. video.release()
  306. # do the actual object detection
  307. def process_frames(shared_arr, shared_output_arr, shared_frame_time, frame_lock, frame_ready,
  308. motion_detected, objects_changed, frame_shape, region_size, region_x_offset, region_y_offset):
  309. debug = True
  310. # shape shared input array into frame for processing
  311. arr = tonumpyarray(shared_arr).reshape(frame_shape)
  312. # Load a (frozen) Tensorflow model into memory before the processing loop
  313. detection_graph = tf.Graph()
  314. with detection_graph.as_default():
  315. od_graph_def = tf.GraphDef()
  316. with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
  317. serialized_graph = fid.read()
  318. od_graph_def.ParseFromString(serialized_graph)
  319. tf.import_graph_def(od_graph_def, name='')
  320. sess = tf.Session(graph=detection_graph)
  321. frame_time = 0.0
  322. while True:
  323. now = datetime.datetime.now().timestamp()
  324. # wait until motion is detected
  325. motion_detected.wait()
  326. with frame_ready:
  327. # if there isnt a frame ready for processing or it is old, wait for a signal
  328. if shared_frame_time.value == frame_time or (now - shared_frame_time.value) > 0.5:
  329. frame_ready.wait()
  330. # make a copy of the cropped frame
  331. with frame_lock:
  332. cropped_frame = arr[region_y_offset:region_y_offset+region_size, region_x_offset:region_x_offset+region_size].copy()
  333. frame_time = shared_frame_time.value
  334. # convert to RGB
  335. cropped_frame_rgb = cv2.cvtColor(cropped_frame, cv2.COLOR_BGR2RGB)
  336. # do the object detection
  337. objects = detect_objects(cropped_frame_rgb, sess, detection_graph, region_size, region_x_offset, region_y_offset, True)
  338. # copy the detected objects to the output array, filling the array when needed
  339. shared_output_arr[:] = objects + [0.0] * (60-len(objects))
  340. with objects_changed:
  341. objects_changed.notify_all()
  342. # do the actual motion detection
  343. def detect_motion(shared_arr, shared_frame_time, frame_lock, frame_ready, motion_detected, frame_shape, region_size, region_x_offset, region_y_offset, min_motion_area, debug):
  344. # shape shared input array into frame for processing
  345. arr = tonumpyarray(shared_arr).reshape(frame_shape)
  346. avg_frame = None
  347. last_motion = -1
  348. frame_time = 0.0
  349. motion_frames = 0
  350. while True:
  351. now = datetime.datetime.now().timestamp()
  352. # if it has been long enough since the last motion, clear the flag
  353. if last_motion > 0 and (now - last_motion) > 2:
  354. last_motion = -1
  355. motion_detected.clear()
  356. with frame_ready:
  357. # if there isnt a frame ready for processing or it is old, wait for a signal
  358. if shared_frame_time.value == frame_time or (now - shared_frame_time.value) > 0.5:
  359. frame_ready.wait()
  360. # lock and make a copy of the cropped frame
  361. with frame_lock:
  362. cropped_frame = arr[region_y_offset:region_y_offset+region_size, region_x_offset:region_x_offset+region_size].copy().astype('uint8')
  363. frame_time = shared_frame_time.value
  364. # convert to grayscale
  365. gray = cv2.cvtColor(cropped_frame, cv2.COLOR_BGR2GRAY)
  366. # apply gaussian blur
  367. gray = cv2.GaussianBlur(gray, (21, 21), 0)
  368. if avg_frame is None:
  369. avg_frame = gray.copy().astype("float")
  370. continue
  371. # look at the delta from the avg_frame
  372. cv2.accumulateWeighted(gray, avg_frame, 0.5)
  373. frameDelta = cv2.absdiff(gray, cv2.convertScaleAbs(avg_frame))
  374. thresh = cv2.threshold(frameDelta, 25, 255, cv2.THRESH_BINARY)[1]
  375. # dilate the thresholded image to fill in holes, then find contours
  376. # on thresholded image
  377. thresh = cv2.dilate(thresh, None, iterations=2)
  378. cnts = cv2.findContours(thresh.copy(), cv2.RETR_EXTERNAL,
  379. cv2.CHAIN_APPROX_SIMPLE)
  380. cnts = imutils.grab_contours(cnts)
  381. # if there are no contours, there is no motion
  382. if len(cnts) < 1:
  383. motion_frames = 0
  384. continue
  385. motion_found = False
  386. # loop over the contours
  387. for c in cnts:
  388. # if the contour is big enough, count it as motion
  389. contour_area = cv2.contourArea(c)
  390. if contour_area > min_motion_area:
  391. motion_found = True
  392. if debug:
  393. cv2.drawContours(cropped_frame, [c], -1, (0, 255, 0), 2)
  394. x, y, w, h = cv2.boundingRect(c)
  395. cv2.putText(cropped_frame, str(contour_area), (x, y),
  396. cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 100, 0), 2)
  397. else:
  398. break
  399. if motion_found:
  400. motion_frames += 1
  401. # if there have been enough consecutive motion frames, report motion
  402. if motion_frames >= 3:
  403. motion_detected.set()
  404. last_motion = now
  405. else:
  406. motion_frames = 0
  407. if debug and motion_frames > 0:
  408. cv2.imwrite("/lab/debug/motion-{}-{}-{}.jpg".format(region_x_offset, region_y_offset, datetime.datetime.now().timestamp()), cropped_frame)
  409. if __name__ == '__main__':
  410. mp.freeze_support()
  411. main()