Prechádzať zdrojové kódy

reworker TranscriptionWorker

KoljaB 7 mesiacov pred
rodič
commit
8c38c5cfdb
1 zmenil súbory, kde vykonal 209 pridanie a 113 odobranie
  1. 209 113
      RealtimeSTT/audio_recorder.py

+ 209 - 113
RealtimeSTT/audio_recorder.py

@@ -85,6 +85,99 @@ if platform.system() != 'Darwin':
     INIT_HANDLE_BUFFER_OVERFLOW = True
 
 
+class TranscriptionWorker:
+    def __init__(self, conn, stdout_pipe, model_path, compute_type, gpu_device_index, device,
+                 ready_event, shutdown_event, interrupt_stop_event, beam_size, initial_prompt, suppress_tokens):
+        self.conn = conn
+        self.stdout_pipe = stdout_pipe
+        self.model_path = model_path
+        self.compute_type = compute_type
+        self.gpu_device_index = gpu_device_index
+        self.device = device
+        self.ready_event = ready_event
+        self.shutdown_event = shutdown_event
+        self.interrupt_stop_event = interrupt_stop_event
+        self.beam_size = beam_size
+        self.initial_prompt = initial_prompt
+        self.suppress_tokens = suppress_tokens
+        self.queue = queue.Queue()
+
+    def custom_print(self, *args, **kwargs):
+        message = ' '.join(map(str, args))
+        try:
+            self.stdout_pipe.send(message)
+        except (BrokenPipeError, EOFError, OSError):
+            pass
+
+    def poll_connection(self):
+        while not self.shutdown_event.is_set():
+            if self.conn.poll(0.01):
+                try:
+                    data = self.conn.recv()
+                    self.queue.put(data)
+                except Exception as e:
+                    logging.error(f"Error receiving data from connection: {e}")
+            else:
+                time.sleep(TIME_SLEEP)
+
+    def run(self):
+        system_signal.signal(system_signal.SIGINT, system_signal.SIG_IGN)
+        __builtins__['print'] = self.custom_print
+
+        logging.info(f"Initializing faster_whisper main transcription model {self.model_path}")
+
+        try:
+            model = faster_whisper.WhisperModel(
+                model_size_or_path=self.model_path,
+                device=self.device,
+                compute_type=self.compute_type,
+                device_index=self.gpu_device_index,
+            )
+        except Exception as e:
+            logging.exception(f"Error initializing main faster_whisper transcription model: {e}")
+            raise
+
+        self.ready_event.set()
+        logging.debug("Faster_whisper main speech to text transcription model initialized successfully")
+
+        # Start the polling thread
+        polling_thread = threading.Thread(target=self.poll_connection)
+        polling_thread.start()
+
+        try:
+            while not self.shutdown_event.is_set():
+                try:
+                    audio, language = self.queue.get(timeout=0.1)
+                    try:
+                        segments, info = model.transcribe(
+                            audio,
+                            language=language if language else None,
+                            beam_size=self.beam_size,
+                            initial_prompt=self.initial_prompt,
+                            suppress_tokens=self.suppress_tokens
+                        )
+                        transcription = " ".join(seg.text for seg in segments).strip()
+                        logging.debug(f"Final text detected with main model: {transcription}")
+                        self.conn.send(('success', (transcription, info)))
+                    except Exception as e:
+                        logging.error(f"General error in transcription: {e}")
+                        self.conn.send(('error', str(e)))
+                except queue.Empty:
+                    continue
+                except KeyboardInterrupt:
+                    self.interrupt_stop_event.set()
+                    logging.debug("Transcription worker process finished due to KeyboardInterrupt")
+                    break
+                except Exception as e:
+                    logging.error(f"General error in processing queue item: {e}")
+        finally:
+            __builtins__['print'] = print  # Restore the original print function
+            self.conn.close()
+            self.stdout_pipe.close()
+            self.shutdown_event.set()  # Ensure the polling thread will stop
+            polling_thread.join()  # Wait for the polling thread to finish
+
+
 class AudioToTextRecorder:
     """
     A class responsible for capturing audio from the microphone, detecting
@@ -758,136 +851,139 @@ class AudioToTextRecorder:
                 break 
             time.sleep(0.1)
 
-    @staticmethod
-    def _transcription_worker(conn,
-                              stdout_pipe,
-                              model_path,
-                              compute_type,
-                              gpu_device_index,
-                              device,
-                              ready_event,
-                              shutdown_event,
-                              interrupt_stop_event,
-                              beam_size,
-                              initial_prompt,
-                              suppress_tokens
-                              ):
-        """
-        Worker method that handles the continuous
-        process of transcribing audio data.
+    def _transcription_worker(*args, **kwargs):
+        worker = TranscriptionWorker(*args, **kwargs)
+        worker.run()
+    # @staticmethod
+    # def _transcription_worker(conn,
+    #                           stdout_pipe,
+    #                           model_path,
+    #                           compute_type,
+    #                           gpu_device_index,
+    #                           device,
+    #                           ready_event,
+    #                           shutdown_event,
+    #                           interrupt_stop_event,
+    #                           beam_size,
+    #                           initial_prompt,
+    #                           suppress_tokens
+    #                           ):
+    #     """
+    #     Worker method that handles the continuous
+    #     process of transcribing audio data.
 
-        This method runs in a separate process and is responsible for:
-        - Initializing the `faster_whisper` model used for transcription.
-        - Receiving audio data sent through a pipe and using the model
-          to transcribe it.
-        - Sending transcription results back through the pipe.
-        - Continuously checking for a shutdown event to gracefully
-          terminate the transcription process.
+    #     This method runs in a separate process and is responsible for:
+    #     - Initializing the `faster_whisper` model used for transcription.
+    #     - Receiving audio data sent through a pipe and using the model
+    #       to transcribe it.
+    #     - Sending transcription results back through the pipe.
+    #     - Continuously checking for a shutdown event to gracefully
+    #       terminate the transcription process.
 
-        Args:
-            conn (multiprocessing.Connection): The connection endpoint used
-              for receiving audio data and sending transcription results.
-            model_path (str): The path to the pre-trained faster_whisper model
-              for transcription.
-            compute_type (str): Specifies the type of computation to be used
-                for transcription.
-            gpu_device_index (int): Device ID to use.
-            device (str): Device for model to use.
-            ready_event (threading.Event): An event that is set when the
-              transcription model is successfully initialized and ready.
-            shutdown_event (threading.Event): An event that, when set,
-              signals this worker method to terminate.
-            interrupt_stop_event (threading.Event): An event that, when set,
-                signals this worker method to stop processing audio data.
-            beam_size (int): The beam size to use for beam search decoding.
-            initial_prompt (str or iterable of int): Initial prompt to be fed
-                to the transcription model.
-            suppress_tokens (list of int): Tokens to be suppressed from the
-                transcription output.
-        Raises:
-            Exception: If there is an error while initializing the
-            transcription model.
-        """
+    #     Args:
+    #         conn (multiprocessing.Connection): The connection endpoint used
+    #           for receiving audio data and sending transcription results.
+    #         model_path (str): The path to the pre-trained faster_whisper model
+    #           for transcription.
+    #         compute_type (str): Specifies the type of computation to be used
+    #             for transcription.
+    #         gpu_device_index (int): Device ID to use.
+    #         device (str): Device for model to use.
+    #         ready_event (threading.Event): An event that is set when the
+    #           transcription model is successfully initialized and ready.
+    #         shutdown_event (threading.Event): An event that, when set,
+    #           signals this worker method to terminate.
+    #         interrupt_stop_event (threading.Event): An event that, when set,
+    #             signals this worker method to stop processing audio data.
+    #         beam_size (int): The beam size to use for beam search decoding.
+    #         initial_prompt (str or iterable of int): Initial prompt to be fed
+    #             to the transcription model.
+    #         suppress_tokens (list of int): Tokens to be suppressed from the
+    #             transcription output.
+    #     Raises:
+    #         Exception: If there is an error while initializing the
+    #         transcription model.
+    #     """
 
-        system_signal.signal(system_signal.SIGINT, system_signal.SIG_IGN)
+    #     system_signal.signal(system_signal.SIGINT, system_signal.SIG_IGN)
 
-        def custom_print(*args, **kwargs):
-            message = ' '.join(map(str, args))
-            try:
-                stdout_pipe.send(message)
-            except (BrokenPipeError, EOFError, OSError):
-                # The pipe probably has been closed, so we ignore the error
-                pass
+    #     def custom_print(*args, **kwargs):
+    #         message = ' '.join(map(str, args))
+    #         try:
+    #             stdout_pipe.send(message)
+    #         except (BrokenPipeError, EOFError, OSError):
+    #             # The pipe probably has been closed, so we ignore the error
+    #             pass
 
-        # Replace the built-in print function with our custom one
-        __builtins__['print'] = custom_print
+    #     # Replace the built-in print function with our custom one
+    #     __builtins__['print'] = custom_print
 
-        logging.info("Initializing faster_whisper "
-                     f"main transcription model {model_path}"
-                     )
+    #     logging.info("Initializing faster_whisper "
+    #                  f"main transcription model {model_path}"
+    #                  )
 
-        try:
-            model = faster_whisper.WhisperModel(
-                model_size_or_path=model_path,
-                device=device,
-                compute_type=compute_type,
-                device_index=gpu_device_index,
-            )
+    #     try:
+    #         model = faster_whisper.WhisperModel(
+    #             model_size_or_path=model_path,
+    #             device=device,
+    #             compute_type=compute_type,
+    #             device_index=gpu_device_index,
+    #         )
 
-        except Exception as e:
-            logging.exception("Error initializing main "
-                              f"faster_whisper transcription model: {e}"
-                              )
-            raise
+    #     except Exception as e:
+    #         logging.exception("Error initializing main "
+    #                           f"faster_whisper transcription model: {e}"
+    #                           )
+    #         raise
 
-        ready_event.set()
+    #     ready_event.set()
 
-        logging.debug("Faster_whisper main speech to text "
-                      "transcription model initialized successfully"
-                      )
+    #     logging.debug("Faster_whisper main speech to text "
+    #                   "transcription model initialized successfully"
+    #                   )
 
-        try:
-            while not shutdown_event.is_set():
-                try:
-                    if conn.poll(0.01):
-                        logging.debug("Receive from _transcription_worker  pipe")
-                        audio, language = conn.recv()
-                        try:
-                            segments, info = model.transcribe(
-                                audio,
-                                language=language if language else None,
-                                beam_size=beam_size,
-                                initial_prompt=initial_prompt,
-                                suppress_tokens=suppress_tokens
-                            )
-                            transcription = " ".join(seg.text for seg in segments)
-                            transcription = transcription.strip()
-                            logging.debug(f"Final text detected with main model: {transcription}")
-                            conn.send(('success', (transcription, info)))
-                        except Exception as e:
-                            logging.error(f"General error in _transcription_worker in transcription: {e}")
-                            conn.send(('error', str(e)))
-                    else:
-                        time.sleep(TIME_SLEEP)
+    #     try:
+    #         while not shutdown_event.is_set():
+    #             try:
+    #                 if conn.poll(0.01):
+    #                     logging.debug("Receive from _transcription_worker  pipe")
+    #                     audio, language = conn.recv()
+    #                     try:
+    #                         segments, info = model.transcribe(
+    #                             audio,
+    #                             language=language if language else None,
+    #                             beam_size=beam_size,
+    #                             initial_prompt=initial_prompt,
+    #                             suppress_tokens=suppress_tokens
+    #                         )
+    #                         transcription = " ".join(seg.text for seg in segments)
+    #                         transcription = transcription.strip()
+    #                         logging.debug(f"Final text detected with main model: {transcription}")
+    #                         conn.send(('success', (transcription, info)))
+    #                     except Exception as e:
+    #                         logging.error(f"General error in _transcription_worker in transcription: {e}")
+    #                         conn.send(('error', str(e)))
+    #                 else:
+    #                     time.sleep(TIME_SLEEP)
 
 
 
-                except KeyboardInterrupt:
-                    interrupt_stop_event.set()
+    #             except KeyboardInterrupt:
+    #                 interrupt_stop_event.set()
                     
-                    logging.debug("Transcription worker process "
-                                    "finished due to KeyboardInterrupt"
-                                    )
-                    stdout_pipe.close()
-                    break
+    #                 logging.debug("Transcription worker process "
+    #                                 "finished due to KeyboardInterrupt"
+    #                                 )
+    #                 stdout_pipe.close()
+    #                 break
 
-                except Exception as e:
-                    logging.error(f"General error in _transcription_worker in accessing pipe: {e}")
+    #             except Exception as e:
+    #                 logging.error(f"General error in _transcription_worker in accessing pipe: {e}")
 
-        finally:
-            __builtins__['print'] = print  # Restore the original print function            
-            conn.close()
-            stdout_pipe.close()
+    #     finally:
+    #         __builtins__['print'] = print  # Restore the original print function            
+    #         conn.close()
+    #         stdout_pipe.close()
 
     @staticmethod
     def _audio_data_worker(audio_queue,