diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..c836869 --- /dev/null +++ b/.flake8 @@ -0,0 +1,2 @@ +[flake8] +max-line-length = 110 diff --git a/.gitignore b/.gitignore index d9005f2..1276355 100644 --- a/.gitignore +++ b/.gitignore @@ -150,3 +150,4 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +config.yaml diff --git a/config.yaml.example b/config.yaml.example new file mode 100644 index 0000000..7f9ff89 --- /dev/null +++ b/config.yaml.example @@ -0,0 +1,15 @@ +# Configuration for using Open Wake Work with Rhasspy voice assistant + +mqtt: + broker: 127.0.0.1 + port: 1883 + username: mymqttusername # Login to broker. Delete if not required + password: mymqttpassword # Login to broker. Delete if not required + +# Open Wake Word config +# https://github.com/dscripka/openWakeWord#recommendations-for-usage +oww: + activation_threshold: 0.5 + vad_threshold: 0.5 + enable_speex_noise_suppression: false + diff --git a/detect.py b/detect.py index c1bedd0..45ca634 100644 --- a/detect.py +++ b/detect.py @@ -1,46 +1,134 @@ +import argparse import io import queue import socket import threading +import time import wave +from json import dumps import numpy as np +import paho.mqtt.client as mqtt +import yaml from openwakeword.model import Model -CHUNK = 1280 +RHASSPY_BYTES = 2092 +RHASSPY_FRAMES = 1024 +CHUNK = 1280 # 80 ms window @ 16 kHz = 1280 frames +OWW_FRAMES = CHUNK * 3 # Increase efficiency of detection but higher latency -oww = Model() q = queue.Queue() +parser = argparse.ArgumentParser(description="PiJuice to MQTT") +parser.add_argument( + "-c", + "--config", + default="config.yaml", + help="Configuration yaml file, defaults to `config.yaml`", + dest="config_file", +) +args = parser.parse_args() + + +def load_config(config_file): + """Load the configuration from config yaml file and use it to override the defaults.""" + with open(config_file, "r") as f: + config_override = yaml.safe_load(f) + + default_config = { + "mqtt": { + "broker": "127.0.0.1", + "port": 1883, + "username": None, + "password": None, + }, + "oww": { + "activation_threshold": 0.5, + "vad_threshold": 0, + "enable_speex_noise_suppression": False, + }, + } + + config = {**default_config, **config_override} + return config + def receive_udp_audio(port=12102): """ Get audio from UDP stream and add to wake word detection queue. Rhasspy sends 1024 x 16bit frames + header = 2092 bytes - Open Wake Word expects 1280 x 16bit frames (CHUNK) + Open Wake Word expects minimum of 1280 x 16bit frames """ sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.bind(("", port)) + print(f"Listening on UDP port {port}") audio_buffer = [] while True: - data, addr = sock.recvfrom(2092) + data, addr = sock.recvfrom(RHASSPY_BYTES) audio = wave.open(io.BytesIO(data)) - frames = audio.readframes(1024) + frames = audio.readframes(RHASSPY_FRAMES) audio_buffer.extend(np.frombuffer(frames, dtype=np.int16)) - if len(audio_buffer) > CHUNK: - q.put(audio_buffer[:CHUNK]) - audio_buffer = audio_buffer[CHUNK:] + print(".", end="", flush=True) + if len(audio_buffer) > OWW_FRAMES: + q.put( + np.asarray(audio_buffer[:OWW_FRAMES], dtype=np.int16) + ) # Must be np array for VAD + audio_buffer = audio_buffer[OWW_FRAMES:] -receive_audio_thread = threading.Thread(target=receive_udp_audio) -receive_audio_thread.start() -while True: - prediction = oww.predict(q.get()) - for model_name in prediction.keys(): - prediction_level = prediction[model_name] - if prediction_level >= 0.5: - print(model_name, prediction_level) +def on_connect(client, userdata, flags, rc): + client.subscribe("hermes/hotword/#") + + +def on_message(client, userdata, msg): + # print(f"{msg.topic} {msg.payload}") + pass + + +config = load_config(args.config_file) + +if __name__ == "__main__": + client = mqtt.Client() + client.on_connect = on_connect + client.on_message = on_message + client.username_pw_set(config["mqtt"]["username"], config["mqtt"]["password"]) + client.connect(config["mqtt"]["broker"], config["mqtt"]["port"], 60) + print("Connected to MQTT broker") + + oww = Model( + vad_threshold=config["oww"]["vad_threshold"], + enable_speex_noise_suppression=config["oww"]["enable_speex_noise_suppression"], + ) + receive_audio_thread = threading.Thread(target=receive_udp_audio) + receive_audio_thread.start() + + published = 0 + client.loop_start() + while True: + prediction = oww.predict(q.get()) + for model_name in prediction.keys(): + prediction_level = prediction[model_name] + if prediction_level >= config["oww"]["activation_threshold"]: + delta = time.time() - published + print(f"{model_name} {prediction_level:.3f} {delta:.3f}") + if delta > 5: + payload = { + "modelId": model_name, + "modelVersion": "", + "modelType": "universal", + "currentSensitivity": config["oww"]["activation_threshold"], + "siteId": "bedroom", + "sessionId": None, + "sendAudioCaptured": None, + "lang": None, + "customEntities": None, + } + client.publish( + f"hermes/hotword/{model_name}/detected", dumps(payload) + ) + print("Sent wakeword to Rhasspy") + published = time.time() diff --git a/requirements.txt b/requirements.txt index 9bcf0a0..dcc36b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,3 @@ -openwakeword +openwakeword @ git+https://github.com/dscripka/openWakeWord +paho-mqtt +pyyaml