stt_cli_client.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. import os
  2. import sys
  3. import pyaudio
  4. import numpy as np
  5. from scipy import signal
  6. import logging
  7. os.environ['ALSA_LOG_LEVEL'] = 'none'
  8. CHUNK = 1024
  9. FORMAT = pyaudio.paInt16
  10. CHANNELS = 1
  11. RATE = 44100 # Default fallback rate
  12. input_device_index = None
  13. audio_interface = None
  14. stream = None
  15. device_sample_rate = None
  16. chunk_size = CHUNK
  17. def get_highest_sample_rate(audio_interface, device_index):
  18. """Get the highest supported sample rate for the specified device."""
  19. try:
  20. device_info = audio_interface.get_device_info_by_index(device_index)
  21. max_rate = int(device_info['defaultSampleRate'])
  22. if 'supportedSampleRates' in device_info:
  23. supported_rates = [int(rate) for rate in device_info['supportedSampleRates']]
  24. if supported_rates:
  25. max_rate = max(supported_rates)
  26. return max_rate
  27. except Exception as e:
  28. logging.warning(f"Failed to get highest sample rate: {e}")
  29. return 48000 # Fallback to a common high sample rate
  30. def initialize_audio_stream(audio_interface, device_index, sample_rate, chunk_size):
  31. """Initialize the audio stream with error handling."""
  32. try:
  33. stream = audio_interface.open(
  34. format=pyaudio.paInt16,
  35. channels=CHANNELS,
  36. rate=sample_rate,
  37. input=True,
  38. frames_per_buffer=chunk_size,
  39. input_device_index=device_index,
  40. )
  41. return stream
  42. except Exception as e:
  43. logging.error(f"Error initializing audio stream: {e}")
  44. raise
  45. def preprocess_audio(chunk, original_sample_rate, target_sample_rate):
  46. """Preprocess audio chunk similar to feed_audio method."""
  47. if isinstance(chunk, np.ndarray):
  48. if chunk.ndim == 2: # Stereo to mono conversion
  49. chunk = np.mean(chunk, axis=1)
  50. # Resample if needed
  51. if original_sample_rate != target_sample_rate:
  52. num_samples = int(len(chunk) * target_sample_rate / original_sample_rate)
  53. chunk = signal.resample(chunk, num_samples)
  54. chunk = chunk.astype(np.int16)
  55. else:
  56. chunk = np.frombuffer(chunk, dtype=np.int16)
  57. if original_sample_rate != target_sample_rate:
  58. num_samples = int(len(chunk) * target_sample_rate / original_sample_rate)
  59. chunk = signal.resample(chunk, num_samples)
  60. chunk = chunk.astype(np.int16)
  61. return chunk.tobytes()
  62. def setup_audio():
  63. global audio_interface, stream, device_sample_rate, input_device_index
  64. try:
  65. audio_interface = pyaudio.PyAudio()
  66. if input_device_index is None:
  67. try:
  68. default_device = audio_interface.get_default_input_device_info()
  69. input_device_index = default_device['index']
  70. except OSError as e:
  71. input_device_index = None
  72. sample_rates_to_try = [16000] # Try 16000 Hz first
  73. if input_device_index is not None:
  74. highest_rate = get_highest_sample_rate(audio_interface, input_device_index)
  75. if highest_rate != 16000:
  76. sample_rates_to_try.append(highest_rate)
  77. else:
  78. sample_rates_to_try.append(48000) # Fallback sample rate
  79. for rate in sample_rates_to_try:
  80. try:
  81. device_sample_rate = rate
  82. stream = initialize_audio_stream(audio_interface, input_device_index, device_sample_rate, chunk_size)
  83. if stream is not None:
  84. logging.debug(f"Audio recording initialized successfully at {device_sample_rate} Hz, reading {chunk_size} frames at a time")
  85. return True
  86. except Exception as e:
  87. logging.warning(f"Failed to initialize audio stream at {device_sample_rate} Hz: {e}")
  88. continue
  89. raise Exception("Failed to initialize audio stream with all sample rates.")
  90. except Exception as e:
  91. logging.exception(f"Error initializing audio recording: {e}")
  92. if audio_interface:
  93. audio_interface.terminate()
  94. return False
  95. from .install_packages import check_and_install_packages
  96. check_and_install_packages([
  97. {
  98. 'module_name': 'websocket', # Import module
  99. 'install_name': 'websocket-client', # Package name for pip install (websocket-client is the correct package for websocket)
  100. },
  101. {
  102. 'module_name': 'pyaudio', # Import module
  103. 'install_name': 'pyaudio', # Package name for pip install
  104. },
  105. {
  106. 'module_name': 'colorama', # Import module
  107. 'attribute': 'init', # Attribute to check (init method from colorama)
  108. 'install_name': 'colorama', # Package name for pip install
  109. 'version': '', # Optional version constraint
  110. },
  111. ])
  112. import websocket
  113. import pyaudio
  114. from colorama import init, Fore, Style
  115. import argparse
  116. import json
  117. import threading
  118. import time
  119. import struct
  120. import socket
  121. import subprocess
  122. import shutil
  123. from urllib.parse import urlparse
  124. from queue import Queue
  125. # Constants
  126. CHUNK = 1024
  127. FORMAT = pyaudio.paInt16
  128. CHANNELS = 1
  129. RATE = 44100
  130. DEFAULT_SERVER_URL = "ws://localhost:8011"
  131. class STTWebSocketClient:
  132. def __init__(self, server_url, debug=False, file_output=None, norealtime=False):
  133. self.server_url = server_url
  134. self.ws = None
  135. self.is_running = False
  136. self.debug = debug
  137. self.file_output = file_output
  138. self.last_text = ""
  139. self.pbar = None
  140. self.console_width = shutil.get_terminal_size().columns
  141. self.recording_indicator = "🔴"
  142. self.norealtime = norealtime
  143. self.connection_established = threading.Event()
  144. self.message_queue = Queue()
  145. def debug_print(self, message):
  146. if self.debug:
  147. print(message, file=sys.stderr)
  148. def connect(self):
  149. if not self.ensure_server_running():
  150. self.debug_print("Cannot start STT server. Exiting.")
  151. return False
  152. websocket.enableTrace(self.debug)
  153. try:
  154. self.ws = websocket.WebSocketApp(self.server_url,
  155. on_message=self.on_message,
  156. on_error=self.on_error,
  157. on_close=self.on_close,
  158. on_open=self.on_open)
  159. self.ws_thread = threading.Thread(target=self.ws.run_forever)
  160. self.ws_thread.daemon = True
  161. self.ws_thread.start()
  162. # Wait for the connection to be established
  163. if not self.connection_established.wait(timeout=10):
  164. self.debug_print("Timeout while connecting to the server.")
  165. return False
  166. self.debug_print("WebSocket connection established successfully.")
  167. return True
  168. except Exception as e:
  169. self.debug_print(f"Error while connecting to the server: {e}")
  170. return False
  171. def on_open(self, ws):
  172. self.debug_print("WebSocket connection opened.")
  173. self.is_running = True
  174. self.connection_established.set()
  175. self.start_recording()
  176. def on_error(self, ws, error):
  177. self.debug_print(f"WebSocket error: {error}")
  178. def on_close(self, ws, close_status_code, close_msg):
  179. self.debug_print(f"WebSocket connection closed: {close_status_code} - {close_msg}")
  180. self.is_running = False
  181. def is_server_running(self):
  182. parsed_url = urlparse(self.server_url)
  183. host = parsed_url.hostname
  184. port = parsed_url.port or 80
  185. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  186. return s.connect_ex((host, port)) == 0
  187. def ask_to_start_server(self):
  188. response = input("Would you like to start the STT server now? (y/n): ").strip().lower()
  189. return response == 'y' or response == 'yes'
  190. def start_server(self):
  191. if os.name == 'nt': # Windows
  192. subprocess.Popen('start /min cmd /c stt-server', shell=True)
  193. else: # Unix-like systems
  194. subprocess.Popen(['stt-server'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, start_new_session=True)
  195. print("STT server start command issued. Please wait a moment for it to initialize.", file=sys.stderr)
  196. def ensure_server_running(self):
  197. if not self.is_server_running():
  198. print("STT server is not running.", file=sys.stderr)
  199. if self.ask_to_start_server():
  200. self.start_server()
  201. print("Waiting for STT server to start...", file=sys.stderr)
  202. for _ in range(20): # Wait up to 20 seconds
  203. if self.is_server_running():
  204. print("STT server started successfully.", file=sys.stderr)
  205. time.sleep(2) # Give the server a moment to fully initialize
  206. return True
  207. time.sleep(1)
  208. print("Failed to start STT server.", file=sys.stderr)
  209. return False
  210. else:
  211. print("STT server is required. Please start it manually.", file=sys.stderr)
  212. return False
  213. return True
  214. def on_message(self, ws, message):
  215. try:
  216. data = json.loads(message)
  217. if data['type'] == 'realtime':
  218. if data['text'] != self.last_text:
  219. self.last_text = data['text']
  220. if not self.norealtime:
  221. self.update_progress_bar(self.last_text)
  222. elif data['type'] == 'fullSentence':
  223. if self.file_output:
  224. sys.stderr.write('\r\033[K')
  225. sys.stderr.write(data['text'])
  226. sys.stderr.write('\n')
  227. sys.stderr.flush()
  228. print(data['text'], file=self.file_output)
  229. self.file_output.flush() # Ensure it's written immediately
  230. else:
  231. self.finish_progress_bar()
  232. print(f"{data['text']}")
  233. self.stop()
  234. except json.JSONDecodeError:
  235. self.debug_print(f"\nReceived non-JSON message: {message}")
  236. def show_initial_indicator(self):
  237. if self.norealtime:
  238. return
  239. initial_text = f"{self.recording_indicator}\b\b"
  240. sys.stderr.write(initial_text)
  241. sys.stderr.flush()
  242. def update_progress_bar(self, text):
  243. # Reserve some space for the progress bar decorations
  244. available_width = self.console_width - 5
  245. # Clear the current line
  246. sys.stderr.write('\r\033[K') # Move to the beginning of the line and clear it
  247. # Get the last 'available_width' characters, but don't cut words
  248. words = text.split()
  249. last_chars = ""
  250. for word in reversed(words):
  251. if len(last_chars) + len(word) + 1 > available_width:
  252. break
  253. last_chars = word + " " + last_chars
  254. last_chars = last_chars.strip()
  255. # Color the text yellow and add recording indicator
  256. colored_text = f"{Fore.YELLOW}{last_chars}{Style.RESET_ALL}{self.recording_indicator}\b\b"
  257. sys.stderr.write(colored_text)
  258. sys.stderr.flush()
  259. def finish_progress_bar(self):
  260. # Clear the current line
  261. sys.stderr.write('\r\033[K')
  262. sys.stderr.flush()
  263. def stop(self):
  264. self.finish_progress_bar()
  265. self.is_running = False
  266. if self.ws:
  267. self.ws.close()
  268. #if hasattr(self, 'ws_thread'):
  269. # self.ws_thread.join(timeout=2)
  270. def start_recording(self):
  271. threading.Thread(target=self.record_and_send_audio).start()
  272. def record_and_send_audio(self):
  273. if not setup_audio():
  274. raise Exception("Failed to set up audio recording.")
  275. self.debug_print("Recording and sending audio...")
  276. self.show_initial_indicator()
  277. while self.is_running:
  278. try:
  279. audio_data = stream.read(CHUNK)
  280. # Prepare metadata
  281. metadata = {
  282. "sampleRate": device_sample_rate
  283. }
  284. metadata_json = json.dumps(metadata)
  285. metadata_length = len(metadata_json)
  286. # Construct the message
  287. message = struct.pack('<I', metadata_length) + metadata_json.encode('utf-8') + audio_data
  288. self.ws.send(message, opcode=websocket.ABNF.OPCODE_BINARY)
  289. except Exception as e:
  290. self.debug_print(f"Error sending audio data: {e}")
  291. break
  292. self.debug_print("Stopped recording.")
  293. stream.stop_stream()
  294. stream.close()
  295. audio_interface.terminate()
  296. def main():
  297. parser = argparse.ArgumentParser(description="STT Client")
  298. parser.add_argument("--server", default=DEFAULT_SERVER_URL, help="STT WebSocket server URL")
  299. parser.add_argument("--debug", action="store_true", help="Enable debug mode")
  300. parser.add_argument("-nort", "--norealtime", action="store_true", help="Disable real-time output")
  301. args = parser.parse_args()
  302. # Check if output is being redirected
  303. if not os.isatty(sys.stdout.fileno()):
  304. file_output = sys.stdout
  305. else:
  306. file_output = None
  307. client = STTWebSocketClient(args.server, args.debug, file_output, args.norealtime)
  308. def signal_handler(sig, frame):
  309. # print("\nInterrupted by user, shutting down...")
  310. client.stop()
  311. sys.exit(0)
  312. import signal
  313. signal.signal(signal.SIGINT, signal_handler)
  314. try:
  315. if client.connect():
  316. # print("Connection established. Recording... (Press Ctrl+C to stop)", file=sys.stderr)
  317. while client.is_running:
  318. time.sleep(0.1)
  319. else:
  320. print("Failed to connect to the server.", file=sys.stderr)
  321. except Exception as e:
  322. print(f"An error occurred: {e}")
  323. finally:
  324. client.stop()
  325. if __name__ == "__main__":
  326. main()