mdt: Major functionality changes

The big things that occured in this change was the addition of the install,
push, pull, and exec commands. Past this change, we're actually functionally
equivalent to the board.sh utility in the Mendel tree.

Things that changed in order of biggest to smallest:

  - Added the install, push, and pull commands. They're brittle at the moment --
    refinements to come.
  - Refactored most of the command classes to use NetworkCommand which handles
    discovery of connected devices based upon user preferences and what is
    available.
  - Migrated most of the SSH code into its own class called SshClient.
  - Made ConnectionClosedErrors actually store exit codes so we can return those
    to the host-side shell.
  - Made Console track exit codes from the given Channels.
  - Made Console work with non-PTY shells so we can suck down the output to the
    terminal
  - Added the exec command to run a command without a shell.
  - Fixed a command sanitization check in config

Apologies for the gigantic diff. O.o

Change-Id: I08a2ce18e777b9c047503ad3225083eb0ceab053
diff --git a/mdt/command.py b/mdt/command.py
new file mode 100644
index 0000000..ad5aaa0
--- /dev/null
+++ b/mdt/command.py
@@ -0,0 +1,76 @@
+import os
+import socket
+
+from time import sleep
+
+from paramiko.ssh_exception import SSHException
+
+from mdt import config
+from mdt import console
+from mdt import discoverer
+from mdt import sshclient
+
+
+class NetworkCommand:
+    def __init__(self):
+        self.config = config.Config()
+        self.discoverer = discoverer.Discoverer(self)
+        self.device = self.config.preferredDevice()
+        self.address = None
+
+    def add_device(self, hostname, address):
+        if not self.device:
+            self.device = hostname
+            self.address = address
+        elif self.device == hostname:
+            self.address = address
+
+    def run(self, args):
+        if not self.preConnectRun(args):
+            return 1
+
+        if not self.address:
+            if self.device:
+                print('Waiting for device {0}...'.format(self.device))
+            else:
+                print('Waiting for a device...')
+
+            while not self.address:
+                sleep(0.1)
+
+        client = None
+        try:
+            print('Connecting to {0} at {1}'.format(self.device, self.address))
+            client = sshclient.SshClient(self.device, self.address)
+            return self.runWithClient(client, args)
+        except sshclient.KeyPushError as e:
+            print("Unable to push keys to the device: {0}".format(e))
+            return 1
+        except sshclient.DefaultLoginError as e:
+            print("Can't login using default credentials: {0}".format(e))
+            return 1
+        except SSHException as e:
+            print("Couldn't establish ssh connection to device: {0}".format(e))
+            return 1
+        except socket.error as e:
+            print("Couldn't establish ssh connection to device: socket error: {0}".format(e))
+            return 1
+        except console.SocketTimeoutError as e:
+            print("\r\nConnection to {0} at {1} closed: socket timeout".format(self.device, self.address))
+            return 1
+        except console.ConnectionClosedError as e:
+            if e.exit_code:
+                print("\r\nConnection to {0} at {1} closed "
+                      "with exit code {2}".format(self.device, self.address, e.exit_code))
+            else:
+                print("\r\nConnection to {0} at {1} closed".format(self.device, self.address))
+            return e.exit_code
+        finally:
+            if client:
+                client.close()
+
+    def runWithClient(self, client, args):
+        return 1
+
+    def preConnectRun(self, args):
+        return True
diff --git a/mdt/config.py b/mdt/config.py
index bc4b6b9..eec3400 100644
--- a/mdt/config.py
+++ b/mdt/config.py
@@ -111,8 +111,9 @@
         self.config = Config()
 
     def run(self, args):
-        if len(args) != 2:
+        if len(args) != 3:
             print("Usage: mdt set <variablename> <value>")
+            return 1
 
         self.config.setAttribute(args[1], args[2])
         print("Set {0} to {1}".format(args[1], args[2]))
diff --git a/mdt/console.py b/mdt/console.py
index 5f2f639..bf20916 100644
--- a/mdt/console.py
+++ b/mdt/console.py
@@ -8,7 +8,8 @@
 
 
 class ConnectionClosedError(Exception):
-    pass
+    def __init__(self, exit_code=None):
+        self.exit_code = exit_code
 
 
 class SocketTimeoutError(Exception):
@@ -24,32 +25,46 @@
         import termios
         import tty
 
-        localtty = termios.tcgetattr(sys.stdin)
+        localtty = None
         try:
-            tty.setraw(sys.stdin.fileno())
-            tty.setcbreak(sys.stdin.fileno())
+            localtty = termios.tcgetattr(self.inputfile)
+        except termios.error as e:
+            pass
+
+        try:
+            if localtty:
+                tty.setraw(self.inputfile.fileno())
+                tty.setcbreak(self.inputfile.fileno())
+
             self.channel.settimeout(0)
 
             while True:
-                read, write, exception = select.select([self.channel, sys.stdin], [], [])
+                read, write, exception = select.select([self.channel, self.inputfile], [], [])
 
                 if self.channel in read:
                     try:
                         data = self.channel.recv(256)
                         if len(data) == 0:
-                            raise ConnectionClosedError()
+                            exit_code = None
+                            if self.channel.exit_status_ready():
+                                exit_code = self.channel.recv_exit_status()
+                            raise ConnectionClosedError(exit_code=exit_code)
                         sys.stdout.write(data.decode("utf-8", errors="ignore"))
                         sys.stdout.flush()
                     except socket.timeout as e:
                         raise SocketTimeoutError(e)
 
-                if sys.stdin in read:
-                    data = sys.stdin.read(1)
+                if self.inputfile in read:
+                    data = self.inputfile.read(1)
                     if len(data) == 0:
-                        raise ConnectionClosedError()
+                        exit_code = None
+                        if self.channel.exit_status_ready():
+                            exit_code = self.channel.recv_exit_status()
+                        raise ConnectionClosedError(exit_code=exit_code)
                     self.channel.send(data)
         finally:
-            termios.tcsetattr(sys.stdin, termios.TCSADRAIN, localtty)
+            if localtty:
+                termios.tcsetattr(self.inputfile, termios.TCSADRAIN, localtty)
 
 
 class WindowsConsole:
@@ -81,7 +96,10 @@
                 try:
                     data = self.channel.recv(256)
                     if len(data) == 0:
-                        self.queue.put((TYPE_REMOTE_CLOSED, None))
+                        exit_code = None
+                        if self.channel.exit_status_ready():
+                            exit_code = self.channel.recv_exit_status()
+                        self.queue.put((TYPE_REMOTE_CLOSED, exit_code))
                         break
                     data = data.decode("utf-8", errors="ignore")
                     self.queue.put((TYPE_TERMINAL_OUTPUT, data))
@@ -89,6 +107,9 @@
                     self.queue.put((TYPE_SOCKET_TIMEOUT, None))
                     break
 
+            exit_status = self.channel.recv_exit_status()
+            self.queue.put((TYPE_EXIT_CODE, exit_status))
+
     def __init__(self, channel, inputfile):
         self.channel = channel
         self.dataQueue = queue.Queue()
@@ -108,7 +129,7 @@
                 sys.stdout.write(data)
                 sys.stdout.flush()
             if dataType == TYPE_REMOTE_CLOSED:
-                raise ConnectionClosedError()
+                raise ConnectionClosedError(exit_code=data)
             if dataType == TYPE_SOCKET_TIMEOUT:
                 raise SocketTimeoutError()
 
diff --git a/mdt/devices.py b/mdt/devices.py
index a5e5a7a..0697f48 100644
--- a/mdt/devices.py
+++ b/mdt/devices.py
@@ -23,7 +23,6 @@
 
     def run(self, args):
         sleep(1)
-        print('Devices found:')
         discoveries = self.discoverer.discoveries
         for host, address in discoveries.items():
             if self.device and host == self.device:
diff --git a/mdt/files.py b/mdt/files.py
new file mode 100644
index 0000000..a7cb2f7
--- /dev/null
+++ b/mdt/files.py
@@ -0,0 +1,118 @@
+import os
+import sys
+import time
+
+from mdt import command
+from mdt import config
+from mdt import console
+from mdt import discoverer
+from mdt import sshclient
+
+
+PROGRESS_WIDTH = 45
+FILENAME_WIDTH = 30
+
+
+def MakeProgressFunc(full_filename, width, char='>'):
+    def closure(bytes_xferred, total_bytes):
+        pcnt = bytes_xferred / total_bytes
+        filename = full_filename
+        if len(filename) > FILENAME_WIDTH:
+            filename = filename[0:FILENAME_WIDTH - 3] + '...'
+        left = char * int(pcnt * width)
+        right = ' ' * int((1 - pcnt) * width)
+        pcnt = '%3d' % (int(pcnt * 100))
+        sys.stdout.write('\r{0}% |{1}{2}| {3}'.format(pcnt, left, right, filename))
+        sys.stdout.flush()
+
+    return closure
+
+
+class InstallCommand(command.NetworkCommand):
+    def preConnectRun(self, args):
+        if len(args) < 2:
+            print("Usage: mdt install [<package-filename...>]")
+            return False
+
+        return True
+
+    def runWithClient(self, client, args):
+        package_to_install = args[1]
+        package_filename = os.path.basename(package_to_install)
+        remote_filename = os.path.join('/tmp', package_filename)
+
+        sftp_callback = MakeProgressFunc(package_filename, PROGRESS_WIDTH)
+        sftp = client.openSftp()
+        sftp.put(package_to_install, remote_filename, callback=sftp_callback)
+        sftp.close()
+        client.close()
+        print()
+
+        channel = client.shellExec("sudo /usr/sbin/mdt-install-package {0}; rm -f {0}".format(remote_filename), allocPty=True)
+        cons = console.Console(channel, sys.stdin)
+        return cons.run()
+
+
+class PushCommand(command.NetworkCommand):
+    def preConnectRun(self, args):
+        if len(args) < 3:
+            print("Usage: mdt push <filename...> <destination-directory>")
+            return False
+
+        for file in args[1:-1]:
+            if not os.path.isfile(file):
+                print("{0}: Is a directory -- cannot push".format(file))
+                return False
+
+        return True
+
+    def runWithClient(self, client, args):
+        files_to_push = args[1:-1]
+        destination = args[-1]
+
+        try:
+            sftp = client.openSftp()
+            for file in files_to_push:
+                base_filename = os.path.basename(file)
+                sftp_callback = MakeProgressFunc(file, PROGRESS_WIDTH)
+                remote_filename = os.path.join(destination, base_filename)
+
+                sftp_callback(0, 1)
+                sftp.put(file, remote_filename, callback=sftp_callback)
+                sftp_callback(1, 1)
+                print()
+        finally:
+            print()
+            sftp.close()
+
+        return 0
+
+
+class PullCommand(command.NetworkCommand):
+    def preConnectRun(self, args):
+        if len(args) < 3:
+            print("Usage: mdt pull [<filename...>]")
+            return False
+
+        return True
+
+    def runWithClient(self, client, args):
+        files_to_pull = args[1:-1]
+        destination = args[-1]
+
+        try:
+            sftp = client.openSftp()
+            for file in files_to_pull:
+                base_filename = os.path.basename(file)
+                sftp_callback = MakeProgressFunc(file, PROGRESS_WIDTH, char='<')
+                destination_filename = os.path.join(destination, base_filename)
+
+                sftp_callback(0, 1)
+                sftp.get(file, destination_filename, callback=sftp_callback)
+                sftp_callback(1, 1)
+                print()
+        finally:
+            print()
+            sftp.close()
+
+        return 0
diff --git a/mdt/main.py b/mdt/main.py
index ea91184..5527d30 100755
--- a/mdt/main.py
+++ b/mdt/main.py
@@ -11,6 +11,7 @@
 
 from mdt import config
 from mdt import devices
+from mdt import files
 from mdt import keys
 from mdt import shell
 
@@ -35,6 +36,10 @@
             print('    clear           - clears an MDT variable')
             print('    genkey          - generates an SSH key for connecting to a device')
             print('    shell           - opens an interactive shell to a device')
+            print('    exec            - runs a shell command and returns the output and the exit code')
+            print('    install         - installs a Debian package using mdt-install-package on the device')
+            print('    push            - pushes a file (or files) to the device')
+            print('    pull            - pulls a file (or files) from the device')
             print()
             print('Use "mdt help <subcommand>" for more details.')
             print()
@@ -50,14 +55,20 @@
 
 
 COMMANDS = {
-    'help': HelpCommand(),
-    'devices': devices.DevicesCommand(),
-    'wait-for-device': devices.DevicesWaitCommand(),
-    'get': config.GetCommand(),
-    'set': config.SetCommand(),
     'clear': config.ClearCommand(),
+    'devices': devices.DevicesCommand(),
+    'exec': shell.ExecCommand(),
     'genkey': keys.GenKeyCommand(),
+    'get': config.GetCommand(),
+    'help': HelpCommand(),
+    'install': files.InstallCommand(),
+    'pull': files.PullCommand(),
+    'push': files.PushCommand(),
+    'reboot': shell.RebootCommand(),
+    'reboot-bootloader': shell.RebootBootloaderCommand(),
+    'set': config.SetCommand(),
     'shell': shell.ShellCommand(),
+    'wait-for-device': devices.DevicesWaitCommand(),
 }
 
 
diff --git a/mdt/shell.py b/mdt/shell.py
index 936f407..ffb544b 100644
--- a/mdt/shell.py
+++ b/mdt/shell.py
@@ -1,123 +1,15 @@
-from time import sleep
-
-import os
-import platform
-import select
-import socket
-import subprocess
 import sys
-import termios
-import tty
 
-import paramiko
-from paramiko.ssh_exception import AuthenticationException, SSHException
-
-from mdt import discoverer
-from mdt import config
+from mdt import command
 from mdt import console
-from mdt import keys
 
 
 
-class KeyPushError(Exception):
-    pass
+class ShellCommand(command.NetworkCommand):
+    '''Usage: mdt shell
 
-
-class DefaultLoginError(Exception):
-    pass
-
-
-class SshClient:
-    def __init__(self, device, address):
-        self.config = config.Config()
-        self.keystore = keys.Keystore()
-
-        self.device = device
-        self.address = address
-
-        self.username = self.config.username()
-        self.password = self.config.password()
-        self.ssh_command = self.config.sshCommand()
-
-        if not self.maybeGenerateSshKeys():
-            return False
-
-        self.client = paramiko.SSHClient()
-        self.client.set_missing_host_key_policy(paramiko.client.AutoAddPolicy())
-
-    def _shouldPushKey(self):
-        try:
-            self.client.connect(
-                self.address,
-                username=self.username,
-                pkey=self.keystore.key(),
-                allow_agent=False,
-                look_for_keys=False,
-                compress=True)
-        except AuthenticationException as e:
-            return True
-        except (SSHException, socket.error) as e:
-            raise e
-        finally:
-            self.client.close()
-
-    def _pushKey(self):
-        try:
-            self.client.connect(
-                    self.address,
-                    username=self.username,
-                    password=self.password,
-                    allow_agent=False,
-                    look_for_keys=False,
-                    compress=True)
-        except AuthenticationException as e:
-            raise DefaultLoginError(e)
-        except (SSHException, socket.error) as e:
-            raise KeyPushError(e)
-        else:
-            public_key = self.keystore.key().get_base64()
-            self.client.exec_command('mkdir -p $HOME/.ssh')
-            self.client.exec_command(
-                'echo ssh-rsa {0} mdt@localhost >>$HOME/.ssh/authorized_keys'.format(public_key))
-        finally:
-            self.client.close()
-
-    def maybeGenerateSshKeys(self):
-        if not self.keystore.key():
-            print('Looks like you don\'t have a private key yet. Generating one.')
-
-            if not self.keystore.generateKey():
-                print('Unable to generate private key.')
-                return False
-
-        return True
-
-    def openShell(self):
-        term = os.getenv("TERM", default="vt100")
-        width, height = os.get_terminal_size()
-
-        if self._shouldPushKey():
-            print("Key not present on {0} -- pushing".format(self.device))
-            self._pushKey()
-
-        self.client.connect(
-            self.address,
-            username=self.username,
-            pkey=self.keystore.key(),
-            allow_agent=False,
-            look_for_keys=False,
-            compress=True)
-        return self.client.invoke_shell(term=term, width=width, height=height)
-
-    def close(self):
-        self.client.close()
-
-
-class ShellCommand:
-    '''Usage: mdt shell [<devicename>]
-
-Opens an interactive shell to either your preferred device, the given
-devicename, or to the first device found.
+Opens an interactive shell to either your preferred device or to the first
+device found.
 
 Variables used:
     preferred-device    - set this to your preferred device name to connect
@@ -140,61 +32,56 @@
      login credentials in the 'username' and 'password' variables.
   3. Installs your SSH key to the device after logging in.
   4. Disconnects and reconnects using the SSH key.
-
-Note: this will not return the exit code of the shell executed on the device.
-If you need automation, use 'mdt run' instead.
 '''
 
-    def __init__(self):
-        self.config = config.Config()
-        self.discoverer = discoverer.Discoverer(self)
-        self.device = self.config.preferredDevice()
-        self.address = None
+    def runWithClient(self, client, args):
+        channel = client.openShell()
+        cons = console.Console(channel, sys.stdin)
+        return cons.run()
 
-    def add_device(self, hostname, address):
-        if not self.device:
-            self.device = hostname
-            self.address = address
-        elif self.device == hostname:
-            self.address = address
 
-    def run(self, args):
-        if len(args) > 1:
-            self.device = args[1]
+class ExecCommand(command.NetworkCommand):
+    '''Usage: mdt exec [<shell-command...>]
 
-        if not self.address:
-            if self.device:
-                print('Waiting for device {0}...'.format(self.device))
-            else:
-                print('Waiting for a device...')
+Opens a non-interactive shell to either your preferred device or to the first
+device found.
 
-            while not self.address:
-                sleep(0.1)
+Variables used:
+    preferred-device    - set this to your preferred device name to connect
+                          to by default if no <devicename> is provided on the
+                          command line.
+    username            - set this to the username that should be used to
+                          connect to a device with. Defaults to 'mendel'.
+    password            - set this to the password to use to login to a new
+                          device with. Defaults to 'mendel'. Only used
+                          during the initial setup phase of pushing an SSH
+                          key to the board.
 
-        print('Connecting to {0} at {1}'.format(self.device, self.address))
+If no SSH key is available on disk (ie: you didn't run genkey before running
+shell), this will implicitly run genkey for you. Additionally, shell will
+attempt to connect to a device by doing the following:
 
-        try:
-            client = SshClient(self.device, self.address)
-            channel = client.openShell()
-            cons = console.Console(channel, sys.stdin)
-            cons.run()
-        except KeyPushError as e:
-            print("Unable to push keys to the device: {0}".format(e))
-            return 1
-        except DefaultLoginError as e:
-            print("Can't login using default credentials: {0}".format(e))
-            return 1
-        except SSHException as e:
-            print("Couldn't establish ssh connection to device: {0}".format(e))
-            return 1
-        except socket.error as e:
-            print("Couldn't establish ssh connection to device: {0}".format(e))
-            return 1
-        except console.SocketTimeoutError as e:
-            print("Connection to {0} at {1} closed: socket timeout".format(self.device, self.address))
-            return 1
-        except console.ConnectionClosedError as e:
-            print("Connection to {0} at {1} closed".format(self.device, self.address))
-            return 0
-        finally:
-            client.close()
+  1. Attempt a connection using your SSH key only, with no password.
+  2. If the connection attempt failed due to authentication, will
+     attempt to push the key to the device by using the default
+     login credentials in the 'username' and 'password' variables.
+  3. Installs your SSH key to the device after logging in.
+  4. Disconnects and reconnects using the SSH key.
+'''
+    def runWithClient(self, client, args):
+        channel = client.shellExec(' '.join(args[1:]))
+        cons = console.Console(channel, sys.stdin)
+        return cons.run()
+
+
+class RebootCommand(command.NetworkCommand):
+    def runWithClient(self, client, args):
+        channel = client.shellExec("sudo reboot")
+        cons = console.Console(channel, sys.stdin)
+        return cons.run()
+
+class RebootBootloaderCommand(command.NetworkCommand):
+    def runWithClient(self, client, args):
+        channel = client.shellExec("sudo reboot-bootloader")
+        cons = console.Console(channel, sys.stdin)
+        return cons.run()
diff --git a/mdt/sshclient.py b/mdt/sshclient.py
new file mode 100644
index 0000000..dfdf3ed
--- /dev/null
+++ b/mdt/sshclient.py
@@ -0,0 +1,140 @@
+import os
+
+import paramiko
+from paramiko.ssh_exception import AuthenticationException, SSHException
+
+from mdt import config
+from mdt import discoverer
+from mdt import keys
+from mdt import sshclient
+
+
+class KeyPushError(Exception):
+    pass
+
+
+class DefaultLoginError(Exception):
+    pass
+
+
+class SshClient:
+    def __init__(self, device, address):
+        self.config = config.Config()
+        self.keystore = keys.Keystore()
+
+        self.device = device
+        self.address = address
+
+        self.username = self.config.username()
+        self.password = self.config.password()
+        self.ssh_command = self.config.sshCommand()
+
+        if not self.maybeGenerateSshKeys():
+            return False
+
+        self.client = paramiko.SSHClient()
+        self.client.set_missing_host_key_policy(paramiko.client.AutoAddPolicy())
+
+    def _shouldPushKey(self):
+        try:
+            self.client.connect(
+                self.address,
+                username=self.username,
+                pkey=self.keystore.key(),
+                allow_agent=False,
+                look_for_keys=False,
+                compress=True)
+        except AuthenticationException as e:
+            return True
+        except (SSHException, socket.error) as e:
+            raise e
+        finally:
+            self.client.close()
+
+    def _pushKey(self):
+        try:
+            self.client.connect(
+                    self.address,
+                    username=self.username,
+                    password=self.password,
+                    allow_agent=False,
+                    look_for_keys=False,
+                    compress=True)
+        except AuthenticationException as e:
+            raise DefaultLoginError(e)
+        except (SSHException, socket.error) as e:
+            raise KeyPushError(e)
+        else:
+            public_key = self.keystore.key().get_base64()
+            self.client.exec_command('mkdir -p $HOME/.ssh')
+            self.client.exec_command(
+                'echo ssh-rsa {0} mdt@localhost >>$HOME/.ssh/authorized_keys'.format(public_key))
+        finally:
+            self.client.close()
+
+    def maybeGenerateSshKeys(self):
+        if not self.keystore.key():
+            print('Looks like you don\'t have a private key yet. Generating one.')
+
+            if not self.keystore.generateKey():
+                print('Unable to generate private key.')
+                return False
+
+        return True
+
+    def openShell(self):
+        term = os.getenv("TERM", default="vt100")
+        width, height = os.get_terminal_size()
+
+        if self._shouldPushKey():
+            print("Key not present on {0} -- pushing".format(self.device))
+            self._pushKey()
+
+        self.client.connect(
+            self.address,
+            username=self.username,
+            pkey=self.keystore.key(),
+            allow_agent=False,
+            look_for_keys=False,
+            compress=True)
+        return self.client.invoke_shell(term=term, width=width, height=height)
+
+    def shellExec(self, cmd, allocPty=False):
+        if self._shouldPushKey():
+            print("Key not present on {0} -- pushing".format(self.device))
+            self._pushKey()
+
+        self.client.connect(
+            self.address,
+            username=self.username,
+            pkey=self.keystore.key(),
+            allow_agent=False,
+            look_for_keys=False,
+            compress=True)
+
+        session = self.client.get_transport().open_session()
+        if allocPty:
+            term = os.getenv("TERM", default="vt100")
+            width, height = os.get_terminal_size()
+            session.get_pty(term=term, width=width, height=height)
+        session.exec_command(cmd)
+        return session
+
+    def openSftp(self):
+        if self._shouldPushKey():
+            print("Key not present on {0} -- pushing".format(self.device))
+            self._pushKey()
+
+        self.client.connect(
+            self.address,
+            username=self.username,
+            pkey=self.keystore.key(),
+            allow_agent=False,
+            look_for_keys=False,
+            compress=True)
+
+        session = self.client.open_sftp()
+        return session
+
+    def close(self):
+        self.client.close()