|
@@ -92,6 +92,7 @@ class AudioToTextRecorder:
|
|
|
compute_type: str = "default",
|
|
|
input_device_index: int = 0,
|
|
|
gpu_device_index: Union[int, List[int]] = 0,
|
|
|
+ device: str = "cuda",
|
|
|
on_recording_start=None,
|
|
|
on_recording_stop=None,
|
|
|
on_transcription_start=None,
|
|
@@ -173,6 +174,8 @@ class AudioToTextRecorder:
|
|
|
IDs (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can
|
|
|
run in parallel when transcribe() is called from multiple Python
|
|
|
threads
|
|
|
+ - device (str, default="cuda"): Device for model to use. Can either be
|
|
|
+ "cuda" or "cpu".
|
|
|
- on_recording_start (callable, default=None): Callback function to be
|
|
|
called when recording of audio to be transcripted starts.
|
|
|
- on_recording_stop (callable, default=None): Callback function to be
|
|
@@ -298,6 +301,7 @@ class AudioToTextRecorder:
|
|
|
self.compute_type = compute_type
|
|
|
self.input_device_index = input_device_index
|
|
|
self.gpu_device_index = gpu_device_index
|
|
|
+ self.device = device
|
|
|
self.wake_words = wake_words
|
|
|
self.wake_word_activation_delay = wake_word_activation_delay
|
|
|
self.wake_word_timeout = wake_word_timeout
|
|
@@ -406,6 +410,9 @@ class AudioToTextRecorder:
|
|
|
self.main_transcription_ready_event = mp.Event()
|
|
|
self.parent_transcription_pipe, child_transcription_pipe = mp.Pipe()
|
|
|
|
|
|
+ # Set device for model
|
|
|
+ self.device = "cuda" if self.device == "cuda" and torch.cuda.is_available() else "cpu"
|
|
|
+
|
|
|
self.transcript_process = mp.Process(
|
|
|
target=AudioToTextRecorder._transcription_worker,
|
|
|
args=(
|
|
@@ -413,6 +420,7 @@ class AudioToTextRecorder:
|
|
|
model,
|
|
|
self.compute_type,
|
|
|
self.gpu_device_index,
|
|
|
+ self.device,
|
|
|
self.main_transcription_ready_event,
|
|
|
self.shutdown_event,
|
|
|
self.interrupt_stop_event,
|
|
@@ -452,7 +460,7 @@ class AudioToTextRecorder:
|
|
|
)
|
|
|
self.realtime_model_type = faster_whisper.WhisperModel(
|
|
|
model_size_or_path=self.realtime_model_type,
|
|
|
- device='cuda' if torch.cuda.is_available() else 'cpu',
|
|
|
+ device=self.device,
|
|
|
compute_type=self.compute_type,
|
|
|
device_index=self.gpu_device_index
|
|
|
)
|
|
@@ -566,6 +574,7 @@ class AudioToTextRecorder:
|
|
|
model_path,
|
|
|
compute_type,
|
|
|
gpu_device_index,
|
|
|
+ device,
|
|
|
ready_event,
|
|
|
shutdown_event,
|
|
|
interrupt_stop_event,
|
|
@@ -593,6 +602,7 @@ class AudioToTextRecorder:
|
|
|
compute_type (str): Specifies the type of computation to be used
|
|
|
for transcription.
|
|
|
gpu_device_index (int): Device ID to use.
|
|
|
+ device (str): Device for model to use.
|
|
|
ready_event (threading.Event): An event that is set when the
|
|
|
transcription model is successfully initialized and ready.
|
|
|
shutdown_event (threading.Event): An event that, when set,
|
|
@@ -616,7 +626,7 @@ class AudioToTextRecorder:
|
|
|
try:
|
|
|
model = faster_whisper.WhisperModel(
|
|
|
model_size_or_path=model_path,
|
|
|
- device='cuda' if torch.cuda.is_available() else 'cpu',
|
|
|
+ device=device,
|
|
|
compute_type=compute_type,
|
|
|
device_index=gpu_device_index,
|
|
|
)
|