Procházet zdrojové kódy

added clear_audio_queue method

KoljaB před 7 měsíci
rodič
revize
d7fbcc3933
1 změnil soubory, kde provedl 215 přidání a 85 odebrání
  1. 215 85
      RealtimeSTT/audio_recorder.py

+ 215 - 85
RealtimeSTT/audio_recorder.py

@@ -43,10 +43,12 @@ import traceback
 import threading
 import webrtcvad
 import itertools
+import datetime
 import platform
 import pyaudio
 import logging
 import struct
+import queue
 import halo
 import time
 import copy
@@ -374,6 +376,7 @@ class AudioToTextRecorder:
         self.on_transcription_start = on_transcription_start
         self.enable_realtime_transcription = enable_realtime_transcription
         self.use_main_model_for_realtime = use_main_model_for_realtime
+        self.main_model_type = model
         self.realtime_model_type = realtime_model_type
         self.realtime_processing_pause = realtime_processing_pause
         self.on_realtime_transcription_update = (
@@ -463,6 +466,7 @@ class AudioToTextRecorder:
         self.was_interrupted = mp.Event()
         self.main_transcription_ready_event = mp.Event()
         self.parent_transcription_pipe, child_transcription_pipe = mp.Pipe()
+        self.parent_stdout_pipe, child_stdout_pipe = mp.Pipe()
 
         # Set device for model
         self.device = "cuda" if self.device == "cuda" and torch.cuda.is_available() else "cpu"
@@ -471,6 +475,7 @@ class AudioToTextRecorder:
             target=AudioToTextRecorder._transcription_worker,
             args=(
                 child_transcription_pipe,
+                child_stdout_pipe,
                 model,
                 self.compute_type,
                 self.gpu_device_index,
@@ -670,6 +675,10 @@ class AudioToTextRecorder:
         self.main_transcription_ready_event.wait()
         logging.debug('Main transcription model ready')
 
+        self.stdout_thread = threading.Thread(target=self._read_stdout)
+        self.stdout_thread.daemon = True
+        self.stdout_thread.start()
+
         logging.debug('RealtimeSTT initialization completed successfully')
                    
     def _start_thread(self, target=None, args=()):
@@ -695,8 +704,15 @@ class AudioToTextRecorder:
             thread.start()
             return thread
 
+    def _read_stdout(self):
+        while not self.shutdown_event.is_set():
+            if self.parent_stdout_pipe.poll(0.1):
+                message = self.parent_stdout_pipe.recv()
+                print(message, flush=True)
+
     @staticmethod
     def _transcription_worker(conn,
+                              stdout_pipe,
                               model_path,
                               compute_type,
                               gpu_device_index,
@@ -744,6 +760,12 @@ class AudioToTextRecorder:
             Exception: If there is an error while initializing the
             transcription model.
         """
+        def custom_print(*args, **kwargs):
+            message = ' '.join(map(str, args))
+            stdout_pipe.send(message)
+
+        # Replace the built-in print function with our custom one
+        __builtins__['print'] = custom_print
 
         logging.info("Initializing faster_whisper "
                      f"main transcription model {model_path}"
@@ -771,7 +793,7 @@ class AudioToTextRecorder:
 
         while not shutdown_event.is_set():
             try:
-                if conn.poll(0.5):
+                if conn.poll(0.01):
                     audio, language = conn.recv()
                     try:
                         segments, info = model.transcribe(
@@ -788,8 +810,8 @@ class AudioToTextRecorder:
                         logging.error(f"General transcription error: {e}")
                         conn.send(('error', str(e)))
                 else:
-                    # If there's no data, sleep / prevent busy waiting
-                    time.sleep(0.02)
+                    time.sleep(TIME_SLEEP)
+
             except KeyboardInterrupt:
                 interrupt_stop_event.set()
                 logging.debug("Transcription worker process "
@@ -799,75 +821,155 @@ class AudioToTextRecorder:
 
     @staticmethod
     def _audio_data_worker(audio_queue,
-                           sample_rate,
-                           buffer_size,
-                           input_device_index,
-                           shutdown_event,
-                           interrupt_stop_event,
-                           use_microphone):
+                        target_sample_rate,
+                        buffer_size,
+                        input_device_index,
+                        shutdown_event,
+                        interrupt_stop_event,
+                        use_microphone):
         """
         Worker method that handles the audio recording process.
 
         This method runs in a separate process and is responsible for:
-        - Setting up the audio input stream for recording.
-        - Continuously reading audio data from the input stream
-          and placing it in a queue.
-        - Handling errors during the recording process, including
-          input overflow.
-        - Gracefully terminating the recording process when a shutdown
-          event is set.
+        - Setting up the audio input stream for recording at the highest possible sample rate.
+        - Continuously reading audio data from the input stream, resampling if necessary,
+        preprocessing the data, and placing complete chunks in a queue.
+        - Handling errors during the recording process.
+        - Gracefully terminating the recording process when a shutdown event is set.
 
         Args:
-            audio_queue (queue.Queue): A queue where recorded audio
-              data is placed.
-            sample_rate (int): The sample rate of the audio input stream.
-            buffer_size (int): The size of the buffer used in the audio
-              input stream.
-            input_device_index (int): The index of the audio input device
-            shutdown_event (threading.Event): An event that, when set, signals
-              this worker method to terminate.
+            audio_queue (queue.Queue): A queue where recorded audio data is placed.
+            target_sample_rate (int): The desired sample rate for the output audio (for Silero VAD).
+            buffer_size (int): The number of samples expected by the Silero VAD model.
+            input_device_index (int): The index of the audio input device.
+            shutdown_event (threading.Event): An event that, when set, signals this worker method to terminate.
+            interrupt_stop_event (threading.Event): An event to signal keyboard interrupt.
+            use_microphone (multiprocessing.Value): A shared value indicating whether to use the microphone.
 
         Raises:
-            Exception: If there is an error while initializing the audio
-              recording.
+            Exception: If there is an error while initializing the audio recording.
         """
+        import pyaudio
+        import numpy as np
+        from scipy import signal
+
+        def get_highest_sample_rate(audio_interface, device_index):
+            """Get the highest supported sample rate for the specified device."""
+            try:
+                device_info = audio_interface.get_device_info_by_index(device_index)
+                max_rate = int(device_info['defaultSampleRate'])
+                
+                if 'supportedSampleRates' in device_info:
+                    supported_rates = [int(rate) for rate in device_info['supportedSampleRates']]
+                    if supported_rates:
+                        max_rate = max(supported_rates)
+                
+                return max_rate
+            except Exception as e:
+                logging.warning(f"Failed to get highest sample rate: {e}")
+                return 48000  # Fallback to a common high sample rate
+
+        def initialize_audio_stream(audio_interface, device_index, sample_rate, chunk_size):
+            """Initialize the audio stream with error handling."""
+            try:
+                stream = audio_interface.open(
+                    format=pyaudio.paInt16,
+                    channels=1,
+                    rate=sample_rate,
+                    input=True,
+                    frames_per_buffer=chunk_size,
+                    input_device_index=device_index,
+                )
+                return stream
+            except Exception as e:
+                logging.error(f"Error initializing audio stream: {e}")
+                raise
+
+        def preprocess_audio(chunk, original_sample_rate, target_sample_rate):
+            """Preprocess audio chunk similar to feed_audio method."""
+            if isinstance(chunk, np.ndarray):
+                # Handle stereo to mono conversion if necessary
+                if chunk.ndim == 2:
+                    chunk = np.mean(chunk, axis=1)
+
+                # Resample to target_sample_rate if necessary
+                if original_sample_rate != target_sample_rate:
+                    num_samples = int(len(chunk) * target_sample_rate / original_sample_rate)
+                    chunk = signal.resample(chunk, num_samples)
+
+                # Ensure data type is int16
+                chunk = chunk.astype(np.int16)
+            else:
+                # If chunk is bytes, convert to numpy array
+                chunk = np.frombuffer(chunk, dtype=np.int16)
+
+                # Resample if necessary
+                if original_sample_rate != target_sample_rate:
+                    num_samples = int(len(chunk) * target_sample_rate / original_sample_rate)
+                    chunk = signal.resample(chunk, num_samples)
+                    chunk = chunk.astype(np.int16)
+
+            return chunk.tobytes()
+
+        audio_interface = None
+        stream = None
+        device_sample_rate = None
+        chunk_size = 1024  # Increased chunk size for better performance
+
         try:
             audio_interface = pyaudio.PyAudio()
             if input_device_index is None:
-                default_device = audio_interface.get_default_input_device_info()
-                input_device_index = default_device['index']
-            stream = audio_interface.open(
-                rate=sample_rate,
-                format=pyaudio.paInt16,
-                channels=1,
-                input=True,
-                frames_per_buffer=buffer_size,
-                input_device_index=input_device_index,
-                )
+                try:
+                    default_device = audio_interface.get_default_input_device_info()
+                    input_device_index = default_device['index']
+                except OSError as e:
+                    input_device_index = None
+
+
+            if input_device_index is not None:
+                device_sample_rate = get_highest_sample_rate(audio_interface, input_device_index)
+            else:
+                device_sample_rate = 16000  # better: try 16000, 48000, ... until it works
+
+            stream = initialize_audio_stream(audio_interface, input_device_index, device_sample_rate, chunk_size)
+
+            if stream is None:
+                raise Exception("Failed to initialize audio stream.")
 
         except Exception as e:
-            logging.exception("Error initializing pyaudio "
-                              f"audio recording: {e}"
-                              )
+            logging.exception(f"Error initializing pyaudio audio recording: {e}")
+            if audio_interface:
+                audio_interface.terminate()
             raise
 
-        logging.debug("Audio recording (pyAudio input "
-                      "stream) initialized successfully"
-                      )
+        logging.debug(f"Audio recording initialized successfully at {device_sample_rate} Hz, reading {chunk_size} frames at a time")
+
+        buffer = bytearray()
+        silero_buffer_size = 2 * buffer_size  # silero complains if too short
 
         try:
             while not shutdown_event.is_set():
                 try:
-                    data = stream.read(buffer_size)
+                    data = stream.read(chunk_size)
+                    
+                    if use_microphone.value:
+                        processed_data = preprocess_audio(data, device_sample_rate, target_sample_rate)
+                        buffer += processed_data
+
+                        # Check if the buffer has reached or exceeded the silero_buffer_size
+                        while len(buffer) >= silero_buffer_size:
+                            # Extract silero_buffer_size amount of data from the buffer
+                            to_process = buffer[:silero_buffer_size]
+                            buffer = buffer[silero_buffer_size:]
+
+                            # Feed the extracted data to the audio_queue
+                            audio_queue.put(to_process)
 
                 except OSError as e:
                     if e.errno == pyaudio.paInputOverflowed:
                         logging.warning("Input overflowed. Frame dropped.")
                     else:
                         logging.error(f"Error during recording: {e}")
-                    tb_str = traceback.format_exc()
-                    print(f"Traceback: {tb_str}")
-                    print(f"Error: {e}")
                     continue
 
                 except Exception as e:
@@ -877,18 +979,19 @@ class AudioToTextRecorder:
                     print(f"Error: {e}")
                     continue
 
-                if use_microphone.value:
-                    audio_queue.put(data)
-
         except KeyboardInterrupt:
             interrupt_stop_event.set()
-            logging.debug("Audio data worker process "
-                          "finished due to KeyboardInterrupt"
-                          )
+            logging.debug("Audio data worker process finished due to KeyboardInterrupt")
         finally:
-            stream.stop_stream()
-            stream.close()
-            audio_interface.terminate()
+            # After recording stops, feed any remaining audio data
+            if buffer:
+                audio_queue.put(bytes(buffer))
+            
+            if stream:
+                stream.stop_stream()
+                stream.close()
+            if audio_interface:
+                audio_interface.terminate()
 
     def wakeup(self):
         """
@@ -927,6 +1030,7 @@ class AudioToTextRecorder:
             self.start_recording_on_voice_activity = True
 
             # Wait until recording starts
+            logging.debug('Waiting for recording start')
             while not self.interrupt_stop_event.is_set():
                 if self.start_recording_event.wait(timeout=0.02):
                     break
@@ -937,6 +1041,7 @@ class AudioToTextRecorder:
             self.stop_recording_on_voice_deactivity = True
 
             # Wait until recording stops
+            logging.debug('Waiting for recording stop')
             while not self.interrupt_stop_event.is_set():
                 if (self.stop_recording_event.wait(timeout=0.02)):
                     break
@@ -979,20 +1084,22 @@ class AudioToTextRecorder:
         """
         self._set_state("transcribing")
         audio_copy = copy.deepcopy(self.audio)
-
-
+        start_time = time.time()  # Start timing
         with self.transcription_lock:
             try:
                 self.parent_transcription_pipe.send((self.audio, self.language))
                 status, result = self.parent_transcription_pipe.recv()
-
                 self._set_state("inactive")
                 if status == 'success':
                     segments, info = result
                     self.detected_language = info.language if info.language_probability > 0 else None
                     self.detected_language_probability = info.language_probability
                     self.last_transcription_bytes = audio_copy
-                    return self._preprocess_output(segments)
+                    transcription = self._preprocess_output(segments)
+                    end_time = time.time()  # End timing
+                    transcription_time = end_time - start_time
+                    # print(f"Model {self.main_model_type} completed transcription in {transcription_time:.2f} seconds")
+                    return transcription
                 else:
                     logging.error(f"Transcription error: {result}")
                     raise Exception(result)
@@ -1252,13 +1359,20 @@ class AudioToTextRecorder:
         try:
             was_recording = False
             delay_was_passed = False
+            wakeword_detected_time = None
+            wakeword_samples_to_remove = None
 
             # Continuously monitor audio for voice activity
             while self.is_running:
 
                 try:
+                    try:
+                        data = self.audio_queue.get(timeout=0.1)
+                    except queue.Empty:
+                        if not self.is_running:
+                            break
+                        continue
 
-                    data = self.audio_queue.get()
                     if self.on_recorded_chunk:
                         self.on_recorded_chunk(data)
 
@@ -1329,25 +1443,9 @@ class AudioToTextRecorder:
 
                         # If a wake word is detected                        
                         if wakeword_index >= 0:
-
-                            # Removing the wake word from the recording
-                            samples_time = int(self.sample_rate * self.wake_word_buffer_duration)
-                            start_index = max(
-                                0,
-                                len(self.audio_buffer) - samples_time
-                                )
-                            temp_samples = collections.deque(
-                                itertools.islice(
-                                    self.audio_buffer,
-                                    start_index,
-                                    None)
-                                )
-                            self.audio_buffer.clear()
-                            self.audio_buffer.extend(temp_samples)
-
-                            self.wake_word_detect_time = time.time()
+                            wakeword_detected_time = time.time()
+                            wakeword_samples_to_remove = int(self.sample_rate * self.wake_word_buffer_duration)
                             self.wakeword_detected = True
-                            #self.wake_word_cooldown_time = time.time()
                             if self.on_wakeword_detected:
                                 self.on_wakeword_detected()
 
@@ -1363,13 +1461,12 @@ class AudioToTextRecorder:
 
                             self.start()
 
-                            if self.is_recording:
-                                self.start_recording_on_voice_activity = False
+                            self.start_recording_on_voice_activity = False
 
-                                # Add the buffered audio
-                                # to the recording frames
-                                self.frames.extend(list(self.audio_buffer))
-                                self.audio_buffer.clear()
+                            # Add the buffered audio
+                            # to the recording frames
+                            self.frames.extend(list(self.audio_buffer))
+                            self.audio_buffer.clear()
 
                             self.silero_vad_model.reset_states()
                         else:
@@ -1380,6 +1477,22 @@ class AudioToTextRecorder:
 
                 else:
                     # If we are currently recording
+                    if wakeword_samples_to_remove and wakeword_samples_to_remove > 0:
+                        # Remove samples from the beginning of self.frames
+                        samples_removed = 0
+                        while wakeword_samples_to_remove > 0 and self.frames:
+                            frame = self.frames[0]
+                            frame_samples = len(frame) // 2  # Assuming 16-bit audio
+                            if wakeword_samples_to_remove >= frame_samples:
+                                self.frames.pop(0)
+                                samples_removed += frame_samples
+                                wakeword_samples_to_remove -= frame_samples
+                            else:
+                                self.frames[0] = frame[wakeword_samples_to_remove * 2:]
+                                samples_removed += wakeword_samples_to_remove
+                                samples_to_remove = 0
+                        
+                        wakeword_samples_to_remove = 0
 
                     # Stop the recording if silence is detected after speech
                     if self.stop_recording_on_voice_deactivity:
@@ -1398,9 +1511,10 @@ class AudioToTextRecorder:
 
                         # Wait for silence to stop recording after speech
                         if self.speech_end_silence_start and time.time() - \
-                                self.speech_end_silence_start > \
+                                self.speech_end_silence_start >= \
                                 self.post_speech_silence_duration:
                             logging.info("voice deactivity detected")
+                            self.frames.append(data)
                             self.stop()
 
                 if not self.is_recording and was_recording:
@@ -1433,6 +1547,8 @@ class AudioToTextRecorder:
                 logging.error(f"Unhandled exeption in _recording_worker: {e}")
                 raise
 
+
+
     def _realtime_worker(self):
         """
         Performs real-time transcription if the feature is enabled.
@@ -1683,6 +1799,20 @@ class AudioToTextRecorder:
                     target=self._is_silero_speech,
                     args=(data,)).start()
 
+    def clear_audio_queue(self):
+        """
+        Safely empties the audio queue to ensure no remaining audio 
+        fragments get processed e.g. after waking up the recorder.
+        """
+        self.audio_buffer.clear()
+        try:
+            while True:
+                self.audio_queue.get_nowait()
+        except:
+            # PyTorch's mp.Queue doesn't have a specific Empty exception
+            # so we catch any exception that might occur when the queue is empty
+            pass
+
     def _is_voice_active(self):
         """
         Determine if voice is active.