detect_objects.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. import os
  2. import cv2
  3. import imutils
  4. import time
  5. import datetime
  6. import ctypes
  7. import logging
  8. import multiprocessing as mp
  9. import queue
  10. import threading
  11. import json
  12. import yaml
  13. from contextlib import closing
  14. import numpy as np
  15. from object_detection.utils import visualization_utils as vis_util
  16. from flask import Flask, Response, make_response, send_file
  17. import paho.mqtt.client as mqtt
  18. from frigate.util import tonumpyarray
  19. from frigate.mqtt import MqttMotionPublisher, MqttObjectPublisher
  20. from frigate.objects import ObjectParser, ObjectCleaner, BestPersonFrame
  21. from frigate.motion import detect_motion
  22. from frigate.video import fetch_frames, FrameTracker
  23. from frigate.object_detection import FramePrepper, PreppedQueueProcessor
  24. with open('/config/config.yml') as f:
  25. # use safe_load instead load
  26. CONFIG = yaml.safe_load(f)
  27. rtsp_camera = CONFIG['cameras']['back']['rtsp']
  28. if (rtsp_camera['password'].startswith('$')):
  29. rtsp_camera['password'] = os.getenv(rtsp_camera['password'][1:])
  30. RTSP_URL = 'rtsp://{}:{}@{}:{}{}'.format(rtsp_camera['user'],
  31. rtsp_camera['password'], rtsp_camera['host'], rtsp_camera['port'],
  32. rtsp_camera['path'])
  33. MQTT_HOST = CONFIG['mqtt']['host']
  34. MQTT_PORT = CONFIG.get('mqtt', {}).get('port', 1883)
  35. MQTT_TOPIC_PREFIX = CONFIG['mqtt']['topic_prefix'] + '/back'
  36. MQTT_USER = CONFIG.get('mqtt', {}).get('user')
  37. MQTT_PASS = CONFIG.get('mqtt', {}).get('password')
  38. WEB_PORT = CONFIG.get('web_port', 5000)
  39. DEBUG = (CONFIG.get('debug', '0') == '1')
  40. def main():
  41. DETECTED_OBJECTS = []
  42. recent_frames = {}
  43. # Parse selected regions
  44. regions = CONFIG['cameras']['back']['regions']
  45. # capture a single frame and check the frame shape so the correct array
  46. # size can be allocated in memory
  47. video = cv2.VideoCapture(RTSP_URL)
  48. ret, frame = video.read()
  49. if ret:
  50. frame_shape = frame.shape
  51. else:
  52. print("Unable to capture video stream")
  53. exit(1)
  54. video.release()
  55. # compute the flattened array length from the array shape
  56. flat_array_length = frame_shape[0] * frame_shape[1] * frame_shape[2]
  57. # create shared array for storing the full frame image data
  58. shared_arr = mp.Array(ctypes.c_uint8, flat_array_length)
  59. # create shared value for storing the frame_time
  60. shared_frame_time = mp.Value('d', 0.0)
  61. # Lock to control access to the frame
  62. frame_lock = mp.Lock()
  63. # Condition for notifying that a new frame is ready
  64. frame_ready = mp.Condition()
  65. # Condition for notifying that objects were parsed
  66. objects_parsed = mp.Condition()
  67. # Queue for detected objects
  68. object_queue = queue.Queue()
  69. # Queue for prepped frames
  70. prepped_frame_queue = queue.Queue(len(regions)*2)
  71. # shape current frame so it can be treated as an image
  72. frame_arr = tonumpyarray(shared_arr).reshape(frame_shape)
  73. # start the process to capture frames from the RTSP stream and store in a shared array
  74. capture_process = mp.Process(target=fetch_frames, args=(shared_arr,
  75. shared_frame_time, frame_lock, frame_ready, frame_shape, RTSP_URL))
  76. capture_process.daemon = True
  77. # for each region, start a separate thread to resize the region and prep for detection
  78. detection_prep_threads = []
  79. for region in regions:
  80. detection_prep_threads.append(FramePrepper(
  81. frame_arr,
  82. shared_frame_time,
  83. frame_ready,
  84. frame_lock,
  85. region['size'], region['x_offset'], region['y_offset'],
  86. prepped_frame_queue
  87. ))
  88. prepped_queue_processor = PreppedQueueProcessor(
  89. prepped_frame_queue,
  90. object_queue
  91. )
  92. prepped_queue_processor.start()
  93. # start a thread to store recent motion frames for processing
  94. frame_tracker = FrameTracker(frame_arr, shared_frame_time, frame_ready, frame_lock,
  95. recent_frames)
  96. frame_tracker.start()
  97. # start a thread to store the highest scoring recent person frame
  98. best_person_frame = BestPersonFrame(objects_parsed, recent_frames, DETECTED_OBJECTS)
  99. best_person_frame.start()
  100. # start a thread to parse objects from the queue
  101. object_parser = ObjectParser(object_queue, objects_parsed, DETECTED_OBJECTS, regions)
  102. object_parser.start()
  103. # start a thread to expire objects from the detected objects list
  104. object_cleaner = ObjectCleaner(objects_parsed, DETECTED_OBJECTS)
  105. object_cleaner.start()
  106. # connect to mqtt and setup last will
  107. def on_connect(client, userdata, flags, rc):
  108. print("On connect called")
  109. # publish a message to signal that the service is running
  110. client.publish(MQTT_TOPIC_PREFIX+'/available', 'online', retain=True)
  111. client = mqtt.Client()
  112. client.on_connect = on_connect
  113. client.will_set(MQTT_TOPIC_PREFIX+'/available', payload='offline', qos=1, retain=True)
  114. if not MQTT_USER is None:
  115. client.username_pw_set(MQTT_USER, password=MQTT_PASS)
  116. client.connect(MQTT_HOST, MQTT_PORT, 60)
  117. client.loop_start()
  118. # start a thread to publish object scores (currently only person)
  119. mqtt_publisher = MqttObjectPublisher(client, MQTT_TOPIC_PREFIX, objects_parsed, DETECTED_OBJECTS)
  120. mqtt_publisher.start()
  121. # start the process of capturing frames
  122. capture_process.start()
  123. print("capture_process pid ", capture_process.pid)
  124. # start the object detection prep threads
  125. for detection_prep_thread in detection_prep_threads:
  126. detection_prep_thread.start()
  127. # create a flask app that encodes frames a mjpeg on demand
  128. app = Flask(__name__)
  129. @app.route('/best_person.jpg')
  130. def best_person():
  131. frame = np.zeros(frame_shape, np.uint8) if best_person_frame.best_frame is None else best_person_frame.best_frame
  132. ret, jpg = cv2.imencode('.jpg', frame)
  133. response = make_response(jpg.tobytes())
  134. response.headers['Content-Type'] = 'image/jpg'
  135. return response
  136. @app.route('/')
  137. def index():
  138. # return a multipart response
  139. return Response(imagestream(),
  140. mimetype='multipart/x-mixed-replace; boundary=frame')
  141. def imagestream():
  142. while True:
  143. # max out at 5 FPS
  144. time.sleep(0.2)
  145. # make a copy of the current detected objects
  146. detected_objects = DETECTED_OBJECTS.copy()
  147. # lock and make a copy of the current frame
  148. with frame_lock:
  149. frame = frame_arr.copy()
  150. # convert to RGB for drawing
  151. frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  152. # draw the bounding boxes on the screen
  153. for obj in detected_objects:
  154. vis_util.draw_bounding_box_on_image_array(frame,
  155. obj['ymin'],
  156. obj['xmin'],
  157. obj['ymax'],
  158. obj['xmax'],
  159. color='red',
  160. thickness=2,
  161. display_str_list=["{}: {}%".format(obj['name'],int(obj['score']*100))],
  162. use_normalized_coordinates=False)
  163. for region in regions:
  164. color = (255,255,255)
  165. cv2.rectangle(frame, (region['x_offset'], region['y_offset']),
  166. (region['x_offset']+region['size'], region['y_offset']+region['size']),
  167. color, 2)
  168. # convert back to BGR
  169. frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
  170. # encode the image into a jpg
  171. ret, jpg = cv2.imencode('.jpg', frame)
  172. yield (b'--frame\r\n'
  173. b'Content-Type: image/jpeg\r\n\r\n' + jpg.tobytes() + b'\r\n\r\n')
  174. app.run(host='0.0.0.0', port=WEB_PORT, debug=False)
  175. capture_process.join()
  176. for detection_prep_thread in detection_prep_threads:
  177. detection_prep_thread.join()
  178. frame_tracker.join()
  179. best_person_frame.join()
  180. object_parser.join()
  181. object_cleaner.join()
  182. mqtt_publisher.join()
  183. if __name__ == '__main__':
  184. main()