detect_objects.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. import os
  2. import cv2
  3. import imutils
  4. import time
  5. import datetime
  6. import ctypes
  7. import logging
  8. import multiprocessing as mp
  9. import threading
  10. import json
  11. from contextlib import closing
  12. import numpy as np
  13. import tensorflow as tf
  14. from object_detection.utils import label_map_util
  15. from object_detection.utils import visualization_utils as vis_util
  16. from flask import Flask, Response, make_response
  17. import paho.mqtt.client as mqtt
  18. RTSP_URL = os.getenv('RTSP_URL')
  19. # Path to frozen detection graph. This is the actual model that is used for the object detection.
  20. PATH_TO_CKPT = '/frozen_inference_graph.pb'
  21. # List of the strings that is used to add correct label for each box.
  22. PATH_TO_LABELS = '/label_map.pbtext'
  23. MQTT_HOST = os.getenv('MQTT_HOST')
  24. MQTT_TOPIC_PREFIX = os.getenv('MQTT_TOPIC_PREFIX')
  25. MQTT_OBJECT_CLASSES = os.getenv('MQTT_OBJECT_CLASSES')
  26. # TODO: make dynamic?
  27. NUM_CLASSES = 90
  28. # REGIONS = "350,0,300,50:400,350,250,50:400,750,250,50"
  29. # REGIONS = "400,350,250,50"
  30. REGIONS = os.getenv('REGIONS')
  31. DETECTED_OBJECTS = []
  32. # Loading label map
  33. label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
  34. categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,
  35. use_display_name=True)
  36. category_index = label_map_util.create_category_index(categories)
  37. def detect_objects(cropped_frame, sess, detection_graph, region_size, region_x_offset, region_y_offset, debug):
  38. # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
  39. image_np_expanded = np.expand_dims(cropped_frame, axis=0)
  40. image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
  41. # Each box represents a part of the image where a particular object was detected.
  42. boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
  43. # Each score represent how level of confidence for each of the objects.
  44. # Score is shown on the result image, together with the class label.
  45. scores = detection_graph.get_tensor_by_name('detection_scores:0')
  46. classes = detection_graph.get_tensor_by_name('detection_classes:0')
  47. num_detections = detection_graph.get_tensor_by_name('num_detections:0')
  48. # Actual detection.
  49. (boxes, scores, classes, num_detections) = sess.run(
  50. [boxes, scores, classes, num_detections],
  51. feed_dict={image_tensor: image_np_expanded})
  52. if debug:
  53. if len([category_index.get(value) for index,value in enumerate(classes[0]) if scores[0,index] > 0.5]) > 0:
  54. vis_util.visualize_boxes_and_labels_on_image_array(
  55. cropped_frame,
  56. np.squeeze(boxes),
  57. np.squeeze(classes).astype(np.int32),
  58. np.squeeze(scores),
  59. category_index,
  60. use_normalized_coordinates=True,
  61. line_thickness=4)
  62. cv2.imwrite("/lab/debug/obj-{}-{}-{}.jpg".format(region_x_offset, region_y_offset, datetime.datetime.now().timestamp()), cropped_frame)
  63. # build an array of detected objects
  64. objects = []
  65. for index, value in enumerate(classes[0]):
  66. score = scores[0, index]
  67. if score > 0.5:
  68. box = boxes[0, index].tolist()
  69. box[0] = (box[0] * region_size) + region_y_offset
  70. box[1] = (box[1] * region_size) + region_x_offset
  71. box[2] = (box[2] * region_size) + region_y_offset
  72. box[3] = (box[3] * region_size) + region_x_offset
  73. objects += [value, scores[0, index]] + box
  74. # only get the first 10 objects
  75. if len(objects) == 60:
  76. break
  77. return objects
  78. class ObjectParser(threading.Thread):
  79. def __init__(self, object_arrays):
  80. threading.Thread.__init__(self)
  81. self._object_arrays = object_arrays
  82. def run(self):
  83. global DETECTED_OBJECTS
  84. while True:
  85. detected_objects = []
  86. for object_array in self._object_arrays:
  87. object_index = 0
  88. while(object_index < 60 and object_array[object_index] > 0):
  89. object_class = object_array[object_index]
  90. detected_objects.append({
  91. 'name': str(category_index.get(object_class).get('name')),
  92. 'score': object_array[object_index+1],
  93. 'ymin': int(object_array[object_index+2]),
  94. 'xmin': int(object_array[object_index+3]),
  95. 'ymax': int(object_array[object_index+4]),
  96. 'xmax': int(object_array[object_index+5])
  97. })
  98. object_index += 6
  99. DETECTED_OBJECTS = detected_objects
  100. time.sleep(0.1)
  101. class MqttPublisher(threading.Thread):
  102. def __init__(self, host, topic_prefix, object_classes, motion_flags):
  103. threading.Thread.__init__(self)
  104. self.client = mqtt.Client()
  105. self.client.will_set(topic_prefix+'/available', payload='offline', qos=1, retain=True)
  106. self.client.connect(host, 1883, 60)
  107. self.client.loop_start()
  108. self.client.publish(topic_prefix+'/available', 'online', retain=True)
  109. self.topic_prefix = topic_prefix
  110. self.object_classes = object_classes
  111. self.motion_flags = motion_flags
  112. def run(self):
  113. global DETECTED_OBJECTS
  114. last_sent_payload = ""
  115. last_motion = ""
  116. while True:
  117. # initialize the payload
  118. payload = {}
  119. for obj in self.object_classes:
  120. payload[obj] = []
  121. # loop over detected objects and populate
  122. # the payload
  123. detected_objects = DETECTED_OBJECTS.copy()
  124. for obj in detected_objects:
  125. if obj['name'] in self.object_classes:
  126. payload[obj['name']].append(obj)
  127. # send message for objects
  128. new_payload = json.dumps(payload, sort_keys=True)
  129. if new_payload != last_sent_payload:
  130. last_sent_payload = new_payload
  131. self.client.publish(self.topic_prefix+'/objects', new_payload, retain=False)
  132. # send message for motion
  133. motion_status = 'OFF'
  134. if any(obj.value == 1 for obj in self.motion_flags):
  135. motion_status = 'ON'
  136. if motion_status != last_motion:
  137. last_motion = motion_status
  138. self.client.publish(self.topic_prefix+'/motion', motion_status, retain=False)
  139. time.sleep(0.1)
  140. def main():
  141. # Parse selected regions
  142. regions = []
  143. for region_string in REGIONS.split(':'):
  144. region_parts = region_string.split(',')
  145. regions.append({
  146. 'size': int(region_parts[0]),
  147. 'x_offset': int(region_parts[1]),
  148. 'y_offset': int(region_parts[2]),
  149. 'min_object_size': int(region_parts[3]),
  150. # shared value for signaling to the capture process that we are ready for the next frame
  151. # (1 for ready 0 for not ready)
  152. 'ready_for_frame': mp.Value('i', 1),
  153. # shared value for motion detection signal (1 for motion 0 for no motion)
  154. 'motion_detected': mp.Value('i', 0),
  155. # create shared array for storing 10 detected objects
  156. # note: this must be a double even though the value you are storing
  157. # is a float. otherwise it stops updating the value in shared
  158. # memory. probably something to do with the size of the memory block
  159. 'output_array': mp.Array(ctypes.c_double, 6*10)
  160. })
  161. # capture a single frame and check the frame shape so the correct array
  162. # size can be allocated in memory
  163. video = cv2.VideoCapture(RTSP_URL)
  164. ret, frame = video.read()
  165. if ret:
  166. frame_shape = frame.shape
  167. else:
  168. print("Unable to capture video stream")
  169. exit(1)
  170. video.release()
  171. # compute the flattened array length from the array shape
  172. flat_array_length = frame_shape[0] * frame_shape[1] * frame_shape[2]
  173. # create shared array for storing the full frame image data
  174. shared_arr = mp.Array(ctypes.c_uint16, flat_array_length)
  175. # create shared value for storing the frame_time
  176. shared_frame_time = mp.Value('d', 0.0)
  177. # Lock to control access to the frame while writing
  178. frame_lock = mp.Lock()
  179. # Condition for notifying that a new frame is ready
  180. frame_ready = mp.Condition()
  181. # shape current frame so it can be treated as an image
  182. frame_arr = tonumpyarray(shared_arr).reshape(frame_shape)
  183. capture_process = mp.Process(target=fetch_frames, args=(shared_arr,
  184. shared_frame_time, frame_lock, frame_ready, frame_shape))
  185. capture_process.daemon = True
  186. detection_processes = []
  187. for index, region in enumerate(regions):
  188. detection_process = mp.Process(target=process_frames, args=(shared_arr,
  189. region['output_array'],
  190. shared_frame_time,
  191. region['motion_detected'],
  192. frame_shape,
  193. region['size'], region['x_offset'], region['y_offset']))
  194. detection_process.daemon = True
  195. detection_processes.append(detection_process)
  196. motion_processes = []
  197. for index, region in enumerate(regions):
  198. motion_process = mp.Process(target=detect_motion, args=(shared_arr,
  199. shared_frame_time,
  200. frame_lock, frame_ready,
  201. region['motion_detected'],
  202. frame_shape,
  203. region['size'], region['x_offset'], region['y_offset'],
  204. region['min_object_size'],
  205. True))
  206. motion_process.daemon = True
  207. motion_processes.append(motion_process)
  208. object_parser = ObjectParser([region['output_array'] for region in regions])
  209. object_parser.start()
  210. mqtt_publisher = MqttPublisher(MQTT_HOST, MQTT_TOPIC_PREFIX,
  211. MQTT_OBJECT_CLASSES.split(','),
  212. [region['motion_detected'] for region in regions])
  213. mqtt_publisher.start()
  214. capture_process.start()
  215. print("capture_process pid ", capture_process.pid)
  216. for detection_process in detection_processes:
  217. detection_process.start()
  218. print("detection_process pid ", detection_process.pid)
  219. for motion_process in motion_processes:
  220. motion_process.start()
  221. print("motion_process pid ", motion_process.pid)
  222. app = Flask(__name__)
  223. @app.route('/')
  224. def index():
  225. # return a multipart response
  226. return Response(imagestream(),
  227. mimetype='multipart/x-mixed-replace; boundary=frame')
  228. def imagestream():
  229. global DETECTED_OBJECTS
  230. while True:
  231. # max out at 5 FPS
  232. time.sleep(0.2)
  233. # make a copy of the current detected objects
  234. detected_objects = DETECTED_OBJECTS.copy()
  235. # lock and make a copy of the current frame
  236. frame_lock.aquire()
  237. frame = frame_arr.copy()
  238. frame_lock.release()
  239. # convert to RGB for drawing
  240. frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  241. # draw the bounding boxes on the screen
  242. for obj in detected_objects:
  243. vis_util.draw_bounding_box_on_image_array(frame,
  244. obj['ymin'],
  245. obj['xmin'],
  246. obj['ymax'],
  247. obj['xmax'],
  248. color='red',
  249. thickness=2,
  250. display_str_list=["{}: {}%".format(obj['name'],int(obj['score']*100))],
  251. use_normalized_coordinates=False)
  252. for region in regions:
  253. color = (255,255,255)
  254. if region['motion_detected'].value == 1:
  255. color = (0,255,0)
  256. cv2.rectangle(frame, (region['x_offset'], region['y_offset']),
  257. (region['x_offset']+region['size'], region['y_offset']+region['size']),
  258. color, 2)
  259. cv2.putText(frame, datetime.datetime.now().strftime("%H:%M:%S"), (1125, 20),
  260. cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
  261. # convert back to BGR
  262. frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
  263. # encode the image into a jpg
  264. ret, jpg = cv2.imencode('.jpg', frame)
  265. yield (b'--frame\r\n'
  266. b'Content-Type: image/jpeg\r\n\r\n' + jpg.tobytes() + b'\r\n\r\n')
  267. app.run(host='0.0.0.0', debug=False)
  268. capture_process.join()
  269. for detection_process in detection_processes:
  270. detection_process.join()
  271. for motion_process in motion_processes:
  272. motion_process.join()
  273. object_parser.join()
  274. mqtt_publisher.join()
  275. # convert shared memory array into numpy array
  276. def tonumpyarray(mp_arr):
  277. return np.frombuffer(mp_arr.get_obj(), dtype=np.uint16)
  278. # fetch the frames as fast a possible, only decoding the frames when the
  279. # detection_process has consumed the current frame
  280. def fetch_frames(shared_arr, shared_frame_time, frame_lock, frame_ready, frame_shape):
  281. # convert shared memory array into numpy and shape into image array
  282. arr = tonumpyarray(shared_arr).reshape(frame_shape)
  283. # start the video capture
  284. video = cv2.VideoCapture(RTSP_URL)
  285. # keep the buffer small so we minimize old data
  286. video.set(cv2.CAP_PROP_BUFFERSIZE,1)
  287. while True:
  288. # grab the frame, but dont decode it yet
  289. ret = video.grab()
  290. # snapshot the time the frame was grabbed
  291. frame_time = datetime.datetime.now()
  292. if ret:
  293. # go ahead and decode the current frame
  294. ret, frame = video.retrieve()
  295. if ret:
  296. # Lock access and update frame
  297. frame_lock.acquire()
  298. arr[:] = frame
  299. shared_frame_time.value = frame_time.timestamp()
  300. frame_lock.release()
  301. # Notify with the condition that a new frame is ready
  302. with frame_ready:
  303. frame_ready.notify_all()
  304. video.release()
  305. # do the actual object detection
  306. def process_frames(shared_arr, shared_output_arr, shared_frame_time, shared_motion, frame_shape, region_size, region_x_offset, region_y_offset):
  307. debug = True
  308. # shape shared input array into frame for processing
  309. arr = tonumpyarray(shared_arr).reshape(frame_shape)
  310. # Load a (frozen) Tensorflow model into memory before the processing loop
  311. detection_graph = tf.Graph()
  312. with detection_graph.as_default():
  313. od_graph_def = tf.GraphDef()
  314. with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
  315. serialized_graph = fid.read()
  316. od_graph_def.ParseFromString(serialized_graph)
  317. tf.import_graph_def(od_graph_def, name='')
  318. sess = tf.Session(graph=detection_graph)
  319. no_frames_available = -1
  320. frame_time = 0.0
  321. while True:
  322. now = datetime.datetime.now().timestamp()
  323. # if there is no motion detected
  324. if shared_motion.value == 0:
  325. time.sleep(0.1)
  326. continue
  327. # if there isnt a new frame ready for processing
  328. if shared_frame_time.value == frame_time:
  329. # save the first time there were no frames available
  330. if no_frames_available == -1:
  331. no_frames_available = now
  332. # if there havent been any frames available in 30 seconds,
  333. # sleep to avoid using so much cpu if the camera feed is down
  334. if no_frames_available > 0 and (now - no_frames_available) > 30:
  335. time.sleep(1)
  336. print("sleeping because no frames have been available in a while")
  337. else:
  338. # rest a little bit to avoid maxing out the CPU
  339. time.sleep(0.1)
  340. continue
  341. # we got a valid frame, so reset the timer
  342. no_frames_available = -1
  343. # if the frame is more than 0.5 second old, ignore it
  344. if (now - shared_frame_time.value) > 0.5:
  345. # rest a little bit to avoid maxing out the CPU
  346. time.sleep(0.1)
  347. continue
  348. # make a copy of the cropped frame
  349. cropped_frame = arr[region_y_offset:region_y_offset+region_size, region_x_offset:region_x_offset+region_size].copy()
  350. frame_time = shared_frame_time.value
  351. # convert to RGB
  352. cropped_frame_rgb = cv2.cvtColor(cropped_frame, cv2.COLOR_BGR2RGB)
  353. # do the object detection
  354. objects = detect_objects(cropped_frame_rgb, sess, detection_graph, region_size, region_x_offset, region_y_offset, True)
  355. # copy the detected objects to the output array, filling the array when needed
  356. shared_output_arr[:] = objects + [0.0] * (60-len(objects))
  357. # do the actual motion detection
  358. def detect_motion(shared_arr, shared_frame_time, frame_lock, frame_ready, shared_motion, frame_shape, region_size, region_x_offset, region_y_offset, min_motion_area, debug):
  359. # shape shared input array into frame for processing
  360. arr = tonumpyarray(shared_arr).reshape(frame_shape)
  361. no_frames_available = -1
  362. avg_frame = None
  363. last_motion = -1
  364. frame_time = 0.0
  365. motion_frames = 0
  366. while True:
  367. now = datetime.datetime.now().timestamp()
  368. # if it has been long enough since the last motion, clear the flag
  369. if last_motion > 0 and (now - last_motion) > 2:
  370. last_motion = -1
  371. shared_motion.value = 0
  372. with frame_ready:
  373. # if there isnt a frame ready for processing or it is old, wait for a signal
  374. if shared_frame_time.value == frame_time or (now - shared_frame_time.value) > 0.5:
  375. frame_ready.wait()
  376. # lock and make a copy of the cropped frame
  377. frame_lock.acquire()
  378. cropped_frame = arr[region_y_offset:region_y_offset+region_size, region_x_offset:region_x_offset+region_size].copy().astype('uint8')
  379. frame_time = shared_frame_time.value
  380. frame_lock.release()
  381. # convert to grayscale
  382. gray = cv2.cvtColor(cropped_frame, cv2.COLOR_BGR2GRAY)
  383. # apply gaussian blur
  384. gray = cv2.GaussianBlur(gray, (21, 21), 0)
  385. if avg_frame is None:
  386. avg_frame = gray.copy().astype("float")
  387. continue
  388. # look at the delta from the avg_frame
  389. cv2.accumulateWeighted(gray, avg_frame, 0.5)
  390. frameDelta = cv2.absdiff(gray, cv2.convertScaleAbs(avg_frame))
  391. thresh = cv2.threshold(frameDelta, 25, 255, cv2.THRESH_BINARY)[1]
  392. # dilate the thresholded image to fill in holes, then find contours
  393. # on thresholded image
  394. thresh = cv2.dilate(thresh, None, iterations=2)
  395. cnts = cv2.findContours(thresh.copy(), cv2.RETR_EXTERNAL,
  396. cv2.CHAIN_APPROX_SIMPLE)
  397. cnts = imutils.grab_contours(cnts)
  398. # if there are no contours, there is no motion
  399. if len(cnts) < 1:
  400. motion_frames = 0
  401. continue
  402. motion_found = False
  403. # loop over the contours
  404. for c in cnts:
  405. # if the contour is big enough, count it as motion
  406. contour_area = cv2.contourArea(c)
  407. if contour_area > min_motion_area:
  408. motion_found = True
  409. if debug:
  410. cv2.drawContours(cropped_frame, [c], -1, (0, 255, 0), 2)
  411. x, y, w, h = cv2.boundingRect(c)
  412. cv2.putText(cropped_frame, str(contour_area), (x, y),
  413. cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 100, 0), 2)
  414. else:
  415. break
  416. if motion_found:
  417. motion_frames += 1
  418. # if there have been enough consecutive motion frames, report motion
  419. if motion_frames >= 3:
  420. shared_motion.value = 1
  421. last_motion = now
  422. else:
  423. motion_frames = 0
  424. if debug and motion_frames > 0:
  425. cv2.imwrite("/lab/debug/motion-{}-{}-{}.jpg".format(region_x_offset, region_y_offset, datetime.datetime.now().timestamp()), cropped_frame)
  426. if __name__ == '__main__':
  427. mp.freeze_support()
  428. main()