stt_cli_client.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. from urllib.parse import urlparse
  2. from scipy import signal
  3. from queue import Queue
  4. import numpy as np
  5. import subprocess
  6. import threading
  7. import websocket
  8. import argparse
  9. import pyaudio
  10. import logging
  11. import struct
  12. import socket
  13. import shutil
  14. import queue
  15. import json
  16. import time
  17. import sys
  18. import os
  19. os.environ['ALSA_LOG_LEVEL'] = 'none'
  20. # Constants
  21. CHUNK = 1024
  22. FORMAT = pyaudio.paInt16
  23. CHANNELS = 1
  24. RATE = 44100
  25. DEFAULT_CONTROL_URL = "ws://127.0.0.1:8011"
  26. DEFAULT_DATA_URL = "ws://127.0.0.1:8012"
  27. # Initialize colorama
  28. from colorama import init, Fore, Style
  29. init()
  30. # Stop websocket from spamming the log
  31. websocket.enableTrace(False)
  32. class STTWebSocketClient:
  33. def __init__(self, control_url, data_url, debug=False, file_output=None, norealtime=False):
  34. self.control_url = control_url
  35. self.data_url = data_url
  36. self.control_ws = None
  37. self.data_ws_app = None
  38. self.data_ws_connected = None # WebSocket object that will be used for sending
  39. self.is_running = True
  40. self.debug = debug
  41. self.file_output = file_output
  42. self.last_text = ""
  43. self.console_width = shutil.get_terminal_size().columns
  44. self.recording_indicator = "🔴"
  45. self.norealtime = norealtime
  46. self.connection_established = threading.Event()
  47. self.message_queue = Queue()
  48. self.commands = Queue()
  49. self.stop_event = threading.Event()
  50. # Audio attributes
  51. self.audio_interface = None
  52. self.stream = None
  53. self.device_sample_rate = None
  54. self.input_device_index = None
  55. # Threads
  56. self.control_ws_thread = None
  57. self.data_ws_thread = None
  58. self.recording_thread = None
  59. def debug_print(self, message):
  60. if self.debug:
  61. print(message, file=sys.stderr)
  62. def connect(self):
  63. if not self.ensure_server_running():
  64. self.debug_print("Cannot start STT server. Exiting.")
  65. return False
  66. try:
  67. # Connect to control WebSocket
  68. self.control_ws = websocket.WebSocketApp(self.control_url,
  69. on_message=self.on_control_message,
  70. on_error=self.on_error,
  71. on_close=self.on_close,
  72. on_open=self.on_control_open)
  73. self.control_ws_thread = threading.Thread(target=self.control_ws.run_forever)
  74. self.control_ws_thread.daemon = False # Set to False to ensure proper shutdown
  75. self.control_ws_thread.start()
  76. # Connect to data WebSocket
  77. self.data_ws_app = websocket.WebSocketApp(self.data_url,
  78. on_message=self.on_data_message,
  79. on_error=self.on_error,
  80. on_close=self.on_close,
  81. on_open=self.on_data_open)
  82. self.data_ws_thread = threading.Thread(target=self.data_ws_app.run_forever)
  83. self.data_ws_thread.daemon = False # Set to False to ensure proper shutdown
  84. self.data_ws_thread.start()
  85. # Wait for the connections to be established
  86. if not self.connection_established.wait(timeout=10):
  87. self.debug_print("Timeout while connecting to the server.")
  88. return False
  89. self.debug_print("WebSocket connections established successfully.")
  90. return True
  91. except Exception as e:
  92. self.debug_print(f"Error while connecting to the server: {e}")
  93. return False
  94. def on_control_open(self, ws):
  95. self.debug_print("Control WebSocket connection opened.")
  96. self.connection_established.set()
  97. self.start_command_processor()
  98. def on_data_open(self, ws):
  99. self.debug_print("Data WebSocket connection opened.")
  100. self.data_ws_connected = ws # Store the connected websocket object for sending data
  101. self.start_recording()
  102. def on_error(self, ws, error):
  103. self.debug_print(f"WebSocket error: {error}")
  104. def on_close(self, ws, close_status_code, close_msg):
  105. self.debug_print(f"WebSocket connection closed: {close_status_code} - {close_msg}")
  106. self.is_running = False
  107. self.stop_event.set()
  108. def is_server_running(self):
  109. parsed_url = urlparse(self.control_url)
  110. host = parsed_url.hostname
  111. port = parsed_url.port or 80
  112. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  113. return s.connect_ex((host, port)) == 0
  114. def ask_to_start_server(self):
  115. response = input("Would you like to start the STT server now? (y/n): ").strip().lower()
  116. return response == 'y' or response == 'yes'
  117. def start_server(self):
  118. if os.name == 'nt': # Windows
  119. subprocess.Popen('start /min cmd /c stt-server', shell=True)
  120. else: # Unix-like systems
  121. terminal_emulators = [
  122. 'gnome-terminal',
  123. 'x-terminal-emulator',
  124. 'konsole',
  125. 'xfce4-terminal',
  126. 'lxterminal',
  127. 'xterm',
  128. 'mate-terminal',
  129. 'terminator',
  130. 'tilix',
  131. 'alacritty',
  132. 'urxvt',
  133. 'eterm',
  134. 'rxvt',
  135. 'kitty',
  136. 'hyper'
  137. ]
  138. terminal = None
  139. for term in terminal_emulators:
  140. if shutil.which(term):
  141. terminal = term
  142. break
  143. if terminal:
  144. terminal_exec_options = {
  145. 'x-terminal-emulator': ['--'],
  146. 'gnome-terminal': ['--'],
  147. 'mate-terminal': ['--'],
  148. 'terminator': ['--'],
  149. 'tilix': ['--'],
  150. 'konsole': ['-e'],
  151. 'xfce4-terminal': ['-e'],
  152. 'lxterminal': ['-e'],
  153. 'alacritty': ['-e'],
  154. 'xterm': ['-e'],
  155. 'rxvt': ['-e'],
  156. 'urxvt': ['-e'],
  157. 'eterm': ['-e'],
  158. 'kitty': [],
  159. 'hyper': ['--command']
  160. }
  161. exec_option = terminal_exec_options.get(terminal, None)
  162. if exec_option is not None:
  163. subprocess.Popen([terminal] + exec_option + ['stt-server'], start_new_session=True)
  164. print(f"STT server started in a new terminal window using {terminal}.", file=sys.stderr)
  165. else:
  166. print(f"Unsupported terminal emulator '{terminal}'. Please start the STT server manually.", file=sys.stderr)
  167. else:
  168. print("No supported terminal emulator found. Please start the STT server manually.", file=sys.stderr)
  169. def ensure_server_running(self):
  170. if not self.is_server_running():
  171. print("STT server is not running.", file=sys.stderr)
  172. if self.ask_to_start_server():
  173. self.start_server()
  174. print("Waiting for STT server to start...", file=sys.stderr)
  175. for _ in range(20): # Wait up to 20 seconds
  176. if self.is_server_running():
  177. print("STT server started successfully.", file=sys.stderr)
  178. time.sleep(2) # Give the server a moment to fully initialize
  179. return True
  180. time.sleep(1)
  181. print("Failed to start STT server.", file=sys.stderr)
  182. return False
  183. else:
  184. print("STT server is required. Please start it manually.", file=sys.stderr)
  185. return False
  186. return True
  187. def on_control_message(self, ws, message):
  188. try:
  189. data = json.loads(message)
  190. if 'status' in data:
  191. if data['status'] == 'success':
  192. if 'parameter' in data and 'value' in data:
  193. print(f"Parameter {data['parameter']} = {data['value']}")
  194. elif data['status'] == 'error':
  195. print(f"Server Error: {data.get('message', '')}")
  196. else:
  197. self.debug_print(f"Unknown control message format: {data}")
  198. except json.JSONDecodeError:
  199. self.debug_print(f"Received non-JSON control message: {message}")
  200. except Exception as e:
  201. self.debug_print(f"Error processing control message: {e}")
  202. def on_data_message(self, ws, message):
  203. try:
  204. data = json.loads(message)
  205. message_type = data.get('type')
  206. if message_type == 'realtime':
  207. if data['text'] != self.last_text:
  208. self.last_text = data['text']
  209. if not self.norealtime:
  210. self.update_progress_bar(self.last_text)
  211. elif message_type == 'fullSentence':
  212. if self.file_output:
  213. sys.stderr.write('\r\033[K')
  214. sys.stderr.write(data['text'])
  215. sys.stderr.write('\n')
  216. sys.stderr.flush()
  217. print(data['text'], file=self.file_output)
  218. self.file_output.flush() # Ensure it's written immediately
  219. else:
  220. self.finish_progress_bar()
  221. print(f"{data['text']}")
  222. self.stop()
  223. elif message_type in {
  224. 'vad_detect_start',
  225. 'vad_detect_stop',
  226. 'recording_start',
  227. 'recording_stop',
  228. 'wakeword_detected',
  229. 'wakeword_detection_start',
  230. 'wakeword_detection_end',
  231. 'transcription_start'}:
  232. pass # Known message types, no action needed
  233. else:
  234. self.debug_print(f"Unknown data message format: {data}")
  235. except json.JSONDecodeError:
  236. self.debug_print(f"Received non-JSON data message: {message}")
  237. except Exception as e:
  238. self.debug_print(f"Error processing data message: {e}")
  239. def show_initial_indicator(self):
  240. if self.norealtime:
  241. return
  242. initial_text = f"{self.recording_indicator}\b\b"
  243. sys.stderr.write(initial_text)
  244. sys.stderr.flush()
  245. def update_progress_bar(self, text):
  246. try:
  247. available_width = self.console_width - 5 # Adjust for progress bar decorations
  248. sys.stderr.write('\r\033[K') # Clear the current line
  249. words = text.split()
  250. last_chars = ""
  251. for word in reversed(words):
  252. if len(last_chars) + len(word) + 1 > available_width:
  253. break
  254. last_chars = word + " " + last_chars
  255. last_chars = last_chars.strip()
  256. colored_text = f"{Fore.YELLOW}{last_chars}{Style.RESET_ALL}{self.recording_indicator}\b\b"
  257. sys.stderr.write(colored_text)
  258. sys.stderr.flush()
  259. except Exception as e:
  260. self.debug_print(f"Error updating progress bar: {e}")
  261. def finish_progress_bar(self):
  262. try:
  263. sys.stderr.write('\r\033[K')
  264. sys.stderr.flush()
  265. except Exception as e:
  266. self.debug_print(f"Error finishing progress bar: {e}")
  267. def stop(self):
  268. self.finish_progress_bar()
  269. self.is_running = False
  270. self.stop_event.set()
  271. self.debug_print("Stopping client and cleaning up resources.")
  272. if self.control_ws:
  273. self.control_ws.close()
  274. if self.data_ws_connected:
  275. self.data_ws_connected.close()
  276. # Join threads to ensure they finish before exiting
  277. if self.control_ws_thread:
  278. self.control_ws_thread.join()
  279. if self.data_ws_thread:
  280. self.data_ws_thread.join()
  281. if self.recording_thread:
  282. self.recording_thread.join()
  283. # Clean up audio resources
  284. if self.stream:
  285. self.stream.stop_stream()
  286. self.stream.close()
  287. if self.audio_interface:
  288. self.audio_interface.terminate()
  289. def start_recording(self):
  290. self.recording_thread = threading.Thread(target=self.record_and_send_audio)
  291. self.recording_thread.daemon = False # Set to False to ensure proper shutdown
  292. self.recording_thread.start()
  293. def record_and_send_audio(self):
  294. try:
  295. if not self.setup_audio():
  296. raise Exception("Failed to set up audio recording.")
  297. self.debug_print("Recording and sending audio...")
  298. self.show_initial_indicator()
  299. while self.is_running:
  300. try:
  301. audio_data = self.stream.read(CHUNK)
  302. metadata = {"sampleRate": self.device_sample_rate}
  303. metadata_json = json.dumps(metadata)
  304. metadata_length = len(metadata_json)
  305. message = struct.pack('<I', metadata_length) + metadata_json.encode('utf-8') + audio_data
  306. self.data_ws_connected.send(message, opcode=websocket.ABNF.OPCODE_BINARY)
  307. except Exception as e:
  308. self.debug_print(f"Error sending audio data: {e}")
  309. break # Exit the recording loop
  310. except Exception as e:
  311. self.debug_print(f"Error in record_and_send_audio: {e}")
  312. finally:
  313. self.cleanup_audio()
  314. def setup_audio(self):
  315. try:
  316. self.audio_interface = pyaudio.PyAudio()
  317. self.input_device_index = None
  318. try:
  319. default_device = self.audio_interface.get_default_input_device_info()
  320. self.input_device_index = default_device['index']
  321. except OSError as e:
  322. self.debug_print(f"No default input device found: {e}")
  323. return False
  324. self.device_sample_rate = 16000 # Try 16000 Hz first
  325. try:
  326. self.stream = self.audio_interface.open(
  327. format=FORMAT,
  328. channels=CHANNELS,
  329. rate=self.device_sample_rate,
  330. input=True,
  331. frames_per_buffer=CHUNK,
  332. input_device_index=self.input_device_index,
  333. )
  334. self.debug_print(f"Audio recording initialized successfully at {self.device_sample_rate} Hz, device index {self.input_device_index}")
  335. return True
  336. except Exception as e:
  337. self.debug_print(f"Failed to initialize audio stream at {self.device_sample_rate} Hz, device index {self.input_device_index}: {e}")
  338. return False
  339. except Exception as e:
  340. self.debug_print(f"Error initializing audio recording: {e}")
  341. if self.audio_interface:
  342. self.audio_interface.terminate()
  343. return False
  344. def cleanup_audio(self):
  345. try:
  346. if self.stream:
  347. self.stream.stop_stream()
  348. self.stream.close()
  349. self.stream = None
  350. if self.audio_interface:
  351. self.audio_interface.terminate()
  352. self.audio_interface = None
  353. except Exception as e:
  354. self.debug_print(f"Error cleaning up audio resources: {e}")
  355. def set_parameter(self, parameter, value):
  356. command = {
  357. "command": "set_parameter",
  358. "parameter": parameter,
  359. "value": value
  360. }
  361. self.control_ws.send(json.dumps(command))
  362. def get_parameter(self, parameter):
  363. command = {
  364. "command": "get_parameter",
  365. "parameter": parameter
  366. }
  367. self.control_ws.send(json.dumps(command))
  368. def call_method(self, method, args=None, kwargs=None):
  369. command = {
  370. "command": "call_method",
  371. "method": method,
  372. "args": args or [],
  373. "kwargs": kwargs or {}
  374. }
  375. self.control_ws.send(json.dumps(command))
  376. def start_command_processor(self):
  377. self.command_thread = threading.Thread(target=self.command_processor)
  378. self.command_thread.daemon = False # Ensure it is not a daemon thread
  379. self.command_thread.start()
  380. def command_processor(self):
  381. self.debug_print(f"Starting command processor")
  382. while not self.stop_event.is_set():
  383. try:
  384. command = self.commands.get(timeout=0.1)
  385. if command['type'] == 'set_parameter':
  386. self.set_parameter(command['parameter'], command['value'])
  387. elif command['type'] == 'get_parameter':
  388. self.get_parameter(command['parameter'])
  389. elif command['type'] == 'call_method':
  390. self.call_method(command['method'], command.get('args'), command.get('kwargs'))
  391. except queue.Empty:
  392. continue
  393. except Exception as e:
  394. self.debug_print(f"Error in command processor: {e}")
  395. self.debug_print(f"Leaving command processor")
  396. def add_command(self, command):
  397. self.commands.put(command)
  398. def main():
  399. parser = argparse.ArgumentParser(description="STT Client")
  400. parser.add_argument("--control-url", default=DEFAULT_CONTROL_URL, help="STT Control WebSocket URL")
  401. parser.add_argument("--data-url", default=DEFAULT_DATA_URL, help="STT Data WebSocket URL")
  402. parser.add_argument("--debug", action="store_true", help="Enable debug mode")
  403. parser.add_argument("-nort", "--norealtime", action="store_true", help="Disable real-time output")
  404. parser.add_argument("--set-param", nargs=2, metavar=('PARAM', 'VALUE'), action='append',
  405. help="Set a recorder parameter. Can be used multiple times.")
  406. parser.add_argument("--call-method", nargs='+', metavar='METHOD', action='append',
  407. help="Call a recorder method with optional arguments.")
  408. parser.add_argument("--get-param", nargs=1, metavar='PARAM', action='append',
  409. help="Get the value of a recorder parameter. Can be used multiple times.")
  410. args = parser.parse_args()
  411. # Check if output is being redirected
  412. if not os.isatty(sys.stdout.fileno()):
  413. file_output = sys.stdout
  414. else:
  415. file_output = None
  416. client = STTWebSocketClient(args.control_url, args.data_url, args.debug, file_output, args.norealtime)
  417. def signal_handler(sig, frame):
  418. client.stop()
  419. sys.exit(0)
  420. import signal
  421. signal.signal(signal.SIGINT, signal_handler)
  422. try:
  423. if client.connect():
  424. # Process command-line parameters
  425. if args.set_param:
  426. for param, value in args.set_param:
  427. try:
  428. if '.' in value:
  429. value = float(value)
  430. else:
  431. value = int(value)
  432. except ValueError:
  433. pass # Keep as string if not a number
  434. client.add_command({
  435. 'type': 'set_parameter',
  436. 'parameter': param,
  437. 'value': value
  438. })
  439. if args.get_param:
  440. for param_list in args.get_param:
  441. param = param_list[0]
  442. client.add_command({
  443. 'type': 'get_parameter',
  444. 'parameter': param
  445. })
  446. if args.call_method:
  447. for method_call in args.call_method:
  448. method = method_call[0]
  449. args_list = method_call[1:] if len(method_call) > 1 else []
  450. client.add_command({
  451. 'type': 'call_method',
  452. 'method': method,
  453. 'args': args_list
  454. })
  455. # If command-line parameters were used (like --get-param), wait for them to be processed
  456. if args.set_param or args.get_param or args.call_method:
  457. while not client.commands.empty():
  458. time.sleep(0.1)
  459. # Start recording directly if no command-line params were provided
  460. while client.is_running:
  461. time.sleep(0.1)
  462. else:
  463. print("Failed to connect to the server.", file=sys.stderr)
  464. except Exception as e:
  465. print(f"An error occurred: {e}")
  466. finally:
  467. client.stop()
  468. if __name__ == "__main__":
  469. main()