浏览代码

added transcribe while speaking functionality

Kolja Beigel 1 年之前
父节点
当前提交
dc11e6f048
共有 3 个文件被更改,包括 112 次插入4 次删除
  1. 11 0
      README.md
  2. 75 4
      RealtimeSTT/audio_recorder.py
  3. 26 0
      tests/realtime_test.py

+ 11 - 0
README.md

@@ -171,6 +171,17 @@ When you initialize the `AudioToTextRecorder` class, you have various options to
 
 
 - **level** (int, default=logging.WARNING): Logging level.
 - **level** (int, default=logging.WARNING): Logging level.
 
 
+#### Real-time Transcription Parameters
+
+- **realtime_preview** (bool, default=False): Specifies whether transcription should occur in real-time. If set to True, the audio will also be transcribed as it is recorded.
+
+- **realtime_preview_model** (str, default="tiny"): Specifies the size or path of the machine learning model to be used for real-time transcription.
+    - Valid options: 'tiny', 'tiny.en', 'base', 'base.en', 'small', 'small.en', 'medium', 'medium.en', 'large-v1', 'large-v2'.
+
+- **realtime_preview_resolution** (float, default=0.1): Specifies the time interval in seconds after a chunk of audio gets transcribed. Lower values will result in more "real-time" (frequent) transcription updates but may increase computational load.
+
+- **on_realtime_preview**: A callable function triggered during real-time transcription. The function is invoked with the transcribed text as its argument.
+
 #### Voice Activation Parameters
 #### Voice Activation Parameters
 
 
 - **silero_sensitivity** (float, default=0.6): Sensitivity for Silero's voice activity detection ranging from 0 (least sensitive) to 1 (most sensitive). Default is 0.6.
 - **silero_sensitivity** (float, default=0.6): Sensitivity for Silero's voice activity detection ranging from 0 (least sensitive) to 1 (most sensitive). Default is 0.6.

+ 75 - 4
RealtimeSTT/audio_recorder.py

@@ -38,7 +38,7 @@ from halo import Halo
 
 
 SAMPLE_RATE = 16000
 SAMPLE_RATE = 16000
 BUFFER_SIZE = 512
 BUFFER_SIZE = 512
-SILERO_SENSITIVITY = 0.6
+SILERO_SENSITIVITY = 0.4
 WEBRTC_SENSITIVITY = 3
 WEBRTC_SENSITIVITY = 3
 WAKE_WORDS_SENSITIVITY = 0.6
 WAKE_WORDS_SENSITIVITY = 0.6
 TIME_SLEEP = 0.02
 TIME_SLEEP = 0.02
@@ -57,6 +57,12 @@ class AudioToTextRecorder:
                  spinner = True,
                  spinner = True,
                  level=logging.WARNING,
                  level=logging.WARNING,
 
 
+                 # Realtime transcription parameters
+                 realtime_preview = False,
+                 realtime_preview_model = "tiny",
+                 realtime_preview_resolution = 0.1,
+                 on_realtime_preview = None,
+
                  # Voice activation parameters
                  # Voice activation parameters
                  silero_sensitivity: float = SILERO_SENSITIVITY,
                  silero_sensitivity: float = SILERO_SENSITIVITY,
                  webrtc_sensitivity: int = WEBRTC_SENSITIVITY,
                  webrtc_sensitivity: int = WEBRTC_SENSITIVITY,
@@ -87,9 +93,13 @@ class AudioToTextRecorder:
         - language (str, default=""): Language code for speech-to-text engine. If not specified, the model will attempt to detect the language automatically.
         - language (str, default=""): Language code for speech-to-text engine. If not specified, the model will attempt to detect the language automatically.
         - on_recording_start (callable, default=None): Callback function to be called when recording of audio to be transcripted starts.
         - on_recording_start (callable, default=None): Callback function to be called when recording of audio to be transcripted starts.
         - on_recording_stop (callable, default=None): Callback function to be called when recording of audio to be transcripted stops.
         - on_recording_stop (callable, default=None): Callback function to be called when recording of audio to be transcripted stops.
-        - on_transcription_start (callable, default=None): Callback function to be called when transcription of audio to text starts.       
+        - on_transcription_start (callable, default=None): Callback function to be called when transcription of audio to text starts.  
         - spinner (bool, default=True): Show spinner animation with current state.
         - spinner (bool, default=True): Show spinner animation with current state.
         - level (int, default=logging.WARNING): Logging level.
         - level (int, default=logging.WARNING): Logging level.
+        - realtime_preview (bool, default=False): Specifies whether a preview of the transcription should occur in real-time. If set to True, the audio will be transcribed as it is recorded.
+        - realtime_preview_model (str, default="tiny"): Specifies the size or path of the machine learning model to be used for real-time transcription. Valid options include 'tiny', 'tiny.en', 'base', 'base.en', 'small', 'small.en', 'medium', 'medium.en', 'large-v1', 'large-v2'.
+        - realtime_preview_resolution (float, default=0.1): Specifies the time interval in seconds after a chunk of audio gets transcribed. Lower values will result in more "real-time" (frequent) transcription updates but may increase computational load.
+        - on_realtime_preview = A callable function triggered during real-time transcription. The function is invoked with the transcribed text as its argument.
         - silero_sensitivity (float, default=SILERO_SENSITIVITY): Sensitivity for the Silero Voice Activity Detection model ranging from 0 (least sensitive) to 1 (most sensitive). Default is 0.5.
         - silero_sensitivity (float, default=SILERO_SENSITIVITY): Sensitivity for the Silero Voice Activity Detection model ranging from 0 (least sensitive) to 1 (most sensitive). Default is 0.5.
         - webrtc_sensitivity (int, default=WEBRTC_SENSITIVITY): Sensitivity for the WebRTC Voice Activity Detection engine ranging from 1 (least sensitive) to 3 (most sensitive). Default is 3.
         - webrtc_sensitivity (int, default=WEBRTC_SENSITIVITY): Sensitivity for the WebRTC Voice Activity Detection engine ranging from 1 (least sensitive) to 3 (most sensitive). Default is 3.
         - post_speech_silence_duration (float, default=0.2): Duration in seconds of silence that must follow speech before the recording is considered to be completed. This ensures that any brief pauses during speech don't prematurely end the recording.
         - post_speech_silence_duration (float, default=0.2): Duration in seconds of silence that must follow speech before the recording is considered to be completed. This ensures that any brief pauses during speech don't prematurely end the recording.
@@ -129,6 +139,10 @@ class AudioToTextRecorder:
         self.on_wakeword_detection_start = on_wakeword_detection_start
         self.on_wakeword_detection_start = on_wakeword_detection_start
         self.on_wakeword_detection_end = on_wakeword_detection_end
         self.on_wakeword_detection_end = on_wakeword_detection_end
         self.on_transcription_start = on_transcription_start
         self.on_transcription_start = on_transcription_start
+        self.realtime_preview = realtime_preview
+        self.realtime_preview_model = realtime_preview_model
+        self.realtime_preview_resolution = realtime_preview_resolution
+        self.on_realtime_preview = on_realtime_preview
     
     
         self.level = level
         self.level = level
         self.buffer_size = BUFFER_SIZE
         self.buffer_size = BUFFER_SIZE
@@ -153,6 +167,10 @@ class AudioToTextRecorder:
         try:
         try:
             self.model = faster_whisper.WhisperModel(model_size_or_path=model, device='cuda' if torch.cuda.is_available() else 'cpu')
             self.model = faster_whisper.WhisperModel(model_size_or_path=model, device='cuda' if torch.cuda.is_available() else 'cpu')
 
 
+            if self.realtime_preview:
+                self.realtime_preview_model = faster_whisper.WhisperModel(model_size_or_path=self.realtime_preview_model, device='cuda' if torch.cuda.is_available() else 'cpu')
+
+
         except Exception as e:
         except Exception as e:
             logging.exception(f"Error initializing faster_whisper transcription model: {e}")
             logging.exception(f"Error initializing faster_whisper transcription model: {e}")
             raise            
             raise            
@@ -220,6 +238,11 @@ class AudioToTextRecorder:
         self.recording_thread.daemon = True
         self.recording_thread.daemon = True
         self.recording_thread.start()
         self.recording_thread.start()
 
 
+        # Start the realtime transcription worker thread
+        self.realtime_thread = threading.Thread(target=self._realtime_worker)
+        self.realtime_thread.daemon = True
+        self.realtime_thread.start()
+
 
 
     def text(self):
     def text(self):
         """
         """
@@ -601,6 +624,7 @@ class AudioToTextRecorder:
 
 
                             # Add the buffered audio to the recording frames
                             # Add the buffered audio to the recording frames
                             self.frames.extend(list(self.audio_buffer))
                             self.frames.extend(list(self.audio_buffer))
+                            self.audio_buffer.clear()
 
 
                         self.silero_vad_model.reset_states()
                         self.silero_vad_model.reset_states()
 
 
@@ -641,13 +665,60 @@ class AudioToTextRecorder:
 
 
             if self.is_recording:
             if self.is_recording:
                 self.frames.append(data)
                 self.frames.append(data)
-
-            self.audio_buffer.append(data)	
+            else:
+                self.audio_buffer.append(data)	
 
 
             was_recording = self.is_recording
             was_recording = self.is_recording
             time.sleep(TIME_SLEEP)
             time.sleep(TIME_SLEEP)
 
 
 
 
+
+    def _realtime_worker(self):
+        """
+        Performs real-time transcription if the feature is enabled.
+
+        The method is responsible transcribing recorded audio frames in real-time
+         based on the specified resolution interval.
+        The transcribed text is stored in `self.realtime_preview_text` and a callback
+        function is invoked with this text if specified.
+        """
+
+        # Return immediately if real-time transcription is not enabled
+        if not self.realtime_preview:
+            return
+        
+        # Continue running as long as the main process is active
+        while self.is_running:
+
+            # Check if the recording is active
+            if self.is_recording:
+                
+                # Sleep for the duration of the transcription resolution
+                time.sleep(self.realtime_preview_resolution)
+                
+                # Convert the buffer frames to a NumPy array
+                audio_array = np.frombuffer(b''.join(self.frames), dtype=np.int16)
+                
+                # Normalize the array to a [-1, 1] range
+                audio_array = audio_array.astype(np.float32) / 32768.0
+
+                # Perform transcription and assemble the text
+                segments = self.realtime_preview_model.transcribe(
+                    audio_array,
+                    language=self.language if self.language else None
+                )
+                self.realtime_preview_text = " ".join(seg.text for seg in segments[0]).strip()
+
+                # Invoke the callback with the transcribed text
+                if self.is_recording:
+                    if self.on_realtime_preview:
+                        self.on_realtime_preview(self.realtime_preview_text)
+
+            # If not recording, sleep briefly before checking again
+            else:
+                time.sleep(0.1)
+
+
     def __del__(self):
     def __del__(self):
         """
         """
         Destructor method ensures safe shutdown of the recorder when the instance is destroyed.
         Destructor method ensures safe shutdown of the recorder when the instance is destroyed.

+ 26 - 0
tests/realtime_test.py

@@ -0,0 +1,26 @@
+import os
+from RealtimeSTT import AudioToTextRecorder
+
+detected_text = ""
+displayed_text = ""
+
+def clear_console():
+    if os.name == 'posix':  # For UNIX or macOS
+        os.system('clear')
+    elif os.name == 'nt':  # For Windows
+        os.system('cls')
+
+def text_detected(text):
+    global displayed_text
+    if detected_text + text != displayed_text:
+        displayed_text = detected_text + text
+        clear_console()
+        print(displayed_text)
+
+recorder = AudioToTextRecorder(spinner=False, model="large-v2", language="en", silero_sensitivity=0.2, post_speech_silence_duration=0.4, min_length_of_recording=0.5, min_gap_between_recordings=0.05, realtime_preview_resolution = 0.05, realtime_preview = True, realtime_preview_model = "tiny", on_realtime_preview=text_detected)
+
+print("Say something...")
+
+while (True): 
+    detected_text += recorder.text() + " "
+    text_detected("")