Bläddra i källkod

update client server

KoljaB 6 månader sedan
förälder
incheckning
64ff132374
4 ändrade filer med 102 tillägg och 14 borttagningar
  1. 5 3
      RealtimeSTT/audio_recorder.py
  2. 78 1
      RealtimeSTT/audio_recorder_client.py
  3. 3 2
      server/stt_cli_client.py
  4. 16 8
      server/stt_server.py

+ 5 - 3
RealtimeSTT/audio_recorder.py

@@ -29,7 +29,6 @@ Author: Kolja Beigel
 from typing import Iterable, List, Optional, Union
 import torch.multiprocessing as mp
 import torch
-from typing import List, Union
 from ctypes import c_bool
 from openwakeword.model import Model
 from scipy.signal import resample
@@ -49,6 +48,7 @@ import platform
 import pyaudio
 import logging
 import struct
+import base64
 import queue
 import halo
 import time
@@ -541,6 +541,7 @@ class AudioToTextRecorder:
         self.start_recording_event = threading.Event()
         self.stop_recording_event = threading.Event()
         self.last_transcription_bytes = None
+        self.last_transcription_bytes_b64 = None
         self.initial_prompt = initial_prompt
         self.suppress_tokens = suppress_tokens
         self.use_wake_words = wake_words or wakeword_backend in {'oww', 'openwakeword', 'openwakewords'}
@@ -1209,7 +1210,7 @@ class AudioToTextRecorder:
                 if self.transcribe_count == 0:
                     logging.debug("Adding transcription request, no early transcription started")
                     start_time = time.time()  # Start timing
-                    self.parent_transcription_pipe.send((self.audio, self.language))
+                    self.parent_transcription_pipe.send((audio_copy, self.language))
                     self.transcribe_count += 1
 
                 while self.transcribe_count > 0:
@@ -1223,7 +1224,8 @@ class AudioToTextRecorder:
                     segments, info = result
                     self.detected_language = info.language if info.language_probability > 0 else None
                     self.detected_language_probability = info.language_probability
-                    self.last_transcription_bytes = audio_copy
+                    self.last_transcription_bytes = copy.deepcopy(audio_copy)                    
+                    self.last_transcription_bytes_b64 = base64.b64encode(self.last_transcription_bytes.tobytes()).decode('utf-8')
                     transcription = self._preprocess_output(segments)
                     end_time = time.time()  # End timing
                     transcription_time = end_time - start_time

+ 78 - 1
RealtimeSTT/audio_recorder_client.py

@@ -199,6 +199,8 @@ class AudioToTextRecorderClient:
         self.autostart_server = autostart_server
 
         # Instance variables
+        self.muted = False
+        self.recording_thread = None
         self.is_running = True
         self.connection_established = threading.Event()
         self.recording_start = threading.Event()
@@ -214,7 +216,8 @@ class AudioToTextRecorderClient:
             if self.debug_mode:
                 print("STT server is running and connected.")
 
-        self.start_recording()
+        if self.use_microphone:
+            self.start_recording()
 
     def text(self, on_transcription_finished=None):
         self.realtime_text = ""
@@ -255,11 +258,45 @@ class AudioToTextRecorderClient:
                 threading.Thread(target=on_transcription_finished, args=(self.final_text,)).start()
 
             return self.final_text
+
         except KeyboardInterrupt:
             if self.debug_mode:
                 print("KeyboardInterrupt in record_and_send_audio, exiting...")
             raise KeyboardInterrupt
 
+        except Exception as e:
+            print(f"Error in AudioToTextRecorderClient.text(): {e}")
+            return ""
+
+    def feed_audio(self, chunk, original_sample_rate=16000):
+        metadata = {"sampleRate": original_sample_rate}
+        metadata_json = json.dumps(metadata)
+        metadata_length = len(metadata_json)
+        message = struct.pack('<I', metadata_length) + metadata_json.encode('utf-8') + chunk
+
+        if self.is_running:
+            self.data_ws.send(message, opcode=websocket.ABNF.OPCODE_BINARY)
+
+    def set_microphone(self, microphone_on=True):
+        """
+        Set the microphone on or off.
+        """
+        self.muted = not microphone_on
+        #self.call_method("set_microphone", [microphone_on])
+        # self.use_microphone.value = microphone_on
+
+    def abort(self):
+        self.call_method("abort")
+
+    def wakeup(self):
+        self.call_method("wakeup")
+
+    def clear_audio_queue(self):
+        self.call_method("clear_audio_queue")
+
+    def stop(self):
+        self.call_method("stop")
+
     def connect(self):
         if not self.ensure_server_running():
             print("Cannot start STT server. Exiting.")
@@ -423,9 +460,19 @@ class AudioToTextRecorderClient:
                 print("Recording and sending audio...")
 
             while self.is_running:
+                if self.muted:
+                    time.sleep(0.01)
+                    continue
+
                 try:
                     audio_data = self.stream.read(CHUNK)
 
+                    if self.on_recorded_chunk:
+                        self.on_recorded_chunk(audio_data)
+
+                    if self.muted:
+                        continue
+
                     if self.recording_start.is_set():
                         metadata = {"sampleRate": self.device_sample_rate}
                         metadata_json = json.dumps(metadata)
@@ -503,6 +550,12 @@ class AudioToTextRecorderClient:
             elif data.get('type') == 'vad_detect_start':
                 if self.on_vad_detect_start:
                     self.on_vad_detect_start()
+            elif data.get('type') == 'wakeword_detection_start':
+                if self.on_wakeword_detection_start:
+                    self.on_wakeword_detection_start()
+            elif data.get('type') == 'wakeword_detection_end':
+                if self.on_wakeword_detection_end:
+                    self.on_wakeword_detection_end()
 
             else:
                 print(f"Unknown data message format: {data}")
@@ -533,6 +586,30 @@ class AudioToTextRecorderClient:
         if self.debug_mode:
             print("Data WebSocket connection opened.")
 
+    def set_parameter(self, parameter, value):
+        command = {
+            "command": "set_parameter",
+            "parameter": parameter,
+            "value": value
+        }
+        self.control_ws.send(json.dumps(command))
+
+    def get_parameter(self, parameter):
+        command = {
+            "command": "get_parameter",
+            "parameter": parameter
+        }
+        self.control_ws.send(json.dumps(command))
+
+    def call_method(self, method, args=None, kwargs=None):
+        command = {
+            "command": "call_method",
+            "method": method,
+            "args": args or [],
+            "kwargs": kwargs or {}
+        }
+        self.control_ws.send(json.dumps(command))
+
     def shutdown(self):
         self.is_running = False
         #self.stop_event.set()

+ 3 - 2
server/stt_cli_client.py

@@ -80,6 +80,7 @@ import threading
 import time
 import struct
 import socket
+import subprocess
 import shutil
 from urllib.parse import urlparse
 import queue 
@@ -92,8 +93,8 @@ CHUNK = 1024
 FORMAT = pyaudio.paInt16
 CHANNELS = 1
 RATE = 44100
-DEFAULT_CONTROL_URL = "ws://localhost:8011"
-DEFAULT_DATA_URL = "ws://localhost:8012"
+DEFAULT_CONTROL_URL = "ws://127.0.0.1:8011"
+DEFAULT_DATA_URL = "ws://127.0.0.1:8012"
 
 # Initialize colorama
 from colorama import init, Fore, Style

+ 16 - 8
server/stt_server.py

@@ -119,7 +119,7 @@ allowed_methods = [
     'clear_audio_queue',
     'wakeup',
     'shutdown',
-    'text',  # Allow 'text' method to initiate transcription
+    'text',
 ]
 allowed_parameters = [
     'silero_sensitivity',
@@ -127,8 +127,8 @@ allowed_parameters = [
     'post_speech_silence_duration',
     'listen_start',
     'recording_stop_time',
-    'recorderActive',
-    # Add other parameters as needed
+    'last_transcription_bytes',
+    'last_transcription_bytes_b64',
 ]
 
 # Queues and connections for control and data
@@ -308,7 +308,6 @@ def _recorder_thread(loop):
     
     def process_text(full_sentence):
         full_sentence = preprocess_text(full_sentence)
-        prev_text = ""
         message = json.dumps({
             'type': 'fullSentence',
             'text': full_sentence
@@ -367,16 +366,24 @@ async def control_handler(websocket, path):
                             # Optionally send a response back to the client
                             await websocket.send(json.dumps({"status": "success", "message": f"Parameter {parameter} set to {value}"}))
                         else:
-                            print(f"Parameter {parameter} is not allowed or does not exist")
-                            await websocket.send(json.dumps({"status": "error", "message": f"Parameter {parameter} is not allowed or does not exist"}))
+                            if not parameter in allowed_parameters:
+                                print(f"Parameter {parameter} is not allowed (set_parameter)")
+                                await websocket.send(json.dumps({"status": "error", "message": f"Parameter {parameter} is not allowed (set_parameter)"}))
+                            else:
+                                print(f"Parameter {parameter} does not exist (set_parameter)")
+                                await websocket.send(json.dumps({"status": "error", "message": f"Parameter {parameter} does not exist (set_parameter)"}))
                     elif command == "get_parameter":
                         parameter = command_data.get("parameter")
                         if parameter in allowed_parameters and hasattr(recorder, parameter):
                             value = getattr(recorder, parameter)
                             await websocket.send(json.dumps({"status": "success", "parameter": parameter, "value": value}))
                         else:
-                            print(f"Parameter {parameter} is not allowed or does not exist")
-                            await websocket.send(json.dumps({"status": "error", "message": f"Parameter {parameter} is not allowed or does not exist"}))
+                            if not parameter in allowed_parameters:
+                                print(f"Parameter {parameter} is not allowed (get_parameter)")
+                                await websocket.send(json.dumps({"status": "error", "message": f"Parameter {parameter} is not allowed (get_parameter)"}))
+                            else:
+                                print(f"Parameter {parameter} does not exist (get_parameter)")
+                                await websocket.send(json.dumps({"status": "error", "message": f"Parameter {parameter} does not exist (get_parameter)"}))
                     elif command == "call_method":
                         method_name = command_data.get("method")
                         if method_name in allowed_methods:
@@ -434,6 +441,7 @@ async def broadcast_audio_messages():
         message = await audio_queue.get()
         for conn in list(data_connections):
             try:
+                # print(f"Sending message: {message}")
                 await conn.send(message)
             except websockets.exceptions.ConnectionClosed:
                 data_connections.remove(conn)