Browse Source

Update audio_recorder.py to use a consistent threading library

Currently on linux there is a CUDA initialization error caused by a multiple model loadings that the pytorch Multiprocessing library. Standard thread.Thread() works fine. This commit consolidates how threads are created to use one way or the other and defaults to thread.Thread() for Linux.
Daniel Williams 9 months ago
parent
commit
2bf5cdf6d7
1 changed files with 32 additions and 15 deletions
  1. 32 15
      RealtimeSTT/audio_recorder.py

+ 32 - 15
RealtimeSTT/audio_recorder.py

@@ -443,7 +443,7 @@ class AudioToTextRecorder:
         # Set device for model
         # Set device for model
         self.device = "cuda" if self.device == "cuda" and torch.cuda.is_available() else "cpu"
         self.device = "cuda" if self.device == "cuda" and torch.cuda.is_available() else "cpu"
 
 
-        self.transcript_process = mp.Process(
+        self.transcript_process = self._start_thread(
             target=AudioToTextRecorder._transcription_worker,
             target=AudioToTextRecorder._transcription_worker,
             args=(
             args=(
                 child_transcription_pipe,
                 child_transcription_pipe,
@@ -459,7 +459,6 @@ class AudioToTextRecorder:
                 self.suppress_tokens
                 self.suppress_tokens
             )
             )
         )
         )
-        self.transcript_process.start()
 
 
         # Start audio data reading process
         # Start audio data reading process
         if self.use_microphone.value:
         if self.use_microphone.value:
@@ -468,7 +467,7 @@ class AudioToTextRecorder:
                          f" sample rate: {self.sample_rate}"
                          f" sample rate: {self.sample_rate}"
                          f" buffer size: {self.buffer_size}"
                          f" buffer size: {self.buffer_size}"
                          )
                          )
-            self.reader_process = mp.Process(
+            self.reader_process = self._start_thread(
                 target=AudioToTextRecorder._audio_data_worker,
                 target=AudioToTextRecorder._audio_data_worker,
                 args=(
                 args=(
                     self.audio_queue,
                     self.audio_queue,
@@ -480,7 +479,6 @@ class AudioToTextRecorder:
                     self.use_microphone
                     self.use_microphone
                 )
                 )
             )
             )
-            self.reader_process.start()
 
 
         # Initialize the realtime transcription model
         # Initialize the realtime transcription model
         if self.enable_realtime_transcription:
         if self.enable_realtime_transcription:
@@ -632,14 +630,10 @@ class AudioToTextRecorder:
         self.stop_recording_on_voice_deactivity = False
         self.stop_recording_on_voice_deactivity = False
 
 
         # Start the recording worker thread
         # Start the recording worker thread
-        self.recording_thread = threading.Thread(target=self._recording_worker)
-        self.recording_thread.daemon = True
-        self.recording_thread.start()
+        self.recording_thread = self._start_thread(target=self._recording_worker)
 
 
         # Start the realtime transcription worker thread
         # Start the realtime transcription worker thread
-        self.realtime_thread = threading.Thread(target=self._realtime_worker)
-        self.realtime_thread.daemon = True
-        self.realtime_thread.start()
+        self.realtime_thread = self._start_thread(target=self._realtime_worker)
 
 
         # Wait for transcription models to start
         # Wait for transcription models to start
         logging.debug('Waiting for main transcription model to start')
         logging.debug('Waiting for main transcription model to start')
@@ -647,6 +641,29 @@ class AudioToTextRecorder:
         logging.debug('Main transcription model ready')
         logging.debug('Main transcription model ready')
 
 
         logging.debug('RealtimeSTT initialization completed successfully')
         logging.debug('RealtimeSTT initialization completed successfully')
+                   
+    def _start_thread(self, target=None, args=()):
+        """
+        Implement a consistent threading model across the library.
+
+        This method is used to start any thread in this library. It uses the
+        standard threading. Thread for Linux and for all others uses the pytorch
+        MultiProcessing library 'Process'.
+        Args:
+            target (callable object): is the callable object to be invoked by
+              the run() method. Defaults to None, meaning nothing is called.
+            args (tuple): is a list or tuple of arguments for the target
+              invocation. Defaults to ().
+        """
+        if (platform.system() == 'Linux'):
+            thread = threading.Thread(target=target, args=args)
+            thread.deamon = True
+            thread.start()
+            return thread
+        else:
+            thread = mp.Process(target=target, args=args)
+            thread.start()
+            return thread
 
 
     @staticmethod
     @staticmethod
     def _transcription_worker(conn,
     def _transcription_worker(conn,
@@ -992,7 +1009,7 @@ class AudioToTextRecorder:
         using the `faster_whisper` model.
         using the `faster_whisper` model.
 
 
         - Automatically starts recording upon voice activity if not manually
         - Automatically starts recording upon voice activity if not manually
-          started using `recorder.start()`.
+          started using `recorder.`.
         - Automatically stops recording upon voice deactivity if not manually
         - Automatically stops recording upon voice deactivity if not manually
           stopped with `recorder.stop()`.
           stopped with `recorder.stop()`.
         - Processes the recorded audio to generate transcription.
         - Processes the recorded audio to generate transcription.
@@ -1020,8 +1037,8 @@ class AudioToTextRecorder:
             return ""
             return ""
 
 
         if on_transcription_finished:
         if on_transcription_finished:
-            threading.Thread(target=on_transcription_finished,
-                             args=(self.transcribe(),)).start()
+            self._start_thread(target=on_transcription_finished,
+                             args=(self.transcribe(),))
         else:
         else:
             return self.transcribe()
             return self.transcribe()
 
 
@@ -1597,9 +1614,9 @@ class AudioToTextRecorder:
                 self.silero_working = True
                 self.silero_working = True
 
 
                 # Run the intensive check in a separate thread
                 # Run the intensive check in a separate thread
-                threading.Thread(
+                self._start_thread(
                     target=self._is_silero_speech,
                     target=self._is_silero_speech,
-                    args=(data,)).start()
+                    args=(data,))
 
 
     def _is_voice_active(self):
     def _is_voice_active(self):
         """
         """