فهرست منبع

added beam size

Kolja Beigel 1 سال پیش
والد
کامیت
434ed13840
1فایلهای تغییر یافته به همراه48 افزوده شده و 11 حذف شده
  1. 48 11
      RealtimeSTT/audio_recorder.py

+ 48 - 11
RealtimeSTT/audio_recorder.py

@@ -29,6 +29,7 @@ Author: Kolja Beigel
 import torch.multiprocessing as mp
 import torch.multiprocessing as mp
 import torch
 import torch
 from typing import List, Union
 from typing import List, Union
+from scipy.signal import resample
 import faster_whisper
 import faster_whisper
 import collections
 import collections
 import numpy as np
 import numpy as np
@@ -43,6 +44,7 @@ import logging
 import struct
 import struct
 import halo
 import halo
 import time
 import time
+import copy
 import os
 import os
 import re
 import re
 import gc
 import gc
@@ -133,6 +135,10 @@ class AudioToTextRecorder:
                  on_recorded_chunk=None,
                  on_recorded_chunk=None,
                  debug_mode=False,
                  debug_mode=False,
                  handle_buffer_overflow: bool = INIT_HANDLE_BUFFER_OVERFLOW,
                  handle_buffer_overflow: bool = INIT_HANDLE_BUFFER_OVERFLOW,
+                 beam_size: int = 5,
+                 beam_size_realtime: int = 1,
+                 buffer_size: int = BUFFER_SIZE,
+                 sample_rate: int = SAMPLE_RATE,
                  ):
                  ):
         """
         """
         Initializes an audio recorder and  transcription
         Initializes an audio recorder and  transcription
@@ -305,12 +311,14 @@ class AudioToTextRecorder:
         )
         )
         self.debug_mode = debug_mode
         self.debug_mode = debug_mode
         self.handle_buffer_overflow = handle_buffer_overflow
         self.handle_buffer_overflow = handle_buffer_overflow
+        self.beam_size = beam_size
+        self.beam_size_realtime = beam_size_realtime
         self.allowed_latency_limit = ALLOWED_LATENCY_LIMIT
         self.allowed_latency_limit = ALLOWED_LATENCY_LIMIT
 
 
         self.level = level
         self.level = level
         self.audio_queue = mp.Queue()
         self.audio_queue = mp.Queue()
-        self.buffer_size = BUFFER_SIZE
-        self.sample_rate = SAMPLE_RATE
+        self.buffer_size = buffer_size
+        self.sample_rate = sample_rate
         self.recording_start_time = 0
         self.recording_start_time = 0
         self.recording_stop_time = 0
         self.recording_stop_time = 0
         self.wake_word_detect_time = 0
         self.wake_word_detect_time = 0
@@ -335,6 +343,7 @@ class AudioToTextRecorder:
         self.stream = None
         self.stream = None
         self.start_recording_event = threading.Event()
         self.start_recording_event = threading.Event()
         self.stop_recording_event = threading.Event()
         self.stop_recording_event = threading.Event()
+        self.last_transcription_bytes = None
 
 
         # Initialize the logging configuration with the specified level
         # Initialize the logging configuration with the specified level
         log_format = 'RealTimeSTT: %(name)s - %(levelname)s - %(message)s'
         log_format = 'RealTimeSTT: %(name)s - %(levelname)s - %(message)s'
@@ -384,13 +393,19 @@ class AudioToTextRecorder:
                 self.gpu_device_index,
                 self.gpu_device_index,
                 self.main_transcription_ready_event,
                 self.main_transcription_ready_event,
                 self.shutdown_event,
                 self.shutdown_event,
-                self.interrupt_stop_event
+                self.interrupt_stop_event,
+                self.beam_size
             )
             )
         )
         )
         self.transcript_process.start()
         self.transcript_process.start()
 
 
         # Start audio data reading process
         # Start audio data reading process
         if use_microphone:
         if use_microphone:
+            logging.info("Initializing audio recording"
+                         " (creating pyAudio input stream,"
+                         f" sample rate: {self.sample_rate}"
+                         f" buffer size: {self.buffer_size}"
+                         )
             self.reader_process = mp.Process(
             self.reader_process = mp.Process(
                 target=AudioToTextRecorder._audio_data_worker,
                 target=AudioToTextRecorder._audio_data_worker,
                 args=(
                 args=(
@@ -528,7 +543,8 @@ class AudioToTextRecorder:
                               gpu_device_index,
                               gpu_device_index,
                               ready_event,
                               ready_event,
                               shutdown_event,
                               shutdown_event,
-                              interrupt_stop_event):
+                              interrupt_stop_event,
+                              beam_size):
         """
         """
         Worker method that handles the continuous
         Worker method that handles the continuous
         process of transcribing audio data.
         process of transcribing audio data.
@@ -553,6 +569,9 @@ class AudioToTextRecorder:
               transcription model is successfully initialized and ready.
               transcription model is successfully initialized and ready.
             shutdown_event (threading.Event): An event that, when set,
             shutdown_event (threading.Event): An event that, when set,
               signals this worker method to terminate.
               signals this worker method to terminate.
+            interrupt_stop_event (threading.Event): An event that, when set,
+                signals this worker method to stop processing audio data.
+            beam_size (int): The beam size to use for beam search decoding.
 
 
         Raises:
         Raises:
             Exception: If there is an error while initializing the
             Exception: If there is an error while initializing the
@@ -589,7 +608,9 @@ class AudioToTextRecorder:
                     audio, language = conn.recv()
                     audio, language = conn.recv()
                     try:
                     try:
                         segments = model.transcribe(
                         segments = model.transcribe(
-                            audio, language=language if language else None
+                            audio,
+                            language=language if language else None,
+                            beam_size=beam_size
                         )
                         )
                         segments = segments[0]
                         segments = segments[0]
                         transcription = " ".join(seg.text for seg in segments)
                         transcription = " ".join(seg.text for seg in segments)
@@ -641,10 +662,6 @@ class AudioToTextRecorder:
             Exception: If there is an error while initializing the audio
             Exception: If there is an error while initializing the audio
               recording.
               recording.
         """
         """
-        logging.info("Initializing audio recording "
-                     "(creating pyAudio input stream)"
-                     )
-
         try:
         try:
             audio_interface = pyaudio.PyAudio()
             audio_interface = pyaudio.PyAudio()
             stream = audio_interface.open(
             stream = audio_interface.open(
@@ -788,11 +805,13 @@ class AudioToTextRecorder:
             Exception: If there is an error during the transcription process.
             Exception: If there is an error during the transcription process.
         """
         """
         self._set_state("transcribing")
         self._set_state("transcribing")
+        audio_copy = copy.deepcopy(self.audio)
         self.parent_transcription_pipe.send((self.audio, self.language))
         self.parent_transcription_pipe.send((self.audio, self.language))
         status, result = self.parent_transcription_pipe.recv()
         status, result = self.parent_transcription_pipe.recv()
 
 
         self._set_state("inactive")
         self._set_state("inactive")
         if status == 'success':
         if status == 'success':
+            self.last_transcription_bytes = audio_copy
             return self._preprocess_output(result)
             return self._preprocess_output(result)
         else:
         else:
             logging.error(result)
             logging.error(result)
@@ -901,7 +920,7 @@ class AudioToTextRecorder:
 
 
         return self
         return self
 
 
-    def feed_audio(self, chunk):
+    def feed_audio(self, chunk, original_sample_rate=16000):
         """
         """
         Feed an audio chunk into the processing pipeline. Chunks are
         Feed an audio chunk into the processing pipeline. Chunks are
         accumulated until the buffer size is reached, and then the accumulated
         accumulated until the buffer size is reached, and then the accumulated
@@ -911,6 +930,23 @@ class AudioToTextRecorder:
         if not hasattr(self, 'buffer'):
         if not hasattr(self, 'buffer'):
             self.buffer = bytearray()
             self.buffer = bytearray()
 
 
+        # Check if input is a NumPy array
+        if isinstance(chunk, np.ndarray):
+            # Handle stereo to mono conversion if necessary
+            if chunk.ndim == 2:
+                chunk = np.mean(chunk, axis=1)
+
+            # Resample to 16000 Hz if necessary
+            if original_sample_rate != 16000:
+                num_samples = int(len(chunk) * 16000 / original_sample_rate)
+                chunk = resample(chunk, num_samples)
+
+            # Ensure data type is int16
+            chunk = chunk.astype(np.int16)
+
+            # Convert the NumPy array to bytes
+            chunk = chunk.tobytes()
+
         # Append the chunk to the buffer
         # Append the chunk to the buffer
         self.buffer += chunk
         self.buffer += chunk
         buf_size = 2 * self.buffer_size  # silero complains if too short
         buf_size = 2 * self.buffer_size  # silero complains if too short
@@ -1209,7 +1245,8 @@ class AudioToTextRecorder:
                     # Perform transcription and assemble the text
                     # Perform transcription and assemble the text
                     segments = self.realtime_model_type.transcribe(
                     segments = self.realtime_model_type.transcribe(
                         audio_array,
                         audio_array,
-                        language=self.language if self.language else None
+                        language=self.language if self.language else None,
+                        beam_size=self.beam_size_realtime
                     )
                     )
 
 
                     # double check recording state
                     # double check recording state