Add moving average filter with hysteresis
This commit is contained in:
parent
4222295766
commit
c5c186accb
2 changed files with 88 additions and 47 deletions
|
@ -10,16 +10,17 @@ mqtt:
|
|||
password: mymqttpassword # Login to broker. Delete if not required
|
||||
|
||||
# Open Wake Word config
|
||||
# https://github.com/dscripka/openWakeWord#recommendations-for-usage
|
||||
oww:
|
||||
activation_threshold: 0.5
|
||||
activation_samples: 3 # Number of samples in moving average
|
||||
activation_threshold: 0.7 # Trigger wakeword when average above this threshold
|
||||
deactivation_threshold: 0.2 # Do not trigger again until average falls below this threshold
|
||||
# OWW config, see https://github.com/dscripka/openWakeWord#recommendations-for-usage
|
||||
vad_threshold: 0.5
|
||||
enable_speex_noise_suppression: false
|
||||
activation_ratelimit: 5 # Only 1 activation will be sent to Rhasspy in 5 seconds
|
||||
|
||||
# Rhasspy microphone UDP ports, 1 per device/satellite
|
||||
# https://rhasspy.readthedocs.io/en/latest/tutorials/#udp-audio-streaming
|
||||
udp_ports:
|
||||
base: 12202
|
||||
kitchen: 12203
|
||||
bedroom: 12204
|
||||
base: 12202 # Delete or change as needed
|
||||
kitchen: 12203 # Delete or change as needed
|
||||
bedroom: 12204 # Delete or change as needed
|
||||
|
|
122
detect.py
122
detect.py
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Listen on UDP for audio from Rhasspy, detect wake words using Open Wake Word,
|
||||
and then publish on MQTT when wake word is detected to trigger Rhasspy speech-to-text.
|
||||
and publish on MQTT when wake word is detected to trigger Rhasspy speech-to-text.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
@ -10,6 +10,7 @@ import socket
|
|||
import threading
|
||||
import time
|
||||
import wave
|
||||
from collections import deque
|
||||
from json import dumps
|
||||
|
||||
import numpy as np
|
||||
|
@ -19,8 +20,7 @@ from openwakeword.model import Model
|
|||
|
||||
RHASSPY_BYTES = 2092
|
||||
RHASSPY_FRAMES = 1024
|
||||
CHUNK = 1280 # 80 ms window @ 16 kHz = 1280 frames
|
||||
OWW_FRAMES = CHUNK * 3 # Increase efficiency of detection but higher latency
|
||||
OWW_FRAMES = 1280 # 80 ms window @ 16 kHz = 1280 frames
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Open Wake Word detection for Rhasspy")
|
||||
|
@ -50,15 +50,22 @@ def load_config(config_file):
|
|||
"password": None,
|
||||
},
|
||||
"oww": {
|
||||
"activation_threshold": 0.5,
|
||||
"activation_threshold": 0.7,
|
||||
"deactivation_threshold": 0.2,
|
||||
"activation_samples": 3,
|
||||
"vad_threshold": 0,
|
||||
"enable_speex_noise_suppression": False,
|
||||
"activation_ratelimit": 5,
|
||||
},
|
||||
"udp_ports": {"base": 12202},
|
||||
}
|
||||
|
||||
config = {**default_config, **config_override}
|
||||
if not config["udp_ports"]:
|
||||
print(
|
||||
"No UDP ports configured. Configure UDP ports to receive audio for wakeword detection.",
|
||||
flush=True,
|
||||
)
|
||||
exit()
|
||||
return config
|
||||
|
||||
|
||||
|
@ -67,48 +74,47 @@ class RhasspyUdpAudio(threading.Thread):
|
|||
|
||||
def __init__(self, roomname, port, queue):
|
||||
threading.Thread.__init__(self)
|
||||
self._roomname = roomname
|
||||
self._port = port
|
||||
self._queue = queue
|
||||
self._buffer = []
|
||||
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self._sock.bind(("", port))
|
||||
self.roomname = roomname
|
||||
self.port = port
|
||||
self.queue = queue
|
||||
self.buffer = []
|
||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.sock.bind(("", port))
|
||||
|
||||
def run(self):
|
||||
"""Thread to receive UDP audio and add to processing queue."""
|
||||
print(
|
||||
f"Listening for {self._roomname} audio on UDP port {self._port}", flush=True
|
||||
f"Listening for {self.roomname} audio on UDP port {self.port}", flush=True
|
||||
)
|
||||
while True:
|
||||
data, addr = self._sock.recvfrom(RHASSPY_BYTES)
|
||||
data, addr = self.sock.recvfrom(RHASSPY_BYTES)
|
||||
audio = wave.open(io.BytesIO(data))
|
||||
frames = audio.readframes(RHASSPY_FRAMES)
|
||||
self._buffer.extend(np.frombuffer(frames, dtype=np.int16))
|
||||
if len(self._buffer) > OWW_FRAMES:
|
||||
self._queue.put(
|
||||
{
|
||||
"roomname": self._roomname,
|
||||
"timestamp": time.time(),
|
||||
"audio": np.asarray(self._buffer[:OWW_FRAMES], dtype=np.int16),
|
||||
}
|
||||
self.buffer.extend(np.frombuffer(frames, dtype=np.int16))
|
||||
if len(self.buffer) > OWW_FRAMES:
|
||||
self.queue.put(
|
||||
(
|
||||
self.roomname,
|
||||
time.time(),
|
||||
np.asarray(self.buffer[:OWW_FRAMES], dtype=np.int16),
|
||||
)
|
||||
self._buffer = self._buffer[OWW_FRAMES:]
|
||||
)
|
||||
self.buffer = self.buffer[OWW_FRAMES:]
|
||||
|
||||
|
||||
class Prediction(threading.Thread):
|
||||
"""Process wake word detection queue and publishing MQTT message when a wake word is detected."""
|
||||
|
||||
def __init__(self, config, queue):
|
||||
def __init__(self, queue):
|
||||
threading.Thread.__init__(self)
|
||||
self.config = config
|
||||
self.queue = queue
|
||||
self.published = 0
|
||||
self.filters = {}
|
||||
self.mqtt = paho.mqtt.client.Client()
|
||||
self.mqtt.username_pw_set(
|
||||
config["mqtt"]["username"], config["mqtt"]["password"]
|
||||
)
|
||||
self.mqtt.connect(config["mqtt"]["broker"], config["mqtt"]["port"], 60)
|
||||
print("Connected to MQTT broker", flush=True)
|
||||
print("MQTT: Connected to broker", flush=True)
|
||||
|
||||
self.oww = Model(
|
||||
vad_threshold=config["oww"]["vad_threshold"],
|
||||
|
@ -122,33 +128,68 @@ class Prediction(threading.Thread):
|
|||
while True:
|
||||
roomname, timestamp, audio = self.queue.get()
|
||||
prediction = self.oww.predict(audio)
|
||||
for model_name in prediction.keys():
|
||||
prediction_level = prediction[model_name]
|
||||
if prediction_level >= self.config["oww"]["activation_threshold"]:
|
||||
delta = time.time() - self.published
|
||||
for wakeword in prediction.keys():
|
||||
confidence = prediction[wakeword]
|
||||
if self.__filter(wakeword, confidence):
|
||||
print(
|
||||
f"{roomname} {model_name} {prediction_level:.3f} {delta:.3f}",
|
||||
f"Detected wakeword {wakeword} in {roomname}",
|
||||
flush=True,
|
||||
)
|
||||
if delta > self.config["oww"]["activation_ratelimit"]:
|
||||
self.__publish(model_name, roomname)
|
||||
self.published = time.time()
|
||||
self.__publish(wakeword, roomname)
|
||||
|
||||
def __publish(self, model_name, roomname):
|
||||
def __filter(self, wakeword, confidence):
|
||||
"""
|
||||
Filter so that a wakeword is only triggered once per utterance.
|
||||
|
||||
When simple moving average (of length `activation_samples`) crosses the `activation_threshold`
|
||||
then trigger Rhasspy. Only "re-arm" the wakeword when the moving average drops below
|
||||
the `deactivation_threshold`.
|
||||
"""
|
||||
try:
|
||||
self.filters[wakeword]["samples"].append(confidence)
|
||||
except KeyError:
|
||||
self.filters[wakeword] = {
|
||||
"samples": deque(
|
||||
[confidence], maxlen=config["oww"]["activation_samples"]
|
||||
),
|
||||
"active": False,
|
||||
}
|
||||
moving_average = (
|
||||
sum(self.filters[wakeword]["samples"]) / config["oww"]["activation_samples"]
|
||||
)
|
||||
activated = False
|
||||
if (
|
||||
not self.filters[wakeword]["active"]
|
||||
and moving_average >= config["oww"]["activation_threshold"]
|
||||
):
|
||||
self.filters[wakeword]["active"] = True
|
||||
activated = True
|
||||
elif (
|
||||
self.filters[wakeword]["active"]
|
||||
and moving_average < config["oww"]["deactivation_threshold"]
|
||||
):
|
||||
self.filters[wakeword]["active"] = False
|
||||
if moving_average > 0.1:
|
||||
print(
|
||||
f"{wakeword:<16} {activated!s:<8} {self.filters[wakeword]}", flush=True
|
||||
)
|
||||
return activated
|
||||
|
||||
def __publish(self, wakeword, roomname):
|
||||
"""Publish wake word message to Rhasspy Hermes/MQTT."""
|
||||
payload = {
|
||||
"modelId": model_name,
|
||||
"modelId": wakeword,
|
||||
"modelVersion": "",
|
||||
"modelType": "universal",
|
||||
"currentSensitivity": self.config["oww"]["activation_threshold"],
|
||||
"currentSensitivity": config["oww"]["activation_threshold"],
|
||||
"siteId": roomname,
|
||||
"sessionId": None,
|
||||
"sendAudioCaptured": None,
|
||||
"lang": None,
|
||||
"customEntities": None,
|
||||
}
|
||||
self.mqtt.publish(f"hermes/hotword/{model_name}/detected", dumps(payload))
|
||||
print("Sent wakeword to Rhasspy", flush=True)
|
||||
self.mqtt.publish(f"hermes/hotword/{wakeword}/detected", dumps(payload))
|
||||
print("MQTT: Published to Rhasspy", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -160,8 +201,7 @@ if __name__ == "__main__":
|
|||
t.daemon = True
|
||||
t.start()
|
||||
threads.append(t)
|
||||
t = Prediction(config, q)
|
||||
t.deamon = True
|
||||
t = Prediction(q)
|
||||
t.start()
|
||||
threads.append(t)
|
||||
print(f"Threads: {threads}")
|
||||
|
|
Loading…
Add table
Reference in a new issue