stt_server.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. from .install_packages import check_and_install_packages
  2. check_and_install_packages([
  3. {
  4. 'module_name': 'RealtimeSTT', # Import module
  5. 'attribute': 'AudioToTextRecorder', # Specific class to check
  6. 'install_name': 'RealtimeSTT', # Package name for pip install
  7. },
  8. {
  9. 'module_name': 'websockets', # Import module
  10. 'install_name': 'websockets', # Package name for pip install
  11. },
  12. {
  13. 'module_name': 'numpy', # Import module
  14. 'install_name': 'numpy', # Package name for pip install
  15. },
  16. {
  17. 'module_name': 'scipy.signal', # Submodule of scipy
  18. 'attribute': 'resample', # Specific function to check
  19. 'install_name': 'scipy', # Package name for pip install
  20. }
  21. ])
  22. print("Starting server, please wait...")
  23. import asyncio
  24. import threading
  25. import json
  26. import websockets
  27. from RealtimeSTT import AudioToTextRecorder
  28. import numpy as np
  29. from scipy.signal import resample
  30. global_args = None
  31. recorder = None
  32. recorder_config = {}
  33. recorder_ready = threading.Event()
  34. client_websocket = None
  35. stop_recorder = False
  36. prev_text = ""
  37. async def send_to_client(message):
  38. global client_websocket
  39. if client_websocket and client_websocket.open:
  40. try:
  41. await client_websocket.send(message)
  42. except websockets.exceptions.ConnectionClosed:
  43. print("Client websocket is closed, resetting client_websocket")
  44. client_websocket = None
  45. else:
  46. print("No client connected or connection is closed.")
  47. client_websocket = None # Ensure it resets
  48. def preprocess_text(text):
  49. # Remove leading whitespaces
  50. text = text.lstrip()
  51. # Remove starting ellipses if present
  52. if text.startswith("..."):
  53. text = text[3:]
  54. # Remove any leading whitespaces again after ellipses removal
  55. text = text.lstrip()
  56. # Uppercase the first letter
  57. if text:
  58. text = text[0].upper() + text[1:]
  59. return text
  60. def text_detected(text):
  61. global prev_text
  62. text = preprocess_text(text)
  63. sentence_end_marks = ['.', '!', '?', '。']
  64. if text.endswith("..."):
  65. recorder.post_speech_silence_duration = global_args.mid_sentence_detection_pause
  66. elif text and text[-1] in sentence_end_marks and prev_text and prev_text[-1] in sentence_end_marks:
  67. recorder.post_speech_silence_duration = global_args.end_of_sentence_detection_pause
  68. else:
  69. recorder.post_speech_silence_duration = global_args.unknown_sentence_detection_pause
  70. prev_text = text
  71. try:
  72. asyncio.new_event_loop().run_until_complete(
  73. send_to_client(
  74. json.dumps({
  75. 'type': 'realtime',
  76. 'text': text
  77. })
  78. )
  79. )
  80. except Exception as e:
  81. print(f"Error in text_detected while sending to client: {e}")
  82. print(f"\r{text}", flush=True, end='')
  83. # Define the server's arguments
  84. def parse_arguments():
  85. import argparse
  86. parser = argparse.ArgumentParser(description='Start the Speech-to-Text (STT) server with various configuration options.')
  87. parser.add_argument('--model', type=str, default='medium.en',
  88. help='Path to the STT model or model size. Options: tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2 or any hugginface CTranslate2 stt model like deepdml/faster-whisper-large-v3-turbo-ct2. Default: medium.en')
  89. parser.add_argument('--realtime_model_type', type=str, default='tiny.en',
  90. help='Model size for real-time transcription. Same options as --model. Used only if real-time transcription is enabled. Default: tiny.en')
  91. parser.add_argument('--language', type=str, default='en',
  92. help='Language code for the STT model. Leave empty for auto-detection. Default: en')
  93. parser.add_argument('--input_device_index', type=int, default=1,
  94. help='Index of the audio input device to use. Default: 1')
  95. parser.add_argument('--silero_sensitivity', type=float, default=0.05,
  96. help='Sensitivity for Silero Voice Activity Detection (0 to 1). Lower values are less sensitive. Default: 0.05')
  97. parser.add_argument('--webrtc_sensitivity', type=float, default=3,
  98. help='Sensitivity for WebRTC Voice Activity Detection (0 to 3). Higher values are less sensitive. Default: 3')
  99. parser.add_argument('--min_length_of_recording', type=float, default=1.1,
  100. help='Minimum duration (in seconds) for a valid recording. Prevents excessively short recordings. Default: 1.1')
  101. parser.add_argument('--min_gap_between_recordings', type=float, default=0,
  102. help='Minimum time (in seconds) between consecutive recordings. Prevents rapid successive recordings. Default: 0')
  103. parser.add_argument('--enable_realtime_transcription', action='store_true', default=True,
  104. help='Enable continuous real-time transcription of audio. Default: True')
  105. parser.add_argument('--realtime_processing_pause', type=float, default=0.02,
  106. help='Time interval (in seconds) between processing audio chunks for real-time transcription. Lower values increase responsiveness but may increase CPU load. Default: 0.02')
  107. parser.add_argument('--silero_deactivity_detection', action='store_true', default=True,
  108. help='Use Silero model for end-of-speech detection. More robust against background noise but uses more GPU resources. Default: True')
  109. parser.add_argument('--early_transcription_on_silence', type=float, default=0.2,
  110. help='Start transcription after specified seconds of silence. Should be lower than post_speech_silence_duration. Set to 0 to disable. Default: 0.2')
  111. parser.add_argument('--beam_size', type=int, default=5,
  112. help='Beam size for the main transcription model. Larger values may improve accuracy but increase processing time. Default: 5')
  113. parser.add_argument('--beam_size_realtime', type=int, default=3,
  114. help='Beam size for the real-time transcription model. Smaller than main beam_size for faster processing. Default: 3')
  115. parser.add_argument('--initial_prompt', type=str,
  116. default='Add periods only for complete sentences. Use ellipsis (...) for unfinished thoughts or unclear endings. Examples: \n- Complete: "I went to the store."\n- Incomplete: "I think it was..."',
  117. help='Initial prompt for the transcription model to guide its output format and style. Default provides instructions for sentence completion and ellipsis usage.')
  118. parser.add_argument('--end_of_sentence_detection_pause', type=float, default=0.45,
  119. help='Duration of pause (in seconds) to consider as end of a sentence. Default: 0.45')
  120. parser.add_argument('--unknown_sentence_detection_pause', type=float, default=0.7,
  121. help='Duration of pause (in seconds) to consider as an unknown or incomplete sentence. Default: 0.7')
  122. parser.add_argument('--mid_sentence_detection_pause', type=float, default=2.0,
  123. help='Duration of pause (in seconds) to consider as a mid-sentence break. Default: 2.0')
  124. return parser.parse_args()
  125. def _recorder_thread():
  126. global recorder, prev_text, stop_recorder
  127. # print("Initializing RealtimeSTT...")
  128. print(f"Initializing RealtimeSTT server with parameters {recorder_config}")
  129. recorder = AudioToTextRecorder(**recorder_config)
  130. print("RealtimeSTT initialized")
  131. recorder_ready.set()
  132. def process_text(full_sentence):
  133. full_sentence = preprocess_text(full_sentence)
  134. prev_text = ""
  135. try:
  136. asyncio.new_event_loop().run_until_complete(
  137. send_to_client(
  138. json.dumps({
  139. 'type': 'fullSentence',
  140. 'text': full_sentence
  141. })
  142. )
  143. )
  144. except Exception as e:
  145. print(f"Error in _recorder_thread while sending to client: {e}")
  146. print(f"\rSentence: {full_sentence}")
  147. try:
  148. while not stop_recorder:
  149. recorder.text(process_text)
  150. except KeyboardInterrupt:
  151. print("Exiting application due to keyboard interrupt")
  152. def decode_and_resample(
  153. audio_data,
  154. original_sample_rate,
  155. target_sample_rate):
  156. # Decode 16-bit PCM data to numpy array
  157. if original_sample_rate == target_sample_rate:
  158. return audio_data
  159. audio_np = np.frombuffer(audio_data, dtype=np.int16)
  160. # Calculate the number of samples after resampling
  161. num_original_samples = len(audio_np)
  162. num_target_samples = int(num_original_samples * target_sample_rate /
  163. original_sample_rate)
  164. # Resample the audio
  165. resampled_audio = resample(audio_np, num_target_samples)
  166. return resampled_audio.astype(np.int16).tobytes()
  167. async def echo(websocket, path):
  168. print("Client connected")
  169. global client_websocket
  170. client_websocket = websocket
  171. recorder.post_speech_silence_duration = global_args.unknown_sentence_detection_pause
  172. try:
  173. async for message in websocket:
  174. if not recorder_ready.is_set():
  175. print("Recorder not ready")
  176. continue
  177. metadata_length = int.from_bytes(message[:4], byteorder='little')
  178. metadata_json = message[4:4+metadata_length].decode('utf-8')
  179. metadata = json.loads(metadata_json)
  180. sample_rate = metadata['sampleRate']
  181. chunk = message[4+metadata_length:]
  182. resampled_chunk = decode_and_resample(chunk, sample_rate, 16000)
  183. recorder.feed_audio(resampled_chunk)
  184. except websockets.exceptions.ConnectionClosed as e:
  185. print(f"Client disconnected: {e}")
  186. finally:
  187. print("Resetting client_websocket after disconnect")
  188. client_websocket = None # Reset websocket reference
  189. async def main_async():
  190. global stop_recorder, recorder_config, global_args
  191. args = parse_arguments()
  192. global_args = args
  193. recorder_config = {
  194. 'model': args.model,
  195. 'realtime_model_type': args.realtime_model_type,
  196. 'language': args.language,
  197. 'input_device_index': args.input_device_index,
  198. 'silero_sensitivity': args.silero_sensitivity,
  199. 'webrtc_sensitivity': args.webrtc_sensitivity,
  200. 'post_speech_silence_duration': args.unknown_sentence_detection_pause,
  201. 'min_length_of_recording': args.min_length_of_recording,
  202. 'min_gap_between_recordings': args.min_gap_between_recordings,
  203. 'enable_realtime_transcription': args.enable_realtime_transcription,
  204. 'realtime_processing_pause': args.realtime_processing_pause,
  205. 'silero_deactivity_detection': args.silero_deactivity_detection,
  206. 'early_transcription_on_silence': args.early_transcription_on_silence,
  207. 'beam_size': args.beam_size,
  208. 'beam_size_realtime': args.beam_size_realtime,
  209. 'initial_prompt': args.initial_prompt,
  210. 'spinner': False,
  211. 'use_microphone': False,
  212. 'on_realtime_transcription_update': text_detected,
  213. 'no_log_file': True,
  214. }
  215. start_server = await websockets.serve(echo, "localhost", 8011)
  216. recorder_thread = threading.Thread(target=_recorder_thread)
  217. recorder_thread.start()
  218. recorder_ready.wait()
  219. print("Server started. Press Ctrl+C to stop the server.")
  220. try:
  221. await start_server.wait_closed() # This will keep the server running
  222. except KeyboardInterrupt:
  223. print("Shutting down gracefully...")
  224. finally:
  225. # Shut down the recorder
  226. if recorder:
  227. stop_recorder = True
  228. recorder.abort()
  229. recorder.stop()
  230. recorder.shutdown()
  231. print("Recorder shut down")
  232. recorder_thread.join()
  233. print("Recorder thread finished")
  234. # Cancel all active tasks in the event loop
  235. tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
  236. for task in tasks:
  237. task.cancel()
  238. # Run pending tasks and handle cancellation
  239. await asyncio.gather(*tasks, return_exceptions=True)
  240. print("All tasks cancelled, closing event loop now.")
  241. def main():
  242. try:
  243. asyncio.run(main_async())
  244. except KeyboardInterrupt:
  245. # Capture any final KeyboardInterrupt to prevent it from showing up in logs
  246. print("Server interrupted by user.")
  247. exit(0)
  248. if __name__ == '__main__':
  249. main()