Add simple voice detection API

Change-Id: Ic7e176b734afdd780a4b7749378e85bfdccfea84
diff --git a/Makefile b/Makefile
index b718214..c368b19 100644
--- a/Makefile
+++ b/Makefile
@@ -19,10 +19,19 @@
 mobilenet_v2_1.0_224_quant_edgetpu.tflite:
 	wget "$(TEST_DATA_URL)/$@"
 
-download: imagenet_labels.txt ssd_mobilenet_v2_face_quant_postprocess_edgetpu.tflite mobilenet_v2_1.0_224_quant_edgetpu.tflite
+labels_gc2.raw.txt:
+	wget "https://github.com/google-coral/project-keyword-spotter/raw/master/config/$@"
+
+voice_commands_v0.7_edgetpu.tflite:
+	wget "https://github.com/google-coral/project-keyword-spotter/raw/master/models/$@"
+
+download: imagenet_labels.txt \
+          ssd_mobilenet_v2_face_quant_postprocess_edgetpu.tflite \
+          mobilenet_v2_1.0_224_quant_edgetpu.tflite \
+          labels_gc2.raw.txt \
+          voice_commands_v0.7_edgetpu.tflite
 
 clean:
 	rm -rf __pycache__ \
-	       imagenet_labels.txt \
-	       ssd_mobilenet_v2_face_quant_postprocess_edgetpu.tflite \
-	       mobilenet_v2_1.0_224_quant_edgetpu.tflite
+	       *.txt \
+	       *.tflite
diff --git a/audio_recorder.py b/audio_recorder.py
new file mode 100644
index 0000000..ff23059
--- /dev/null
+++ b/audio_recorder.py
@@ -0,0 +1,220 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Interface to asynchronously capture continuous audio from PyAudio.
+
+
+This module requires pyaudio. See here for installation instructions:
+http://people.csail.mit.edu/hubert/pyaudio/
+
+This module provides one class, AudioRecorder, which buffers chunks of audio
+from PyAudio.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+
+import math
+import time
+
+import numpy as np
+import pyaudio
+import queue
+
+logger = logging.getLogger(__name__)
+
+
+class TimeoutError(Exception):
+  """A timeout while waiting for pyaudio to buffer samples."""
+  pass
+
+
+class AudioRecorder(object):
+  """Asynchronously record and buffer audio using pyaudio.
+
+  This class wraps the pyaudio interface. It contains a queue.Queue object to
+  hold chunks of raw audio, and a callback function _enqueue_audio() which
+  places raw audio into this queue. This allows the pyaudio.Stream object to
+  record asynchronously at low latency.
+
+  The class acts as a context manager. When entering the context it creates a
+  pyaudio.Stream object and starts recording; it stops recording on exit. The
+  Stream saves all of its audio to the Queue as two-tuples of
+  (timestamp, raw_audio). The raw_audio is available from the queue as a numpy
+  array using the get_audio() function.
+
+  This class uses the term "frame" in the same sense that PortAudio does, so
+  "frame" means something different here than elsewhere in the daredevil stack.
+  A frame in PortAudio is one audio sample across all channels, so one frame of
+  16-bit stereo audio is four bytes of data as two 16-bit integers.
+  """
+  pyaudio_format = pyaudio.paInt16
+  numpy_format = np.int16
+  num_channels = 1
+
+  # How many frames of audio PyAudio will fetch at once.
+  # Higher numbers will increase the latancy.
+  frames_per_chunk = 2**9
+
+  # Limit queue to this number of audio chunks.
+  max_queue_chunks = 1200
+
+  # Timeout if we can't get a chunk from the queue for timeout_factor times the
+  # chunk duration.
+  timeout_factor = 4
+
+  def __init__(self, raw_audio_sample_rate_hz=48000,
+                     downsample_factor=3,
+                     device_index=None):
+    self._downsample_factor = downsample_factor
+    self._raw_audio_sample_rate_hz = raw_audio_sample_rate_hz
+    self.audio_sample_rate_hz = self._raw_audio_sample_rate_hz // self._downsample_factor
+    self._raw_audio_queue = queue.Queue(self.max_queue_chunks)
+    self._audio = pyaudio.PyAudio()
+    self._print_input_devices()
+    self._device_index = device_index
+
+  def __enter__(self):
+    if self._device_index is None:
+      self._device_index = self._audio.get_default_input_device_info()["index"]
+    kwargs = {
+        "input_device_index": self._device_index
+    }
+    device_info = self._audio.get_device_info_by_host_api_device_index(
+        0, self._device_index)
+    if device_info.get("maxInputChannels") <= 0:
+      raise ValueError("Audio device has insufficient input channels.")
+    print("Using audio device '%s' for index %d" % (
+        device_info["name"], device_info["index"]))
+    self._stream = self._audio.open(
+        format=self.pyaudio_format,
+        channels=self.num_channels,
+        rate=self._raw_audio_sample_rate_hz,
+        input=True,
+        output=False,
+        frames_per_buffer=self.frames_per_chunk,
+        start=True,
+        stream_callback=self._enqueue_raw_audio,
+        **kwargs)
+    logger.info("Started audio stream.")
+    return self
+
+  def __exit__(self, exception_type, exception_value, traceback):
+    self._stream.stop_stream()
+    self._stream.close()
+    logger.info("Stopped and closed audio stream.")
+
+  def __del__(self):
+    self._audio.terminate()
+    logger.info("Terminated PyAudio/PortAudio.")
+
+  @property
+  def is_active(self):
+    return self._stream.is_active()
+
+  @property
+  def bytes_per_sample(self):
+    return pyaudio.get_sample_size(self.pyaudio_format)
+
+  @property
+  def _chunk_duration_seconds(self):
+    return self.frames_per_chunk / self._raw_audio_sample_rate_hz
+
+  def _print_input_devices(self):
+    info = self._audio.get_host_api_info_by_index(0)
+    print("\nInput microphone devices:")
+    for i in range(0, info.get("deviceCount")):
+      device_info = self._audio.get_device_info_by_host_api_device_index(0, i)
+      if device_info.get("maxInputChannels") <= 0: continue
+      print("  ID: ", i, " - ", device_info.get("name"))
+
+  def _enqueue_raw_audio(self, in_data, *_):  # unused args to match expected
+    try:
+      self._raw_audio_queue.put((in_data, time.time()), block=False)
+      return None, pyaudio.paContinue
+    except queue.Full:
+      error_message = "Raw audio buffer full."
+      logger.critical(error_message)
+      raise TimeoutError(error_message)
+
+  def _get_chunk(self, timeout=None):
+    raw_data, timestamp = self._raw_audio_queue.get(timeout=timeout)
+    array_data = np.fromstring(raw_data, self.numpy_format).reshape(
+        -1, self.num_channels)
+    return array_data, timestamp
+
+  def get_audio_device_info(self):
+    if self._device_index is None:
+      return self._audio.get_default_input_device_info()
+    else:
+      return self._audio.get_device_info_by_index(self._device_index)
+
+  def sample_duration_seconds(self, num_samples):
+    return num_samples / self.audio_sample_rate_hz / self.num_channels
+
+  def clear_queue(self):
+    logger.debug("Purging %d chunks from queue.", self._raw_audio_queue.qsize())
+    while not self._raw_audio_queue.empty():
+      self._raw_audio_queue.get()
+
+  def get_audio(self, num_audio_frames):
+    """Grab at least num_audio_frames frames of audio.
+
+    Record at least num_audio_frames of audio and transform it into a
+    numpy array. The term "frame" is in the sense used by PortAudio; see the
+    note in the class docstring for details.
+
+    Audio returned will be the earliest audio in the queue; it could be from
+    before this function was called.
+
+    Args:
+      num_audio_frames: minimum number of samples of audio to grab.
+
+    Returns:
+      A tuple of (audio, first_timestamp, last_timestamp).
+    """
+    num_audio_chunks = int(math.ceil(num_audio_frames *
+                    self._downsample_factor / self.frames_per_chunk))
+    logger.debug("Capturing %d chunks to get at least %d frames.",
+                 num_audio_chunks, num_audio_frames)
+    if num_audio_chunks < 1:
+      num_audio_chunks = 1
+    try:
+      timeout = self.timeout_factor * self._chunk_duration_seconds
+      chunks, timestamps = zip(
+          *[self._get_chunk(timeout=timeout) for _ in range(num_audio_chunks)])
+    except queue.Empty:
+      error_message = "Audio capture timed out after %.1f seconds." % timeout
+      logger.critical(error_message)
+      raise TimeoutError(error_message)
+
+    assert len(chunks) == num_audio_chunks
+    logger.debug("Got %d chunks. Chunk 0 has shape %s and dtype %s.",
+                 len(chunks), chunks[0].shape, chunks[0].dtype)
+    if self._raw_audio_queue.qsize() > (0.8 * self.max_queue_chunks):
+      logger.warning("%d chunks remain in the queue.",
+                     self._raw_audio_queue.qsize())
+    else:
+      logger.debug("%d chunks remain in the queue.",
+                   self._raw_audio_queue.qsize())
+
+    audio = np.concatenate(chunks)
+    if self._downsample_factor != 1:
+      audio = audio[::self._downsample_factor]
+    logging.debug("Audio array has shape %s and dtype %s.", audio.shape,
+                  audio.dtype)
+    return audio * 0.5, timestamps[0], timestamps[-1]
diff --git a/mel_features.py b/mel_features.py
new file mode 100644
index 0000000..1a124cb
--- /dev/null
+++ b/mel_features.py
@@ -0,0 +1,222 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Defines routines to compute mel spectrogram features from audio waveform."""
+
+import numpy as np
+
+
+def frame(data, window_length, hop_length):
+  """Convert array into a sequence of successive possibly overlapping frames.
+
+  An n-dimensional array of shape (num_samples, ...) is converted into an
+  (n+1)-D array of shape (num_frames, window_length, ...), where each frame
+  starts hop_length points after the preceding one.
+
+  This is accomplished using stride_tricks, so the original data is not
+  copied.  However, there is no zero-padding, so any incomplete frames at the
+  end are not included.
+
+  Args:
+    data: np.array of dimension N >= 1.
+    window_length: Number of samples in each frame.
+    hop_length: Advance (in samples) between each window.
+
+  Returns:
+    (N+1)-D np.array with as many rows as there are complete frames that can be
+    extracted.
+  """
+  num_samples = data.shape[0]
+  num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length))
+  shape = (num_frames, window_length) + data.shape[1:]
+  strides = (data.strides[0] * hop_length,) + data.strides
+  return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
+
+
+def periodic_hann(window_length):
+  """Calculate a "periodic" Hann window.
+
+  The classic Hann window is defined as a raised cosine that starts and
+  ends on zero, and where every value appears twice, except the middle
+  point for an odd-length window.  Matlab calls this a "symmetric" window
+  and np.hanning() returns it.  However, for Fourier analysis, this
+  actually represents just over one cycle of a period N-1 cosine, and
+  thus is not compactly expressed on a length-N Fourier basis.  Instead,
+  it's better to use a raised cosine that ends just before the final
+  zero value - i.e. a complete cycle of a period-N cosine.  Matlab
+  calls this a "periodic" window. This routine calculates it.
+
+  Args:
+    window_length: The number of points in the returned window.
+
+  Returns:
+    A 1D np.array containing the periodic hann window.
+  """
+  return 0.5 - (0.5 * np.cos(2 * np.pi / window_length *
+                             np.arange(window_length)))
+
+
+def stft_magnitude(signal, fft_length,
+                   hop_length=None,
+                   window_length=None):
+  """Calculate the short-time Fourier transform magnitude.
+
+  Args:
+    signal: 1D np.array of the input time-domain signal.
+    fft_length: Size of the FFT to apply.
+    hop_length: Advance (in samples) between each frame passed to FFT.
+    window_length: Length of each block of samples to pass to FFT.
+
+  Returns:
+    2D np.array where each row contains the magnitudes of the fft_length/2+1
+    unique values of the FFT for the corresponding frame of input samples.
+  """
+  frames = frame(signal, window_length, hop_length)
+  # Apply frame window to each frame. We use a periodic Hann (cosine of period
+  # window_length) instead of the symmetric Hann of np.hanning (period
+  # window_length-1).
+  window = periodic_hann(window_length)
+  windowed_frames = frames * window
+  return np.abs(np.fft.rfft(windowed_frames, int(fft_length)))
+
+
+# Mel spectrum constants and functions.
+_MEL_BREAK_FREQUENCY_HERTZ = 700.0
+_MEL_HIGH_FREQUENCY_Q = 1127.0
+
+
+def hertz_to_mel(frequencies_hertz):
+  """Convert frequencies to mel scale using HTK formula.
+
+  Args:
+    frequencies_hertz: Scalar or np.array of frequencies in hertz.
+
+  Returns:
+    Object of same size as frequencies_hertz containing corresponding values
+    on the mel scale.
+  """
+  return _MEL_HIGH_FREQUENCY_Q * np.log(
+      1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
+
+
+def spectrogram_to_mel_matrix(num_mel_bins=20,
+                              num_spectrogram_bins=129,
+                              audio_sample_rate=8000,
+                              lower_edge_hertz=125.0,
+                              upper_edge_hertz=3800.0):
+  """Return a matrix that can post-multiply spectrogram rows to make mel.
+
+  Returns a np.array matrix A that can be used to post-multiply a matrix S of
+  spectrogram values (STFT magnitudes) arranged as frames x bins to generate a
+  "mel spectrogram" M of frames x num_mel_bins.  M = S A.
+
+  The classic HTK algorithm exploits the complementarity of adjacent mel bands
+  to multiply each FFT bin by only one mel weight, then add it, with positive
+  and negative signs, to the two adjacent mel bands to which that bin
+  contributes.  Here, by expressing this operation as a matrix multiply, we go
+  from num_fft multiplies per frame (plus around 2*num_fft adds) to around
+  num_fft^2 multiplies and adds.  However, because these are all presumably
+  accomplished in a single call to np.dot(), it's not clear which approach is
+  faster in Python.  The matrix multiplication has the attraction of being more
+  general and flexible, and much easier to read.
+
+  Args:
+    num_mel_bins: How many bands in the resulting mel spectrum.  This is
+      the number of columns in the output matrix.
+    num_spectrogram_bins: How many bins there are in the source spectrogram
+      data, which is understood to be fft_size/2 + 1, i.e. the spectrogram
+      only contains the nonredundant FFT bins.
+    audio_sample_rate: Samples per second of the audio at the input to the
+      spectrogram. We need this to figure out the actual frequencies for
+      each spectrogram bin, which dictates how they are mapped into mel.
+    lower_edge_hertz: Lower bound on the frequencies to be included in the mel
+      spectrum.  This corresponds to the lower edge of the lowest triangular
+      band.
+    upper_edge_hertz: The desired top edge of the highest frequency band.
+
+  Returns:
+    An np.array with shape (num_spectrogram_bins, num_mel_bins).
+
+  Raises:
+    ValueError: if frequency edges are incorrectly ordered or out of range.
+  """
+  nyquist_hertz = audio_sample_rate / 2.
+  if lower_edge_hertz < 0.0:
+    raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz)
+  if lower_edge_hertz >= upper_edge_hertz:
+    raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
+                     (lower_edge_hertz, upper_edge_hertz))
+  if upper_edge_hertz > nyquist_hertz:
+    raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" %
+                     (upper_edge_hertz, nyquist_hertz))
+  spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
+  spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz)
+  # The i'th mel band (starting from i=1) has center frequency
+  # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
+  # band_edges_mel[i+1].  Thus, we need num_mel_bins + 2 values in
+  # the band_edges_mel arrays.
+  band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz),
+                               hertz_to_mel(upper_edge_hertz), num_mel_bins + 2)
+  # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
+  # of spectrogram values.
+  mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins))
+  for i in range(num_mel_bins):
+    lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3]
+    # Calculate lower and upper slopes for every spectrogram bin.
+    # Line segments are linear in the *mel* domain, not hertz.
+    lower_slope = ((spectrogram_bins_mel - lower_edge_mel) /
+                   (center_mel - lower_edge_mel))
+    upper_slope = ((upper_edge_mel - spectrogram_bins_mel) /
+                   (upper_edge_mel - center_mel))
+    # .. then intersect them with each other and zero.
+    mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope,
+                                                          upper_slope))
+  # HTK excludes the spectrogram DC bin; make sure it always gets a zero
+  # coefficient.
+  mel_weights_matrix[0, :] = 0.0
+  return mel_weights_matrix
+
+
+def log_mel_spectrogram(data,
+                        audio_sample_rate=8000,
+                        log_offset=0.0,
+                        window_length_secs=0.025,
+                        hop_length_secs=0.010,
+                        **kwargs):
+  """Convert waveform to a log magnitude mel-frequency spectrogram.
+
+  Args:
+    data: 1D np.array of waveform data.
+    audio_sample_rate: The sampling rate of data.
+    log_offset: Add this to values when taking log to avoid -Infs.
+    window_length_secs: Duration of each window to analyze.
+    hop_length_secs: Advance between successive analysis windows.
+    **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix.
+
+  Returns:
+    2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank
+    magnitudes for successive frames.
+  """
+  window_length_samples = int(round(audio_sample_rate * window_length_secs))
+  hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
+  fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
+  spectrogram = stft_magnitude(
+      data,
+      fft_length=fft_length,
+      hop_length=hop_length_samples,
+      window_length=window_length_samples)
+  mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix(
+      num_spectrogram_bins=spectrogram.shape[1],
+      audio_sample_rate=audio_sample_rate, **kwargs))
+  return np.log(mel_spectrogram + log_offset)
diff --git a/example.py b/vision_example.py
similarity index 100%
rename from example.py
rename to vision_example.py
diff --git a/voice.py b/voice.py
new file mode 100644
index 0000000..da966b6
--- /dev/null
+++ b/voice.py
@@ -0,0 +1,219 @@
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Keyword spotter model."""
+import logging
+import sys
+import audio_recorder
+import mel_features
+import numpy as np
+import queue
+import tflite_runtime.interpreter as tflite
+
+logging.basicConfig(
+    stream=sys.stdout,
+    format="%(levelname)-8s %(asctime)-15s %(name)s %(message)s")
+audio_recorder.logger.setLevel(logging.ERROR)
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.ERROR)
+
+class Uint8LogMelFeatureExtractor(object):
+  """Provide uint8 log mel spectrogram slices from an AudioRecorder object.
+
+  This class provides one public method, get_next_spectrogram(), which gets
+  a specified number of spectral slices from an AudioRecorder.
+  """
+
+  def __init__(self, num_frames_hop=33):
+    self.spectrogram_window_length_seconds = 0.025
+    self.spectrogram_hop_length_seconds = 0.010
+    self.num_mel_bins = 32
+    self.frame_length_spectra = 198
+    if self.frame_length_spectra % num_frames_hop:
+        raise ValueError('Invalid num_frames_hop value (%d), '
+                         'must devide %d' % (num_frames_hop,
+                                             self.frame_length_spectra))
+    self.frame_hop_spectra = num_frames_hop
+    self._norm_factor = 3
+    self._clear_buffers()
+
+  def _clear_buffers(self):
+    self._audio_buffer = np.array([], dtype=np.int16).reshape(0, 1)
+    self._spectrogram = np.zeros((self.frame_length_spectra, self.num_mel_bins),
+                                 dtype=np.float32)
+
+  def _spectrogram_underlap_samples(self, audio_sample_rate_hz):
+    return int((self.spectrogram_window_length_seconds -
+                self.spectrogram_hop_length_seconds) * audio_sample_rate_hz)
+
+  def _frame_duration_seconds(self, num_spectra):
+    return (self.spectrogram_window_length_seconds +
+            (num_spectra - 1) * self.spectrogram_hop_length_seconds)
+
+  def _compute_spectrogram(self, audio_samples, audio_sample_rate_hz):
+    """Compute log-mel spectrogram and scale it to uint8."""
+    samples = audio_samples.flatten() / float(2**15)
+    spectrogram = 30 * (
+        mel_features.log_mel_spectrogram(
+            samples,
+            audio_sample_rate_hz,
+            log_offset=0.001,
+            window_length_secs=self.spectrogram_window_length_seconds,
+            hop_length_secs=self.spectrogram_hop_length_seconds,
+            num_mel_bins=self.num_mel_bins,
+            lower_edge_hertz=60,
+            upper_edge_hertz=3800) - np.log(1e-3))
+    return spectrogram
+
+  def _get_next_spectra(self, recorder, num_spectra):
+    """Returns the next spectrogram.
+
+    Compute num_spectra spectrogram samples from an AudioRecorder.
+    Blocks until num_spectra spectrogram slices are available.
+
+    Args:
+      recorder: an AudioRecorder object from which to get raw audio samples.
+      num_spectra: the number of spectrogram slices to return.
+
+    Returns:
+      num_spectra spectrogram slices computed from the samples.
+    """
+    required_audio_duration_seconds = self._frame_duration_seconds(num_spectra)
+    logger.info("required_audio_duration_seconds %f",
+                required_audio_duration_seconds)
+    required_num_samples = int(
+        np.ceil(required_audio_duration_seconds *
+                recorder.audio_sample_rate_hz))
+    logger.info("required_num_samples %d, %s", required_num_samples,
+                str(self._audio_buffer.shape))
+    audio_samples = np.concatenate(
+        (self._audio_buffer,
+         recorder.get_audio(required_num_samples - len(self._audio_buffer))[0]))
+    self._audio_buffer = audio_samples[
+        required_num_samples -
+        self._spectrogram_underlap_samples(recorder.audio_sample_rate_hz):]
+    spectrogram = self._compute_spectrogram(
+        audio_samples[:required_num_samples], recorder.audio_sample_rate_hz)
+    assert len(spectrogram) == num_spectra
+    return spectrogram
+
+  def get_next_spectrogram(self, recorder):
+    """Get the most recent spectrogram frame.
+
+    Blocks until the frame is available.
+
+    Args:
+      recorder: an AudioRecorder instance which provides the audio samples.
+
+    Returns:
+      The next spectrogram frame as a uint8 numpy array.
+    """
+    assert recorder.is_active
+    logger.info("self._spectrogram shape %s", str(self._spectrogram.shape))
+    self._spectrogram[:-self.frame_hop_spectra] = (
+        self._spectrogram[self.frame_hop_spectra:])
+    self._spectrogram[-self.frame_hop_spectra:] = (
+        self._get_next_spectra(recorder, self.frame_hop_spectra))
+    # Return a copy of the internal state that's safe to persist and won't
+    # change the next time we call this function.
+    logger.info("self._spectrogram shape %s", str(self._spectrogram.shape))
+    spectrogram = self._spectrogram.copy()
+    spectrogram -= np.mean(spectrogram, axis=0)
+    if self._norm_factor:
+      spectrogram /= self._norm_factor * np.std(spectrogram, axis=0)
+      spectrogram += 1
+      spectrogram *= 127.5
+    return np.maximum(0, np.minimum(255, spectrogram)).astype(np.uint8)
+
+def read_labels(filename):
+  # The labels file can be made something like this.
+  f = open(filename, "r")
+  lines = f.readlines()
+  return ['negative'] + [l.rstrip() for l in lines]
+
+def get_output(interpreter):
+    """Returns entire output, threshold is applied later."""
+    return output_tensor(interpreter, 0)
+
+def output_tensor(interpreter, i):
+    """Returns dequantized output tensor if quantized before."""
+    output_details = interpreter.get_output_details()[i]
+    output_data = np.squeeze(interpreter.tensor(output_details['index'])())
+    if 'quantization' not in output_details:
+        return output_data
+    scale, zero_point = output_details['quantization']
+    if scale == 0:
+        return output_data - zero_point
+    return scale * (output_data - zero_point)
+
+def input_tensor(interpreter):
+    """Returns the input tensor view as numpy array."""
+    tensor_index = interpreter.get_input_details()[0]['index']
+    return interpreter.tensor(tensor_index)()[0]
+
+def set_input(interpreter, data):
+    """Copies data to input tensor."""
+    interpreter_shape = interpreter.get_input_details()[0]['shape']
+    input_tensor(interpreter)[:,:] = np.reshape(data, interpreter_shape[1:3])
+
+def make_interpreter(model_file):
+  model_file, *device = model_file.split('@')
+  return tflite.Interpreter(
+      model_path=model_file,
+      experimental_delegates=[tflite.load_delegate('libedgetpu.so.1',
+                              {'device': device[0]} if device else {})])
+
+def classify_audio(model_file, labels_file, dectection_callback,
+                   audio_device_index=0, sample_rate_hz=16000,
+                   negative_threshold=0.6, num_frames_hop=33):
+  """Acquire audio, preprocess, and classify."""
+  downsample_factor = 1
+  if sample_rate_hz == 48000:
+    downsample_factor = 3
+  # Most microphones support this
+  # Because the model expects 16KHz audio, we downsample 3 fold
+  recorder = audio_recorder.AudioRecorder(
+      sample_rate_hz,
+      downsample_factor=downsample_factor,
+      device_index=audio_device_index)
+  feature_extractor = Uint8LogMelFeatureExtractor(num_frames_hop=num_frames_hop)
+  labels = read_labels(labels_file)
+
+  interpreter = make_interpreter(model_file)
+  interpreter.allocate_tensors()
+
+  keep_listening = True
+  prev_detection = -1
+  with recorder:
+    while keep_listening:
+      spectrogram = feature_extractor.get_next_spectrogram(recorder)
+      if spectrogram.mean() < 0.001:
+        print("Warning: Input audio signal is nearly 0. Mic may be off ?")
+
+      set_input(interpreter, spectrogram.flatten())
+      interpreter.invoke()
+      result = get_output(interpreter)
+
+      if result[0] >= negative_threshold:
+        prev_detection = -1
+        continue
+
+      detection = np.argmax(result)
+      if detection == 0:
+        prev_detection = -1
+        continue
+
+      if detection != prev_detection:
+        keep_listening = dectection_callback(labels[detection], result[detection])
+        prev_detection = detection
diff --git a/voice_example.py b/voice_example.py
new file mode 100644
index 0000000..a311d50
--- /dev/null
+++ b/voice_example.py
@@ -0,0 +1,15 @@
+import voice
+
+def detection_callback(label, score):
+  print('Detected: "%s", score=%s' % (label, score))
+  if label.startswith('exit'):
+    return False # stop listening
+  return True # keep listening
+
+def run_voice_example():
+  voice.classify_audio(model_file='voice_commands_v0.7_edgetpu.tflite',
+                       labels_file='labels_gc2.raw.txt',
+                       dectection_callback=detection_callback)
+
+if __name__ == '__main__':
+  run_voice_example()