stt_server.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. """
  2. Speech-to-Text (STT) Server with Real-Time Transcription and WebSocket Interface
  3. This server provides real-time speech-to-text (STT) transcription using the RealtimeSTT library. It allows clients to connect via WebSocket to send audio data and receive real-time transcription updates. The server supports configurable audio recording parameters, voice activity detection (VAD), and wake word detection. It is designed to handle continuous transcription as well as post-recording processing, enabling real-time feedback with the option to improve final transcription quality after the complete sentence is recognized.
  4. ### Features:
  5. - Real-time transcription using pre-configured or user-defined STT models.
  6. - WebSocket-based communication for control and data handling.
  7. - Flexible recording and transcription options, including configurable pauses for sentence detection.
  8. - Supports Silero and WebRTC VAD for robust voice activity detection.
  9. ### Starting the Server:
  10. You can start the server using the command-line interface (CLI) command `stt-server`, passing the desired configuration options.
  11. ```bash
  12. stt-server [OPTIONS]
  13. ```
  14. ### Available Parameters:
  15. - `--model` (str, default: 'medium.en'): Path to the STT model or model size. Options: tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, or any huggingface CTranslate2 STT model like `deepdml/faster-whisper-large-v3-turbo-ct2`.
  16. - `--realtime_model_type` (str, default: 'tiny.en'): Model size for real-time transcription. Same options as `--model`.
  17. - `--language` (str, default: 'en'): Language code for the STT model. Leave empty for auto-detection.
  18. - `--input_device_index` (int, default: 1): Index of the audio input device to use.
  19. - `--silero_sensitivity` (float, default: 0.05): Sensitivity for Silero Voice Activity Detection (VAD). Lower values are less sensitive.
  20. - `--webrtc_sensitivity` (int, default: 3): Sensitivity for WebRTC VAD. Higher values are less sensitive.
  21. - `--min_length_of_recording` (float, default: 1.1): Minimum duration (in seconds) for a valid recording. Prevents short recordings.
  22. - `--min_gap_between_recordings` (float, default: 0): Minimum time (in seconds) between consecutive recordings.
  23. - `--enable_realtime_transcription` (flag, default: True): Enable real-time transcription of audio.
  24. - `--realtime_processing_pause` (float, default: 0.02): Time interval (in seconds) between processing audio chunks for real-time transcription. Lower values increase responsiveness.
  25. - `--silero_deactivity_detection` (flag, default: True): Use Silero model for end-of-speech detection.
  26. - `--early_transcription_on_silence` (float, default: 0.2): Start transcription after specified seconds of silence.
  27. - `--beam_size` (int, default: 5): Beam size for the main transcription model.
  28. - `--beam_size_realtime` (int, default: 3): Beam size for the real-time transcription model.
  29. - `--initial_prompt` (str, default: '...'): Initial prompt for the transcription model to guide its output format and style.
  30. - `--end_of_sentence_detection_pause` (float, default: 0.45): Duration of pause (in seconds) to consider as the end of a sentence.
  31. - `--unknown_sentence_detection_pause` (float, default: 0.7): Duration of pause (in seconds) to consider as an unknown or incomplete sentence.
  32. - `--mid_sentence_detection_pause` (float, default: 2.0): Duration of pause (in seconds) to consider as a mid-sentence break.
  33. - `--control_port` (int, default: 8011): Port for the control WebSocket connection.
  34. - `--data_port` (int, default: 8012): Port for the data WebSocket connection.
  35. ### WebSocket Interface:
  36. The server supports two WebSocket connections:
  37. 1. **Control WebSocket**: Used to send and receive commands, such as setting parameters or calling recorder methods.
  38. 2. **Data WebSocket**: Used to send audio data for transcription and receive real-time transcription updates.
  39. The server will broadcast real-time transcription updates to all connected clients on the data WebSocket.
  40. """
  41. import asyncio
  42. import sys
  43. if sys.platform == 'win32':
  44. asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
  45. from .install_packages import check_and_install_packages
  46. check_and_install_packages([
  47. {
  48. 'module_name': 'RealtimeSTT', # Import module
  49. 'attribute': 'AudioToTextRecorder', # Specific class to check
  50. 'install_name': 'RealtimeSTT', # Package name for pip install
  51. },
  52. {
  53. 'module_name': 'websockets', # Import module
  54. 'install_name': 'websockets', # Package name for pip install
  55. },
  56. {
  57. 'module_name': 'numpy', # Import module
  58. 'install_name': 'numpy', # Package name for pip install
  59. },
  60. {
  61. 'module_name': 'scipy.signal', # Submodule of scipy
  62. 'attribute': 'resample', # Specific function to check
  63. 'install_name': 'scipy', # Package name for pip install
  64. }
  65. ])
  66. print("Starting server, please wait...")
  67. import threading
  68. import json
  69. import websockets
  70. from RealtimeSTT import AudioToTextRecorder
  71. import numpy as np
  72. from scipy.signal import resample
  73. global_args = None
  74. recorder = None
  75. recorder_config = {}
  76. recorder_ready = threading.Event()
  77. stop_recorder = False
  78. prev_text = ""
  79. # Define allowed methods and parameters for security
  80. allowed_methods = [
  81. 'set_microphone',
  82. 'abort',
  83. 'stop',
  84. 'clear_audio_queue',
  85. 'wakeup',
  86. 'shutdown',
  87. 'text',
  88. ]
  89. allowed_parameters = [
  90. 'silero_sensitivity',
  91. 'wake_word_activation_delay',
  92. 'post_speech_silence_duration',
  93. 'listen_start',
  94. 'recording_stop_time',
  95. 'last_transcription_bytes',
  96. 'last_transcription_bytes_b64',
  97. ]
  98. # Queues and connections for control and data
  99. control_connections = set()
  100. data_connections = set()
  101. control_queue = asyncio.Queue()
  102. audio_queue = asyncio.Queue()
  103. def preprocess_text(text):
  104. # Remove leading whitespaces
  105. text = text.lstrip()
  106. # Remove starting ellipses if present
  107. if text.startswith("..."):
  108. text = text[3:]
  109. # Remove any leading whitespaces again after ellipses removal
  110. text = text.lstrip()
  111. # Uppercase the first letter
  112. if text:
  113. text = text[0].upper() + text[1:]
  114. return text
  115. def text_detected(text, loop):
  116. global prev_text
  117. text = preprocess_text(text)
  118. sentence_end_marks = ['.', '!', '?', '。']
  119. if text.endswith("..."):
  120. recorder.post_speech_silence_duration = global_args.mid_sentence_detection_pause
  121. elif text and text[-1] in sentence_end_marks and prev_text and prev_text[-1] in sentence_end_marks:
  122. recorder.post_speech_silence_duration = global_args.end_of_sentence_detection_pause
  123. else:
  124. recorder.post_speech_silence_duration = global_args.unknown_sentence_detection_pause
  125. prev_text = text
  126. # Put the message in the audio queue to be sent to clients
  127. message = json.dumps({
  128. 'type': 'realtime',
  129. 'text': text
  130. })
  131. asyncio.run_coroutine_threadsafe(audio_queue.put(message), loop)
  132. print(f"\r{text}", flush=True, end='')
  133. def on_recording_start(loop):
  134. # Send a message to the client indicating recording has started
  135. message = json.dumps({
  136. 'type': 'recording_start'
  137. })
  138. asyncio.run_coroutine_threadsafe(audio_queue.put(message), loop)
  139. def on_recording_stop(loop):
  140. # Send a message to the client indicating recording has stopped
  141. message = json.dumps({
  142. 'type': 'recording_stop'
  143. })
  144. asyncio.run_coroutine_threadsafe(audio_queue.put(message), loop)
  145. def on_vad_detect_start(loop):
  146. message = json.dumps({
  147. 'type': 'vad_detect_start'
  148. })
  149. asyncio.run_coroutine_threadsafe(audio_queue.put(message), loop)
  150. def on_wakeword_detection_start(loop):
  151. # Send a message to the client when wake word detection starts
  152. message = json.dumps({
  153. 'type': 'wakeword_detection_start'
  154. })
  155. asyncio.run_coroutine_threadsafe(audio_queue.put(message), loop)
  156. def on_wakeword_detection_end(loop):
  157. # Send a message to the client when wake word detection ends
  158. message = json.dumps({
  159. 'type': 'wakeword_detection_end'
  160. })
  161. asyncio.run_coroutine_threadsafe(audio_queue.put(message), loop)
  162. def on_transcription_start(loop):
  163. # Send a message to the client when transcription starts
  164. message = json.dumps({
  165. 'type': 'transcription_start'
  166. })
  167. asyncio.run_coroutine_threadsafe(audio_queue.put(message), loop)
  168. def on_realtime_transcription_update(text, loop):
  169. # Send real-time transcription updates to the client
  170. text = preprocess_text(text)
  171. message = json.dumps({
  172. 'type': 'realtime_update',
  173. 'text': text
  174. })
  175. asyncio.run_coroutine_threadsafe(audio_queue.put(message), loop)
  176. def on_recorded_chunk(chunk):
  177. # Process each recorded audio chunk (optional implementation)
  178. pass
  179. # Define the server's arguments
  180. def parse_arguments():
  181. import argparse
  182. parser = argparse.ArgumentParser(description='Start the Speech-to-Text (STT) server with various configuration options.')
  183. parser.add_argument('--model', type=str, default='large-v2',
  184. help='Path to the STT model or model size. Options: tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2 or any hugginface CTranslate2 stt model like deepdml/faster-whisper-large-v3-turbo-ct2. Default: medium.en')
  185. parser.add_argument('--realtime_model_type', type=str, default='tiny.en',
  186. help='Model size for real-time transcription. Same options as --model. Used only if real-time transcription is enabled. Default: tiny.en')
  187. parser.add_argument('--language', type=str, default='en',
  188. help='Language code for the STT model. Leave empty for auto-detection. Default: en')
  189. parser.add_argument('--input_device_index', type=int, default=1,
  190. help='Index of the audio input device to use. Default: 1')
  191. parser.add_argument('--silero_sensitivity', type=float, default=0.05,
  192. help='Sensitivity for Silero Voice Activity Detection (0 to 1). Lower values are less sensitive. Default: 0.05')
  193. parser.add_argument('--webrtc_sensitivity', type=int, default=3,
  194. help='Sensitivity for WebRTC Voice Activity Detection (0 to 3). Higher values are less sensitive. Default: 3')
  195. parser.add_argument('--min_length_of_recording', type=float, default=1.1,
  196. help='Minimum duration (in seconds) for a valid recording. Prevents excessively short recordings. Default: 1.1')
  197. parser.add_argument('--min_gap_between_recordings', type=float, default=0,
  198. help='Minimum time (in seconds) between consecutive recordings. Prevents rapid successive recordings. Default: 0')
  199. parser.add_argument('--enable_realtime_transcription', action='store_true', default=True,
  200. help='Enable continuous real-time transcription of audio. Default: True')
  201. parser.add_argument('--realtime_processing_pause', type=float, default=0.02,
  202. help='Time interval (in seconds) between processing audio chunks for real-time transcription. Lower values increase responsiveness but may increase CPU load. Default: 0.02')
  203. parser.add_argument('--silero_deactivity_detection', action='store_true', default=True,
  204. help='Use Silero model for end-of-speech detection. More robust against background noise but uses more GPU resources. Default: True')
  205. parser.add_argument('--early_transcription_on_silence', type=float, default=0.2,
  206. help='Start transcription after specified seconds of silence. Should be lower than post_speech_silence_duration. Set to 0 to disable. Default: 0.2')
  207. parser.add_argument('--beam_size', type=int, default=5,
  208. help='Beam size for the main transcription model. Larger values may improve accuracy but increase processing time. Default: 5')
  209. parser.add_argument('--beam_size_realtime', type=int, default=3,
  210. help='Beam size for the real-time transcription model. Smaller than main beam_size for faster processing. Default: 3')
  211. parser.add_argument('--initial_prompt', type=str,
  212. default='End incomplete sentences with ellipses.\nExamples:\nComplete: The sky is blue.\nIncomplete: When the sky...\nComplete: She walked home.\nIncomplete: Because he...',
  213. help='Initial prompt for the transcription model to guide its output format and style. Default provides instructions for sentence completion and ellipsis usage.')
  214. parser.add_argument('--end_of_sentence_detection_pause', type=float, default=0.45,
  215. help='Duration of pause (in seconds) to consider as end of a sentence. Default: 0.45')
  216. parser.add_argument('--unknown_sentence_detection_pause', type=float, default=0.7,
  217. help='Duration of pause (in seconds) to consider as an unknown or incomplete sentence. Default: 0.7')
  218. parser.add_argument('--mid_sentence_detection_pause', type=float, default=2.0,
  219. help='Duration of pause (in seconds) to consider as a mid-sentence break. Default: 2.0')
  220. parser.add_argument('--control_port', type=int, default=8011,
  221. help='Port for the control WebSocket connection. Default: 8011')
  222. parser.add_argument('--data_port', type=int, default=8012,
  223. help='Port for the data WebSocket connection. Default: 8012')
  224. return parser.parse_args()
  225. def _recorder_thread(loop):
  226. global recorder, prev_text, stop_recorder
  227. print(f"Initializing RealtimeSTT server with parameters {recorder_config}")
  228. recorder = AudioToTextRecorder(**recorder_config)
  229. print("RealtimeSTT initialized")
  230. recorder_ready.set()
  231. def process_text(full_sentence):
  232. full_sentence = preprocess_text(full_sentence)
  233. message = json.dumps({
  234. 'type': 'fullSentence',
  235. 'text': full_sentence
  236. })
  237. # Use the passed event loop here
  238. asyncio.run_coroutine_threadsafe(audio_queue.put(message), loop)
  239. print(f"\rSentence: {full_sentence}")
  240. try:
  241. while not stop_recorder:
  242. recorder.text(process_text)
  243. except KeyboardInterrupt:
  244. print("Exiting application due to keyboard interrupt")
  245. def decode_and_resample(
  246. audio_data,
  247. original_sample_rate,
  248. target_sample_rate):
  249. # Decode 16-bit PCM data to numpy array
  250. if original_sample_rate == target_sample_rate:
  251. return audio_data
  252. audio_np = np.frombuffer(audio_data, dtype=np.int16)
  253. # Calculate the number of samples after resampling
  254. num_original_samples = len(audio_np)
  255. num_target_samples = int(num_original_samples * target_sample_rate /
  256. original_sample_rate)
  257. # Resample the audio
  258. resampled_audio = resample(audio_np, num_target_samples)
  259. return resampled_audio.astype(np.int16).tobytes()
  260. async def control_handler(websocket, path):
  261. print("Control client connected")
  262. global recorder
  263. control_connections.add(websocket)
  264. try:
  265. async for message in websocket:
  266. if not recorder_ready.is_set():
  267. print("Recorder not ready")
  268. continue
  269. if isinstance(message, str):
  270. # Handle text message (command)
  271. try:
  272. command_data = json.loads(message)
  273. command = command_data.get("command")
  274. if command == "set_parameter":
  275. parameter = command_data.get("parameter")
  276. value = command_data.get("value")
  277. if parameter in allowed_parameters and hasattr(recorder, parameter):
  278. setattr(recorder, parameter, value)
  279. print(f"Set recorder.{parameter} to {value}")
  280. # Optionally send a response back to the client
  281. await websocket.send(json.dumps({"status": "success", "message": f"Parameter {parameter} set to {value}"}))
  282. else:
  283. if not parameter in allowed_parameters:
  284. print(f"Parameter {parameter} is not allowed (set_parameter)")
  285. await websocket.send(json.dumps({"status": "error", "message": f"Parameter {parameter} is not allowed (set_parameter)"}))
  286. else:
  287. print(f"Parameter {parameter} does not exist (set_parameter)")
  288. await websocket.send(json.dumps({"status": "error", "message": f"Parameter {parameter} does not exist (set_parameter)"}))
  289. elif command == "get_parameter":
  290. parameter = command_data.get("parameter")
  291. if parameter in allowed_parameters and hasattr(recorder, parameter):
  292. value = getattr(recorder, parameter)
  293. await websocket.send(json.dumps({"status": "success", "parameter": parameter, "value": value}))
  294. else:
  295. if not parameter in allowed_parameters:
  296. print(f"Parameter {parameter} is not allowed (get_parameter)")
  297. await websocket.send(json.dumps({"status": "error", "message": f"Parameter {parameter} is not allowed (get_parameter)"}))
  298. else:
  299. print(f"Parameter {parameter} does not exist (get_parameter)")
  300. await websocket.send(json.dumps({"status": "error", "message": f"Parameter {parameter} does not exist (get_parameter)"}))
  301. elif command == "call_method":
  302. method_name = command_data.get("method")
  303. if method_name in allowed_methods:
  304. method = getattr(recorder, method_name, None)
  305. if method and callable(method):
  306. args = command_data.get("args", [])
  307. kwargs = command_data.get("kwargs", {})
  308. method(*args, **kwargs)
  309. print(f"Called method recorder.{method_name}")
  310. await websocket.send(json.dumps({"status": "success", "message": f"Method {method_name} called"}))
  311. else:
  312. print(f"Recorder does not have method {method_name}")
  313. await websocket.send(json.dumps({"status": "error", "message": f"Recorder does not have method {method_name}"}))
  314. else:
  315. print(f"Method {method_name} is not allowed")
  316. await websocket.send(json.dumps({"status": "error", "message": f"Method {method_name} is not allowed"}))
  317. else:
  318. print(f"Unknown command: {command}")
  319. await websocket.send(json.dumps({"status": "error", "message": f"Unknown command {command}"}))
  320. except json.JSONDecodeError:
  321. print("Received invalid JSON command")
  322. await websocket.send(json.dumps({"status": "error", "message": "Invalid JSON command"}))
  323. else:
  324. print("Received unknown message type on control connection")
  325. except websockets.exceptions.ConnectionClosed as e:
  326. print(f"Control client disconnected: {e}")
  327. finally:
  328. control_connections.remove(websocket)
  329. async def data_handler(websocket, path):
  330. print("Data client connected")
  331. data_connections.add(websocket)
  332. try:
  333. while True:
  334. message = await websocket.recv()
  335. if isinstance(message, bytes):
  336. # Handle binary message (audio data)
  337. metadata_length = int.from_bytes(message[:4], byteorder='little')
  338. metadata_json = message[4:4+metadata_length].decode('utf-8')
  339. metadata = json.loads(metadata_json)
  340. sample_rate = metadata['sampleRate']
  341. chunk = message[4+metadata_length:]
  342. resampled_chunk = decode_and_resample(chunk, sample_rate, 16000)
  343. recorder.feed_audio(resampled_chunk)
  344. else:
  345. print("Received non-binary message on data connection")
  346. except websockets.exceptions.ConnectionClosed as e:
  347. print(f"Data client disconnected: {e}")
  348. finally:
  349. data_connections.remove(websocket)
  350. recorder.clear_audio_queue() # Ensure audio queue is cleared if client disconnects
  351. async def broadcast_audio_messages():
  352. while True:
  353. message = await audio_queue.get()
  354. for conn in list(data_connections):
  355. try:
  356. # print(f"Sending message: {message}")
  357. await conn.send(message)
  358. except websockets.exceptions.ConnectionClosed:
  359. data_connections.remove(conn)
  360. # Helper function to create event loop bound closures for callbacks
  361. def make_callback(loop, callback):
  362. def inner_callback(*args, **kwargs):
  363. callback(*args, **kwargs, loop=loop)
  364. return inner_callback
  365. async def main_async():
  366. global stop_recorder, recorder_config, global_args
  367. args = parse_arguments()
  368. global_args = args
  369. # Get the event loop here and pass it to the recorder thread
  370. loop = asyncio.get_event_loop()
  371. recorder_config = {
  372. 'model': args.model,
  373. 'realtime_model_type': args.realtime_model_type,
  374. 'language': args.language,
  375. 'input_device_index': args.input_device_index,
  376. 'silero_sensitivity': args.silero_sensitivity,
  377. 'webrtc_sensitivity': args.webrtc_sensitivity,
  378. 'post_speech_silence_duration': args.unknown_sentence_detection_pause,
  379. 'min_length_of_recording': args.min_length_of_recording,
  380. 'min_gap_between_recordings': args.min_gap_between_recordings,
  381. 'enable_realtime_transcription': args.enable_realtime_transcription,
  382. 'realtime_processing_pause': args.realtime_processing_pause,
  383. 'silero_deactivity_detection': args.silero_deactivity_detection,
  384. 'early_transcription_on_silence': args.early_transcription_on_silence,
  385. 'beam_size': args.beam_size,
  386. 'beam_size_realtime': args.beam_size_realtime,
  387. 'initial_prompt': args.initial_prompt,
  388. 'spinner': False,
  389. 'use_microphone': False,
  390. 'on_realtime_transcription_update': make_callback(loop, text_detected),
  391. 'on_recording_start': make_callback(loop, on_recording_start),
  392. 'on_recording_stop': make_callback(loop, on_recording_stop),
  393. 'on_vad_detect_start': make_callback(loop, on_vad_detect_start),
  394. 'on_wakeword_detection_start': make_callback(loop, on_wakeword_detection_start),
  395. 'on_wakeword_detection_end': make_callback(loop, on_wakeword_detection_end),
  396. 'on_transcription_start': make_callback(loop, on_transcription_start),
  397. 'no_log_file': True,
  398. }
  399. control_server = await websockets.serve(control_handler, "localhost", args.control_port)
  400. data_server = await websockets.serve(data_handler, "localhost", args.data_port)
  401. print(f"Control server started on ws://localhost:{args.control_port}")
  402. print(f"Data server started on ws://localhost:{args.data_port}")
  403. # Task to broadcast audio messages
  404. broadcast_task = asyncio.create_task(broadcast_audio_messages())
  405. recorder_thread = threading.Thread(target=_recorder_thread, args=(loop,))
  406. recorder_thread.start()
  407. recorder_ready.wait()
  408. print("Server started. Press Ctrl+C to stop the server.")
  409. try:
  410. await asyncio.gather(control_server.wait_closed(), data_server.wait_closed(), broadcast_task)
  411. except KeyboardInterrupt:
  412. print("Shutting down gracefully...")
  413. finally:
  414. # Shut down the recorder
  415. if recorder:
  416. stop_recorder = True
  417. recorder.abort()
  418. recorder.stop()
  419. recorder.shutdown()
  420. print("Recorder shut down")
  421. recorder_thread.join()
  422. print("Recorder thread finished")
  423. # Cancel all active tasks in the event loop
  424. tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
  425. for task in tasks:
  426. task.cancel()
  427. # Run pending tasks and handle cancellation
  428. await asyncio.gather(*tasks, return_exceptions=True)
  429. print("All tasks cancelled, closing event loop now.")
  430. def main():
  431. try:
  432. asyncio.run(main_async())
  433. except KeyboardInterrupt:
  434. # Capture any final KeyboardInterrupt to prevent it from showing up in logs
  435. print("Server interrupted by user.")
  436. exit(0)
  437. if __name__ == '__main__':
  438. main()