Browse Source

added initial_prompt and suppress_tokens

Kolja Beigel 1 year ago
parent
commit
11b27a1134
2 changed files with 93 additions and 23 deletions
  1. 18 3
      README.md
  2. 75 20
      RealtimeSTT/audio_recorder.py

+ 18 - 3
README.md

@@ -75,6 +75,8 @@ To use RealtimeSTT with GPU support via CUDA please follow these steps:
 
 3. **Install ffmpeg**:
 
+    > **Note**: *Installation of ffmpeg might not actually be needed to operate RealtimeSTT* <sup> *thanks to jgilbert2017 for pointing this out</sup>
+
     You can download an installer for your OS from the [ffmpeg Website](https://ffmpeg.org/download.html).  
     
     Or use a package manager:
@@ -262,6 +264,18 @@ When you initialize the `AudioToTextRecorder` class, you have various options to
 
 - **level** (int, default=logging.WARNING): Logging level.
 
+- **handle_buffer_overflow** (bool, default=True): If set, the system will log a warning when an input overflow occurs during recording and remove the data from the buffer.
+
+- **beam_size** (int, default=5): The beam size to use for beam search decoding.
+
+- **initial_prompt** (str or iterable of int, default=None): Initial prompt to be fed to the transcription models.
+
+- **suppress_tokens** (list of int, default=[-1]): Tokens to be suppressed from the transcription output.
+
+- **on_recorded_chunk**: A callback function that is triggered when a chunk of audio is recorded. Submits the chunk data as parameter.
+
+- **debug_mode** (bool, default=False): If set, the system prints additional debug information to the console.
+
 #### Real-time Transcription Parameters
 
 > **Note**: *When enabling realtime description a GPU installation is strongly advised. Using realtime transcription may create high GPU loads.*
@@ -277,14 +291,16 @@ When you initialize the `AudioToTextRecorder` class, you have various options to
 
 - **on_realtime_transcription_stabilized**: A callback function that is triggered whenever there's an update in the real-time transcription and returns a higher quality, stabilized text as its argument.
 
-#### Voice Activation Parameters
+- **beam_size_realtime** (int, default=3): The beam size to use for real-time transcription beam search decoding.
 
-- **silero_sensitivity** (float, default=0.6): Sensitivity for Silero's voice activity detection ranging from 0 (least sensitive) to 1 (most sensitive). Default is 0.6.
+#### Voice Activation Parameters
 
 - **silero_sensitivity** (float, default=0.6): Sensitivity for Silero's voice activity detection ranging from 0 (least sensitive) to 1 (most sensitive). Default is 0.6.
 
 - **silero_use_onnx** (bool, default=False): Enables usage of the pre-trained model from Silero in the ONNX (Open Neural Network Exchange) format instead of the PyTorch format. Default is False. Recommended for faster performance.
 
+- **webrtc_sensitivity** (int, default=3): Sensitivity for the WebRTC Voice Activity Detection engine ranging from 0 (least aggressive / most sensitive) to 3 (most aggressive, least sensitive). Default is 3.
+
 - **post_speech_silence_duration** (float, default=0.2): Duration in seconds of silence that must follow speech before the recording is considered to be completed. This ensures that any brief pauses during speech don't prematurely end the recording.
 
 - **min_gap_between_recordings** (float, default=1.0): Specifies the minimum time interval in seconds that should exist between the end of one recording session and the beginning of another to prevent rapid consecutive recordings.
@@ -315,7 +331,6 @@ When you initialize the `AudioToTextRecorder` class, you have various options to
 
 - **on_wakeword_detection_end**: A callable function triggered when stopping to listen for wake words (e.g. because of timeout or wake word detected)
 
-
 ## Contribution
 
 Contributions are always welcome! 

+ 75 - 20
RealtimeSTT/audio_recorder.py

@@ -26,10 +26,13 @@ Author: Kolja Beigel
 
 """
 
+from typing import Iterable, List, Optional, Union
 import torch.multiprocessing as mp
 import torch
 from typing import List, Union
+from ctypes import c_bool
 from scipy.signal import resample
+from scipy import signal
 import faster_whisper
 import collections
 import numpy as np
@@ -139,6 +142,8 @@ class AudioToTextRecorder:
                  beam_size_realtime: int = 3,
                  buffer_size: int = BUFFER_SIZE,
                  sample_rate: int = SAMPLE_RATE,
+                 initial_prompt: Optional[Union[str, Iterable[int]]] = None,
+                 suppress_tokens: Optional[List[int]] = [-1],
                  ):
         """
         Initializes an audio recorder and  transcription
@@ -265,8 +270,22 @@ class AudioToTextRecorder:
             with the recorded audio chunk as its argument.
         - debug_mode (bool, default=False): If set to True, the system will
             print additional debug information to the console.
-        - log_buffer_overflow (bool, default=True): If set to True, the system
-            will log a warning when an input overflow occurs during recording.
+        - handle_buffer_overflow (bool, default=True): If set to True, the system
+            will log a warning when an input overflow occurs during recording and
+            remove the data from the buffer.
+        - beam_size (int, default=5): The beam size to use for beam search
+            decoding.
+        - beam_size_realtime (int, default=3): The beam size to use for beam
+            search decoding in the real-time transcription model.
+        - buffer_size (int, default=512): The buffer size to use for audio
+            recording. Changing this may break functionality.
+        - sample_rate (int, default=16000): The sample rate to use for audio
+            recording. Changing this will very probably functionality (as the
+            WebRTC VAD model is very sensitive towards the sample rate).
+        - initial_prompt (str or iterable of int, default=None): Initial
+            prompt to be fed to the transcription models.
+        - suppress_tokens (list of int, default=[-1]): Tokens to be suppressed
+            from the transcription output.
 
         Raises:
             Exception: Errors related to initializing transcription
@@ -285,7 +304,7 @@ class AudioToTextRecorder:
         self.ensure_sentence_ends_with_period = (
             ensure_sentence_ends_with_period
         )
-        self.use_microphone = use_microphone
+        self.use_microphone = mp.Value(c_bool, use_microphone)
         self.min_gap_between_recordings = min_gap_between_recordings
         self.min_length_of_recording = min_length_of_recording
         self.pre_recording_buffer_duration = pre_recording_buffer_duration
@@ -344,6 +363,8 @@ class AudioToTextRecorder:
         self.start_recording_event = threading.Event()
         self.stop_recording_event = threading.Event()
         self.last_transcription_bytes = None
+        self.initial_prompt = initial_prompt
+        self.suppress_tokens = suppress_tokens
 
         # Initialize the logging configuration with the specified level
         log_format = 'RealTimeSTT: %(name)s - %(levelname)s - %(message)s'
@@ -394,13 +415,15 @@ class AudioToTextRecorder:
                 self.main_transcription_ready_event,
                 self.shutdown_event,
                 self.interrupt_stop_event,
-                self.beam_size
+                self.beam_size,
+                self.initial_prompt,
+                self.suppress_tokens
             )
         )
         self.transcript_process.start()
 
         # Start audio data reading process
-        if use_microphone:
+        if self.use_microphone.value:
             logging.info("Initializing audio recording"
                          " (creating pyAudio input stream,"
                          f" sample rate: {self.sample_rate}"
@@ -414,7 +437,8 @@ class AudioToTextRecorder:
                     self.buffer_size,
                     self.input_device_index,
                     self.shutdown_event,
-                    self.interrupt_stop_event
+                    self.interrupt_stop_event,
+                    self.use_microphone
                 )
             )
             self.reader_process.start()
@@ -544,7 +568,10 @@ class AudioToTextRecorder:
                               ready_event,
                               shutdown_event,
                               interrupt_stop_event,
-                              beam_size):
+                              beam_size,
+                              initial_prompt,
+                              suppress_tokens
+                              ):
         """
         Worker method that handles the continuous
         process of transcribing audio data.
@@ -572,7 +599,10 @@ class AudioToTextRecorder:
             interrupt_stop_event (threading.Event): An event that, when set,
                 signals this worker method to stop processing audio data.
             beam_size (int): The beam size to use for beam search decoding.
-
+            initial_prompt (str or iterable of int): Initial prompt to be fed
+                to the transcription model.
+            suppress_tokens (list of int): Tokens to be suppressed from the
+                transcription output.
         Raises:
             Exception: If there is an error while initializing the
             transcription model.
@@ -587,7 +617,7 @@ class AudioToTextRecorder:
                 model_size_or_path=model_path,
                 device='cuda' if torch.cuda.is_available() else 'cpu',
                 compute_type=compute_type,
-                device_index=gpu_device_index
+                device_index=gpu_device_index,
             )
 
         except Exception as e:
@@ -610,7 +640,9 @@ class AudioToTextRecorder:
                         segments = model.transcribe(
                             audio,
                             language=language if language else None,
-                            beam_size=beam_size
+                            beam_size=beam_size,
+                            initial_prompt=initial_prompt,
+                            suppress_tokens=suppress_tokens
                         )
                         segments = segments[0]
                         transcription = " ".join(seg.text for seg in segments)
@@ -635,7 +667,8 @@ class AudioToTextRecorder:
                            buffer_size,
                            input_device_index,
                            shutdown_event,
-                           interrupt_stop_event):
+                           interrupt_stop_event,
+                           use_microphone):
         """
         Worker method that handles the audio recording process.
 
@@ -705,7 +738,8 @@ class AudioToTextRecorder:
                     print(f"Error: {e}")
                     continue
 
-                audio_queue.put(data)
+                if use_microphone.value:
+                    audio_queue.put(data)
 
         except KeyboardInterrupt:
             interrupt_stop_event.set()
@@ -960,6 +994,13 @@ class AudioToTextRecorder:
             # Feed the extracted data to the audio_queue
             self.audio_queue.put(to_process)
 
+    def set_microphone(self, microphone_on=True):
+        """
+        Set the microphone on or off.
+        """
+        logging.info("Setting microphone to: " + str(microphone_on))
+        self.use_microphone.value = microphone_on
+
     def shutdown(self):
         """
         Safely shuts down the audio recording by stopping the
@@ -980,6 +1021,7 @@ class AudioToTextRecorder:
             self.recording_thread.join()
 
         logging.debug('Terminating reader process')
+
         # Give it some time to finish the loop and cleanup.
         if self.use_microphone:
             self.reader_process.join(timeout=10)
@@ -1246,7 +1288,9 @@ class AudioToTextRecorder:
                     segments = self.realtime_model_type.transcribe(
                         audio_array,
                         language=self.language if self.language else None,
-                        beam_size=self.beam_size_realtime
+                        beam_size=self.beam_size_realtime,
+                        initial_prompt=self.initial_prompt,
+                        suppress_tokens=self.suppress_tokens,
                     )
 
                     # double check recording state
@@ -1339,7 +1383,7 @@ class AudioToTextRecorder:
             logging.error(f"Unhandled exeption in _realtime_worker: {e}")
             raise
 
-    def _is_silero_speech(self, data):
+    def _is_silero_speech(self, chunk):
         """
         Returns true if speech is detected in the provided audio data
 
@@ -1347,9 +1391,14 @@ class AudioToTextRecorder:
             data (bytes): raw bytes of audio data (1024 raw bytes with
             16000 sample rate and 16 bits per sample)
         """
+        if self.sample_rate != 16000:
+            pcm_data = np.frombuffer(chunk, dtype=np.int16)
+            data_16000 = signal.resample_poly(
+                pcm_data, 16000, self.sample_rate)
+            chunk = data_16000.astype(np.int16).tobytes()
 
         self.silero_working = True
-        audio_chunk = np.frombuffer(data, dtype=np.int16)
+        audio_chunk = np.frombuffer(chunk, dtype=np.int16)
         audio_chunk = audio_chunk.astype(np.float32) / INT16_MAX_ABS_VALUE
         vad_prob = self.silero_vad_model(
             torch.from_numpy(audio_chunk),
@@ -1360,7 +1409,7 @@ class AudioToTextRecorder:
         self.silero_working = False
         return is_silero_speech_active
 
-    def _is_webrtc_speech(self, data, all_frames_must_be_true=False):
+    def _is_webrtc_speech(self, chunk, all_frames_must_be_true=False):
         """
         Returns true if speech is detected in the provided audio data
 
@@ -1368,16 +1417,22 @@ class AudioToTextRecorder:
             data (bytes): raw bytes of audio data (1024 raw bytes with
             16000 sample rate and 16 bits per sample)
         """
+        if self.sample_rate != 16000:
+            pcm_data = np.frombuffer(chunk, dtype=np.int16)
+            data_16000 = signal.resample_poly(
+                pcm_data, 16000, self.sample_rate)
+            chunk = data_16000.astype(np.int16).tobytes()
+
         # Number of audio frames per millisecond
-        frame_length = int(self.sample_rate * 0.01)  # for 10ms frame
-        num_frames = int(len(data) / (2 * frame_length))
+        frame_length = int(16000 * 0.01)  # for 10ms frame
+        num_frames = int(len(chunk) / (2 * frame_length))
         speech_frames = 0
 
         for i in range(num_frames):
             start_byte = i * frame_length * 2
             end_byte = start_byte + frame_length * 2
-            frame = data[start_byte:end_byte]
-            if self.webrtc_vad_model.is_speech(frame, self.sample_rate):
+            frame = chunk[start_byte:end_byte]
+            if self.webrtc_vad_model.is_speech(frame, 16000):
                 speech_frames += 1
                 if not all_frames_must_be_true:
                     if self.debug_mode: