detect_objects.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. import os
  2. import cv2
  3. import time
  4. import datetime
  5. import ctypes
  6. import logging
  7. import multiprocessing as mp
  8. from contextlib import closing
  9. import numpy as np
  10. import tensorflow as tf
  11. from object_detection.utils import label_map_util
  12. from object_detection.utils import visualization_utils as vis_util
  13. from flask import Flask, Response, make_response
  14. RTSP_URL = os.getenv('RTSP_URL')
  15. # Path to frozen detection graph. This is the actual model that is used for the object detection.
  16. PATH_TO_CKPT = '/frozen_inference_graph.pb'
  17. # List of the strings that is used to add correct label for each box.
  18. PATH_TO_LABELS = '/label_map.pbtext'
  19. # TODO: make dynamic?
  20. NUM_CLASSES = 90
  21. # Loading label map
  22. label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
  23. categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,
  24. use_display_name=True)
  25. category_index = label_map_util.create_category_index(categories)
  26. def detect_objects(cropped_frame, full_frame, sess, detection_graph):
  27. # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
  28. image_np_expanded = np.expand_dims(cropped_frame, axis=0)
  29. image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
  30. # Each box represents a part of the image where a particular object was detected.
  31. boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
  32. # Each score represent how level of confidence for each of the objects.
  33. # Score is shown on the result image, together with the class label.
  34. scores = detection_graph.get_tensor_by_name('detection_scores:0')
  35. classes = detection_graph.get_tensor_by_name('detection_classes:0')
  36. num_detections = detection_graph.get_tensor_by_name('num_detections:0')
  37. # Actual detection.
  38. (boxes, scores, classes, num_detections) = sess.run(
  39. [boxes, scores, classes, num_detections],
  40. feed_dict={image_tensor: image_np_expanded})
  41. # build an array of detected objects
  42. objects = []
  43. for index, value in enumerate(classes[0]):
  44. object_dict = {}
  45. if scores[0, index] > 0.1:
  46. object_dict[(category_index.get(value)).get('name').encode('utf8')] = \
  47. scores[0, index]
  48. objects.append(object_dict)
  49. squeezed_boxes = np.squeeze(boxes)
  50. squeezed_scores = np.squeeze(scores)
  51. full_frame_shape = full_frame.shape
  52. cropped_frame_shape = cropped_frame.shape
  53. if(len(objects)>0):
  54. # reposition bounding box based on full frame
  55. for i, box in enumerate(squeezed_boxes):
  56. if box[2] > 0:
  57. squeezed_boxes[i][0] = ((box[0] * cropped_frame_shape[0]) + 200)/full_frame_shape[0] # ymin
  58. squeezed_boxes[i][1] = ((box[1] * cropped_frame_shape[0]) + 1300)/full_frame_shape[1] # xmin
  59. squeezed_boxes[i][2] = ((box[2] * cropped_frame_shape[0]) + 200)/full_frame_shape[0] # ymax
  60. squeezed_boxes[i][3] = ((box[3] * cropped_frame_shape[0]) + 1300)/full_frame_shape[1] # xmax
  61. # draw boxes for detected objects on image
  62. vis_util.visualize_boxes_and_labels_on_image_array(
  63. full_frame,
  64. squeezed_boxes,
  65. np.squeeze(classes).astype(np.int32),
  66. squeezed_scores,
  67. category_index,
  68. use_normalized_coordinates=True,
  69. line_thickness=4,
  70. min_score_thresh=.1)
  71. # cv2.rectangle(full_frame, (800, 100), (1250, 550), (255,0,0), 2)
  72. return objects, full_frame
  73. def main():
  74. # capture a single frame and check the frame shape so the correct array
  75. # size can be allocated in memory
  76. video = cv2.VideoCapture(RTSP_URL)
  77. ret, frame = video.read()
  78. if ret:
  79. frame_shape = frame.shape
  80. else:
  81. print("Unable to capture video stream")
  82. exit(1)
  83. video.release()
  84. # create shared value for storing the time the frame was captured
  85. # note: this must be a double even though the value you are storing
  86. # is a float. otherwise it stops updating the value in shared
  87. # memory. probably something to do with the size of the memory block
  88. shared_frame_time = mp.Value('d', 0.0)
  89. # compute the flattened array length from the array shape
  90. flat_array_length = frame_shape[0] * frame_shape[1] * frame_shape[2]
  91. # create shared array for storing the full frame image data
  92. shared_arr = mp.Array(ctypes.c_uint16, flat_array_length)
  93. # create shared array for storing the cropped frame image data
  94. # TODO: make dynamic
  95. shared_cropped_arr = mp.Array(ctypes.c_uint16, 300*300*3)
  96. # create shared array for passing the image data from detect_objects to flask
  97. shared_output_arr = mp.Array(ctypes.c_uint16, flat_array_length)
  98. # create a numpy array with the image shape from the shared memory array
  99. # this is used by flask to output an mjpeg stream
  100. frame_output_arr = tonumpyarray(shared_output_arr).reshape(frame_shape)
  101. capture_process = mp.Process(target=fetch_frames, args=(shared_arr, shared_cropped_arr, shared_frame_time, frame_shape))
  102. capture_process.daemon = True
  103. detection_process = mp.Process(target=process_frames, args=(shared_arr, shared_cropped_arr, shared_output_arr, shared_frame_time, frame_shape))
  104. detection_process.daemon = True
  105. capture_process.start()
  106. print("capture_process pid ", capture_process.pid)
  107. detection_process.start()
  108. print("detection_process pid ", detection_process.pid)
  109. app = Flask(__name__)
  110. @app.route('/')
  111. def index():
  112. # return a multipart response
  113. return Response(imagestream(),
  114. mimetype='multipart/x-mixed-replace; boundary=frame')
  115. def imagestream():
  116. while True:
  117. # max out at 5 FPS
  118. time.sleep(0.2)
  119. # convert back to BGR
  120. # frame_bgr = cv2.cvtColor(frame_output_arr, cv2.COLOR_RGB2BGR)
  121. # encode the image into a jpg
  122. ret, jpg = cv2.imencode('.jpg', frame_output_arr)
  123. yield (b'--frame\r\n'
  124. b'Content-Type: image/jpeg\r\n\r\n' + jpg.tobytes() + b'\r\n\r\n')
  125. app.run(host='0.0.0.0', debug=False)
  126. capture_process.join()
  127. detection_process.join()
  128. # convert shared memory array into numpy array
  129. def tonumpyarray(mp_arr):
  130. return np.frombuffer(mp_arr.get_obj(), dtype=np.uint16)
  131. # fetch the frames as fast a possible, only decoding the frames when the
  132. # detection_process has consumed the current frame
  133. def fetch_frames(shared_arr, shared_cropped_arr, shared_frame_time, frame_shape):
  134. # convert shared memory array into numpy and shape into image array
  135. arr = tonumpyarray(shared_arr).reshape(frame_shape)
  136. cropped_frame = tonumpyarray(shared_cropped_arr).reshape(300,300,3)
  137. # start the video capture
  138. video = cv2.VideoCapture(RTSP_URL)
  139. # keep the buffer small so we minimize old data
  140. video.set(cv2.CAP_PROP_BUFFERSIZE,1)
  141. while True:
  142. # grab the frame, but dont decode it yet
  143. ret = video.grab()
  144. # snapshot the time the frame was grabbed
  145. frame_time = datetime.datetime.now()
  146. if ret:
  147. # if the detection_process is ready for the next frame decode it
  148. # otherwise skip this frame and move onto the next one
  149. if shared_frame_time.value == 0.0:
  150. # go ahead and decode the current frame
  151. ret, frame = video.retrieve()
  152. if ret:
  153. # copy the frame into the numpy array
  154. # Position 1
  155. # cropped_frame[:] = frame[270:720, 550:1000]
  156. # Position 2
  157. # frame_cropped = frame[270:720, 100:550]
  158. # Car
  159. cropped_frame[:] = frame[200:500, 1300:1600]
  160. arr[:] = frame
  161. # signal to the detection_process by setting the shared_frame_time
  162. shared_frame_time.value = frame_time.timestamp()
  163. video.release()
  164. # do the actual object detection
  165. def process_frames(shared_arr, shared_cropped_arr, shared_output_arr, shared_frame_time, frame_shape):
  166. # shape shared input array into frame for processing
  167. arr = tonumpyarray(shared_arr).reshape(frame_shape)
  168. shared_cropped_frame = tonumpyarray(shared_cropped_arr).reshape(300,300,3)
  169. # shape shared output array into frame so it can be copied into
  170. output_arr = tonumpyarray(shared_output_arr).reshape(frame_shape)
  171. # Load a (frozen) Tensorflow model into memory before the processing loop
  172. detection_graph = tf.Graph()
  173. with detection_graph.as_default():
  174. od_graph_def = tf.GraphDef()
  175. with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
  176. serialized_graph = fid.read()
  177. od_graph_def.ParseFromString(serialized_graph)
  178. tf.import_graph_def(od_graph_def, name='')
  179. sess = tf.Session(graph=detection_graph)
  180. no_frames_available = -1
  181. while True:
  182. # if there isnt a frame ready for processing
  183. if shared_frame_time.value == 0.0:
  184. # save the first time there were no frames available
  185. if no_frames_available == -1:
  186. no_frames_available = datetime.datetime.now().timestamp()
  187. # if there havent been any frames available in 30 seconds,
  188. # sleep to avoid using so much cpu if the camera feed is down
  189. if no_frames_available > 0 and (datetime.datetime.now().timestamp() - no_frames_available) > 30:
  190. time.sleep(1)
  191. print("sleeping because no frames have been available in a while")
  192. else:
  193. # rest a little bit to avoid maxing out the CPU
  194. time.sleep(0.01)
  195. continue
  196. # we got a valid frame, so reset the timer
  197. no_frames_available = -1
  198. # if the frame is more than 0.5 second old, discard it
  199. if (datetime.datetime.now().timestamp() - shared_frame_time.value) > 0.5:
  200. # signal that we need a new frame
  201. shared_frame_time.value = 0.0
  202. # rest a little bit to avoid maxing out the CPU
  203. time.sleep(0.01)
  204. continue
  205. # make a copy of the frame
  206. frame = arr.copy()
  207. cropped_frame = shared_cropped_frame.copy()
  208. frame_time = shared_frame_time.value
  209. # signal that the frame has been used so a new one will be ready
  210. shared_frame_time.value = 0.0
  211. # convert to RGB
  212. cropped_frame_rgb = cv2.cvtColor(cropped_frame, cv2.COLOR_BGR2RGB)
  213. # do the object detection
  214. objects, frame_overlay = detect_objects(cropped_frame_rgb, frame, sess, detection_graph)
  215. # copy the output frame with the bounding boxes to the output array
  216. output_arr[:] = frame_overlay
  217. if(len(objects) > 0):
  218. print(objects)
  219. if __name__ == '__main__':
  220. mp.freeze_support()
  221. main()