浏览代码

Update audio_recorder.py to use a consistent threading library

Currently on linux there is a CUDA initialization error caused by a multiple model loadings that the pytorch Multiprocessing library. Standard thread.Thread() works fine. This commit consolidates how threads are created to use one way or the other and defaults to thread.Thread() for Linux.
Daniel Williams 9 月之前
父节点
当前提交
2bf5cdf6d7
共有 1 个文件被更改,包括 32 次插入15 次删除
  1. 32 15
      RealtimeSTT/audio_recorder.py

+ 32 - 15
RealtimeSTT/audio_recorder.py

@@ -443,7 +443,7 @@ class AudioToTextRecorder:
         # Set device for model
         self.device = "cuda" if self.device == "cuda" and torch.cuda.is_available() else "cpu"
 
-        self.transcript_process = mp.Process(
+        self.transcript_process = self._start_thread(
             target=AudioToTextRecorder._transcription_worker,
             args=(
                 child_transcription_pipe,
@@ -459,7 +459,6 @@ class AudioToTextRecorder:
                 self.suppress_tokens
             )
         )
-        self.transcript_process.start()
 
         # Start audio data reading process
         if self.use_microphone.value:
@@ -468,7 +467,7 @@ class AudioToTextRecorder:
                          f" sample rate: {self.sample_rate}"
                          f" buffer size: {self.buffer_size}"
                          )
-            self.reader_process = mp.Process(
+            self.reader_process = self._start_thread(
                 target=AudioToTextRecorder._audio_data_worker,
                 args=(
                     self.audio_queue,
@@ -480,7 +479,6 @@ class AudioToTextRecorder:
                     self.use_microphone
                 )
             )
-            self.reader_process.start()
 
         # Initialize the realtime transcription model
         if self.enable_realtime_transcription:
@@ -632,14 +630,10 @@ class AudioToTextRecorder:
         self.stop_recording_on_voice_deactivity = False
 
         # Start the recording worker thread
-        self.recording_thread = threading.Thread(target=self._recording_worker)
-        self.recording_thread.daemon = True
-        self.recording_thread.start()
+        self.recording_thread = self._start_thread(target=self._recording_worker)
 
         # Start the realtime transcription worker thread
-        self.realtime_thread = threading.Thread(target=self._realtime_worker)
-        self.realtime_thread.daemon = True
-        self.realtime_thread.start()
+        self.realtime_thread = self._start_thread(target=self._realtime_worker)
 
         # Wait for transcription models to start
         logging.debug('Waiting for main transcription model to start')
@@ -647,6 +641,29 @@ class AudioToTextRecorder:
         logging.debug('Main transcription model ready')
 
         logging.debug('RealtimeSTT initialization completed successfully')
+                   
+    def _start_thread(self, target=None, args=()):
+        """
+        Implement a consistent threading model across the library.
+
+        This method is used to start any thread in this library. It uses the
+        standard threading. Thread for Linux and for all others uses the pytorch
+        MultiProcessing library 'Process'.
+        Args:
+            target (callable object): is the callable object to be invoked by
+              the run() method. Defaults to None, meaning nothing is called.
+            args (tuple): is a list or tuple of arguments for the target
+              invocation. Defaults to ().
+        """
+        if (platform.system() == 'Linux'):
+            thread = threading.Thread(target=target, args=args)
+            thread.deamon = True
+            thread.start()
+            return thread
+        else:
+            thread = mp.Process(target=target, args=args)
+            thread.start()
+            return thread
 
     @staticmethod
     def _transcription_worker(conn,
@@ -992,7 +1009,7 @@ class AudioToTextRecorder:
         using the `faster_whisper` model.
 
         - Automatically starts recording upon voice activity if not manually
-          started using `recorder.start()`.
+          started using `recorder.`.
         - Automatically stops recording upon voice deactivity if not manually
           stopped with `recorder.stop()`.
         - Processes the recorded audio to generate transcription.
@@ -1020,8 +1037,8 @@ class AudioToTextRecorder:
             return ""
 
         if on_transcription_finished:
-            threading.Thread(target=on_transcription_finished,
-                             args=(self.transcribe(),)).start()
+            self._start_thread(target=on_transcription_finished,
+                             args=(self.transcribe(),))
         else:
             return self.transcribe()
 
@@ -1597,9 +1614,9 @@ class AudioToTextRecorder:
                 self.silero_working = True
 
                 # Run the intensive check in a separate thread
-                threading.Thread(
+                self._start_thread(
                     target=self._is_silero_speech,
-                    args=(data,)).start()
+                    args=(data,))
 
     def _is_voice_active(self):
         """