Kolja Beigel il y a 1 an
Parent
commit
8d0ba914a2
5 fichiers modifiés avec 100 ajouts et 25 suppressions
  1. 8 2
      README.md
  2. 84 20
      RealtimeSTT/audio_recorder.py
  3. 1 1
      install_with_gpu_support.bat
  4. 2 2
      requirements.txt
  5. 5 0
      requirements_raw.txt

+ 8 - 2
README.md

@@ -186,7 +186,7 @@ recorder = AudioToTextRecorder(on_recording_start=my_start_callback,
 
 ### Feed chunks
 
-If you don't want to use the local microphone set use_microphone parameter to false and provide raw PCM audiochunks in 16-bit mono with this method:
+If you don't want to use the local microphone set use_microphone parameter to false and provide raw PCM audiochunks in 16-bit mono (samplerate 16000) with this method:
 
 ```python
 recorder.feed_audio(audio_chunk)
@@ -253,7 +253,13 @@ When you initialize the `AudioToTextRecorder` class, you have various options to
     - Options: 'tiny', 'tiny.en', 'base', 'base.en', 'small', 'small.en', 'medium', 'medium.en', 'large-v1', 'large-v2'.
     - Note: If a size is provided, the model will be downloaded from the Hugging Face Hub.
 
-- **language** (str, default=""): Language code for transcription. If left empty, the model will try to auto-detect the language.
+- **language** (str, default=""): Language code for transcription. If left empty, the model will try to auto-detect the language. Supported language codes are listed in [Whisper Tokenizer library](https://github.com/openai/whisper/blob/main/whisper/tokenizer.py).
+
+- **compute_type** (str, default="default"): Specifies the type of computation to be used for transcription. See [Whisper Quantization](https://opennmt.net/CTranslate2/quantization.html)
+
+- **input_device_index** (int, default=0): Audio Input Device Index to use.
+
+- **gpu_device_index** (int, default=0): GPU Device Index to use. The model can also be loaded on multiple GPUs by passing a list of IDs (e.g. [0, 1, 2, 3]).
 
 - **on_recording_start**: A callable function triggered when recording starts.
 

+ 84 - 20
RealtimeSTT/audio_recorder.py

@@ -26,7 +26,8 @@ Author: Kolja Beigel
 
 """
 
-from multiprocessing import Process, Pipe, Event, Manager
+import torch.multiprocessing as mp
+from typing import List, Union
 import faster_whisper
 import collections
 import numpy as np
@@ -43,6 +44,7 @@ import halo
 import time
 import os
 import re
+import gc
 
 INIT_MODEL_TRANSCRIPTION = "tiny"
 INIT_MODEL_TRANSCRIPTION_REALTIME = "tiny"
@@ -74,6 +76,9 @@ class AudioToTextRecorder:
     def __init__(self,
                  model: str = INIT_MODEL_TRANSCRIPTION,
                  language: str = "",
+                 compute_type: str = "default",
+                 input_device_index: int = 0,
+                 gpu_device_index: Union[int, List[int]] = 0,
                  on_recording_start=None,
                  on_recording_stop=None,
                  on_transcription_start=None,
@@ -136,6 +141,16 @@ class AudioToTextRecorder:
         - language (str, default=""): Language code for speech-to-text engine.
             If not specified, the model will attempt to detect the language
             automatically.
+        - compute_type (str, default="default"): Specifies the type of
+            computation to be used for transcription.
+            See https://opennmt.net/CTranslate2/quantization.html.
+        - input_device_index (int, default=0): The index of the audio input
+            device to use.
+        - gpu_device_index (int, default=0): Device ID to use.
+            The model can also be loaded on multiple GPUs by passing a list of
+            IDs (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can
+            run in parallel when transcribe() is called from multiple Python
+            threads
         - on_recording_start (callable, default=None): Callback function to be
             called when recording of audio to be transcripted starts.
         - on_recording_stop (callable, default=None): Callback function to be
@@ -238,6 +253,9 @@ class AudioToTextRecorder:
         """
 
         self.language = language
+        self.compute_type = compute_type
+        self.input_device_index = input_device_index
+        self.gpu_device_index = gpu_device_index
         self.wake_words = wake_words
         self.wake_word_activation_delay = wake_word_activation_delay
         self.wake_word_timeout = wake_word_timeout
@@ -247,6 +265,7 @@ class AudioToTextRecorder:
         self.ensure_sentence_ends_with_period = (
             ensure_sentence_ends_with_period
         )
+        self.use_microphone = use_microphone
         self.min_gap_between_recordings = min_gap_between_recordings
         self.min_length_of_recording = min_length_of_recording
         self.pre_recording_buffer_duration = pre_recording_buffer_duration
@@ -272,8 +291,7 @@ class AudioToTextRecorder:
         self.allowed_latency_limit = ALLOWED_LATENCY_LIMIT
 
         self.level = level
-        manager = Manager()
-        self.audio_queue = manager.Queue()
+        self.audio_queue = mp.Queue()
         self.buffer_size = BUFFER_SIZE
         self.sample_rate = SAMPLE_RATE
         self.recording_start_time = 0
@@ -323,20 +341,30 @@ class AudioToTextRecorder:
         logger.addHandler(console_handler)
 
         self.is_shut_down = False
-        self.shutdown_event = Event()
+        self.shutdown_event = mp.Event()
 
         logging.info("Starting RealTimeSTT")
 
-        # Start transcription process
-        self.interrupt_stop_event = Event()
-        self.main_transcription_ready_event = Event()
-        self.parent_transcription_pipe, child_transcription_pipe = Pipe()
-
-        self.transcript_process = Process(
+        # Start transcription worker process
+        try:
+            # Only set the start method if it hasn't been set already
+            if mp.get_start_method(allow_none=True) is None:
+                mp.set_start_method("spawn")
+        except RuntimeError as e:
+            print("Start method has already been set. Details:", e)
+
+        self.interrupt_stop_event = mp.Event()
+        self.was_interrupted = mp.Event()
+        self.main_transcription_ready_event = mp.Event()
+        self.parent_transcription_pipe, child_transcription_pipe = mp.Pipe()
+
+        self.transcript_process = mp.Process(
             target=AudioToTextRecorder._transcription_worker,
             args=(
                 child_transcription_pipe,
                 model,
+                self.compute_type,
+                self.gpu_device_index,
                 self.main_transcription_ready_event,
                 self.shutdown_event,
                 self.interrupt_stop_event
@@ -346,12 +374,13 @@ class AudioToTextRecorder:
 
         # Start audio data reading process
         if use_microphone:
-            self.reader_process = Process(
+            self.reader_process = mp.Process(
                 target=AudioToTextRecorder._audio_data_worker,
                 args=(
                     self.audio_queue,
                     self.sample_rate,
                     self.buffer_size,
+                    self.input_device_index,
                     self.shutdown_event,
                     self.interrupt_stop_event
                 )
@@ -366,7 +395,9 @@ class AudioToTextRecorder:
                              )
                 self.realtime_model_type = faster_whisper.WhisperModel(
                     model_size_or_path=self.realtime_model_type,
-                    device='cuda' if torch.cuda.is_available() else 'cpu'
+                    device='cuda' if torch.cuda.is_available() else 'cpu',
+                    compute_type=self.compute_type,
+                    device_index=self.gpu_device_index
                 )
 
             except Exception as e:
@@ -473,12 +504,11 @@ class AudioToTextRecorder:
 
         logging.debug('RealtimeSTT initialization completed successfully')
 
-        print(f"buffer_size: {self.buffer_size}")
-        print(f"samplerate: {self.sample_rate}")
-
     @staticmethod
     def _transcription_worker(conn,
                               model_path,
+                              compute_type,
+                              gpu_device_index,
                               ready_event,
                               shutdown_event,
                               interrupt_stop_event):
@@ -499,6 +529,9 @@ class AudioToTextRecorder:
               for receiving audio data and sending transcription results.
             model_path (str): The path to the pre-trained faster_whisper model
               for transcription.
+            compute_type (str): Specifies the type of computation to be used
+                for transcription.
+            gpu_device_index (int): Device ID to use.
             ready_event (threading.Event): An event that is set when the
               transcription model is successfully initialized and ready.
             shutdown_event (threading.Event): An event that, when set,
@@ -516,7 +549,9 @@ class AudioToTextRecorder:
         try:
             model = faster_whisper.WhisperModel(
                 model_size_or_path=model_path,
-                device='cuda' if torch.cuda.is_available() else 'cpu'
+                device='cuda' if torch.cuda.is_available() else 'cpu',
+                compute_type=compute_type,
+                device_index=gpu_device_index
             )
 
         except Exception as e:
@@ -563,6 +598,7 @@ class AudioToTextRecorder:
     def _audio_data_worker(audio_queue,
                            sample_rate,
                            buffer_size,
+                           input_device_index,
                            shutdown_event,
                            interrupt_stop_event):
         """
@@ -583,6 +619,7 @@ class AudioToTextRecorder:
             sample_rate (int): The sample rate of the audio input stream.
             buffer_size (int): The size of the buffer used in the audio
               input stream.
+            input_device_index (int): The index of the audio input device
             shutdown_event (threading.Event): An event that, when set, signals
               this worker method to terminate.
 
@@ -600,7 +637,8 @@ class AudioToTextRecorder:
                                           format=pyaudio.paInt16,
                                           channels=1,
                                           input=True,
-                                          frames_per_buffer=buffer_size
+                                          frames_per_buffer=buffer_size,
+                                          input_device_index=input_device_index,
                                           )
 
         except Exception as e:
@@ -647,6 +685,20 @@ class AudioToTextRecorder:
             stream.close()
             audio_interface.terminate()
 
+    def wakeup(self):
+        """
+        If in wake work modus, wake up as if a wake word was spoken.
+        """
+        self.listen_start = time.time()
+
+    def abort(self):
+        self.start_recording_on_voice_activity = False
+        self.stop_recording_on_voice_deactivity = False
+        self._set_state("inactive")
+        self.interrupt_stop_event.set()
+        self.was_interrupted.wait()
+        self.was_interrupted.clear()
+
     def wait_audio(self):
         """
         Waits for the start and completion of the audio recording process.
@@ -671,7 +723,7 @@ class AudioToTextRecorder:
 
             # Wait until recording starts
             while not self.interrupt_stop_event.is_set():
-                if self.start_recording_event.wait(timeout=0.5):
+                if self.start_recording_event.wait(timeout=0.02):
                     break
 
         # If recording is ongoing, wait for voice inactivity
@@ -681,7 +733,7 @@ class AudioToTextRecorder:
 
             # Wait until recording stops
             while not self.interrupt_stop_event.is_set():
-                if (self.stop_recording_event.wait(timeout=0.5)):
+                if (self.stop_recording_event.wait(timeout=0.02)):
                     break
 
         # Convert recorded frames to the appropriate audio format.
@@ -756,9 +808,14 @@ class AudioToTextRecorder:
             str: The transcription of the recorded audio
         """
 
+        self.interrupt_stop_event.clear()
+        self.was_interrupted.clear()
+
         self.wait_audio()
 
         if self.is_shut_down or self.interrupt_stop_event.is_set():
+            if self.interrupt_stop_event.is_set():
+                self.was_interrupted.set()
             return ""
 
         if on_transcription_finished:
@@ -873,7 +930,8 @@ class AudioToTextRecorder:
 
         logging.debug('Terminating reader process')
         # Give it some time to finish the loop and cleanup.
-        self.reader_process.join(timeout=10)
+        if self.use_microphone:
+            self.reader_process.join(timeout=10)
 
         if self.reader_process.is_alive():
             logging.warning("Reader process did not terminate "
@@ -896,6 +954,12 @@ class AudioToTextRecorder:
         if self.realtime_thread:
             self.realtime_thread.join()
 
+        if self.enable_realtime_transcription:
+            if self.realtime_model_type:
+                del self.realtime_model_type
+                self.realtime_model_type = None
+        gc.collect()
+
     def _recording_worker(self):
         """
         The main worker method which constantly monitors the audio

+ 1 - 1
install_with_gpu_support.bat

@@ -1,2 +1,2 @@
-pip install torch==2.1.1+cu118 torchaudio==2.1.1+cu118 --index-url https://download.pytorch.org/whl/cu118
+pip install torch==2.1.2+cu118 torchaudio==2.1.2+cu118 --index-url https://download.pytorch.org/whl/cu118
 pip install -r requirements-gpu.txt

+ 2 - 2
requirements.txt

@@ -3,5 +3,5 @@ faster-whisper==0.10.0
 pvporcupine==1.9.5
 webrtcvad==2.0.10
 halo==0.0.31
-torch==2.1.1
-torchaudio==2.1.1
+torch==2.1.2
+torchaudio==2.1.2

+ 5 - 0
requirements_raw.txt

@@ -0,0 +1,5 @@
+PyAudio
+faster-whisper==0.10.0
+pvporcupine==1.9.5
+webrtcvad
+halo