shell: Add in a functional shell command

Currently this only works on UNIX-like systems due to the use of termios, so it
will break on Windows.

Change-Id: Ia114fab90198f310a8b58117c0c53c9347353b6f
diff --git a/mdt/config.py b/mdt/config.py
index 4634e8f..dd0e004 100644
--- a/mdt/config.py
+++ b/mdt/config.py
@@ -2,10 +2,10 @@
 import os
 
 CONFIG_BASEDIR = os.path.join(os.path.expanduser("~"), ".config", "mdt")
-CONFIG_KEYSDIR = os.path.join(CONFIG_BASEDIR, "keys")
 CONFIG_ATTRDIR = os.path.join(CONFIG_BASEDIR, "attribs")
 
 DEFAULT_USERNAME = "mendel"
+DEFAULT_PASSWORD = "mendel"
 DEFAULT_SSH_COMMAND = "ssh"
 
 class Config:
@@ -15,8 +15,6 @@
     def ensureConfigDirExists(self):
         if not os.path.exists(CONFIG_BASEDIR):
             os.makedirs(CONFIG_BASEDIR, mode=0o700)
-        if not os.path.exists(CONFIG_KEYSDIR):
-            os.makedirs(CONFIG_KEYSDIR, mode=0o700)
         if not os.path.exists(CONFIG_ATTRDIR):
             os.makedirs(CONFIG_ATTRDIR, mode=0o700)
 
@@ -48,23 +46,16 @@
             return self.getAttribute("username", DEFAULT_USERNAME)
         self.setAttribute("username", username)
 
+    def password(self, password=None):
+        if not password:
+            return self.getAttribute("password", DEFAULT_PASSWORD)
+        self.setAttribute("password", password)
+
     def sshCommand(self, command=None):
         if not command:
             return self.getAttribute("ssh-command", DEFAULT_SSH_COMMAND)
         self.setAttribute("ssh-command", command)
 
-    def getKey(self, keyname):
-        path = os.path.join(CONFIG_KEYSDIR, keyname)
-        if os.path.exists(path):
-            with open(path, "r") as fp:
-                return fp.read()
-
-    def privateKey(self):
-        return getKey("mdt")
-
-    def publicKey(self):
-        return getKey("mdt.pub")
-
 
 class Get:
     def __init__(self):
diff --git a/mdt/keys.py b/mdt/keys.py
index ce2c5e5..dd62e2a 100644
--- a/mdt/keys.py
+++ b/mdt/keys.py
@@ -2,6 +2,9 @@
 import subprocess
 import os
 
+import paramiko
+from paramiko.ssh_exception import SSHException, PasswordRequiredException
+
 import config
 
 SUPPORTED_SYSTEMS = [
@@ -10,56 +13,48 @@
     'BSD',
 ]
 
+KEYSDIR = os.path.join(config.CONFIG_BASEDIR, "keys")
+KEYFILE_PATH = os.path.join(config.CONFIG_BASEDIR, "keys", "mdt.key")
+
 
 class Keystore:
     def __init__(self):
-        self.config = config.Config()
-        self.private_key = self.config.getKey("mdt")
-        self.public_key = self.config.getKey("mdt.pub")
-
-    def privateKeyPath(self):
-        return os.path.join(config.CONFIG_KEYSDIR, "mdt")
+        if not os.path.exists(config.CONFIG_BASEDIR):
+            os.makedirs(CONFIG_BASEDIR, mode=0o700)
+        if not os.path.exists(KEYSDIR):
+            os.makedirs(KEYSDIR, mode=0o700)
+        if not os.path.exists(KEYFILE_PATH):
+            self.pkey = None
+        else:
+            try:
+                self.pkey = paramiko.rsakey.RSAKey.from_private_key_file(KEYFILE_PATH)
+            except IOError as e:
+                print("Unable to read private key from file: {0}".format(e))
+                sys.exit(1)
+            except PasswordRequiredException as e:
+                print("Unable to load in private key: {0}".format(e))
+                sys.exit(1)
 
     def generateKey(self):
-        if platform.system() not in SUPPORTED_SYSTEMS:
-            print('Sorry, MDT doesn\'t support generating SSH keys on platforms other than:')
-            print('\n'.join(SUPPORTED_SYSTEMS))
-            return False
+        self.pkey = paramiko.rsakey.RSAKey.generate(bits=4096)
 
         try:
-            subprocess.run([
-                "ssh-keygen",
-                "-f",
-                self.privateKeyPath(),
-                "-P",
-                ""
-            ], check=True)
-        except FileNotFoundError as e:
-            print('Couldn\'t find ssh-keygen in your PATH.')
+            self.pkey.write_private_key_file(KEYFILE_PATH)
+        except IOError as e:
+            print("Unable to write private key to disk: {0}".format(e))
             return False
-        except subprocess.CalledProcessError as e:
-            print('Couldn\'t generate SSH keys.')
-            print('ssh-keygen failed with error code {0}'.format(e.returncode))
-            return False
+        else:
+            return True
 
-        self.private_key = self.config.getKey("mdt")
-        self.public_key = self.config.getKey("mdt.pub")
-
-        return True
-
-    def publicKey(self):
-        return self.public_key
-
-    def privateKey(self):
-        return self.private_key
-
-    def pushKey(self):
-        pass
+    def key(self):
+        return self.pkey
 
 
 class GenKey:
-    def __init__(self):
-        self.keystore = Keystore()
-
     def run(self, args):
-        self.keystore.generateKey()
+        if os.path.exists(KEYFILE_PATH):
+            os.unlink(KEYFILE_PATH)
+        keystore = Keystore()
+        if not keystore.generateKey():
+            return 1
+        return 0
diff --git a/mdt/shell.py b/mdt/shell.py
index 678ee3c..7e726c3 100644
--- a/mdt/shell.py
+++ b/mdt/shell.py
@@ -3,23 +3,115 @@
 import platform
 import subprocess
 import os
+import socket
+import select
+import sys
+import termios
+import tty
 
-import spur
+import paramiko
+from paramiko.ssh_exception import AuthenticationException, SSHException
 
 import discoverer
 import config
 import keys
 
+class KeyPushError(Exception):
+    pass
+
+class DefaultLoginError(Exception):
+    pass
+
+class SshClient:
+    def __init__(self, device, address):
+        self.config = config.Config()
+        self.keystore = keys.Keystore()
+
+        self.device = device
+        self.address = address
+
+        self.username = self.config.username()
+        self.password = self.config.password()
+        self.ssh_command = self.config.sshCommand()
+
+        if not self.maybeGenerateSshKeys():
+            return False
+
+        self.client = paramiko.SSHClient()
+        self.client.set_missing_host_key_policy(paramiko.client.AutoAddPolicy())
+
+    def _shouldPushKey(self):
+        try:
+            self.client.connect(
+                self.address,
+                username=self.username,
+                pkey=self.keystore.key(),
+                allow_agent=False,
+                look_for_keys=False,
+                compress=True)
+        except AuthenticationException as e:
+            return True
+        except (SSHException, socket.error) as e:
+            raise e
+        finally:
+            self.client.close()
+
+    def _pushKey(self):
+        try:
+            self.client.connect(
+                    self.address,
+                    username=self.username,
+                    password=self.password,
+                    allow_agent=False,
+                    look_for_keys=False,
+                    compress=True)
+        except AuthenticationException as e:
+            raise DefaultLoginError(e)
+        except (SSHException, socket.error) as e:
+            raise KeyPushError(e)
+        else:
+            public_key = self.keystore.key().get_base64()
+            self.client.exec_command('mkdir -p $HOME/.ssh')
+            self.client.exec_command(
+                'echo ssh-rsa {0} mdt@localhost >>$HOME/.ssh/authorized_keys'.format(public_key))
+        finally:
+            self.client.close()
+
+    def maybeGenerateSshKeys(self):
+        if not self.keystore.key():
+            print('Looks like you don\'t have a private key yet. Generating one.')
+
+            if not self.keystore.generateKey():
+                print('Unable to generate private key.')
+                return False
+
+        return True
+
+    def openShell(self):
+        term = os.getenv("TERM", default="vt100")
+        width, height = os.get_terminal_size()
+
+        if self._shouldPushKey():
+            print("Key not present on {0} -- pushing".format(self.device))
+            self._pushKey()
+
+        self.client.connect(
+            self.address,
+            username=self.username,
+            pkey=self.keystore.key(),
+            allow_agent=False,
+            look_for_keys=False,
+            compress=True)
+        return self.client.invoke_shell(term=term, width=width, height=height)
+
+    def close(self):
+        self.client.close()
+
+
 class Shell:
     def __init__(self):
         self.config = config.Config()
-        self.keystore = keys.Keystore()
         self.discoverer = discoverer.Discoverer(self)
-
-        self.username = self.config.username()
-        self.private_key = self.keystore.privateKey()
-        self.ssh_command = self.config.sshCommand()
-
         self.device = self.config.preferredDevice()
         self.address = None
 
@@ -34,30 +126,58 @@
         if len(args) > 1:
             self.device = args[1]
 
-        if not self.private_key:
-            # Need to call genkey first.
-            print('Looks like you don\'t have a private key yet. Generating one.')
-
-            if not self.keystore.generateKey():
-                print('Unable to generate private key.')
-                return 1
-
         if not self.address:
             if self.device:
-                print('Waiting for device "{0}"...'.format(self.device))
+                print('Waiting for device {0}...'.format(self.device))
             else:
                 print('Waiting for a device...')
 
             while not self.address:
                 sleep(0.1)
 
-            print('Found "{0}" at {1}'.format(self.device, self.address))
+        print('Connecting to {0} at {1}'.format(self.device, self.address))
+        client = SshClient(self.device, self.address)
 
-        os.execvp(
-            self.ssh_command, [
-                self.ssh_command,
-                '-oStrictHostKeyChecking=no',
-                '-i{0}'.format(self.keystore.privateKeyPath()),
-                '-Ct',
-                '@'.join([self.username, self.address])
-            ])
+        try:
+            channel = client.openShell()
+        except KeyPushError as e:
+            print("Unable to push keys to the device: {0}".format(e))
+            return 1
+        except DefaultLoginError as e:
+            print("Can't login using default credentials: {0}".format(e))
+            return 1
+        except SSHException as e:
+            print("Couldn't establish ssh connection to device: {0}".format(e))
+            return 1
+        except socket.error as e:
+            print("Couldn't establish ssh connection to device: {0}".format(e))
+            return 1
+
+        localtty = termios.tcgetattr(sys.stdin)
+        try:
+            tty.setraw(sys.stdin.fileno())
+            tty.setcbreak(sys.stdin.fileno())
+            channel.settimeout(0)
+
+            while True:
+                read, write, exception = select.select([channel, sys.stdin], [], [])
+
+                if channel in read:
+                    try:
+                        data = channel.recv(256)
+                        if len(data) == 0:
+                            sys.stdout.write("Connection to {0} at {1} closed.\r\n".format(self.device, self.address))
+                            break
+                        sys.stdout.write(data.decode("utf-8", errors="ignore"))
+                        sys.stdout.flush()
+                    except socket.timeout:
+                        sys.stdout.write("Connection to {0} at {1} closed: socket timeout\r\n".format(self.device, self.address))
+                        break
+                if sys.stdin in read:
+                    data = sys.stdin.read(1)
+                    if len(data) == 0:
+                        break
+                    channel.send(data)
+        finally:
+            termios.tcsetattr(sys.stdin, termios.TCSADRAIN, localtty)
+            client.close()