فهرست منبع

Merge pull request #60 from vancoder1/device_choice

Adding device choice for whisper model to constructor
Kolja Beigel 11 ماه پیش
والد
کامیت
6dd040c6b7
2فایلهای تغییر یافته به همراه14 افزوده شده و 2 حذف شده
  1. 2 0
      README.md
  2. 12 2
      RealtimeSTT/audio_recorder.py

+ 2 - 0
README.md

@@ -255,6 +255,8 @@ When you initialize the `AudioToTextRecorder` class, you have various options to
 
 - **gpu_device_index** (int, default=0): GPU Device Index to use. The model can also be loaded on multiple GPUs by passing a list of IDs (e.g. [0, 1, 2, 3]).
 
+- **device** (str, default="cuda"): Device for model to use. Can either be "cuda" or "cpu". 
+
 - **on_recording_start**: A callable function triggered when recording starts.
 
 - **on_recording_stop**: A callable function triggered when recording ends.

+ 12 - 2
RealtimeSTT/audio_recorder.py

@@ -92,6 +92,7 @@ class AudioToTextRecorder:
                  compute_type: str = "default",
                  input_device_index: int = 0,
                  gpu_device_index: Union[int, List[int]] = 0,
+                 device: str = "cuda",
                  on_recording_start=None,
                  on_recording_stop=None,
                  on_transcription_start=None,
@@ -173,6 +174,8 @@ class AudioToTextRecorder:
             IDs (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can
             run in parallel when transcribe() is called from multiple Python
             threads
+        - device (str, default="cuda"): Device for model to use. Can either be 
+            "cuda" or "cpu".
         - on_recording_start (callable, default=None): Callback function to be
             called when recording of audio to be transcripted starts.
         - on_recording_stop (callable, default=None): Callback function to be
@@ -298,6 +301,7 @@ class AudioToTextRecorder:
         self.compute_type = compute_type
         self.input_device_index = input_device_index
         self.gpu_device_index = gpu_device_index
+        self.device = device
         self.wake_words = wake_words
         self.wake_word_activation_delay = wake_word_activation_delay
         self.wake_word_timeout = wake_word_timeout
@@ -406,6 +410,9 @@ class AudioToTextRecorder:
         self.main_transcription_ready_event = mp.Event()
         self.parent_transcription_pipe, child_transcription_pipe = mp.Pipe()
 
+        # Set device for model
+        self.device = "cuda" if self.device == "cuda" and torch.cuda.is_available() else "cpu"
+
         self.transcript_process = mp.Process(
             target=AudioToTextRecorder._transcription_worker,
             args=(
@@ -413,6 +420,7 @@ class AudioToTextRecorder:
                 model,
                 self.compute_type,
                 self.gpu_device_index,
+                self.device,
                 self.main_transcription_ready_event,
                 self.shutdown_event,
                 self.interrupt_stop_event,
@@ -452,7 +460,7 @@ class AudioToTextRecorder:
                              )
                 self.realtime_model_type = faster_whisper.WhisperModel(
                     model_size_or_path=self.realtime_model_type,
-                    device='cuda' if torch.cuda.is_available() else 'cpu',
+                    device=self.device,
                     compute_type=self.compute_type,
                     device_index=self.gpu_device_index
                 )
@@ -566,6 +574,7 @@ class AudioToTextRecorder:
                               model_path,
                               compute_type,
                               gpu_device_index,
+                              device,
                               ready_event,
                               shutdown_event,
                               interrupt_stop_event,
@@ -593,6 +602,7 @@ class AudioToTextRecorder:
             compute_type (str): Specifies the type of computation to be used
                 for transcription.
             gpu_device_index (int): Device ID to use.
+            device (str): Device for model to use.
             ready_event (threading.Event): An event that is set when the
               transcription model is successfully initialized and ready.
             shutdown_event (threading.Event): An event that, when set,
@@ -616,7 +626,7 @@ class AudioToTextRecorder:
         try:
             model = faster_whisper.WhisperModel(
                 model_size_or_path=model_path,
-                device='cuda' if torch.cuda.is_available() else 'cpu',
+                device=device,
                 compute_type=compute_type,
                 device_index=gpu_device_index,
             )