Browse Source

added language detection

KoljaB 8 months ago
parent
commit
909aeddeeb
2 changed files with 18 additions and 8 deletions
  1. 15 6
      RealtimeSTT/audio_recorder.py
  2. 3 2
      tests/realtimestt_test.py

+ 15 - 6
RealtimeSTT/audio_recorder.py

@@ -410,6 +410,10 @@ class AudioToTextRecorder:
         self.initial_prompt = initial_prompt
         self.suppress_tokens = suppress_tokens
         self.use_wake_words = wake_words or wakeword_backend in {'oww', 'openwakeword', 'openwakewords'}
+        self.detected_language = None
+        self.detected_language_probability = 0
+        self.detected_realtime_language = None
+        self.detected_realtime_language_probability = 0
 
         # Initialize the logging configuration with the specified level
         log_format = 'RealTimeSTT: %(name)s - %(levelname)s - %(message)s'
@@ -758,17 +762,16 @@ class AudioToTextRecorder:
                 if conn.poll(0.5):
                     audio, language = conn.recv()
                     try:
-                        segments = model.transcribe(
+                        segments, info = model.transcribe(
                             audio,
                             language=language if language else None,
                             beam_size=beam_size,
                             initial_prompt=initial_prompt,
                             suppress_tokens=suppress_tokens
                         )
-                        segments = segments[0]
                         transcription = " ".join(seg.text for seg in segments)
                         transcription = transcription.strip()
-                        conn.send(('success', transcription))
+                        conn.send(('success', (transcription, info)))
                     except Exception as e:
                         logging.error(f"General transcription error: {e}")
                         conn.send(('error', str(e)))
@@ -969,8 +972,11 @@ class AudioToTextRecorder:
 
         self._set_state("inactive")
         if status == 'success':
+            segments, info = result
+            self.detected_language = info.language if info.language_probability > 0 else None
+            self.detected_language_probability = info.language_probability
             self.last_transcription_bytes = audio_copy
-            return self._preprocess_output(result)
+            return self._preprocess_output(segments)
         else:
             logging.error(result)
             raise Exception(result)
@@ -1447,7 +1453,7 @@ class AudioToTextRecorder:
                         INT16_MAX_ABS_VALUE
 
                     # Perform transcription and assemble the text
-                    segments = self.realtime_model_type.transcribe(
+                    segments, info = self.realtime_model_type.transcribe(
                         audio_array,
                         language=self.language if self.language else None,
                         beam_size=self.beam_size_realtime,
@@ -1455,6 +1461,9 @@ class AudioToTextRecorder:
                         suppress_tokens=self.suppress_tokens,
                     )
 
+                    self.detected_realtime_language = info.language if info.language_probability > 0 else None
+                    self.detected_realtime_language_probability = info.language_probability
+
                     # double check recording state
                     # because it could have changed mid-transcription
                     if self.is_recording and time.time() - \
@@ -1462,7 +1471,7 @@ class AudioToTextRecorder:
 
                         logging.debug('Starting realtime transcription')
                         self.realtime_transcription_text = " ".join(
-                            seg.text for seg in segments[0]
+                            seg.text for seg in segments
                         )
                         self.realtime_transcription_text = \
                             self.realtime_transcription_text.strip()

+ 3 - 2
tests/realtimestt_test.py

@@ -26,6 +26,7 @@ if __name__ == '__main__':
         if new_text != displayed_text:
             displayed_text = new_text
             clear_console()
+            print(f"Language: {recorder.detected_language} (realtime: {recorder.detected_realtime_language})")
             print(displayed_text, end="", flush=True)
 
     def process_text(text):
@@ -35,7 +36,6 @@ if __name__ == '__main__':
     recorder_config = {
         'spinner': False,
         'model': 'large-v2',
-        'language': 'en',
         'silero_sensitivity': 0.4,
         'webrtc_sensitivity': 2,
         'post_speech_silence_duration': 0.4,
@@ -43,8 +43,9 @@ if __name__ == '__main__':
         'min_gap_between_recordings': 0,
         'enable_realtime_transcription': True,
         'realtime_processing_pause': 0.2,
-        'realtime_model_type': 'tiny.en',
+        'realtime_model_type': 'tiny',
         'on_realtime_transcription_update': text_detected, 
+        'silero_deactivity_detection': True,
     }
 
     recorder = AudioToTextRecorder(**recorder_config)