Add moving average filter with hysteresis

This commit is contained in:
Dale 2023-04-30 13:41:59 +02:00
parent 4222295766
commit c5c186accb
2 changed files with 88 additions and 47 deletions

View file

@ -10,16 +10,17 @@ mqtt:
password: mymqttpassword # Login to broker. Delete if not required
# Open Wake Word config
# https://github.com/dscripka/openWakeWord#recommendations-for-usage
oww:
activation_threshold: 0.5
activation_samples: 3 # Number of samples in moving average
activation_threshold: 0.7 # Trigger wakeword when average above this threshold
deactivation_threshold: 0.2 # Do not trigger again until average falls below this threshold
# OWW config, see https://github.com/dscripka/openWakeWord#recommendations-for-usage
vad_threshold: 0.5
enable_speex_noise_suppression: false
activation_ratelimit: 5 # Only 1 activation will be sent to Rhasspy in 5 seconds
# Rhasspy microphone UDP ports, 1 per device/satellite
# https://rhasspy.readthedocs.io/en/latest/tutorials/#udp-audio-streaming
udp_ports:
base: 12202
kitchen: 12203
bedroom: 12204
base: 12202 # Delete or change as needed
kitchen: 12203 # Delete or change as needed
bedroom: 12204 # Delete or change as needed

122
detect.py
View file

@ -1,6 +1,6 @@
"""
Listen on UDP for audio from Rhasspy, detect wake words using Open Wake Word,
and then publish on MQTT when wake word is detected to trigger Rhasspy speech-to-text.
and publish on MQTT when wake word is detected to trigger Rhasspy speech-to-text.
"""
import argparse
@ -10,6 +10,7 @@ import socket
import threading
import time
import wave
from collections import deque
from json import dumps
import numpy as np
@ -19,8 +20,7 @@ from openwakeword.model import Model
RHASSPY_BYTES = 2092
RHASSPY_FRAMES = 1024
CHUNK = 1280 # 80 ms window @ 16 kHz = 1280 frames
OWW_FRAMES = CHUNK * 3 # Increase efficiency of detection but higher latency
OWW_FRAMES = 1280 # 80 ms window @ 16 kHz = 1280 frames
parser = argparse.ArgumentParser(description="Open Wake Word detection for Rhasspy")
@ -50,15 +50,22 @@ def load_config(config_file):
"password": None,
},
"oww": {
"activation_threshold": 0.5,
"activation_threshold": 0.7,
"deactivation_threshold": 0.2,
"activation_samples": 3,
"vad_threshold": 0,
"enable_speex_noise_suppression": False,
"activation_ratelimit": 5,
},
"udp_ports": {"base": 12202},
}
config = {**default_config, **config_override}
if not config["udp_ports"]:
print(
"No UDP ports configured. Configure UDP ports to receive audio for wakeword detection.",
flush=True,
)
exit()
return config
@ -67,48 +74,47 @@ class RhasspyUdpAudio(threading.Thread):
def __init__(self, roomname, port, queue):
threading.Thread.__init__(self)
self._roomname = roomname
self._port = port
self._queue = queue
self._buffer = []
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self._sock.bind(("", port))
self.roomname = roomname
self.port = port
self.queue = queue
self.buffer = []
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.sock.bind(("", port))
def run(self):
"""Thread to receive UDP audio and add to processing queue."""
print(
f"Listening for {self._roomname} audio on UDP port {self._port}", flush=True
f"Listening for {self.roomname} audio on UDP port {self.port}", flush=True
)
while True:
data, addr = self._sock.recvfrom(RHASSPY_BYTES)
data, addr = self.sock.recvfrom(RHASSPY_BYTES)
audio = wave.open(io.BytesIO(data))
frames = audio.readframes(RHASSPY_FRAMES)
self._buffer.extend(np.frombuffer(frames, dtype=np.int16))
if len(self._buffer) > OWW_FRAMES:
self._queue.put(
{
"roomname": self._roomname,
"timestamp": time.time(),
"audio": np.asarray(self._buffer[:OWW_FRAMES], dtype=np.int16),
}
self.buffer.extend(np.frombuffer(frames, dtype=np.int16))
if len(self.buffer) > OWW_FRAMES:
self.queue.put(
(
self.roomname,
time.time(),
np.asarray(self.buffer[:OWW_FRAMES], dtype=np.int16),
)
self._buffer = self._buffer[OWW_FRAMES:]
)
self.buffer = self.buffer[OWW_FRAMES:]
class Prediction(threading.Thread):
"""Process wake word detection queue and publishing MQTT message when a wake word is detected."""
def __init__(self, config, queue):
def __init__(self, queue):
threading.Thread.__init__(self)
self.config = config
self.queue = queue
self.published = 0
self.filters = {}
self.mqtt = paho.mqtt.client.Client()
self.mqtt.username_pw_set(
config["mqtt"]["username"], config["mqtt"]["password"]
)
self.mqtt.connect(config["mqtt"]["broker"], config["mqtt"]["port"], 60)
print("Connected to MQTT broker", flush=True)
print("MQTT: Connected to broker", flush=True)
self.oww = Model(
vad_threshold=config["oww"]["vad_threshold"],
@ -122,33 +128,68 @@ class Prediction(threading.Thread):
while True:
roomname, timestamp, audio = self.queue.get()
prediction = self.oww.predict(audio)
for model_name in prediction.keys():
prediction_level = prediction[model_name]
if prediction_level >= self.config["oww"]["activation_threshold"]:
delta = time.time() - self.published
for wakeword in prediction.keys():
confidence = prediction[wakeword]
if self.__filter(wakeword, confidence):
print(
f"{roomname} {model_name} {prediction_level:.3f} {delta:.3f}",
f"Detected wakeword {wakeword} in {roomname}",
flush=True,
)
if delta > self.config["oww"]["activation_ratelimit"]:
self.__publish(model_name, roomname)
self.published = time.time()
self.__publish(wakeword, roomname)
def __publish(self, model_name, roomname):
def __filter(self, wakeword, confidence):
"""
Filter so that a wakeword is only triggered once per utterance.
When simple moving average (of length `activation_samples`) crosses the `activation_threshold`
then trigger Rhasspy. Only "re-arm" the wakeword when the moving average drops below
the `deactivation_threshold`.
"""
try:
self.filters[wakeword]["samples"].append(confidence)
except KeyError:
self.filters[wakeword] = {
"samples": deque(
[confidence], maxlen=config["oww"]["activation_samples"]
),
"active": False,
}
moving_average = (
sum(self.filters[wakeword]["samples"]) / config["oww"]["activation_samples"]
)
activated = False
if (
not self.filters[wakeword]["active"]
and moving_average >= config["oww"]["activation_threshold"]
):
self.filters[wakeword]["active"] = True
activated = True
elif (
self.filters[wakeword]["active"]
and moving_average < config["oww"]["deactivation_threshold"]
):
self.filters[wakeword]["active"] = False
if moving_average > 0.1:
print(
f"{wakeword:<16} {activated!s:<8} {self.filters[wakeword]}", flush=True
)
return activated
def __publish(self, wakeword, roomname):
"""Publish wake word message to Rhasspy Hermes/MQTT."""
payload = {
"modelId": model_name,
"modelId": wakeword,
"modelVersion": "",
"modelType": "universal",
"currentSensitivity": self.config["oww"]["activation_threshold"],
"currentSensitivity": config["oww"]["activation_threshold"],
"siteId": roomname,
"sessionId": None,
"sendAudioCaptured": None,
"lang": None,
"customEntities": None,
}
self.mqtt.publish(f"hermes/hotword/{model_name}/detected", dumps(payload))
print("Sent wakeword to Rhasspy", flush=True)
self.mqtt.publish(f"hermes/hotword/{wakeword}/detected", dumps(payload))
print("MQTT: Published to Rhasspy", flush=True)
if __name__ == "__main__":
@ -160,8 +201,7 @@ if __name__ == "__main__":
t.daemon = True
t.start()
threads.append(t)
t = Prediction(config, q)
t.deamon = True
t = Prediction(q)
t.start()
threads.append(t)
print(f"Threads: {threads}")