Cleanup config

This commit is contained in:
Dale 2023-04-23 13:54:38 +02:00
parent eca74772de
commit c9ca033eda
2 changed files with 31 additions and 16 deletions

View file

@ -12,4 +12,8 @@ oww:
activation_threshold: 0.5 activation_threshold: 0.5
vad_threshold: 0.5 vad_threshold: 0.5
enable_speex_noise_suppression: false enable_speex_noise_suppression: false
activation_ratelimit: 5 # Only 1 activation will be sent to Rhasspy in 5 seconds
rhasspy:
audio_udp_port: 12202 # Port that Rhasspy streams audio on, https://rhasspy.readthedocs.io/en/latest/tutorials/#udp-audio-streaming

View file

@ -1,3 +1,8 @@
"""
Listen on UDP for audio from Rhasspy, detect wake words using Open Wake Word,
and then publish on MQTT when wake word is detected to trigger Rhasspy speech-to-text.
"""
import argparse import argparse
import io import io
import queue import queue
@ -8,7 +13,7 @@ import wave
from json import dumps from json import dumps
import numpy as np import numpy as np
import paho.mqtt.client as mqtt import paho.mqtt.client
import yaml import yaml
from openwakeword.model import Model from openwakeword.model import Model
@ -19,7 +24,7 @@ OWW_FRAMES = CHUNK * 3 # Increase efficiency of detection but higher latency
q = queue.Queue() q = queue.Queue()
parser = argparse.ArgumentParser(description="PiJuice to MQTT") parser = argparse.ArgumentParser(description="Open Wake Word detection for Rhasspy")
parser.add_argument( parser.add_argument(
"-c", "-c",
"--config", "--config",
@ -31,7 +36,7 @@ args = parser.parse_args()
def load_config(config_file): def load_config(config_file):
"""Load the configuration from config yaml file and use it to override the defaults.""" """Load the configuration from config.yaml file and use it to override the defaults."""
with open(config_file, "r") as f: with open(config_file, "r") as f:
config_override = yaml.safe_load(f) config_override = yaml.safe_load(f)
@ -46,14 +51,16 @@ def load_config(config_file):
"activation_threshold": 0.5, "activation_threshold": 0.5,
"vad_threshold": 0, "vad_threshold": 0,
"enable_speex_noise_suppression": False, "enable_speex_noise_suppression": False,
"activation_ratelimit": 5,
}, },
"rhasspy": {"audio_udp_port": 12202},
} }
config = {**default_config, **config_override} config = {**default_config, **config_override}
return config return config
def receive_udp_audio(port=12102): def receive_udp_audio(port=12202):
""" """
Get audio from UDP stream and add to wake word detection queue. Get audio from UDP stream and add to wake word detection queue.
@ -80,11 +87,12 @@ def receive_udp_audio(port=12102):
audio_buffer = audio_buffer[OWW_FRAMES:] audio_buffer = audio_buffer[OWW_FRAMES:]
def on_connect(client, userdata, flags, rc): def mqtt_on_connect(mqtt, userdata, flags, rc):
client.subscribe("hermes/hotword/#") # mqtt.subscribe("hermes/hotword/#")
pass
def on_message(client, userdata, msg): def mqtt_on_message(mqtt, userdata, msg):
# print(f"{msg.topic} {msg.payload}") # print(f"{msg.topic} {msg.payload}")
pass pass
@ -92,22 +100,25 @@ def on_message(client, userdata, msg):
config = load_config(args.config_file) config = load_config(args.config_file)
if __name__ == "__main__": if __name__ == "__main__":
client = mqtt.Client() mqtt = paho.mqtt.client.Client()
client.on_connect = on_connect mqtt.on_connect = mqtt_on_connect
client.on_message = on_message mqtt.on_message = mqtt_on_message
client.username_pw_set(config["mqtt"]["username"], config["mqtt"]["password"]) mqtt.username_pw_set(config["mqtt"]["username"], config["mqtt"]["password"])
client.connect(config["mqtt"]["broker"], config["mqtt"]["port"], 60) mqtt.connect(config["mqtt"]["broker"], config["mqtt"]["port"], 60)
print("Connected to MQTT broker") print("Connected to MQTT broker")
oww = Model( oww = Model(
vad_threshold=config["oww"]["vad_threshold"], vad_threshold=config["oww"]["vad_threshold"],
enable_speex_noise_suppression=config["oww"]["enable_speex_noise_suppression"], enable_speex_noise_suppression=config["oww"]["enable_speex_noise_suppression"],
) )
receive_audio_thread = threading.Thread(target=receive_udp_audio) receive_audio_thread = threading.Thread(
target=receive_udp_audio, kwargs={"port": config["rhasspy"]["audio_udp_port"]}
)
receive_audio_thread.daemon = True
receive_audio_thread.start() receive_audio_thread.start()
published = 0 published = 0
client.loop_start() mqtt.loop_start()
while True: while True:
prediction = oww.predict(q.get()) prediction = oww.predict(q.get())
for model_name in prediction.keys(): for model_name in prediction.keys():
@ -115,7 +126,7 @@ if __name__ == "__main__":
if prediction_level >= config["oww"]["activation_threshold"]: if prediction_level >= config["oww"]["activation_threshold"]:
delta = time.time() - published delta = time.time() - published
print(f"{model_name} {prediction_level:.3f} {delta:.3f}") print(f"{model_name} {prediction_level:.3f} {delta:.3f}")
if delta > 5: if delta > config["oww"]["activation_ratelimit"]:
payload = { payload = {
"modelId": model_name, "modelId": model_name,
"modelVersion": "", "modelVersion": "",
@ -127,7 +138,7 @@ if __name__ == "__main__":
"lang": None, "lang": None,
"customEntities": None, "customEntities": None,
} }
client.publish( mqtt.publish(
f"hermes/hotword/{model_name}/detected", dumps(payload) f"hermes/hotword/{model_name}/detected", dumps(payload)
) )
print("Sent wakeword to Rhasspy") print("Sent wakeword to Rhasspy")