Преглед изворни кода

more logging for client/server

KoljaB пре 6 месеци
родитељ
комит
15ddbca7c9
4 измењених фајлова са 283 додато и 89 уклоњено
  1. 1 1
      README.md
  2. 42 12
      RealtimeSTT/audio_recorder.py
  3. 129 34
      RealtimeSTT_server/stt_cli_client.py
  4. 111 42
      RealtimeSTT_server/stt_server.py

+ 1 - 1
README.md

@@ -28,7 +28,7 @@ https://github.com/user-attachments/assets/797e6552-27cd-41b1-a7f3-e5cbc72094f5
 
 ### Updates
 
-Latest Version: v0.3.4
+Latest Version: v0.3.6
 
 See [release history](https://github.com/KoljaB/RealtimeSTT/releases).
 

+ 42 - 12
RealtimeSTT/audio_recorder.py

@@ -63,6 +63,7 @@ os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
 INIT_MODEL_TRANSCRIPTION = "tiny"
 INIT_MODEL_TRANSCRIPTION_REALTIME = "tiny"
 INIT_REALTIME_PROCESSING_PAUSE = 0.2
+INIT_REALTIME_INITIAL_PAUSE = 0.2
 INIT_SILERO_SENSITIVITY = 0.4
 INIT_WEBRTC_SENSITIVITY = 3
 INIT_POST_SPEECH_SILENCE_DURATION = 0.6
@@ -179,6 +180,12 @@ class TranscriptionWorker:
             polling_thread.join()  # Wait for the polling thread to finish
 
 
+class bcolors:
+    OKGREEN = '\033[92m'  # Green for active speech detection
+    WARNING = '\033[93m'  # Yellow for silence detection
+    ENDC = '\033[0m'      # Reset to default color
+
+
 class AudioToTextRecorder:
     """
     A class responsible for capturing audio from the microphone, detecting
@@ -207,6 +214,7 @@ class AudioToTextRecorder:
                  use_main_model_for_realtime=False,
                  realtime_model_type=INIT_MODEL_TRANSCRIPTION_REALTIME,
                  realtime_processing_pause=INIT_REALTIME_PROCESSING_PAUSE,
+                 init_realtime_after_seconds=INIT_REALTIME_INITIAL_PAUSE,
                  on_realtime_transcription_update=None,
                  on_realtime_transcription_stabilized=None,
 
@@ -326,6 +334,9 @@ class AudioToTextRecorder:
             interval in seconds after a chunk of audio gets transcribed. Lower
             values will result in more "real-time" (frequent) transcription
             updates but may increase computational load.
+        - init_realtime_after_seconds (float, default=0.2): Specifies the 
+            initial waiting time after the recording was initiated before
+            yielding the first realtime transcription
         - on_realtime_transcription_update = A callback function that is
             triggered whenever there's an update in the real-time
             transcription. The function is called with the newly transcribed
@@ -499,6 +510,7 @@ class AudioToTextRecorder:
         self.main_model_type = model
         self.realtime_model_type = realtime_model_type
         self.realtime_processing_pause = realtime_processing_pause
+        self.init_realtime_after_seconds = init_realtime_after_seconds
         self.on_realtime_transcription_update = (
             on_realtime_transcription_update
         )
@@ -1619,8 +1631,8 @@ class AudioToTextRecorder:
             # Continuously monitor audio for voice activity
             while self.is_running:
 
-                if self.use_extended_logging:
-                    logging.debug('Debug: Entering inner try block')
+                # if self.use_extended_logging:
+                #     logging.debug('Debug: Entering inner try block')
                 if last_inner_try_time:
                     last_processing_time = time.time() - last_inner_try_time
                     if last_processing_time > 0.1:
@@ -1628,20 +1640,20 @@ class AudioToTextRecorder:
                             logging.warning('### WARNING: PROCESSING TOOK TOO LONG')
                 last_inner_try_time = time.time()
                 try:
-                    if self.use_extended_logging:
-                        logging.debug('Debug: Trying to get data from audio queue')
+                    # if self.use_extended_logging:
+                    #     logging.debug('Debug: Trying to get data from audio queue')
                     try:
                         data = self.audio_queue.get(timeout=0.01)
                         self.last_words_buffer.append(data)
                     except queue.Empty:
-                        if self.use_extended_logging:
-                            logging.debug('Debug: Queue is empty, checking if still running')
+                        # if self.use_extended_logging:
+                        #     logging.debug('Debug: Queue is empty, checking if still running')
                         if not self.is_running:
                             if self.use_extended_logging:
                                 logging.debug('Debug: Not running, breaking loop')
                             break
-                        if self.use_extended_logging:
-                            logging.debug('Debug: Continuing to next iteration')
+                        # if self.use_extended_logging:
+                        #     logging.debug('Debug: Continuing to next iteration')
                         continue
 
                     if self.use_extended_logging:
@@ -2072,7 +2084,7 @@ class AudioToTextRecorder:
                     # double check recording state
                     # because it could have changed mid-transcription
                     if self.is_recording and time.time() - \
-                            self.recording_start_time > 0.5:
+                            self.recording_start_time > self.init_realtime_after_seconds:
 
                         # logging.debug('Starting realtime transcription')
                         self.realtime_transcription_text = realtime_text
@@ -2179,7 +2191,11 @@ class AudioToTextRecorder:
             SAMPLE_RATE).item()
         is_silero_speech_active = vad_prob > (1 - self.silero_sensitivity)
         if is_silero_speech_active:
-            self.is_silero_speech_active = True
+            if not self.is_silero_speech_active and self.use_extended_logging:
+                logging.info(f"{bcolors.OKGREEN}Silero VAD detected speech{bcolors.ENDC}")
+        elif self.is_silero_speech_active and self.use_extended_logging:
+            logging.info(f"{bcolors.WARNING}Silero VAD detected silence{bcolors.ENDC}")
+        self.is_silero_speech_active = is_silero_speech_active
         self.silero_working = False
         return is_silero_speech_active
 
@@ -2191,6 +2207,8 @@ class AudioToTextRecorder:
             data (bytes): raw bytes of audio data (1024 raw bytes with
             16000 sample rate and 16 bits per sample)
         """
+        speech_str = f"{bcolors.OKGREEN}WebRTC VAD detected speech{bcolors.ENDC}"
+        silence_str = f"{bcolors.WARNING}WebRTC VAD detected silence{bcolors.ENDC}"
         if self.sample_rate != 16000:
             pcm_data = np.frombuffer(chunk, dtype=np.int16)
             data_16000 = signal.resample_poly(
@@ -2212,6 +2230,9 @@ class AudioToTextRecorder:
                     if self.debug_mode:
                         logging.info(f"Speech detected in frame {i + 1}"
                               f" of {num_frames}")
+                    if not self.is_webrtc_speech_active and self.use_extended_logging:
+                        logging.info(speech_str)
+                    self.is_webrtc_speech_active = True
                     return True
         if all_frames_must_be_true:
             if self.debug_mode and speech_frames == num_frames:
@@ -2219,10 +2240,19 @@ class AudioToTextRecorder:
                       f"{num_frames} frames")
             elif self.debug_mode:
                 logging.info(f"Speech not detected in all {num_frames} frames")
-            return speech_frames == num_frames
+            speech_detected = speech_frames == num_frames
+            if speech_detected and not self.is_webrtc_speech_active and self.use_extended_logging:
+                logging.info(speech_str)
+            elif not speech_detected and self.is_webrtc_speech_active and self.use_extended_logging:
+                logging.info(silence_str)
+            self.is_webrtc_speech_active = speech_detected
+            return speech_detected
         else:
             if self.debug_mode:
                 logging.info(f"Speech not detected in any of {num_frames} frames")
+            if self.is_webrtc_speech_active and self.use_extended_logging:
+                logging.info(silence_str)
+            self.is_webrtc_speech_active = False
             return False
 
     def _check_voice_activity(self, data):
@@ -2232,7 +2262,7 @@ class AudioToTextRecorder:
         Args:
             data: The audio data to be checked for voice activity.
         """
-        self.is_webrtc_speech_active = self._is_webrtc_speech(data)
+        self._is_webrtc_speech(data)
 
         # First quick performing check for voice activity using WebRTC
         if self.is_webrtc_speech_active:

+ 129 - 34
RealtimeSTT_server/stt_cli_client.py

@@ -1,3 +1,21 @@
+"""
+This is a command-line client for the Speech-to-Text (STT) server.
+It records audio from the default input device and sends it to the server for speech recognition.
+It can also process commands to set parameters, get parameter values, or call methods on the server.
+
+Usage:
+    stt [--control-url CONTROL_URL] [--data-url DATA_URL] [--debug] [--norealtime] [--set-param PARAM VALUE] [--call-method METHOD [ARGS ...]] [--get-param PARAM]
+
+Options:
+    --control-url CONTROL_URL       STT Control WebSocket URL
+    --data-url DATA_URL             STT Data WebSocket URL
+    --debug                         Enable debug mode
+    --norealtime                    Disable real-time output
+    --set-param PARAM VALUE         Set a recorder parameter. Can be used multiple times.
+    --call-method METHOD [ARGS ...] Call a recorder method with optional arguments.
+    --get-param PARAM               Get the value of a recorder parameter. Can be used multiple times.
+"""
+
 from urllib.parse import urlparse
 from scipy import signal
 from queue import Queue
@@ -7,13 +25,13 @@ import threading
 import websocket
 import argparse
 import pyaudio
-import logging
 import struct
 import socket
 import shutil
 import queue 
 import json
 import time
+import wave
 import sys
 import os
 
@@ -35,7 +53,7 @@ init()
 websocket.enableTrace(False)
 
 class STTWebSocketClient:
-    def __init__(self, control_url, data_url, debug=False, file_output=None, norealtime=False):
+    def __init__(self, control_url, data_url, debug=False, file_output=None, norealtime=False, writechunks=None):
         self.control_url = control_url
         self.data_url = data_url
         self.control_ws = None
@@ -52,6 +70,16 @@ class STTWebSocketClient:
         self.message_queue = Queue()
         self.commands = Queue()
         self.stop_event = threading.Event()
+        self.chunks_sent = 0
+        self.last_chunk_time = time.time()
+        self.writechunks = writechunks  # Add this to store the file name for writing audio chunks
+
+        self.debug_print("Initializing STT WebSocket Client")
+        self.debug_print(f"Control URL: {control_url}")
+        self.debug_print(f"Data URL: {data_url}")
+        self.debug_print(f"File Output: {file_output}")
+        self.debug_print(f"No Realtime: {norealtime}")
+        self.debug_print(f"Write Chunks: {writechunks}")
 
         # Audio attributes
         self.audio_interface = None
@@ -64,9 +92,12 @@ class STTWebSocketClient:
         self.data_ws_thread = None
         self.recording_thread = None
 
+
     def debug_print(self, message):
         if self.debug:
-            print(message, file=sys.stderr)
+            timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
+            thread_name = threading.current_thread().name
+            print(f"{Fore.CYAN}[DEBUG][{timestamp}][{thread_name}] {message}{Style.RESET_ALL}", file=sys.stderr)
 
     def connect(self):
         if not self.ensure_server_running():
@@ -74,7 +105,10 @@ class STTWebSocketClient:
             return False
 
         try:
+            self.debug_print("Attempting to establish WebSocket connections...")
+
             # Connect to control WebSocket
+            self.debug_print(f"Connecting to control WebSocket at {self.control_url}")
             self.control_ws = websocket.WebSocketApp(self.control_url,
                                                      on_message=self.on_control_message,
                                                      on_error=self.on_error,
@@ -82,10 +116,12 @@ class STTWebSocketClient:
                                                      on_open=self.on_control_open)
 
             self.control_ws_thread = threading.Thread(target=self.control_ws.run_forever)
-            self.control_ws_thread.daemon = False  # Set to False to ensure proper shutdown
+            self.control_ws_thread.daemon = False
+            self.debug_print("Starting control WebSocket thread")
             self.control_ws_thread.start()
 
             # Connect to data WebSocket
+            self.debug_print(f"Connecting to data WebSocket at {self.data_url}")
             self.data_ws_app = websocket.WebSocketApp(self.data_url,
                                                       on_message=self.on_data_message,
                                                       on_error=self.on_error,
@@ -93,10 +129,11 @@ class STTWebSocketClient:
                                                       on_open=self.on_data_open)
 
             self.data_ws_thread = threading.Thread(target=self.data_ws_app.run_forever)
-            self.data_ws_thread.daemon = False  # Set to False to ensure proper shutdown
+            self.data_ws_thread.daemon = False
+            self.debug_print("Starting data WebSocket thread")
             self.data_ws_thread.start()
 
-            # Wait for the connections to be established
+            self.debug_print("Waiting for connections to be established...")
             if not self.connection_established.wait(timeout=10):
                 self.debug_print("Timeout while connecting to the server.")
                 return False
@@ -104,24 +141,30 @@ class STTWebSocketClient:
             self.debug_print("WebSocket connections established successfully.")
             return True
         except Exception as e:
-            self.debug_print(f"Error while connecting to the server: {e}")
+            self.debug_print(f"Error while connecting to the server: {str(e)}")
             return False
 
+
     def on_control_open(self, ws):
-        self.debug_print("Control WebSocket connection opened.")
+        self.debug_print("Control WebSocket connection opened successfully")
         self.connection_established.set()
         self.start_command_processor()
 
     def on_data_open(self, ws):
-        self.debug_print("Data WebSocket connection opened.")
-        self.data_ws_connected = ws  # Store the connected websocket object for sending data
+        self.debug_print("Data WebSocket connection opened successfully")
+        self.data_ws_connected = ws
         self.start_recording()
 
     def on_error(self, ws, error):
-        self.debug_print(f"WebSocket error: {error}")
+        self.debug_print(f"WebSocket error occurred: {str(error)}")
+        self.debug_print(f"WebSocket object: {ws}")
+        self.debug_print(f"Error type: {type(error)}")
 
     def on_close(self, ws, close_status_code, close_msg):
-        self.debug_print(f"WebSocket connection closed: {close_status_code} - {close_msg}")
+        self.debug_print(f"WebSocket connection closed")
+        self.debug_print(f"Close status code: {close_status_code}")
+        self.debug_print(f"Close message: {close_msg}")
+        self.debug_print(f"WebSocket object: {ws}")
         self.is_running = False
         self.stop_event.set()
 
@@ -129,8 +172,11 @@ class STTWebSocketClient:
         parsed_url = urlparse(self.control_url)
         host = parsed_url.hostname
         port = parsed_url.port or 80
+        self.debug_print(f"Checking if server is running at {host}:{port}")
         with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
-            return s.connect_ex((host, port)) == 0
+            result = s.connect_ex((host, port)) == 0
+            self.debug_print(f"Server status check result: {'running' if result else 'not running'}")
+            return result
 
     def ask_to_start_server(self):
         response = input("Would you like to start the STT server now? (y/n): ").strip().lower()
@@ -213,37 +259,47 @@ class STTWebSocketClient:
 
     def on_control_message(self, ws, message):
         try:
+            self.debug_print(f"Received control message: {message}")
             data = json.loads(message)
             if 'status' in data:
+                self.debug_print(f"Message status: {data['status']}")
                 if data['status'] == 'success':
                     if 'parameter' in data and 'value' in data:
+                        self.debug_print(f"Parameter update: {data['parameter']} = {data['value']}")
                         print(f"Parameter {data['parameter']} = {data['value']}")
                 elif data['status'] == 'error':
+                    self.debug_print(f"Server error received: {data.get('message', '')}")
                     print(f"Server Error: {data.get('message', '')}")
             else:
                 self.debug_print(f"Unknown control message format: {data}")
         except json.JSONDecodeError:
-            self.debug_print(f"Received non-JSON control message: {message}")
+            self.debug_print(f"Failed to decode JSON control message: {message}")
         except Exception as e:
-            self.debug_print(f"Error processing control message: {e}")
+            self.debug_print(f"Error processing control message: {str(e)}")
 
     def on_data_message(self, ws, message):
         try:
+            self.debug_print(f"Received data message: {message}")
             data = json.loads(message)
             message_type = data.get('type')
+            self.debug_print(f"Message type: {message_type}")
+
             if message_type == 'realtime':
                 if data['text'] != self.last_text:
+                    self.debug_print(f"New realtime text received: {data['text']}")
                     self.last_text = data['text']
                     if not self.norealtime:
                         self.update_progress_bar(self.last_text)
             elif message_type == 'fullSentence':
+                self.debug_print(f"Full sentence received: {data['text']}")
                 if self.file_output:
+                    self.debug_print("Writing to file output")
                     sys.stderr.write('\r\033[K')
                     sys.stderr.write(data['text'])
                     sys.stderr.write('\n')
                     sys.stderr.flush()
                     print(data['text'], file=self.file_output)
-                    self.file_output.flush()  # Ensure it's written immediately
+                    self.file_output.flush()
                 else:
                     self.finish_progress_bar()
                     print(f"{data['text']}")
@@ -258,14 +314,13 @@ class STTWebSocketClient:
                 'wakeword_detection_end',
                 'transcription_start'}:
                 pass  # Known message types, no action needed
-
             else:
-                self.debug_print(f"Unknown data message format: {data}")
+                self.debug_print(f"Other message type received: {message_type}")
 
         except json.JSONDecodeError:
-            self.debug_print(f"Received non-JSON data message: {message}")
+            self.debug_print(f"Failed to decode JSON data message: {message}")
         except Exception as e:
-            self.debug_print(f"Error processing data message: {e}")
+            self.debug_print(f"Error processing data message: {str(e)}")
 
     def show_initial_indicator(self):
         if self.norealtime:
@@ -331,40 +386,67 @@ class STTWebSocketClient:
     def record_and_send_audio(self):
         try:
             if not self.setup_audio():
+                self.debug_print("Failed to set up audio recording")
                 raise Exception("Failed to set up audio recording.")
 
-            self.debug_print("Recording and sending audio...")
+            # Initialize WAV file writer if writechunks is provided
+            if self.writechunks:
+                self.wav_file = wave.open(self.writechunks, 'wb')
+                self.wav_file.setnchannels(CHANNELS)
+                self.wav_file.setsampwidth(pyaudio.get_sample_size(FORMAT))
+                self.wav_file.setframerate(self.device_sample_rate)  # Use self.device_sample_rate
+
+            self.debug_print("Starting audio recording and transmission")
             self.show_initial_indicator()
 
             while self.is_running:
                 try:
                     audio_data = self.stream.read(CHUNK)
+                    self.chunks_sent += 1
+                    current_time = time.time()
+                    elapsed = current_time - self.last_chunk_time
+
+                    # Write to WAV file if enabled
+                    if self.writechunks:
+                        self.wav_file.writeframes(audio_data)
+
+                    if self.chunks_sent % 100 == 0:  # Log every 100 chunks
+                        self.debug_print(f"Sent {self.chunks_sent} chunks. Last chunk took {elapsed:.3f}s")
+
                     metadata = {"sampleRate": self.device_sample_rate}
                     metadata_json = json.dumps(metadata)
                     metadata_length = len(metadata_json)
                     message = struct.pack('<I', metadata_length) + metadata_json.encode('utf-8') + audio_data
+
+                    self.debug_print(f"Sending audio chunk {self.chunks_sent}: {len(audio_data)} bytes, metadata: {metadata_json}")
                     self.data_ws_connected.send(message, opcode=websocket.ABNF.OPCODE_BINARY)
+                    self.last_chunk_time = current_time
+
                 except Exception as e:
-                    self.debug_print(f"Error sending audio data: {e}")
-                    break  # Exit the recording loop
+                    self.debug_print(f"Error sending audio data: {str(e)}")
+                    break
 
         except Exception as e:
-            self.debug_print(f"Error in record_and_send_audio: {e}")
+            self.debug_print(f"Error in record_and_send_audio: {str(e)}")
         finally:
             self.cleanup_audio()
 
     def setup_audio(self):
         try:
+            self.debug_print("Initializing PyAudio interface")
             self.audio_interface = pyaudio.PyAudio()
             self.input_device_index = None
+
             try:
                 default_device = self.audio_interface.get_default_input_device_info()
                 self.input_device_index = default_device['index']
+                self.debug_print(f"Default input device found: {default_device}")
             except OSError as e:
-                self.debug_print(f"No default input device found: {e}")
+                self.debug_print(f"No default input device found: {str(e)}")
                 return False
 
-            self.device_sample_rate = 16000  # Try 16000 Hz first
+            self.device_sample_rate = 16000
+            self.debug_print(f"Attempting to open audio stream with sample rate {self.device_sample_rate} Hz")
 
             try:
                 self.stream = self.audio_interface.open(
@@ -375,29 +457,36 @@ class STTWebSocketClient:
                     frames_per_buffer=CHUNK,
                     input_device_index=self.input_device_index,
                 )
-                self.debug_print(f"Audio recording initialized successfully at {self.device_sample_rate} Hz, device index {self.input_device_index}")
+                self.debug_print(f"Audio stream initialized successfully")
+                self.debug_print(f"Audio parameters: rate={self.device_sample_rate}, channels={CHANNELS}, format={FORMAT}, chunk={CHUNK}")
                 return True
             except Exception as e:
-                self.debug_print(f"Failed to initialize audio stream at {self.device_sample_rate} Hz, device index {self.input_device_index}: {e}")
+                self.debug_print(f"Failed to initialize audio stream: {str(e)}")
                 return False
 
         except Exception as e:
-            self.debug_print(f"Error initializing audio recording: {e}")
+            self.debug_print(f"Error in setup_audio: {str(e)}")
             if self.audio_interface:
                 self.audio_interface.terminate()
             return False
 
     def cleanup_audio(self):
+        self.debug_print("Cleaning up audio resources")
         try:
             if self.stream:
+                self.debug_print("Stopping and closing audio stream")
                 self.stream.stop_stream()
                 self.stream.close()
                 self.stream = None
             if self.audio_interface:
+                self.debug_print("Terminating PyAudio interface")
                 self.audio_interface.terminate()
                 self.audio_interface = None
+            if self.writechunks and self.wav_file:
+                self.debug_print("Closing WAV file")
+                self.wav_file.close()
         except Exception as e:
-            self.debug_print(f"Error cleaning up audio resources: {e}")
+            self.debug_print(f"Error during audio cleanup: {str(e)}")
 
     def set_parameter(self, parameter, value):
         command = {
@@ -428,23 +517,28 @@ class STTWebSocketClient:
         self.command_thread.daemon = False  # Ensure it is not a daemon thread
         self.command_thread.start()
 
+
     def command_processor(self):
-        self.debug_print(f"Starting command processor")
+        self.debug_print("Starting command processor thread")
         while not self.stop_event.is_set():
             try:
                 command = self.commands.get(timeout=0.1)
+                self.debug_print(f"Processing command: {command}")
                 if command['type'] == 'set_parameter':
+                    self.debug_print(f"Setting parameter: {command['parameter']} = {command['value']}")
                     self.set_parameter(command['parameter'], command['value'])
                 elif command['type'] == 'get_parameter':
+                    self.debug_print(f"Getting parameter: {command['parameter']}")
                     self.get_parameter(command['parameter'])
                 elif command['type'] == 'call_method':
+                    self.debug_print(f"Calling method: {command['method']} with args: {command.get('args')} and kwargs: {command.get('kwargs')}")
                     self.call_method(command['method'], command.get('args'), command.get('kwargs'))
             except queue.Empty:
                 continue
             except Exception as e:
-                self.debug_print(f"Error in command processor: {e}")
+                self.debug_print(f"Error in command processor: {str(e)}")
 
-        self.debug_print(f"Leaving command processor")
+        self.debug_print("Command processor thread stopping")
 
     def add_command(self, command):
         self.commands.put(command)
@@ -455,6 +549,7 @@ def main():
     parser.add_argument("--data-url", default=DEFAULT_DATA_URL, help="STT Data WebSocket URL")
     parser.add_argument("--debug", action="store_true", help="Enable debug mode")
     parser.add_argument("-nort", "--norealtime", action="store_true", help="Disable real-time output")
+    parser.add_argument("--writechunks", metavar="FILE", help="Save recorded audio chunks to a WAV file")
     parser.add_argument("--set-param", nargs=2, metavar=('PARAM', 'VALUE'), action='append',
                         help="Set a recorder parameter. Can be used multiple times.")
     parser.add_argument("--call-method", nargs='+', metavar='METHOD', action='append',
@@ -469,7 +564,7 @@ def main():
     else:
         file_output = None
 
-    client = STTWebSocketClient(args.control_url, args.data_url, args.debug, file_output, args.norealtime)
+    client = STTWebSocketClient(args.control_url, args.data_url, args.debug, file_output, args.norealtime, args.writechunks)
 
     def signal_handler(sig, frame):
         client.stop()

+ 111 - 42
RealtimeSTT_server/stt_server.py

@@ -65,18 +65,26 @@ The server supports two WebSocket connections:
 The server will broadcast real-time transcription updates to all connected clients on the data WebSocket.
 """
 
+from .install_packages import check_and_install_packages
+from datetime import datetime
+import logging
+import asyncio
+import pyaudio
+import sys
+
 
-extended_logging = True
+debug_logging = False
+extended_logging = False
 send_recorded_chunk = False
-log_incoming_chunks = True
+log_incoming_chunks = False
 stt_optimizations = False
+writechunks = False#
+wav_file = None
+loglevel = logging.WARNING
 
+FORMAT = pyaudio.paInt16
+CHANNELS = 1
 
-from .install_packages import check_and_install_packages
-from datetime import datetime
-import asyncio
-import base64
-import sys
 
 if sys.platform == 'win32':
     asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
@@ -117,17 +125,25 @@ class bcolors:
 
 print(f"{bcolors.BOLD}{bcolors.OKCYAN}Starting server, please wait...{bcolors.ENDC}")
 
-import threading
-import json
-import websockets
+# Initialize colorama
+from colorama import init, Fore, Style
+init()
+
 from RealtimeSTT import AudioToTextRecorder
-import numpy as np
 from scipy.signal import resample
+import numpy as np
+import websockets
+import threading
+import logging
+import wave
+import json
+import time
 
 global_args = None
 recorder = None
 recorder_config = {}
 recorder_ready = threading.Event()
+recorder_thread = None
 stop_recorder = False
 prev_text = ""
 
@@ -174,6 +190,12 @@ def preprocess_text(text):
     
     return text
 
+def debug_print(message):
+    if debug_logging:
+        timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
+        thread_name = threading.current_thread().name
+        print(f"{Fore.CYAN}[DEBUG][{timestamp}][{thread_name}] {message}{Style.RESET_ALL}", file=sys.stderr)
+
 def text_detected(text, loop):
     global prev_text
 
@@ -279,6 +301,8 @@ def on_transcription_start(loop):
 
 # Define the server's arguments
 def parse_arguments():
+    global debug_logging, extended_logging, loglevel, writechunks, log_incoming_chunks
+
     import argparse
     parser = argparse.ArgumentParser(description='Start the Speech-to-Text (STT) server with various configuration options.')
 
@@ -376,9 +400,27 @@ def parse_arguments():
     parser.add_argument('--use_extended_logging', action='store_true',
                         help='Writes extensive log messages for the recording worker, that processes the audio chunks.')
 
+    parser.add_argument('--debug', action='store_true', help='Enable debug logging for detailed server operations')
+
+    parser.add_argument('--logchunks', action='store_true', help='Enable logging of incoming audio chunks (periods)')
+
+    parser.add_argument("--writechunks", metavar="FILE", help="Save received audio chunks to a WAV file")
+
     # Parse arguments
     args = parser.parse_args()
 
+    debug_logging = args.debug
+    extended_logging = args.use_extended_logging
+    writechunks = args.writechunks
+    log_incoming_chunks = args.logchunks
+
+    if debug_logging:
+        loglevel = logging.DEBUG
+        logging.basicConfig(level=loglevel, format='[%(asctime)s] %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
+    else:
+        loglevel = logging.WARNING
+
+
     # Replace escaped newlines with actual newlines in initial_prompt
     if args.initial_prompt:
         args.initial_prompt = args.initial_prompt.replace("\\n", "\n")
@@ -437,11 +479,13 @@ def decode_and_resample(
     return resampled_audio.astype(np.int16).tobytes()
 
 async def control_handler(websocket, path):
+    debug_print(f"New control connection from {websocket.remote_address}")
     print(f"{bcolors.OKGREEN}Control client connected{bcolors.ENDC}")
     global recorder
     control_connections.add(websocket)
     try:
         async for message in websocket:
+            debug_print(f"Received control message: {message[:200]}...")
             if not recorder_ready.is_set():
                 print(f"{bcolors.WARNING}Recorder not ready{bcolors.ENDC}")
                 continue
@@ -530,21 +574,38 @@ async def control_handler(websocket, path):
         control_connections.remove(websocket)
 
 async def data_handler(websocket, path):
+    global writechunks, wav_file
     print(f"{bcolors.OKGREEN}Data client connected{bcolors.ENDC}")
     data_connections.add(websocket)
     try:
         while True:
             message = await websocket.recv()
             if isinstance(message, bytes):
-                if log_incoming_chunks:
+                if debug_logging:
+                    debug_print(f"Received audio chunk (size: {len(message)} bytes)")
+                elif log_incoming_chunks:
                     print(".", end='', flush=True)
                 # Handle binary message (audio data)
                 metadata_length = int.from_bytes(message[:4], byteorder='little')
                 metadata_json = message[4:4+metadata_length].decode('utf-8')
                 metadata = json.loads(metadata_json)
                 sample_rate = metadata['sampleRate']
+
+                debug_print(f"Processing audio chunk with sample rate {sample_rate}")
                 chunk = message[4+metadata_length:]
+
+                if writechunks:
+                    if not wav_file:
+                        wav_file = wave.open(writechunks, 'wb')
+                        wav_file.setnchannels(CHANNELS)
+                        wav_file.setsampwidth(pyaudio.get_sample_size(FORMAT))
+                        wav_file.setframerate(sample_rate)
+
+                    wav_file.writeframes(chunk)
+
                 resampled_chunk = decode_and_resample(chunk, sample_rate, 16000)
+
+                debug_print(f"Resampled chunk size: {len(resampled_chunk)} bytes")
                 recorder.feed_audio(resampled_chunk)
             else:
                 print(f"{bcolors.WARNING}Received non-binary message on data connection{bcolors.ENDC}")
@@ -622,47 +683,55 @@ async def main_async():
         # 'on_recorded_chunk': make_callback(loop, on_recorded_chunk),
         'no_log_file': True,  # Disable logging to file
         'use_extended_logging': args.use_extended_logging,
+        'level': loglevel,
     }
 
-    control_server = await websockets.serve(control_handler, "localhost", args.control_port)
-    data_server = await websockets.serve(data_handler, "localhost", args.data_port)
-    print(f"{bcolors.OKGREEN}Control server started on {bcolors.OKBLUE}ws://localhost:{args.control_port}{bcolors.ENDC}")
-    print(f"{bcolors.OKGREEN}Data server started on {bcolors.OKBLUE}ws://localhost:{args.data_port}{bcolors.ENDC}")
+    try:
+        # Attempt to start control and data servers
+        control_server = await websockets.serve(control_handler, "localhost", args.control_port)
+        data_server = await websockets.serve(data_handler, "localhost", args.data_port)
+        print(f"{bcolors.OKGREEN}Control server started on {bcolors.OKBLUE}ws://localhost:{args.control_port}{bcolors.ENDC}")
+        print(f"{bcolors.OKGREEN}Data server started on {bcolors.OKBLUE}ws://localhost:{args.data_port}{bcolors.ENDC}")
 
-    # Task to broadcast audio messages
-    broadcast_task = asyncio.create_task(broadcast_audio_messages())
+        # Start the broadcast and recorder threads
+        broadcast_task = asyncio.create_task(broadcast_audio_messages())
 
-    recorder_thread = threading.Thread(target=_recorder_thread, args=(loop,))
-    recorder_thread.start()
-    recorder_ready.wait()
+        recorder_thread = threading.Thread(target=_recorder_thread, args=(loop,))
+        recorder_thread.start()
+        recorder_ready.wait()
 
-    print(f"{bcolors.OKGREEN}Server started. Press Ctrl+C to stop the server.{bcolors.ENDC}")
+        print(f"{bcolors.OKGREEN}Server started. Press Ctrl+C to stop the server.{bcolors.ENDC}")
 
-    try:
+        # Run server tasks
         await asyncio.gather(control_server.wait_closed(), data_server.wait_closed(), broadcast_task)
+    except OSError as e:
+        print(f"{bcolors.FAIL}Error: Could not start server on specified ports. It’s possible another instance of the server is already running, or the ports are being used by another application.{bcolors.ENDC}")
     except KeyboardInterrupt:
-        print(f"{bcolors.WARNING}{bcolors.BOLD}Shutting down gracefully...{bcolors.ENDC}")
+        print(f"{bcolors.WARNING}Server interrupted by user, shutting down...{bcolors.ENDC}")
     finally:
-        # Shut down the recorder
-        if recorder:
-            stop_recorder = True
-            recorder.abort()
-            recorder.stop()
-            recorder.shutdown()
-            print(f"{bcolors.OKGREEN}Recorder shut down{bcolors.ENDC}")
-
+        # Shutdown procedures for recorder and server threads
+        await shutdown_procedure()
+        print(f"{bcolors.OKGREEN}Server shutdown complete.{bcolors.ENDC}")
+
+async def shutdown_procedure():
+    global stop_recorder, recorder_thread
+    if recorder:
+        stop_recorder = True
+        recorder.abort()
+        recorder.stop()
+        recorder.shutdown()
+        print(f"{bcolors.OKGREEN}Recorder shut down{bcolors.ENDC}")
+
+        if recorder_thread:
             recorder_thread.join()
             print(f"{bcolors.OKGREEN}Recorder thread finished{bcolors.ENDC}")
-        
-        # Cancel all active tasks in the event loop
-        tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
-        for task in tasks:
-            task.cancel()
-        
-        # Run pending tasks and handle cancellation
-        await asyncio.gather(*tasks, return_exceptions=True)
-
-        print(f"{bcolors.OKGREEN}All tasks cancelled, closing event loop now.{bcolors.ENDC}")
+
+    tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
+    for task in tasks:
+        task.cancel()
+    await asyncio.gather(*tasks, return_exceptions=True)
+
+    print(f"{bcolors.OKGREEN}All tasks cancelled, closing event loop now.{bcolors.ENDC}")
 
 def main():
     try: