Jelajahi Sumber

Merge pull request #85 from danielwoz/patch-3

Update audio_recorder.py to use a consistent threading library
Kolja Beigel 9 bulan lalu
induk
melakukan
e20f602709
1 mengubah file dengan 27 tambahan dan 10 penghapusan
  1. 27 10
      RealtimeSTT/audio_recorder.py

+ 27 - 10
RealtimeSTT/audio_recorder.py

@@ -443,7 +443,7 @@ class AudioToTextRecorder:
         # Set device for model
         self.device = "cuda" if self.device == "cuda" and torch.cuda.is_available() else "cpu"
 
-        self.transcript_process = mp.Process(
+        self.transcript_process = self._start_thread(
             target=AudioToTextRecorder._transcription_worker,
             args=(
                 child_transcription_pipe,
@@ -459,7 +459,6 @@ class AudioToTextRecorder:
                 self.suppress_tokens
             )
         )
-        self.transcript_process.start()
 
         # Start audio data reading process
         if self.use_microphone.value:
@@ -468,7 +467,7 @@ class AudioToTextRecorder:
                          f" sample rate: {self.sample_rate}"
                          f" buffer size: {self.buffer_size}"
                          )
-            self.reader_process = mp.Process(
+            self.reader_process = self._start_thread(
                 target=AudioToTextRecorder._audio_data_worker,
                 args=(
                     self.audio_queue,
@@ -480,7 +479,6 @@ class AudioToTextRecorder:
                     self.use_microphone
                 )
             )
-            self.reader_process.start()
 
         # Initialize the realtime transcription model
         if self.enable_realtime_transcription:
@@ -632,14 +630,10 @@ class AudioToTextRecorder:
         self.stop_recording_on_voice_deactivity = False
 
         # Start the recording worker thread
-        self.recording_thread = threading.Thread(target=self._recording_worker)
-        self.recording_thread.daemon = True
-        self.recording_thread.start()
+        self.recording_thread = self._start_thread(target=self._recording_worker)
 
         # Start the realtime transcription worker thread
-        self.realtime_thread = threading.Thread(target=self._realtime_worker)
-        self.realtime_thread.daemon = True
-        self.realtime_thread.start()
+        self.realtime_thread = self._start_thread(target=self._realtime_worker)
 
         # Wait for transcription models to start
         logging.debug('Waiting for main transcription model to start')
@@ -647,6 +641,29 @@ class AudioToTextRecorder:
         logging.debug('Main transcription model ready')
 
         logging.debug('RealtimeSTT initialization completed successfully')
+                   
+    def _start_thread(self, target=None, args=()):
+        """
+        Implement a consistent threading model across the library.
+
+        This method is used to start any thread in this library. It uses the
+        standard threading. Thread for Linux and for all others uses the pytorch
+        MultiProcessing library 'Process'.
+        Args:
+            target (callable object): is the callable object to be invoked by
+              the run() method. Defaults to None, meaning nothing is called.
+            args (tuple): is a list or tuple of arguments for the target
+              invocation. Defaults to ().
+        """
+        if (platform.system() == 'Linux'):
+            thread = threading.Thread(target=target, args=args)
+            thread.deamon = True
+            thread.start()
+            return thread
+        else:
+            thread = mp.Process(target=target, args=args)
+            thread.start()
+            return thread
 
     @staticmethod
     def _transcription_worker(conn,