소스 검색

added beam size

Kolja Beigel 1 년 전
부모
커밋
434ed13840
1개의 변경된 파일48개의 추가작업 그리고 11개의 파일을 삭제
  1. 48 11
      RealtimeSTT/audio_recorder.py

+ 48 - 11
RealtimeSTT/audio_recorder.py

@@ -29,6 +29,7 @@ Author: Kolja Beigel
 import torch.multiprocessing as mp
 import torch
 from typing import List, Union
+from scipy.signal import resample
 import faster_whisper
 import collections
 import numpy as np
@@ -43,6 +44,7 @@ import logging
 import struct
 import halo
 import time
+import copy
 import os
 import re
 import gc
@@ -133,6 +135,10 @@ class AudioToTextRecorder:
                  on_recorded_chunk=None,
                  debug_mode=False,
                  handle_buffer_overflow: bool = INIT_HANDLE_BUFFER_OVERFLOW,
+                 beam_size: int = 5,
+                 beam_size_realtime: int = 1,
+                 buffer_size: int = BUFFER_SIZE,
+                 sample_rate: int = SAMPLE_RATE,
                  ):
         """
         Initializes an audio recorder and  transcription
@@ -305,12 +311,14 @@ class AudioToTextRecorder:
         )
         self.debug_mode = debug_mode
         self.handle_buffer_overflow = handle_buffer_overflow
+        self.beam_size = beam_size
+        self.beam_size_realtime = beam_size_realtime
         self.allowed_latency_limit = ALLOWED_LATENCY_LIMIT
 
         self.level = level
         self.audio_queue = mp.Queue()
-        self.buffer_size = BUFFER_SIZE
-        self.sample_rate = SAMPLE_RATE
+        self.buffer_size = buffer_size
+        self.sample_rate = sample_rate
         self.recording_start_time = 0
         self.recording_stop_time = 0
         self.wake_word_detect_time = 0
@@ -335,6 +343,7 @@ class AudioToTextRecorder:
         self.stream = None
         self.start_recording_event = threading.Event()
         self.stop_recording_event = threading.Event()
+        self.last_transcription_bytes = None
 
         # Initialize the logging configuration with the specified level
         log_format = 'RealTimeSTT: %(name)s - %(levelname)s - %(message)s'
@@ -384,13 +393,19 @@ class AudioToTextRecorder:
                 self.gpu_device_index,
                 self.main_transcription_ready_event,
                 self.shutdown_event,
-                self.interrupt_stop_event
+                self.interrupt_stop_event,
+                self.beam_size
             )
         )
         self.transcript_process.start()
 
         # Start audio data reading process
         if use_microphone:
+            logging.info("Initializing audio recording"
+                         " (creating pyAudio input stream,"
+                         f" sample rate: {self.sample_rate}"
+                         f" buffer size: {self.buffer_size}"
+                         )
             self.reader_process = mp.Process(
                 target=AudioToTextRecorder._audio_data_worker,
                 args=(
@@ -528,7 +543,8 @@ class AudioToTextRecorder:
                               gpu_device_index,
                               ready_event,
                               shutdown_event,
-                              interrupt_stop_event):
+                              interrupt_stop_event,
+                              beam_size):
         """
         Worker method that handles the continuous
         process of transcribing audio data.
@@ -553,6 +569,9 @@ class AudioToTextRecorder:
               transcription model is successfully initialized and ready.
             shutdown_event (threading.Event): An event that, when set,
               signals this worker method to terminate.
+            interrupt_stop_event (threading.Event): An event that, when set,
+                signals this worker method to stop processing audio data.
+            beam_size (int): The beam size to use for beam search decoding.
 
         Raises:
             Exception: If there is an error while initializing the
@@ -589,7 +608,9 @@ class AudioToTextRecorder:
                     audio, language = conn.recv()
                     try:
                         segments = model.transcribe(
-                            audio, language=language if language else None
+                            audio,
+                            language=language if language else None,
+                            beam_size=beam_size
                         )
                         segments = segments[0]
                         transcription = " ".join(seg.text for seg in segments)
@@ -641,10 +662,6 @@ class AudioToTextRecorder:
             Exception: If there is an error while initializing the audio
               recording.
         """
-        logging.info("Initializing audio recording "
-                     "(creating pyAudio input stream)"
-                     )
-
         try:
             audio_interface = pyaudio.PyAudio()
             stream = audio_interface.open(
@@ -788,11 +805,13 @@ class AudioToTextRecorder:
             Exception: If there is an error during the transcription process.
         """
         self._set_state("transcribing")
+        audio_copy = copy.deepcopy(self.audio)
         self.parent_transcription_pipe.send((self.audio, self.language))
         status, result = self.parent_transcription_pipe.recv()
 
         self._set_state("inactive")
         if status == 'success':
+            self.last_transcription_bytes = audio_copy
             return self._preprocess_output(result)
         else:
             logging.error(result)
@@ -901,7 +920,7 @@ class AudioToTextRecorder:
 
         return self
 
-    def feed_audio(self, chunk):
+    def feed_audio(self, chunk, original_sample_rate=16000):
         """
         Feed an audio chunk into the processing pipeline. Chunks are
         accumulated until the buffer size is reached, and then the accumulated
@@ -911,6 +930,23 @@ class AudioToTextRecorder:
         if not hasattr(self, 'buffer'):
             self.buffer = bytearray()
 
+        # Check if input is a NumPy array
+        if isinstance(chunk, np.ndarray):
+            # Handle stereo to mono conversion if necessary
+            if chunk.ndim == 2:
+                chunk = np.mean(chunk, axis=1)
+
+            # Resample to 16000 Hz if necessary
+            if original_sample_rate != 16000:
+                num_samples = int(len(chunk) * 16000 / original_sample_rate)
+                chunk = resample(chunk, num_samples)
+
+            # Ensure data type is int16
+            chunk = chunk.astype(np.int16)
+
+            # Convert the NumPy array to bytes
+            chunk = chunk.tobytes()
+
         # Append the chunk to the buffer
         self.buffer += chunk
         buf_size = 2 * self.buffer_size  # silero complains if too short
@@ -1209,7 +1245,8 @@ class AudioToTextRecorder:
                     # Perform transcription and assemble the text
                     segments = self.realtime_model_type.transcribe(
                         audio_array,
-                        language=self.language if self.language else None
+                        language=self.language if self.language else None,
+                        beam_size=self.beam_size_realtime
                     )
 
                     # double check recording state