Add support for multiple network interfaces.

Change-Id: I2ab6106fd4364a5826b31113a621d8650741fc3c
diff --git a/debian/mdt-services.mdt-keymaster.service b/debian/mdt-services.mdt-keymaster.service
index 779f516..0cd4f89 100644
--- a/debian/mdt-services.mdt-keymaster.service
+++ b/debian/mdt-services.mdt-keymaster.service
@@ -5,7 +5,7 @@
 [Service]
 Type=simple
 RemainAfterExit=no
-ExecStart=/usr/bin/mdt-keymaster
+ExecStart=/usr/bin/mdt-keymaster --interface usb0 --interface usb1
 Restart=on-failure
 User=mendel
 Group=mendel
diff --git a/usr/bin/mdt-keymaster b/usr/bin/mdt-keymaster
old mode 100755
new mode 100644
index c2717ac..3dced30
--- a/usr/bin/mdt-keymaster
+++ b/usr/bin/mdt-keymaster
@@ -1,12 +1,16 @@
 #!/usr/bin/python3 -u
 
-import re
+import argparse
 import os
+import re
+import requests
 import sys
 import socket
 import struct
 import time
 
+from multiprocessing import Process
+from multiprocessing import Event
 from http.server import HTTPServer
 from http.server import BaseHTTPRequestHandler
 
@@ -18,30 +22,54 @@
 SSH_BASE_PATH = '/home/mendel/.ssh'
 AUTHORIZED_KEYS_PATH = os.path.join(SSH_BASE_PATH, 'authorized_keys')
 SERVER_PORT = 41337
-BIND_INTERFACE = 'usb0'
 MDT_KEY_REGEX = re.compile(r'^ssh-rsa .* mdt$')
-
-
-key_received = False
+MDT_EXIT_REGEX = re.compile(r'^exit$')
 
 
 def get_iface_address(iface):
+    """ Retreive ip address from a given interface. """
     addresses = netifaces.ifaddresses(iface)
     inet4_addresses = addresses[socket.AF_INET]
     ip_address = inet4_addresses[0]['addr']
     return ip_address
 
 
+def key_exists():
+    """ Check for existence of key. """
+    if os.path.exists(AUTHORIZED_KEYS_PATH):
+        print('authorized_keys file already exists.\n', flush=True)
+        with open(AUTHORIZED_KEYS_PATH) as fp:
+            for line in fp.readlines():
+                if re.match(MDT_KEY_REGEX, line):
+                    print('authorized_keys file contains MDT key already.\n', flush=True)
+                    return True
+        print('authorized_keys file does not contain an MDT key.\n', flush=True)
+    return False
+
+
 class KeyPushHandler(BaseHTTPRequestHandler):
-    key_received = False
-
+    """ Handle keys sent from client and install. """
+    key_handled = Event()
     def do_PUT(self):
-        if not KeyPushHandler.key_received:
-            KeyPushHandler.key_received = True
-            self.close_connection = True
-
+        """ Verify and install key. """
+        if not KeyPushHandler.key_handled.is_set():
             content_length = int(self.headers.get('Content-Length', 0))
             public_key = self.rfile.read(content_length)
+            if not re.match(MDT_KEY_REGEX, public_key.decode("utf-8").strip()):
+                if re.match(MDT_EXIT_REGEX, public_key.decode("utf-8").strip()):
+                    KeyPushHandler.key_handled.set()
+                    print("Received request to exit.", flush=True)
+                    self.send_response(200, "Ok")
+                    self.end_headers()
+                    return
+                KeyPushHandler.key_handled.clear()
+                print("Received invalid key.", flush=True)
+                self.send_response(400, "Invalid key")
+                self.end_headers()
+                return
+            else:
+                KeyPushHandler.key_handled.set()
+                print("Received valid key.", flush=True)
 
             if not os.path.exists(SSH_BASE_PATH):
                 os.mkdir(SSH_BASE_PATH, 0o700)
@@ -56,39 +84,72 @@
             self.end_headers()
 
 
-def main():
-    if os.path.exists(AUTHORIZED_KEYS_PATH):
-        print('authorized_keys file already exists.\n', flush=True)
-        with open(AUTHORIZED_KEYS_PATH) as fp:
-            for line in fp.readlines():
-                if re.match(MDT_KEY_REGEX, line):
-                    print('authorized_keys file contains MDT key already -- exiting.\n', flush=True)
-                    sys.exit(0)
-        print('authorized_keys file does not contain an MDT key.\n', flush=True)
+class HTTPServerProcess(Process):
+    """ Process for starting a HttpServer """
+    def __init__(self, bind_address, *args, **kwargs):
+        super(HTTPServerProcess, self).__init__(*args, **kwargs)
+        self.bind_address = bind_address
 
-    iface = BIND_INTERFACE
-    bind_address = None
+    def run(self):
+        """ Bind and start server on address."""
+        server_address = (self.bind_address, SERVER_PORT)
+        httpd = HTTPServer(server_address, KeyPushHandler)
+        print('Waiting for incoming PUT on {0}:{1}\n'.format(self.bind_address,
+                                                         SERVER_PORT), flush=True)
+        while not KeyPushHandler.key_handled.is_set():
+            httpd.handle_request()
 
-    delay = 1
-    while True:
+    def send_exit_command(self):
+        """ Send command to server to exit. """
+        print("Request exit for {0}".format(self.bind_address, SERVER_PORT))
         try:
-            bind_address = get_iface_address(iface)
-            break
-        except Exception as e:
-            print('Unable to determine {0} IPv4 address: {1}.\n'.format(iface, e), flush=True)
-            print('Waiting {0} seconds before retrying.'.format(delay))
-            time.sleep(delay)
-            delay = min(10, delay * 2)
+            requests.put("http://{0}:{1}".format(self.bind_address, SERVER_PORT), data="exit")
+        except requests.exceptions.ConnectionError:
+            pass
 
-    server_address = (bind_address, SERVER_PORT)
-    httpd = HTTPServer(server_address, KeyPushHandler)
 
-    print('Waiting for incoming PUT on {0}:{1}\n'.format(bind_address, SERVER_PORT), flush=True)
+def main():
+    parser = argparse.ArgumentParser(description="MDT Keymaster Server.")
+    parser.add_argument("--interface", help="Listen on selected interfaces.",
+                        action="append", default=[])
+    args = parser.parse_args()
 
-    while not KeyPushHandler.key_received:
-        httpd.handle_request()
+    interfaces = list(set(args.interface))
 
-    print('Received key. Exiting.\n', flush=True)
+    # If key is installed exit.
+    if key_exists():
+        print('exiting.\n', flush=True)
+        sys.exit(0)
+
+    ifaces = args.interface
+    bind_addresses = list()
+    delay = 1
+    while len(bind_addresses) != len(ifaces):
+        for iface in ifaces:
+            try:
+                bind_addresses.append(get_iface_address(iface))
+            except Exception as e:
+                print('Unable to determine {0} IPv4 address: {1}.\n'.format(iface, e), flush=True)
+                print('Waiting {0} seconds before retrying.'.format(delay), flush=True)
+                time.sleep(delay)
+                delay = min(10, delay * 2)
+
+    server_processes = []
+    for bind_address in bind_addresses:
+        server_processes.append((bind_address,
+            HTTPServerProcess(bind_address)))
+
+    for address, process in server_processes:
+        process.start()
+
+    while not KeyPushHandler.key_handled.is_set():
+        time.sleep(1.0)
+
+    for address, process in server_processes:
+        process.send_exit_command()
+        process.join()
+
+    print('Exiting.\n', flush=True)
     sys.exit(0)