Add simple voice detection API
Change-Id: Ic7e176b734afdd780a4b7749378e85bfdccfea84
diff --git a/Makefile b/Makefile
index b718214..c368b19 100644
--- a/Makefile
+++ b/Makefile
@@ -19,10 +19,19 @@
mobilenet_v2_1.0_224_quant_edgetpu.tflite:
wget "$(TEST_DATA_URL)/$@"
-download: imagenet_labels.txt ssd_mobilenet_v2_face_quant_postprocess_edgetpu.tflite mobilenet_v2_1.0_224_quant_edgetpu.tflite
+labels_gc2.raw.txt:
+ wget "https://github.com/google-coral/project-keyword-spotter/raw/master/config/$@"
+
+voice_commands_v0.7_edgetpu.tflite:
+ wget "https://github.com/google-coral/project-keyword-spotter/raw/master/models/$@"
+
+download: imagenet_labels.txt \
+ ssd_mobilenet_v2_face_quant_postprocess_edgetpu.tflite \
+ mobilenet_v2_1.0_224_quant_edgetpu.tflite \
+ labels_gc2.raw.txt \
+ voice_commands_v0.7_edgetpu.tflite
clean:
rm -rf __pycache__ \
- imagenet_labels.txt \
- ssd_mobilenet_v2_face_quant_postprocess_edgetpu.tflite \
- mobilenet_v2_1.0_224_quant_edgetpu.tflite
+ *.txt \
+ *.tflite
diff --git a/audio_recorder.py b/audio_recorder.py
new file mode 100644
index 0000000..ff23059
--- /dev/null
+++ b/audio_recorder.py
@@ -0,0 +1,220 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Interface to asynchronously capture continuous audio from PyAudio.
+
+
+This module requires pyaudio. See here for installation instructions:
+http://people.csail.mit.edu/hubert/pyaudio/
+
+This module provides one class, AudioRecorder, which buffers chunks of audio
+from PyAudio.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+
+import math
+import time
+
+import numpy as np
+import pyaudio
+import queue
+
+logger = logging.getLogger(__name__)
+
+
+class TimeoutError(Exception):
+ """A timeout while waiting for pyaudio to buffer samples."""
+ pass
+
+
+class AudioRecorder(object):
+ """Asynchronously record and buffer audio using pyaudio.
+
+ This class wraps the pyaudio interface. It contains a queue.Queue object to
+ hold chunks of raw audio, and a callback function _enqueue_audio() which
+ places raw audio into this queue. This allows the pyaudio.Stream object to
+ record asynchronously at low latency.
+
+ The class acts as a context manager. When entering the context it creates a
+ pyaudio.Stream object and starts recording; it stops recording on exit. The
+ Stream saves all of its audio to the Queue as two-tuples of
+ (timestamp, raw_audio). The raw_audio is available from the queue as a numpy
+ array using the get_audio() function.
+
+ This class uses the term "frame" in the same sense that PortAudio does, so
+ "frame" means something different here than elsewhere in the daredevil stack.
+ A frame in PortAudio is one audio sample across all channels, so one frame of
+ 16-bit stereo audio is four bytes of data as two 16-bit integers.
+ """
+ pyaudio_format = pyaudio.paInt16
+ numpy_format = np.int16
+ num_channels = 1
+
+ # How many frames of audio PyAudio will fetch at once.
+ # Higher numbers will increase the latancy.
+ frames_per_chunk = 2**9
+
+ # Limit queue to this number of audio chunks.
+ max_queue_chunks = 1200
+
+ # Timeout if we can't get a chunk from the queue for timeout_factor times the
+ # chunk duration.
+ timeout_factor = 4
+
+ def __init__(self, raw_audio_sample_rate_hz=48000,
+ downsample_factor=3,
+ device_index=None):
+ self._downsample_factor = downsample_factor
+ self._raw_audio_sample_rate_hz = raw_audio_sample_rate_hz
+ self.audio_sample_rate_hz = self._raw_audio_sample_rate_hz // self._downsample_factor
+ self._raw_audio_queue = queue.Queue(self.max_queue_chunks)
+ self._audio = pyaudio.PyAudio()
+ self._print_input_devices()
+ self._device_index = device_index
+
+ def __enter__(self):
+ if self._device_index is None:
+ self._device_index = self._audio.get_default_input_device_info()["index"]
+ kwargs = {
+ "input_device_index": self._device_index
+ }
+ device_info = self._audio.get_device_info_by_host_api_device_index(
+ 0, self._device_index)
+ if device_info.get("maxInputChannels") <= 0:
+ raise ValueError("Audio device has insufficient input channels.")
+ print("Using audio device '%s' for index %d" % (
+ device_info["name"], device_info["index"]))
+ self._stream = self._audio.open(
+ format=self.pyaudio_format,
+ channels=self.num_channels,
+ rate=self._raw_audio_sample_rate_hz,
+ input=True,
+ output=False,
+ frames_per_buffer=self.frames_per_chunk,
+ start=True,
+ stream_callback=self._enqueue_raw_audio,
+ **kwargs)
+ logger.info("Started audio stream.")
+ return self
+
+ def __exit__(self, exception_type, exception_value, traceback):
+ self._stream.stop_stream()
+ self._stream.close()
+ logger.info("Stopped and closed audio stream.")
+
+ def __del__(self):
+ self._audio.terminate()
+ logger.info("Terminated PyAudio/PortAudio.")
+
+ @property
+ def is_active(self):
+ return self._stream.is_active()
+
+ @property
+ def bytes_per_sample(self):
+ return pyaudio.get_sample_size(self.pyaudio_format)
+
+ @property
+ def _chunk_duration_seconds(self):
+ return self.frames_per_chunk / self._raw_audio_sample_rate_hz
+
+ def _print_input_devices(self):
+ info = self._audio.get_host_api_info_by_index(0)
+ print("\nInput microphone devices:")
+ for i in range(0, info.get("deviceCount")):
+ device_info = self._audio.get_device_info_by_host_api_device_index(0, i)
+ if device_info.get("maxInputChannels") <= 0: continue
+ print(" ID: ", i, " - ", device_info.get("name"))
+
+ def _enqueue_raw_audio(self, in_data, *_): # unused args to match expected
+ try:
+ self._raw_audio_queue.put((in_data, time.time()), block=False)
+ return None, pyaudio.paContinue
+ except queue.Full:
+ error_message = "Raw audio buffer full."
+ logger.critical(error_message)
+ raise TimeoutError(error_message)
+
+ def _get_chunk(self, timeout=None):
+ raw_data, timestamp = self._raw_audio_queue.get(timeout=timeout)
+ array_data = np.fromstring(raw_data, self.numpy_format).reshape(
+ -1, self.num_channels)
+ return array_data, timestamp
+
+ def get_audio_device_info(self):
+ if self._device_index is None:
+ return self._audio.get_default_input_device_info()
+ else:
+ return self._audio.get_device_info_by_index(self._device_index)
+
+ def sample_duration_seconds(self, num_samples):
+ return num_samples / self.audio_sample_rate_hz / self.num_channels
+
+ def clear_queue(self):
+ logger.debug("Purging %d chunks from queue.", self._raw_audio_queue.qsize())
+ while not self._raw_audio_queue.empty():
+ self._raw_audio_queue.get()
+
+ def get_audio(self, num_audio_frames):
+ """Grab at least num_audio_frames frames of audio.
+
+ Record at least num_audio_frames of audio and transform it into a
+ numpy array. The term "frame" is in the sense used by PortAudio; see the
+ note in the class docstring for details.
+
+ Audio returned will be the earliest audio in the queue; it could be from
+ before this function was called.
+
+ Args:
+ num_audio_frames: minimum number of samples of audio to grab.
+
+ Returns:
+ A tuple of (audio, first_timestamp, last_timestamp).
+ """
+ num_audio_chunks = int(math.ceil(num_audio_frames *
+ self._downsample_factor / self.frames_per_chunk))
+ logger.debug("Capturing %d chunks to get at least %d frames.",
+ num_audio_chunks, num_audio_frames)
+ if num_audio_chunks < 1:
+ num_audio_chunks = 1
+ try:
+ timeout = self.timeout_factor * self._chunk_duration_seconds
+ chunks, timestamps = zip(
+ *[self._get_chunk(timeout=timeout) for _ in range(num_audio_chunks)])
+ except queue.Empty:
+ error_message = "Audio capture timed out after %.1f seconds." % timeout
+ logger.critical(error_message)
+ raise TimeoutError(error_message)
+
+ assert len(chunks) == num_audio_chunks
+ logger.debug("Got %d chunks. Chunk 0 has shape %s and dtype %s.",
+ len(chunks), chunks[0].shape, chunks[0].dtype)
+ if self._raw_audio_queue.qsize() > (0.8 * self.max_queue_chunks):
+ logger.warning("%d chunks remain in the queue.",
+ self._raw_audio_queue.qsize())
+ else:
+ logger.debug("%d chunks remain in the queue.",
+ self._raw_audio_queue.qsize())
+
+ audio = np.concatenate(chunks)
+ if self._downsample_factor != 1:
+ audio = audio[::self._downsample_factor]
+ logging.debug("Audio array has shape %s and dtype %s.", audio.shape,
+ audio.dtype)
+ return audio * 0.5, timestamps[0], timestamps[-1]
diff --git a/mel_features.py b/mel_features.py
new file mode 100644
index 0000000..1a124cb
--- /dev/null
+++ b/mel_features.py
@@ -0,0 +1,222 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Defines routines to compute mel spectrogram features from audio waveform."""
+
+import numpy as np
+
+
+def frame(data, window_length, hop_length):
+ """Convert array into a sequence of successive possibly overlapping frames.
+
+ An n-dimensional array of shape (num_samples, ...) is converted into an
+ (n+1)-D array of shape (num_frames, window_length, ...), where each frame
+ starts hop_length points after the preceding one.
+
+ This is accomplished using stride_tricks, so the original data is not
+ copied. However, there is no zero-padding, so any incomplete frames at the
+ end are not included.
+
+ Args:
+ data: np.array of dimension N >= 1.
+ window_length: Number of samples in each frame.
+ hop_length: Advance (in samples) between each window.
+
+ Returns:
+ (N+1)-D np.array with as many rows as there are complete frames that can be
+ extracted.
+ """
+ num_samples = data.shape[0]
+ num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length))
+ shape = (num_frames, window_length) + data.shape[1:]
+ strides = (data.strides[0] * hop_length,) + data.strides
+ return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
+
+
+def periodic_hann(window_length):
+ """Calculate a "periodic" Hann window.
+
+ The classic Hann window is defined as a raised cosine that starts and
+ ends on zero, and where every value appears twice, except the middle
+ point for an odd-length window. Matlab calls this a "symmetric" window
+ and np.hanning() returns it. However, for Fourier analysis, this
+ actually represents just over one cycle of a period N-1 cosine, and
+ thus is not compactly expressed on a length-N Fourier basis. Instead,
+ it's better to use a raised cosine that ends just before the final
+ zero value - i.e. a complete cycle of a period-N cosine. Matlab
+ calls this a "periodic" window. This routine calculates it.
+
+ Args:
+ window_length: The number of points in the returned window.
+
+ Returns:
+ A 1D np.array containing the periodic hann window.
+ """
+ return 0.5 - (0.5 * np.cos(2 * np.pi / window_length *
+ np.arange(window_length)))
+
+
+def stft_magnitude(signal, fft_length,
+ hop_length=None,
+ window_length=None):
+ """Calculate the short-time Fourier transform magnitude.
+
+ Args:
+ signal: 1D np.array of the input time-domain signal.
+ fft_length: Size of the FFT to apply.
+ hop_length: Advance (in samples) between each frame passed to FFT.
+ window_length: Length of each block of samples to pass to FFT.
+
+ Returns:
+ 2D np.array where each row contains the magnitudes of the fft_length/2+1
+ unique values of the FFT for the corresponding frame of input samples.
+ """
+ frames = frame(signal, window_length, hop_length)
+ # Apply frame window to each frame. We use a periodic Hann (cosine of period
+ # window_length) instead of the symmetric Hann of np.hanning (period
+ # window_length-1).
+ window = periodic_hann(window_length)
+ windowed_frames = frames * window
+ return np.abs(np.fft.rfft(windowed_frames, int(fft_length)))
+
+
+# Mel spectrum constants and functions.
+_MEL_BREAK_FREQUENCY_HERTZ = 700.0
+_MEL_HIGH_FREQUENCY_Q = 1127.0
+
+
+def hertz_to_mel(frequencies_hertz):
+ """Convert frequencies to mel scale using HTK formula.
+
+ Args:
+ frequencies_hertz: Scalar or np.array of frequencies in hertz.
+
+ Returns:
+ Object of same size as frequencies_hertz containing corresponding values
+ on the mel scale.
+ """
+ return _MEL_HIGH_FREQUENCY_Q * np.log(
+ 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
+
+
+def spectrogram_to_mel_matrix(num_mel_bins=20,
+ num_spectrogram_bins=129,
+ audio_sample_rate=8000,
+ lower_edge_hertz=125.0,
+ upper_edge_hertz=3800.0):
+ """Return a matrix that can post-multiply spectrogram rows to make mel.
+
+ Returns a np.array matrix A that can be used to post-multiply a matrix S of
+ spectrogram values (STFT magnitudes) arranged as frames x bins to generate a
+ "mel spectrogram" M of frames x num_mel_bins. M = S A.
+
+ The classic HTK algorithm exploits the complementarity of adjacent mel bands
+ to multiply each FFT bin by only one mel weight, then add it, with positive
+ and negative signs, to the two adjacent mel bands to which that bin
+ contributes. Here, by expressing this operation as a matrix multiply, we go
+ from num_fft multiplies per frame (plus around 2*num_fft adds) to around
+ num_fft^2 multiplies and adds. However, because these are all presumably
+ accomplished in a single call to np.dot(), it's not clear which approach is
+ faster in Python. The matrix multiplication has the attraction of being more
+ general and flexible, and much easier to read.
+
+ Args:
+ num_mel_bins: How many bands in the resulting mel spectrum. This is
+ the number of columns in the output matrix.
+ num_spectrogram_bins: How many bins there are in the source spectrogram
+ data, which is understood to be fft_size/2 + 1, i.e. the spectrogram
+ only contains the nonredundant FFT bins.
+ audio_sample_rate: Samples per second of the audio at the input to the
+ spectrogram. We need this to figure out the actual frequencies for
+ each spectrogram bin, which dictates how they are mapped into mel.
+ lower_edge_hertz: Lower bound on the frequencies to be included in the mel
+ spectrum. This corresponds to the lower edge of the lowest triangular
+ band.
+ upper_edge_hertz: The desired top edge of the highest frequency band.
+
+ Returns:
+ An np.array with shape (num_spectrogram_bins, num_mel_bins).
+
+ Raises:
+ ValueError: if frequency edges are incorrectly ordered or out of range.
+ """
+ nyquist_hertz = audio_sample_rate / 2.
+ if lower_edge_hertz < 0.0:
+ raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz)
+ if lower_edge_hertz >= upper_edge_hertz:
+ raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
+ (lower_edge_hertz, upper_edge_hertz))
+ if upper_edge_hertz > nyquist_hertz:
+ raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" %
+ (upper_edge_hertz, nyquist_hertz))
+ spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
+ spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz)
+ # The i'th mel band (starting from i=1) has center frequency
+ # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
+ # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
+ # the band_edges_mel arrays.
+ band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz),
+ hertz_to_mel(upper_edge_hertz), num_mel_bins + 2)
+ # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
+ # of spectrogram values.
+ mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins))
+ for i in range(num_mel_bins):
+ lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3]
+ # Calculate lower and upper slopes for every spectrogram bin.
+ # Line segments are linear in the *mel* domain, not hertz.
+ lower_slope = ((spectrogram_bins_mel - lower_edge_mel) /
+ (center_mel - lower_edge_mel))
+ upper_slope = ((upper_edge_mel - spectrogram_bins_mel) /
+ (upper_edge_mel - center_mel))
+ # .. then intersect them with each other and zero.
+ mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope,
+ upper_slope))
+ # HTK excludes the spectrogram DC bin; make sure it always gets a zero
+ # coefficient.
+ mel_weights_matrix[0, :] = 0.0
+ return mel_weights_matrix
+
+
+def log_mel_spectrogram(data,
+ audio_sample_rate=8000,
+ log_offset=0.0,
+ window_length_secs=0.025,
+ hop_length_secs=0.010,
+ **kwargs):
+ """Convert waveform to a log magnitude mel-frequency spectrogram.
+
+ Args:
+ data: 1D np.array of waveform data.
+ audio_sample_rate: The sampling rate of data.
+ log_offset: Add this to values when taking log to avoid -Infs.
+ window_length_secs: Duration of each window to analyze.
+ hop_length_secs: Advance between successive analysis windows.
+ **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix.
+
+ Returns:
+ 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank
+ magnitudes for successive frames.
+ """
+ window_length_samples = int(round(audio_sample_rate * window_length_secs))
+ hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
+ fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
+ spectrogram = stft_magnitude(
+ data,
+ fft_length=fft_length,
+ hop_length=hop_length_samples,
+ window_length=window_length_samples)
+ mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix(
+ num_spectrogram_bins=spectrogram.shape[1],
+ audio_sample_rate=audio_sample_rate, **kwargs))
+ return np.log(mel_spectrogram + log_offset)
diff --git a/example.py b/vision_example.py
similarity index 100%
rename from example.py
rename to vision_example.py
diff --git a/voice.py b/voice.py
new file mode 100644
index 0000000..da966b6
--- /dev/null
+++ b/voice.py
@@ -0,0 +1,219 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Keyword spotter model."""
+import logging
+import sys
+import audio_recorder
+import mel_features
+import numpy as np
+import queue
+import tflite_runtime.interpreter as tflite
+
+logging.basicConfig(
+ stream=sys.stdout,
+ format="%(levelname)-8s %(asctime)-15s %(name)s %(message)s")
+audio_recorder.logger.setLevel(logging.ERROR)
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.ERROR)
+
+class Uint8LogMelFeatureExtractor(object):
+ """Provide uint8 log mel spectrogram slices from an AudioRecorder object.
+
+ This class provides one public method, get_next_spectrogram(), which gets
+ a specified number of spectral slices from an AudioRecorder.
+ """
+
+ def __init__(self, num_frames_hop=33):
+ self.spectrogram_window_length_seconds = 0.025
+ self.spectrogram_hop_length_seconds = 0.010
+ self.num_mel_bins = 32
+ self.frame_length_spectra = 198
+ if self.frame_length_spectra % num_frames_hop:
+ raise ValueError('Invalid num_frames_hop value (%d), '
+ 'must devide %d' % (num_frames_hop,
+ self.frame_length_spectra))
+ self.frame_hop_spectra = num_frames_hop
+ self._norm_factor = 3
+ self._clear_buffers()
+
+ def _clear_buffers(self):
+ self._audio_buffer = np.array([], dtype=np.int16).reshape(0, 1)
+ self._spectrogram = np.zeros((self.frame_length_spectra, self.num_mel_bins),
+ dtype=np.float32)
+
+ def _spectrogram_underlap_samples(self, audio_sample_rate_hz):
+ return int((self.spectrogram_window_length_seconds -
+ self.spectrogram_hop_length_seconds) * audio_sample_rate_hz)
+
+ def _frame_duration_seconds(self, num_spectra):
+ return (self.spectrogram_window_length_seconds +
+ (num_spectra - 1) * self.spectrogram_hop_length_seconds)
+
+ def _compute_spectrogram(self, audio_samples, audio_sample_rate_hz):
+ """Compute log-mel spectrogram and scale it to uint8."""
+ samples = audio_samples.flatten() / float(2**15)
+ spectrogram = 30 * (
+ mel_features.log_mel_spectrogram(
+ samples,
+ audio_sample_rate_hz,
+ log_offset=0.001,
+ window_length_secs=self.spectrogram_window_length_seconds,
+ hop_length_secs=self.spectrogram_hop_length_seconds,
+ num_mel_bins=self.num_mel_bins,
+ lower_edge_hertz=60,
+ upper_edge_hertz=3800) - np.log(1e-3))
+ return spectrogram
+
+ def _get_next_spectra(self, recorder, num_spectra):
+ """Returns the next spectrogram.
+
+ Compute num_spectra spectrogram samples from an AudioRecorder.
+ Blocks until num_spectra spectrogram slices are available.
+
+ Args:
+ recorder: an AudioRecorder object from which to get raw audio samples.
+ num_spectra: the number of spectrogram slices to return.
+
+ Returns:
+ num_spectra spectrogram slices computed from the samples.
+ """
+ required_audio_duration_seconds = self._frame_duration_seconds(num_spectra)
+ logger.info("required_audio_duration_seconds %f",
+ required_audio_duration_seconds)
+ required_num_samples = int(
+ np.ceil(required_audio_duration_seconds *
+ recorder.audio_sample_rate_hz))
+ logger.info("required_num_samples %d, %s", required_num_samples,
+ str(self._audio_buffer.shape))
+ audio_samples = np.concatenate(
+ (self._audio_buffer,
+ recorder.get_audio(required_num_samples - len(self._audio_buffer))[0]))
+ self._audio_buffer = audio_samples[
+ required_num_samples -
+ self._spectrogram_underlap_samples(recorder.audio_sample_rate_hz):]
+ spectrogram = self._compute_spectrogram(
+ audio_samples[:required_num_samples], recorder.audio_sample_rate_hz)
+ assert len(spectrogram) == num_spectra
+ return spectrogram
+
+ def get_next_spectrogram(self, recorder):
+ """Get the most recent spectrogram frame.
+
+ Blocks until the frame is available.
+
+ Args:
+ recorder: an AudioRecorder instance which provides the audio samples.
+
+ Returns:
+ The next spectrogram frame as a uint8 numpy array.
+ """
+ assert recorder.is_active
+ logger.info("self._spectrogram shape %s", str(self._spectrogram.shape))
+ self._spectrogram[:-self.frame_hop_spectra] = (
+ self._spectrogram[self.frame_hop_spectra:])
+ self._spectrogram[-self.frame_hop_spectra:] = (
+ self._get_next_spectra(recorder, self.frame_hop_spectra))
+ # Return a copy of the internal state that's safe to persist and won't
+ # change the next time we call this function.
+ logger.info("self._spectrogram shape %s", str(self._spectrogram.shape))
+ spectrogram = self._spectrogram.copy()
+ spectrogram -= np.mean(spectrogram, axis=0)
+ if self._norm_factor:
+ spectrogram /= self._norm_factor * np.std(spectrogram, axis=0)
+ spectrogram += 1
+ spectrogram *= 127.5
+ return np.maximum(0, np.minimum(255, spectrogram)).astype(np.uint8)
+
+def read_labels(filename):
+ # The labels file can be made something like this.
+ f = open(filename, "r")
+ lines = f.readlines()
+ return ['negative'] + [l.rstrip() for l in lines]
+
+def get_output(interpreter):
+ """Returns entire output, threshold is applied later."""
+ return output_tensor(interpreter, 0)
+
+def output_tensor(interpreter, i):
+ """Returns dequantized output tensor if quantized before."""
+ output_details = interpreter.get_output_details()[i]
+ output_data = np.squeeze(interpreter.tensor(output_details['index'])())
+ if 'quantization' not in output_details:
+ return output_data
+ scale, zero_point = output_details['quantization']
+ if scale == 0:
+ return output_data - zero_point
+ return scale * (output_data - zero_point)
+
+def input_tensor(interpreter):
+ """Returns the input tensor view as numpy array."""
+ tensor_index = interpreter.get_input_details()[0]['index']
+ return interpreter.tensor(tensor_index)()[0]
+
+def set_input(interpreter, data):
+ """Copies data to input tensor."""
+ interpreter_shape = interpreter.get_input_details()[0]['shape']
+ input_tensor(interpreter)[:,:] = np.reshape(data, interpreter_shape[1:3])
+
+def make_interpreter(model_file):
+ model_file, *device = model_file.split('@')
+ return tflite.Interpreter(
+ model_path=model_file,
+ experimental_delegates=[tflite.load_delegate('libedgetpu.so.1',
+ {'device': device[0]} if device else {})])
+
+def classify_audio(model_file, labels_file, dectection_callback,
+ audio_device_index=0, sample_rate_hz=16000,
+ negative_threshold=0.6, num_frames_hop=33):
+ """Acquire audio, preprocess, and classify."""
+ downsample_factor = 1
+ if sample_rate_hz == 48000:
+ downsample_factor = 3
+ # Most microphones support this
+ # Because the model expects 16KHz audio, we downsample 3 fold
+ recorder = audio_recorder.AudioRecorder(
+ sample_rate_hz,
+ downsample_factor=downsample_factor,
+ device_index=audio_device_index)
+ feature_extractor = Uint8LogMelFeatureExtractor(num_frames_hop=num_frames_hop)
+ labels = read_labels(labels_file)
+
+ interpreter = make_interpreter(model_file)
+ interpreter.allocate_tensors()
+
+ keep_listening = True
+ prev_detection = -1
+ with recorder:
+ while keep_listening:
+ spectrogram = feature_extractor.get_next_spectrogram(recorder)
+ if spectrogram.mean() < 0.001:
+ print("Warning: Input audio signal is nearly 0. Mic may be off ?")
+
+ set_input(interpreter, spectrogram.flatten())
+ interpreter.invoke()
+ result = get_output(interpreter)
+
+ if result[0] >= negative_threshold:
+ prev_detection = -1
+ continue
+
+ detection = np.argmax(result)
+ if detection == 0:
+ prev_detection = -1
+ continue
+
+ if detection != prev_detection:
+ keep_listening = dectection_callback(labels[detection], result[detection])
+ prev_detection = detection
diff --git a/voice_example.py b/voice_example.py
new file mode 100644
index 0000000..a311d50
--- /dev/null
+++ b/voice_example.py
@@ -0,0 +1,15 @@
+import voice
+
+def detection_callback(label, score):
+ print('Detected: "%s", score=%s' % (label, score))
+ if label.startswith('exit'):
+ return False # stop listening
+ return True # keep listening
+
+def run_voice_example():
+ voice.classify_audio(model_file='voice_commands_v0.7_edgetpu.tflite',
+ labels_file='labels_gc2.raw.txt',
+ dectection_callback=detection_callback)
+
+if __name__ == '__main__':
+ run_voice_example()