|
@@ -26,10 +26,13 @@ Author: Kolja Beigel
|
|
|
|
|
|
"""
|
|
|
|
|
|
+from typing import Iterable, List, Optional, Union
|
|
|
import torch.multiprocessing as mp
|
|
|
import torch
|
|
|
from typing import List, Union
|
|
|
+from ctypes import c_bool
|
|
|
from scipy.signal import resample
|
|
|
+from scipy import signal
|
|
|
import faster_whisper
|
|
|
import collections
|
|
|
import numpy as np
|
|
@@ -139,6 +142,8 @@ class AudioToTextRecorder:
|
|
|
beam_size_realtime: int = 3,
|
|
|
buffer_size: int = BUFFER_SIZE,
|
|
|
sample_rate: int = SAMPLE_RATE,
|
|
|
+ initial_prompt: Optional[Union[str, Iterable[int]]] = None,
|
|
|
+ suppress_tokens: Optional[List[int]] = [-1],
|
|
|
):
|
|
|
"""
|
|
|
Initializes an audio recorder and transcription
|
|
@@ -265,8 +270,22 @@ class AudioToTextRecorder:
|
|
|
with the recorded audio chunk as its argument.
|
|
|
- debug_mode (bool, default=False): If set to True, the system will
|
|
|
print additional debug information to the console.
|
|
|
- - log_buffer_overflow (bool, default=True): If set to True, the system
|
|
|
- will log a warning when an input overflow occurs during recording.
|
|
|
+ - handle_buffer_overflow (bool, default=True): If set to True, the system
|
|
|
+ will log a warning when an input overflow occurs during recording and
|
|
|
+ remove the data from the buffer.
|
|
|
+ - beam_size (int, default=5): The beam size to use for beam search
|
|
|
+ decoding.
|
|
|
+ - beam_size_realtime (int, default=3): The beam size to use for beam
|
|
|
+ search decoding in the real-time transcription model.
|
|
|
+ - buffer_size (int, default=512): The buffer size to use for audio
|
|
|
+ recording. Changing this may break functionality.
|
|
|
+ - sample_rate (int, default=16000): The sample rate to use for audio
|
|
|
+ recording. Changing this will very probably functionality (as the
|
|
|
+ WebRTC VAD model is very sensitive towards the sample rate).
|
|
|
+ - initial_prompt (str or iterable of int, default=None): Initial
|
|
|
+ prompt to be fed to the transcription models.
|
|
|
+ - suppress_tokens (list of int, default=[-1]): Tokens to be suppressed
|
|
|
+ from the transcription output.
|
|
|
|
|
|
Raises:
|
|
|
Exception: Errors related to initializing transcription
|
|
@@ -285,7 +304,7 @@ class AudioToTextRecorder:
|
|
|
self.ensure_sentence_ends_with_period = (
|
|
|
ensure_sentence_ends_with_period
|
|
|
)
|
|
|
- self.use_microphone = use_microphone
|
|
|
+ self.use_microphone = mp.Value(c_bool, use_microphone)
|
|
|
self.min_gap_between_recordings = min_gap_between_recordings
|
|
|
self.min_length_of_recording = min_length_of_recording
|
|
|
self.pre_recording_buffer_duration = pre_recording_buffer_duration
|
|
@@ -344,6 +363,8 @@ class AudioToTextRecorder:
|
|
|
self.start_recording_event = threading.Event()
|
|
|
self.stop_recording_event = threading.Event()
|
|
|
self.last_transcription_bytes = None
|
|
|
+ self.initial_prompt = initial_prompt
|
|
|
+ self.suppress_tokens = suppress_tokens
|
|
|
|
|
|
# Initialize the logging configuration with the specified level
|
|
|
log_format = 'RealTimeSTT: %(name)s - %(levelname)s - %(message)s'
|
|
@@ -394,13 +415,15 @@ class AudioToTextRecorder:
|
|
|
self.main_transcription_ready_event,
|
|
|
self.shutdown_event,
|
|
|
self.interrupt_stop_event,
|
|
|
- self.beam_size
|
|
|
+ self.beam_size,
|
|
|
+ self.initial_prompt,
|
|
|
+ self.suppress_tokens
|
|
|
)
|
|
|
)
|
|
|
self.transcript_process.start()
|
|
|
|
|
|
# Start audio data reading process
|
|
|
- if use_microphone:
|
|
|
+ if self.use_microphone.value:
|
|
|
logging.info("Initializing audio recording"
|
|
|
" (creating pyAudio input stream,"
|
|
|
f" sample rate: {self.sample_rate}"
|
|
@@ -414,7 +437,8 @@ class AudioToTextRecorder:
|
|
|
self.buffer_size,
|
|
|
self.input_device_index,
|
|
|
self.shutdown_event,
|
|
|
- self.interrupt_stop_event
|
|
|
+ self.interrupt_stop_event,
|
|
|
+ self.use_microphone
|
|
|
)
|
|
|
)
|
|
|
self.reader_process.start()
|
|
@@ -544,7 +568,10 @@ class AudioToTextRecorder:
|
|
|
ready_event,
|
|
|
shutdown_event,
|
|
|
interrupt_stop_event,
|
|
|
- beam_size):
|
|
|
+ beam_size,
|
|
|
+ initial_prompt,
|
|
|
+ suppress_tokens
|
|
|
+ ):
|
|
|
"""
|
|
|
Worker method that handles the continuous
|
|
|
process of transcribing audio data.
|
|
@@ -572,7 +599,10 @@ class AudioToTextRecorder:
|
|
|
interrupt_stop_event (threading.Event): An event that, when set,
|
|
|
signals this worker method to stop processing audio data.
|
|
|
beam_size (int): The beam size to use for beam search decoding.
|
|
|
-
|
|
|
+ initial_prompt (str or iterable of int): Initial prompt to be fed
|
|
|
+ to the transcription model.
|
|
|
+ suppress_tokens (list of int): Tokens to be suppressed from the
|
|
|
+ transcription output.
|
|
|
Raises:
|
|
|
Exception: If there is an error while initializing the
|
|
|
transcription model.
|
|
@@ -587,7 +617,7 @@ class AudioToTextRecorder:
|
|
|
model_size_or_path=model_path,
|
|
|
device='cuda' if torch.cuda.is_available() else 'cpu',
|
|
|
compute_type=compute_type,
|
|
|
- device_index=gpu_device_index
|
|
|
+ device_index=gpu_device_index,
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
@@ -610,7 +640,9 @@ class AudioToTextRecorder:
|
|
|
segments = model.transcribe(
|
|
|
audio,
|
|
|
language=language if language else None,
|
|
|
- beam_size=beam_size
|
|
|
+ beam_size=beam_size,
|
|
|
+ initial_prompt=initial_prompt,
|
|
|
+ suppress_tokens=suppress_tokens
|
|
|
)
|
|
|
segments = segments[0]
|
|
|
transcription = " ".join(seg.text for seg in segments)
|
|
@@ -635,7 +667,8 @@ class AudioToTextRecorder:
|
|
|
buffer_size,
|
|
|
input_device_index,
|
|
|
shutdown_event,
|
|
|
- interrupt_stop_event):
|
|
|
+ interrupt_stop_event,
|
|
|
+ use_microphone):
|
|
|
"""
|
|
|
Worker method that handles the audio recording process.
|
|
|
|
|
@@ -705,7 +738,8 @@ class AudioToTextRecorder:
|
|
|
print(f"Error: {e}")
|
|
|
continue
|
|
|
|
|
|
- audio_queue.put(data)
|
|
|
+ if use_microphone.value:
|
|
|
+ audio_queue.put(data)
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
interrupt_stop_event.set()
|
|
@@ -960,6 +994,13 @@ class AudioToTextRecorder:
|
|
|
# Feed the extracted data to the audio_queue
|
|
|
self.audio_queue.put(to_process)
|
|
|
|
|
|
+ def set_microphone(self, microphone_on=True):
|
|
|
+ """
|
|
|
+ Set the microphone on or off.
|
|
|
+ """
|
|
|
+ logging.info("Setting microphone to: " + str(microphone_on))
|
|
|
+ self.use_microphone.value = microphone_on
|
|
|
+
|
|
|
def shutdown(self):
|
|
|
"""
|
|
|
Safely shuts down the audio recording by stopping the
|
|
@@ -980,6 +1021,7 @@ class AudioToTextRecorder:
|
|
|
self.recording_thread.join()
|
|
|
|
|
|
logging.debug('Terminating reader process')
|
|
|
+
|
|
|
# Give it some time to finish the loop and cleanup.
|
|
|
if self.use_microphone:
|
|
|
self.reader_process.join(timeout=10)
|
|
@@ -1246,7 +1288,9 @@ class AudioToTextRecorder:
|
|
|
segments = self.realtime_model_type.transcribe(
|
|
|
audio_array,
|
|
|
language=self.language if self.language else None,
|
|
|
- beam_size=self.beam_size_realtime
|
|
|
+ beam_size=self.beam_size_realtime,
|
|
|
+ initial_prompt=self.initial_prompt,
|
|
|
+ suppress_tokens=self.suppress_tokens,
|
|
|
)
|
|
|
|
|
|
# double check recording state
|
|
@@ -1339,7 +1383,7 @@ class AudioToTextRecorder:
|
|
|
logging.error(f"Unhandled exeption in _realtime_worker: {e}")
|
|
|
raise
|
|
|
|
|
|
- def _is_silero_speech(self, data):
|
|
|
+ def _is_silero_speech(self, chunk):
|
|
|
"""
|
|
|
Returns true if speech is detected in the provided audio data
|
|
|
|
|
@@ -1347,9 +1391,14 @@ class AudioToTextRecorder:
|
|
|
data (bytes): raw bytes of audio data (1024 raw bytes with
|
|
|
16000 sample rate and 16 bits per sample)
|
|
|
"""
|
|
|
+ if self.sample_rate != 16000:
|
|
|
+ pcm_data = np.frombuffer(chunk, dtype=np.int16)
|
|
|
+ data_16000 = signal.resample_poly(
|
|
|
+ pcm_data, 16000, self.sample_rate)
|
|
|
+ chunk = data_16000.astype(np.int16).tobytes()
|
|
|
|
|
|
self.silero_working = True
|
|
|
- audio_chunk = np.frombuffer(data, dtype=np.int16)
|
|
|
+ audio_chunk = np.frombuffer(chunk, dtype=np.int16)
|
|
|
audio_chunk = audio_chunk.astype(np.float32) / INT16_MAX_ABS_VALUE
|
|
|
vad_prob = self.silero_vad_model(
|
|
|
torch.from_numpy(audio_chunk),
|
|
@@ -1360,7 +1409,7 @@ class AudioToTextRecorder:
|
|
|
self.silero_working = False
|
|
|
return is_silero_speech_active
|
|
|
|
|
|
- def _is_webrtc_speech(self, data, all_frames_must_be_true=False):
|
|
|
+ def _is_webrtc_speech(self, chunk, all_frames_must_be_true=False):
|
|
|
"""
|
|
|
Returns true if speech is detected in the provided audio data
|
|
|
|
|
@@ -1368,16 +1417,22 @@ class AudioToTextRecorder:
|
|
|
data (bytes): raw bytes of audio data (1024 raw bytes with
|
|
|
16000 sample rate and 16 bits per sample)
|
|
|
"""
|
|
|
+ if self.sample_rate != 16000:
|
|
|
+ pcm_data = np.frombuffer(chunk, dtype=np.int16)
|
|
|
+ data_16000 = signal.resample_poly(
|
|
|
+ pcm_data, 16000, self.sample_rate)
|
|
|
+ chunk = data_16000.astype(np.int16).tobytes()
|
|
|
+
|
|
|
# Number of audio frames per millisecond
|
|
|
- frame_length = int(self.sample_rate * 0.01) # for 10ms frame
|
|
|
- num_frames = int(len(data) / (2 * frame_length))
|
|
|
+ frame_length = int(16000 * 0.01) # for 10ms frame
|
|
|
+ num_frames = int(len(chunk) / (2 * frame_length))
|
|
|
speech_frames = 0
|
|
|
|
|
|
for i in range(num_frames):
|
|
|
start_byte = i * frame_length * 2
|
|
|
end_byte = start_byte + frame_length * 2
|
|
|
- frame = data[start_byte:end_byte]
|
|
|
- if self.webrtc_vad_model.is_speech(frame, self.sample_rate):
|
|
|
+ frame = chunk[start_byte:end_byte]
|
|
|
+ if self.webrtc_vad_model.is_speech(frame, 16000):
|
|
|
speech_frames += 1
|
|
|
if not all_frames_must_be_true:
|
|
|
if self.debug_mode:
|