|
@@ -26,7 +26,8 @@ Author: Kolja Beigel
|
|
|
|
|
|
"""
|
|
|
|
|
|
-from multiprocessing import Process, Pipe, Event, Manager
|
|
|
+import torch.multiprocessing as mp
|
|
|
+from typing import List, Union
|
|
|
import faster_whisper
|
|
|
import collections
|
|
|
import numpy as np
|
|
@@ -43,6 +44,7 @@ import halo
|
|
|
import time
|
|
|
import os
|
|
|
import re
|
|
|
+import gc
|
|
|
|
|
|
INIT_MODEL_TRANSCRIPTION = "tiny"
|
|
|
INIT_MODEL_TRANSCRIPTION_REALTIME = "tiny"
|
|
@@ -74,6 +76,9 @@ class AudioToTextRecorder:
|
|
|
def __init__(self,
|
|
|
model: str = INIT_MODEL_TRANSCRIPTION,
|
|
|
language: str = "",
|
|
|
+ compute_type: str = "default",
|
|
|
+ input_device_index: int = 0,
|
|
|
+ gpu_device_index: Union[int, List[int]] = 0,
|
|
|
on_recording_start=None,
|
|
|
on_recording_stop=None,
|
|
|
on_transcription_start=None,
|
|
@@ -136,6 +141,16 @@ class AudioToTextRecorder:
|
|
|
- language (str, default=""): Language code for speech-to-text engine.
|
|
|
If not specified, the model will attempt to detect the language
|
|
|
automatically.
|
|
|
+ - compute_type (str, default="default"): Specifies the type of
|
|
|
+ computation to be used for transcription.
|
|
|
+ See https://opennmt.net/CTranslate2/quantization.html.
|
|
|
+ - input_device_index (int, default=0): The index of the audio input
|
|
|
+ device to use.
|
|
|
+ - gpu_device_index (int, default=0): Device ID to use.
|
|
|
+ The model can also be loaded on multiple GPUs by passing a list of
|
|
|
+ IDs (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can
|
|
|
+ run in parallel when transcribe() is called from multiple Python
|
|
|
+ threads
|
|
|
- on_recording_start (callable, default=None): Callback function to be
|
|
|
called when recording of audio to be transcripted starts.
|
|
|
- on_recording_stop (callable, default=None): Callback function to be
|
|
@@ -238,6 +253,9 @@ class AudioToTextRecorder:
|
|
|
"""
|
|
|
|
|
|
self.language = language
|
|
|
+ self.compute_type = compute_type
|
|
|
+ self.input_device_index = input_device_index
|
|
|
+ self.gpu_device_index = gpu_device_index
|
|
|
self.wake_words = wake_words
|
|
|
self.wake_word_activation_delay = wake_word_activation_delay
|
|
|
self.wake_word_timeout = wake_word_timeout
|
|
@@ -247,6 +265,7 @@ class AudioToTextRecorder:
|
|
|
self.ensure_sentence_ends_with_period = (
|
|
|
ensure_sentence_ends_with_period
|
|
|
)
|
|
|
+ self.use_microphone = use_microphone
|
|
|
self.min_gap_between_recordings = min_gap_between_recordings
|
|
|
self.min_length_of_recording = min_length_of_recording
|
|
|
self.pre_recording_buffer_duration = pre_recording_buffer_duration
|
|
@@ -272,8 +291,7 @@ class AudioToTextRecorder:
|
|
|
self.allowed_latency_limit = ALLOWED_LATENCY_LIMIT
|
|
|
|
|
|
self.level = level
|
|
|
- manager = Manager()
|
|
|
- self.audio_queue = manager.Queue()
|
|
|
+ self.audio_queue = mp.Queue()
|
|
|
self.buffer_size = BUFFER_SIZE
|
|
|
self.sample_rate = SAMPLE_RATE
|
|
|
self.recording_start_time = 0
|
|
@@ -323,20 +341,30 @@ class AudioToTextRecorder:
|
|
|
logger.addHandler(console_handler)
|
|
|
|
|
|
self.is_shut_down = False
|
|
|
- self.shutdown_event = Event()
|
|
|
+ self.shutdown_event = mp.Event()
|
|
|
|
|
|
logging.info("Starting RealTimeSTT")
|
|
|
|
|
|
- # Start transcription process
|
|
|
- self.interrupt_stop_event = Event()
|
|
|
- self.main_transcription_ready_event = Event()
|
|
|
- self.parent_transcription_pipe, child_transcription_pipe = Pipe()
|
|
|
-
|
|
|
- self.transcript_process = Process(
|
|
|
+ # Start transcription worker process
|
|
|
+ try:
|
|
|
+ # Only set the start method if it hasn't been set already
|
|
|
+ if mp.get_start_method(allow_none=True) is None:
|
|
|
+ mp.set_start_method("spawn")
|
|
|
+ except RuntimeError as e:
|
|
|
+ print("Start method has already been set. Details:", e)
|
|
|
+
|
|
|
+ self.interrupt_stop_event = mp.Event()
|
|
|
+ self.was_interrupted = mp.Event()
|
|
|
+ self.main_transcription_ready_event = mp.Event()
|
|
|
+ self.parent_transcription_pipe, child_transcription_pipe = mp.Pipe()
|
|
|
+
|
|
|
+ self.transcript_process = mp.Process(
|
|
|
target=AudioToTextRecorder._transcription_worker,
|
|
|
args=(
|
|
|
child_transcription_pipe,
|
|
|
model,
|
|
|
+ self.compute_type,
|
|
|
+ self.gpu_device_index,
|
|
|
self.main_transcription_ready_event,
|
|
|
self.shutdown_event,
|
|
|
self.interrupt_stop_event
|
|
@@ -346,12 +374,13 @@ class AudioToTextRecorder:
|
|
|
|
|
|
# Start audio data reading process
|
|
|
if use_microphone:
|
|
|
- self.reader_process = Process(
|
|
|
+ self.reader_process = mp.Process(
|
|
|
target=AudioToTextRecorder._audio_data_worker,
|
|
|
args=(
|
|
|
self.audio_queue,
|
|
|
self.sample_rate,
|
|
|
self.buffer_size,
|
|
|
+ self.input_device_index,
|
|
|
self.shutdown_event,
|
|
|
self.interrupt_stop_event
|
|
|
)
|
|
@@ -366,7 +395,9 @@ class AudioToTextRecorder:
|
|
|
)
|
|
|
self.realtime_model_type = faster_whisper.WhisperModel(
|
|
|
model_size_or_path=self.realtime_model_type,
|
|
|
- device='cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
+ device='cuda' if torch.cuda.is_available() else 'cpu',
|
|
|
+ compute_type=self.compute_type,
|
|
|
+ device_index=self.gpu_device_index
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
@@ -473,12 +504,11 @@ class AudioToTextRecorder:
|
|
|
|
|
|
logging.debug('RealtimeSTT initialization completed successfully')
|
|
|
|
|
|
- print(f"buffer_size: {self.buffer_size}")
|
|
|
- print(f"samplerate: {self.sample_rate}")
|
|
|
-
|
|
|
@staticmethod
|
|
|
def _transcription_worker(conn,
|
|
|
model_path,
|
|
|
+ compute_type,
|
|
|
+ gpu_device_index,
|
|
|
ready_event,
|
|
|
shutdown_event,
|
|
|
interrupt_stop_event):
|
|
@@ -499,6 +529,9 @@ class AudioToTextRecorder:
|
|
|
for receiving audio data and sending transcription results.
|
|
|
model_path (str): The path to the pre-trained faster_whisper model
|
|
|
for transcription.
|
|
|
+ compute_type (str): Specifies the type of computation to be used
|
|
|
+ for transcription.
|
|
|
+ gpu_device_index (int): Device ID to use.
|
|
|
ready_event (threading.Event): An event that is set when the
|
|
|
transcription model is successfully initialized and ready.
|
|
|
shutdown_event (threading.Event): An event that, when set,
|
|
@@ -516,7 +549,9 @@ class AudioToTextRecorder:
|
|
|
try:
|
|
|
model = faster_whisper.WhisperModel(
|
|
|
model_size_or_path=model_path,
|
|
|
- device='cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
+ device='cuda' if torch.cuda.is_available() else 'cpu',
|
|
|
+ compute_type=compute_type,
|
|
|
+ device_index=gpu_device_index
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
@@ -563,6 +598,7 @@ class AudioToTextRecorder:
|
|
|
def _audio_data_worker(audio_queue,
|
|
|
sample_rate,
|
|
|
buffer_size,
|
|
|
+ input_device_index,
|
|
|
shutdown_event,
|
|
|
interrupt_stop_event):
|
|
|
"""
|
|
@@ -583,6 +619,7 @@ class AudioToTextRecorder:
|
|
|
sample_rate (int): The sample rate of the audio input stream.
|
|
|
buffer_size (int): The size of the buffer used in the audio
|
|
|
input stream.
|
|
|
+ input_device_index (int): The index of the audio input device
|
|
|
shutdown_event (threading.Event): An event that, when set, signals
|
|
|
this worker method to terminate.
|
|
|
|
|
@@ -600,7 +637,8 @@ class AudioToTextRecorder:
|
|
|
format=pyaudio.paInt16,
|
|
|
channels=1,
|
|
|
input=True,
|
|
|
- frames_per_buffer=buffer_size
|
|
|
+ frames_per_buffer=buffer_size,
|
|
|
+ input_device_index=input_device_index,
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
@@ -647,6 +685,20 @@ class AudioToTextRecorder:
|
|
|
stream.close()
|
|
|
audio_interface.terminate()
|
|
|
|
|
|
+ def wakeup(self):
|
|
|
+ """
|
|
|
+ If in wake work modus, wake up as if a wake word was spoken.
|
|
|
+ """
|
|
|
+ self.listen_start = time.time()
|
|
|
+
|
|
|
+ def abort(self):
|
|
|
+ self.start_recording_on_voice_activity = False
|
|
|
+ self.stop_recording_on_voice_deactivity = False
|
|
|
+ self._set_state("inactive")
|
|
|
+ self.interrupt_stop_event.set()
|
|
|
+ self.was_interrupted.wait()
|
|
|
+ self.was_interrupted.clear()
|
|
|
+
|
|
|
def wait_audio(self):
|
|
|
"""
|
|
|
Waits for the start and completion of the audio recording process.
|
|
@@ -671,7 +723,7 @@ class AudioToTextRecorder:
|
|
|
|
|
|
# Wait until recording starts
|
|
|
while not self.interrupt_stop_event.is_set():
|
|
|
- if self.start_recording_event.wait(timeout=0.5):
|
|
|
+ if self.start_recording_event.wait(timeout=0.02):
|
|
|
break
|
|
|
|
|
|
# If recording is ongoing, wait for voice inactivity
|
|
@@ -681,7 +733,7 @@ class AudioToTextRecorder:
|
|
|
|
|
|
# Wait until recording stops
|
|
|
while not self.interrupt_stop_event.is_set():
|
|
|
- if (self.stop_recording_event.wait(timeout=0.5)):
|
|
|
+ if (self.stop_recording_event.wait(timeout=0.02)):
|
|
|
break
|
|
|
|
|
|
# Convert recorded frames to the appropriate audio format.
|
|
@@ -756,9 +808,14 @@ class AudioToTextRecorder:
|
|
|
str: The transcription of the recorded audio
|
|
|
"""
|
|
|
|
|
|
+ self.interrupt_stop_event.clear()
|
|
|
+ self.was_interrupted.clear()
|
|
|
+
|
|
|
self.wait_audio()
|
|
|
|
|
|
if self.is_shut_down or self.interrupt_stop_event.is_set():
|
|
|
+ if self.interrupt_stop_event.is_set():
|
|
|
+ self.was_interrupted.set()
|
|
|
return ""
|
|
|
|
|
|
if on_transcription_finished:
|
|
@@ -873,7 +930,8 @@ class AudioToTextRecorder:
|
|
|
|
|
|
logging.debug('Terminating reader process')
|
|
|
# Give it some time to finish the loop and cleanup.
|
|
|
- self.reader_process.join(timeout=10)
|
|
|
+ if self.use_microphone:
|
|
|
+ self.reader_process.join(timeout=10)
|
|
|
|
|
|
if self.reader_process.is_alive():
|
|
|
logging.warning("Reader process did not terminate "
|
|
@@ -896,6 +954,12 @@ class AudioToTextRecorder:
|
|
|
if self.realtime_thread:
|
|
|
self.realtime_thread.join()
|
|
|
|
|
|
+ if self.enable_realtime_transcription:
|
|
|
+ if self.realtime_model_type:
|
|
|
+ del self.realtime_model_type
|
|
|
+ self.realtime_model_type = None
|
|
|
+ gc.collect()
|
|
|
+
|
|
|
def _recording_worker(self):
|
|
|
"""
|
|
|
The main worker method which constantly monitors the audio
|