Переглянути джерело

Update audio_recorder.py to use a consistent threading library

Currently on linux there is a CUDA initialization error caused by a multiple model loadings that the pytorch Multiprocessing library. Standard thread.Thread() works fine. This commit consolidates how threads are created to use one way or the other and defaults to thread.Thread() for Linux.
Daniel Williams 9 місяців тому
батько
коміт
2bf5cdf6d7
1 змінених файлів з 32 додано та 15 видалено
  1. 32 15
      RealtimeSTT/audio_recorder.py

+ 32 - 15
RealtimeSTT/audio_recorder.py

@@ -443,7 +443,7 @@ class AudioToTextRecorder:
         # Set device for model
         self.device = "cuda" if self.device == "cuda" and torch.cuda.is_available() else "cpu"
 
-        self.transcript_process = mp.Process(
+        self.transcript_process = self._start_thread(
             target=AudioToTextRecorder._transcription_worker,
             args=(
                 child_transcription_pipe,
@@ -459,7 +459,6 @@ class AudioToTextRecorder:
                 self.suppress_tokens
             )
         )
-        self.transcript_process.start()
 
         # Start audio data reading process
         if self.use_microphone.value:
@@ -468,7 +467,7 @@ class AudioToTextRecorder:
                          f" sample rate: {self.sample_rate}"
                          f" buffer size: {self.buffer_size}"
                          )
-            self.reader_process = mp.Process(
+            self.reader_process = self._start_thread(
                 target=AudioToTextRecorder._audio_data_worker,
                 args=(
                     self.audio_queue,
@@ -480,7 +479,6 @@ class AudioToTextRecorder:
                     self.use_microphone
                 )
             )
-            self.reader_process.start()
 
         # Initialize the realtime transcription model
         if self.enable_realtime_transcription:
@@ -632,14 +630,10 @@ class AudioToTextRecorder:
         self.stop_recording_on_voice_deactivity = False
 
         # Start the recording worker thread
-        self.recording_thread = threading.Thread(target=self._recording_worker)
-        self.recording_thread.daemon = True
-        self.recording_thread.start()
+        self.recording_thread = self._start_thread(target=self._recording_worker)
 
         # Start the realtime transcription worker thread
-        self.realtime_thread = threading.Thread(target=self._realtime_worker)
-        self.realtime_thread.daemon = True
-        self.realtime_thread.start()
+        self.realtime_thread = self._start_thread(target=self._realtime_worker)
 
         # Wait for transcription models to start
         logging.debug('Waiting for main transcription model to start')
@@ -647,6 +641,29 @@ class AudioToTextRecorder:
         logging.debug('Main transcription model ready')
 
         logging.debug('RealtimeSTT initialization completed successfully')
+                   
+    def _start_thread(self, target=None, args=()):
+        """
+        Implement a consistent threading model across the library.
+
+        This method is used to start any thread in this library. It uses the
+        standard threading. Thread for Linux and for all others uses the pytorch
+        MultiProcessing library 'Process'.
+        Args:
+            target (callable object): is the callable object to be invoked by
+              the run() method. Defaults to None, meaning nothing is called.
+            args (tuple): is a list or tuple of arguments for the target
+              invocation. Defaults to ().
+        """
+        if (platform.system() == 'Linux'):
+            thread = threading.Thread(target=target, args=args)
+            thread.deamon = True
+            thread.start()
+            return thread
+        else:
+            thread = mp.Process(target=target, args=args)
+            thread.start()
+            return thread
 
     @staticmethod
     def _transcription_worker(conn,
@@ -992,7 +1009,7 @@ class AudioToTextRecorder:
         using the `faster_whisper` model.
 
         - Automatically starts recording upon voice activity if not manually
-          started using `recorder.start()`.
+          started using `recorder.`.
         - Automatically stops recording upon voice deactivity if not manually
           stopped with `recorder.stop()`.
         - Processes the recorded audio to generate transcription.
@@ -1020,8 +1037,8 @@ class AudioToTextRecorder:
             return ""
 
         if on_transcription_finished:
-            threading.Thread(target=on_transcription_finished,
-                             args=(self.transcribe(),)).start()
+            self._start_thread(target=on_transcription_finished,
+                             args=(self.transcribe(),))
         else:
             return self.transcribe()
 
@@ -1597,9 +1614,9 @@ class AudioToTextRecorder:
                 self.silero_working = True
 
                 # Run the intensive check in a separate thread
-                threading.Thread(
+                self._start_thread(
                     target=self._is_silero_speech,
-                    args=(data,)).start()
+                    args=(data,))
 
     def _is_voice_active(self):
         """