resetkeys: Add a new command to remove MDT keys and start keymaster

This gives us a way to return a board with previously-pushed MDT keys on it back
to factory defaults, with mdt-keymaster running again.

Change-Id: Ife3b71526c9905365426ecca553ef32504692105
diff --git a/mdt/main.py b/mdt/main.py
index f796a0d..36b2818 100755
--- a/mdt/main.py
+++ b/mdt/main.py
@@ -66,6 +66,8 @@
     pushkey           - pushes an SSH public key to a device
     setkey            - imports a PEM-format SSH private key into the MDT
                         keystore
+    resetkeys         - removes all keys from the given board and resets key
+                        authentication to factory defaults
     shell             - opens an interactive shell to a device
     exec              - runs a shell command and returns the output and the
                         exit code
@@ -125,6 +127,7 @@
     'pushkey': shell.PushKeyCommand(),
     'reboot': shell.RebootCommand(),
     'reboot-bootloader': shell.RebootBootloaderCommand(),
+    'resetkeys': shell.ResetKeysCommand(),
     'set': config.SetCommand(),
     'setkey': keys.SetKeyCommand(),
     'shell': shell.ShellCommand(),
diff --git a/mdt/shell.py b/mdt/shell.py
index 0a1befd..3b78f61 100644
--- a/mdt/shell.py
+++ b/mdt/shell.py
@@ -15,11 +15,13 @@
 '''
 
 
+import re
 import os
 import sys
 
 from mdt import command
 from mdt import console
+from mdt import keys
 
 
 class ShellCommand(command.NetworkCommand):
@@ -101,21 +103,21 @@
   4. Disconnects and reconnects using the SSH key.
 '''
     def runWithClient(self, client, args):
-        channel = client.shellExec(' '.join(args[1:]))
+        channel = client.shellExec(' '.join(args[1:]), allocPty=True)
         cons = console.Console(channel, sys.stdin)
         return cons.run()
 
 
 class RebootCommand(command.NetworkCommand):
     def runWithClient(self, client, args):
-        channel = client.shellExec("sudo reboot")
+        channel = client.shellExec("sudo reboot", allocPty=True)
         cons = console.Console(channel, sys.stdin)
         return cons.run()
 
 
 class RebootBootloaderCommand(command.NetworkCommand):
     def runWithClient(self, client, args):
-        channel = client.shellExec("sudo reboot-bootloader")
+        channel = client.shellExec("sudo reboot-bootloader", allocPty=True)
         cons = console.Console(channel, sys.stdin)
         return cons.run()
 
@@ -128,20 +130,15 @@
 public key from ~/.config/mdt/keys/mdt.key.
 '''
 
-    def runWithClient(self, client, args):
-        key_to_push = None
+    def _pushMdtKey(self, client):
+        print('Pushing {0}'.format(keys.KEYFILE_PATH))
+        client.pushKey()
+        print('Push complete.')
+        return 0
 
-        if len(args) == 1:
-            # The key was most likely pushed by the NetworkCommand substrate. We
-            # can simply return here.
-            print("MDT Key pushed.")
-            return 0
+    def _pushOtherKey(self, client, keyfile):
+        sftp = client.openSftp()
 
-        if len(args) != 2:
-            print("Usage: mdt pushkey [<path-to-public-key>]")
-            return 1
-
-        source_keyfile = args[1]
         if not os.path.exists(source_keyfile):
             print("Can't copy {0}: no such file or directory.".format(source_keyfile))
             return 1
@@ -150,7 +147,6 @@
         with open(args[1], 'rb') as fp:
             source_key = fp.read()
 
-        sftp = client.openSftp()
         try:
             sftp.chdir('/home/mendel/.ssh')
         except FileNotFoundError as e:
@@ -162,3 +158,66 @@
 
         print("Key {0} pushed.".format(source_keyfile))
         return 0
+
+    def runWithClient(self, client, args):
+        key_to_push = None
+
+        # No arguments? Let the usual client push take effect.
+        if len(args) == 1:
+            return self._pushMdtKey(client)
+
+        source_keyfile = args[1]
+        print('Pushing {0}'.format(source_keyfile))
+        return self._pushOtherKey(client, source_keyfile)
+
+
+class ResetKeysCommand(command.NetworkCommand):
+    '''Usage: mdt resetkeys <device-or-ip-address>
+
+Resets a device to it's pre-MDT state by removing all MDT keys and restarting
+the mdt-keymaster on the device so that new keys can be pushed again.'''
+
+    def preConnectRun(self, args):
+        if len(args) != 2:
+            print("Usage: mdt resetkeys <device-or-ip-address>")
+            return False
+
+        if len(args) == 2:
+            self.device = args[1]
+
+        return True
+
+    def runWithClient(self, client, args):
+        # Setup this session now, since once we remove the key from
+        # authorized_keys, we won't be able to use execSession.
+        channel = client.openChannel()
+
+        sftp = client.openSftp()
+        try:
+            sftp.chdir('/home/mendel/.ssh')
+        except FileNotFoundError as e:
+            print('No keys were previously pushed to the board.')
+        else:
+            lines = []
+
+            with sftp.open('/home/mendel/.ssh/authorized_keys', 'r') as fp:
+                lines = fp.readlines()
+
+            with sftp.open('/home/mendel/.ssh/authorized_keys', 'w') as fp:
+                for line in lines:
+                    if ' mdt' not in line:
+                        print('wrote: {0}'.format(line))
+                        fp.write(line)
+
+        channel.exec_command("sudo systemctl restart mdt-keymaster")
+        cons = console.Console(channel, sys.stdin)
+        try:
+            cons.run()
+        except console.ConnectionClosedError as e:
+            if e.exit_code:
+                print('`systemctl restart mdt-keymaster` exited with code {0}'.format(e.exit_code))
+                print('Your device may be in an inconsistent state. Verify using')
+                print('the serial console.')
+            else:
+                print('Successfully reset {0}'.format(self.device))
+            return e.exit_code
diff --git a/mdt/sshclient.py b/mdt/sshclient.py
index a319097..270f758 100644
--- a/mdt/sshclient.py
+++ b/mdt/sshclient.py
@@ -28,7 +28,6 @@
 from mdt import config
 from mdt import discoverer
 from mdt import keys
-from mdt import sshclient
 
 
 KEYMASTER_PORT = 41337
@@ -64,7 +63,7 @@
         self.client = paramiko.SSHClient()
         self.client.set_missing_host_key_policy(AutoAddPolicy())
 
-    def _shouldPushKey(self):
+    def shouldPushKey(self):
         try:
             self.client.connect(
                 self.address,
@@ -135,7 +134,7 @@
         finally:
             self.client.close()
 
-    def _pushKey(self):
+    def pushKey(self):
         try:
             self._pushKeyViaKeymaster()
         except KeyPushError as e:
@@ -184,9 +183,9 @@
         env = self._generateEnvironment()
         width, height = os.get_terminal_size()
 
-        if self._shouldPushKey():
+        if self.shouldPushKey():
             print("Key not present on {0} -- pushing".format(self.device))
-            self._pushKey()
+            self.pushKey()
 
         self.client.connect(
             self.address,
@@ -200,10 +199,10 @@
         # support have added in Paramiko v2.1.x or newer.
         return self.client.invoke_shell(term=term, width=width, height=height)
 
-    def shellExec(self, cmd, allocPty=False):
-        if self._shouldPushKey():
+    def openChannel(self, allocPty=False):
+        if self.shouldPushKey():
             print("Key not present on {0} -- pushing".format(self.device))
-            self._pushKey()
+            self.pushKey()
 
         self.client.connect(
             self.address,
@@ -218,13 +217,18 @@
             term = os.getenv("TERM", default="vt100")
             width, height = os.get_terminal_size()
             session.get_pty(term=term, width=width, height=height)
-        session.exec_command(cmd)
+
         return session
 
+    def shellExec(self, cmd, allocPty=False):
+        channel = self.openChannel(allocPty=allocPty)
+        channel.exec_command(cmd)
+        return channel
+
     def openSftp(self):
-        if self._shouldPushKey():
+        if self.shouldPushKey():
             print("Key not present on {0} -- pushing".format(self.device))
-            self._pushKey()
+            self.pushKey()
 
         self.client.connect(
             self.address,