stt_cli_client.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. from .install_packages import check_and_install_packages
  2. check_and_install_packages([
  3. {
  4. 'module_name': 'websocket', # Import module
  5. 'install_name': 'websocket-client', # Package name for pip install (websocket-client is the correct package for websocket)
  6. },
  7. {
  8. 'module_name': 'pyaudio', # Import module
  9. 'install_name': 'pyaudio', # Package name for pip install
  10. },
  11. {
  12. 'module_name': 'colorama', # Import module
  13. 'attribute': 'init', # Attribute to check (init method from colorama)
  14. 'install_name': 'colorama', # Package name for pip install
  15. 'version': '', # Optional version constraint
  16. },
  17. ])
  18. import websocket
  19. import pyaudio
  20. from colorama import init, Fore, Style
  21. import argparse
  22. import json
  23. import threading
  24. import time
  25. import struct
  26. import os
  27. import sys
  28. import socket
  29. import subprocess
  30. import shutil
  31. from urllib.parse import urlparse
  32. from queue import Queue
  33. # Constants
  34. CHUNK = 1024
  35. FORMAT = pyaudio.paInt16
  36. CHANNELS = 1
  37. RATE = 16000
  38. DEFAULT_SERVER_URL = "ws://localhost:8011"
  39. class STTWebSocketClient:
  40. def __init__(self, server_url, debug=False, file_output=None, norealtime=False):
  41. self.server_url = server_url
  42. self.ws = None
  43. self.is_running = False
  44. self.debug = debug
  45. self.file_output = file_output
  46. self.last_text = ""
  47. self.pbar = None
  48. self.console_width = shutil.get_terminal_size().columns
  49. self.recording_indicator = "🔴"
  50. self.norealtime = norealtime
  51. self.connection_established = threading.Event()
  52. self.message_queue = Queue()
  53. def debug_print(self, message):
  54. if self.debug:
  55. print(message, file=sys.stderr)
  56. def connect(self):
  57. if not self.ensure_server_running():
  58. self.debug_print("Cannot start STT server. Exiting.")
  59. return False
  60. websocket.enableTrace(self.debug)
  61. try:
  62. self.ws = websocket.WebSocketApp(self.server_url,
  63. on_message=self.on_message,
  64. on_error=self.on_error,
  65. on_close=self.on_close,
  66. on_open=self.on_open)
  67. self.ws_thread = threading.Thread(target=self.ws.run_forever)
  68. self.ws_thread.daemon = True
  69. self.ws_thread.start()
  70. # Wait for the connection to be established
  71. if not self.connection_established.wait(timeout=10):
  72. self.debug_print("Timeout while connecting to the server.")
  73. return False
  74. self.debug_print("WebSocket connection established successfully.")
  75. return True
  76. except Exception as e:
  77. self.debug_print(f"Error while connecting to the server: {e}")
  78. return False
  79. def on_open(self, ws):
  80. self.debug_print("WebSocket connection opened.")
  81. self.is_running = True
  82. self.connection_established.set()
  83. self.start_recording()
  84. def on_error(self, ws, error):
  85. self.debug_print(f"WebSocket error: {error}")
  86. def on_close(self, ws, close_status_code, close_msg):
  87. self.debug_print(f"WebSocket connection closed: {close_status_code} - {close_msg}")
  88. self.is_running = False
  89. def is_server_running(self):
  90. parsed_url = urlparse(self.server_url)
  91. host = parsed_url.hostname
  92. port = parsed_url.port or 80
  93. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  94. return s.connect_ex((host, port)) == 0
  95. def ask_to_start_server(self):
  96. response = input("Would you like to start the STT server now? (y/n): ").strip().lower()
  97. return response == 'y' or response == 'yes'
  98. def start_server(self):
  99. if os.name == 'nt': # Windows
  100. subprocess.Popen('start /min cmd /c stt-server', shell=True)
  101. else: # Unix-like systems
  102. subprocess.Popen(['stt-server'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, start_new_session=True)
  103. print("STT server start command issued. Please wait a moment for it to initialize.", file=sys.stderr)
  104. def ensure_server_running(self):
  105. if not self.is_server_running():
  106. print("STT server is not running.", file=sys.stderr)
  107. if self.ask_to_start_server():
  108. self.start_server()
  109. print("Waiting for STT server to start...", file=sys.stderr)
  110. for _ in range(20): # Wait up to 20 seconds
  111. if self.is_server_running():
  112. print("STT server started successfully.", file=sys.stderr)
  113. time.sleep(2) # Give the server a moment to fully initialize
  114. return True
  115. time.sleep(1)
  116. print("Failed to start STT server.", file=sys.stderr)
  117. return False
  118. else:
  119. print("STT server is required. Please start it manually.", file=sys.stderr)
  120. return False
  121. return True
  122. def on_message(self, ws, message):
  123. try:
  124. data = json.loads(message)
  125. if data['type'] == 'realtime':
  126. if data['text'] != self.last_text:
  127. self.last_text = data['text']
  128. if not self.norealtime:
  129. self.update_progress_bar(self.last_text)
  130. elif data['type'] == 'fullSentence':
  131. if self.file_output:
  132. sys.stderr.write('\r\033[K')
  133. sys.stderr.write(data['text'])
  134. sys.stderr.write('\n')
  135. sys.stderr.flush()
  136. print(data['text'], file=self.file_output)
  137. self.file_output.flush() # Ensure it's written immediately
  138. else:
  139. self.finish_progress_bar()
  140. print(f"{data['text']}")
  141. # self.update_progress_bar("")
  142. # print(f"\r\033[K{data['text']}")
  143. # #print(f"\r\033[KHello")
  144. # self.stop()
  145. # print("what the fuck")
  146. # print("what the fuck")
  147. # print(f"what the fuck self.file_output {self.file_output}")
  148. # self.update_progress_bar("FGINAAL")
  149. # self.stop()
  150. # sys.stderr.write(f"\n{data['text']}")
  151. # sys.stderr.write(f"\n{data['text']}")
  152. # sys.stderr.write(f"\nTEEEST")
  153. # sys.stderr.write(f"\nTEEEST")
  154. # sys.stderr.flush()
  155. # print("what the fuck")
  156. # print("what the fuck")
  157. # print("what the fuck")
  158. # print("what the fuck")
  159. # print("what the fuck")
  160. self.stop()
  161. except json.JSONDecodeError:
  162. self.debug_print(f"\nReceived non-JSON message: {message}")
  163. def show_initial_indicator(self):
  164. if self.norealtime:
  165. return
  166. initial_text = f"{self.recording_indicator}\b\b"
  167. sys.stderr.write(initial_text)
  168. sys.stderr.flush()
  169. def update_progress_bar(self, text):
  170. # Reserve some space for the progress bar decorations
  171. available_width = self.console_width - 5
  172. # Clear the current line
  173. sys.stderr.write('\r\033[K') # Move to the beginning of the line and clear it
  174. # Get the last 'available_width' characters, but don't cut words
  175. words = text.split()
  176. last_chars = ""
  177. for word in reversed(words):
  178. if len(last_chars) + len(word) + 1 > available_width:
  179. break
  180. last_chars = word + " " + last_chars
  181. last_chars = last_chars.strip()
  182. # Color the text yellow and add recording indicator
  183. colored_text = f"{Fore.YELLOW}{last_chars}{Style.RESET_ALL}{self.recording_indicator}\b\b"
  184. sys.stderr.write(colored_text)
  185. sys.stderr.flush()
  186. def finish_progress_bar(self):
  187. # Clear the current line
  188. sys.stderr.write('\r\033[K')
  189. sys.stderr.flush()
  190. def stop(self):
  191. self.finish_progress_bar()
  192. self.is_running = False
  193. if self.ws:
  194. self.ws.close()
  195. if hasattr(self, 'ws_thread'):
  196. self.ws_thread.join(timeout=2)
  197. def start_recording(self):
  198. self.show_initial_indicator()
  199. threading.Thread(target=self.record_and_send_audio).start()
  200. def record_and_send_audio(self):
  201. p = pyaudio.PyAudio()
  202. stream = p.open(format=FORMAT,
  203. input_device_index=1,
  204. channels=CHANNELS,
  205. rate=RATE,
  206. input=True,
  207. frames_per_buffer=CHUNK)
  208. self.debug_print("Recording and sending audio...")
  209. while self.is_running:
  210. try:
  211. audio_data = stream.read(CHUNK)
  212. # Prepare metadata
  213. metadata = {
  214. "sampleRate": RATE
  215. }
  216. metadata_json = json.dumps(metadata)
  217. metadata_length = len(metadata_json)
  218. # Construct the message
  219. message = struct.pack('<I', metadata_length) + metadata_json.encode('utf-8') + audio_data
  220. self.ws.send(message, opcode=websocket.ABNF.OPCODE_BINARY)
  221. except Exception as e:
  222. self.debug_print(f"\nError sending audio data: {e}")
  223. break
  224. self.debug_print("Stopped recording.")
  225. stream.stop_stream()
  226. stream.close()
  227. p.terminate()
  228. def main():
  229. parser = argparse.ArgumentParser(description="STT Client")
  230. parser.add_argument("--server", default=DEFAULT_SERVER_URL, help="STT WebSocket server URL")
  231. parser.add_argument("--debug", action="store_true", help="Enable debug mode")
  232. parser.add_argument("-nort", "--norealtime", action="store_true", help="Disable real-time output")
  233. args = parser.parse_args()
  234. # Check if output is being redirected
  235. if not os.isatty(sys.stdout.fileno()):
  236. file_output = sys.stdout
  237. else:
  238. file_output = None
  239. client = STTWebSocketClient(args.server, args.debug, file_output, args.norealtime)
  240. def signal_handler(sig, frame):
  241. # print("\nInterrupted by user, shutting down...")
  242. client.stop()
  243. sys.exit(0)
  244. import signal
  245. signal.signal(signal.SIGINT, signal_handler)
  246. try:
  247. if client.connect():
  248. # print("Connection established. Recording... (Press Ctrl+C to stop)", file=sys.stderr)
  249. while client.is_running:
  250. time.sleep(0.1)
  251. else:
  252. print("Failed to connect to the server.", file=sys.stderr)
  253. except Exception as e:
  254. print(f"An error occurred: {e}")
  255. finally:
  256. client.stop()
  257. if __name__ == "__main__":
  258. main()