shell: Add in a functional shell command
Currently this only works on UNIX-like systems due to the use of termios, so it
will break on Windows.
Change-Id: Ia114fab90198f310a8b58117c0c53c9347353b6f
diff --git a/mdt/config.py b/mdt/config.py
index 4634e8f..dd0e004 100644
--- a/mdt/config.py
+++ b/mdt/config.py
@@ -2,10 +2,10 @@
import os
CONFIG_BASEDIR = os.path.join(os.path.expanduser("~"), ".config", "mdt")
-CONFIG_KEYSDIR = os.path.join(CONFIG_BASEDIR, "keys")
CONFIG_ATTRDIR = os.path.join(CONFIG_BASEDIR, "attribs")
DEFAULT_USERNAME = "mendel"
+DEFAULT_PASSWORD = "mendel"
DEFAULT_SSH_COMMAND = "ssh"
class Config:
@@ -15,8 +15,6 @@
def ensureConfigDirExists(self):
if not os.path.exists(CONFIG_BASEDIR):
os.makedirs(CONFIG_BASEDIR, mode=0o700)
- if not os.path.exists(CONFIG_KEYSDIR):
- os.makedirs(CONFIG_KEYSDIR, mode=0o700)
if not os.path.exists(CONFIG_ATTRDIR):
os.makedirs(CONFIG_ATTRDIR, mode=0o700)
@@ -48,23 +46,16 @@
return self.getAttribute("username", DEFAULT_USERNAME)
self.setAttribute("username", username)
+ def password(self, password=None):
+ if not password:
+ return self.getAttribute("password", DEFAULT_PASSWORD)
+ self.setAttribute("password", password)
+
def sshCommand(self, command=None):
if not command:
return self.getAttribute("ssh-command", DEFAULT_SSH_COMMAND)
self.setAttribute("ssh-command", command)
- def getKey(self, keyname):
- path = os.path.join(CONFIG_KEYSDIR, keyname)
- if os.path.exists(path):
- with open(path, "r") as fp:
- return fp.read()
-
- def privateKey(self):
- return getKey("mdt")
-
- def publicKey(self):
- return getKey("mdt.pub")
-
class Get:
def __init__(self):
diff --git a/mdt/keys.py b/mdt/keys.py
index ce2c5e5..dd62e2a 100644
--- a/mdt/keys.py
+++ b/mdt/keys.py
@@ -2,6 +2,9 @@
import subprocess
import os
+import paramiko
+from paramiko.ssh_exception import SSHException, PasswordRequiredException
+
import config
SUPPORTED_SYSTEMS = [
@@ -10,56 +13,48 @@
'BSD',
]
+KEYSDIR = os.path.join(config.CONFIG_BASEDIR, "keys")
+KEYFILE_PATH = os.path.join(config.CONFIG_BASEDIR, "keys", "mdt.key")
+
class Keystore:
def __init__(self):
- self.config = config.Config()
- self.private_key = self.config.getKey("mdt")
- self.public_key = self.config.getKey("mdt.pub")
-
- def privateKeyPath(self):
- return os.path.join(config.CONFIG_KEYSDIR, "mdt")
+ if not os.path.exists(config.CONFIG_BASEDIR):
+ os.makedirs(CONFIG_BASEDIR, mode=0o700)
+ if not os.path.exists(KEYSDIR):
+ os.makedirs(KEYSDIR, mode=0o700)
+ if not os.path.exists(KEYFILE_PATH):
+ self.pkey = None
+ else:
+ try:
+ self.pkey = paramiko.rsakey.RSAKey.from_private_key_file(KEYFILE_PATH)
+ except IOError as e:
+ print("Unable to read private key from file: {0}".format(e))
+ sys.exit(1)
+ except PasswordRequiredException as e:
+ print("Unable to load in private key: {0}".format(e))
+ sys.exit(1)
def generateKey(self):
- if platform.system() not in SUPPORTED_SYSTEMS:
- print('Sorry, MDT doesn\'t support generating SSH keys on platforms other than:')
- print('\n'.join(SUPPORTED_SYSTEMS))
- return False
+ self.pkey = paramiko.rsakey.RSAKey.generate(bits=4096)
try:
- subprocess.run([
- "ssh-keygen",
- "-f",
- self.privateKeyPath(),
- "-P",
- ""
- ], check=True)
- except FileNotFoundError as e:
- print('Couldn\'t find ssh-keygen in your PATH.')
+ self.pkey.write_private_key_file(KEYFILE_PATH)
+ except IOError as e:
+ print("Unable to write private key to disk: {0}".format(e))
return False
- except subprocess.CalledProcessError as e:
- print('Couldn\'t generate SSH keys.')
- print('ssh-keygen failed with error code {0}'.format(e.returncode))
- return False
+ else:
+ return True
- self.private_key = self.config.getKey("mdt")
- self.public_key = self.config.getKey("mdt.pub")
-
- return True
-
- def publicKey(self):
- return self.public_key
-
- def privateKey(self):
- return self.private_key
-
- def pushKey(self):
- pass
+ def key(self):
+ return self.pkey
class GenKey:
- def __init__(self):
- self.keystore = Keystore()
-
def run(self, args):
- self.keystore.generateKey()
+ if os.path.exists(KEYFILE_PATH):
+ os.unlink(KEYFILE_PATH)
+ keystore = Keystore()
+ if not keystore.generateKey():
+ return 1
+ return 0
diff --git a/mdt/shell.py b/mdt/shell.py
index 678ee3c..7e726c3 100644
--- a/mdt/shell.py
+++ b/mdt/shell.py
@@ -3,23 +3,115 @@
import platform
import subprocess
import os
+import socket
+import select
+import sys
+import termios
+import tty
-import spur
+import paramiko
+from paramiko.ssh_exception import AuthenticationException, SSHException
import discoverer
import config
import keys
+class KeyPushError(Exception):
+ pass
+
+class DefaultLoginError(Exception):
+ pass
+
+class SshClient:
+ def __init__(self, device, address):
+ self.config = config.Config()
+ self.keystore = keys.Keystore()
+
+ self.device = device
+ self.address = address
+
+ self.username = self.config.username()
+ self.password = self.config.password()
+ self.ssh_command = self.config.sshCommand()
+
+ if not self.maybeGenerateSshKeys():
+ return False
+
+ self.client = paramiko.SSHClient()
+ self.client.set_missing_host_key_policy(paramiko.client.AutoAddPolicy())
+
+ def _shouldPushKey(self):
+ try:
+ self.client.connect(
+ self.address,
+ username=self.username,
+ pkey=self.keystore.key(),
+ allow_agent=False,
+ look_for_keys=False,
+ compress=True)
+ except AuthenticationException as e:
+ return True
+ except (SSHException, socket.error) as e:
+ raise e
+ finally:
+ self.client.close()
+
+ def _pushKey(self):
+ try:
+ self.client.connect(
+ self.address,
+ username=self.username,
+ password=self.password,
+ allow_agent=False,
+ look_for_keys=False,
+ compress=True)
+ except AuthenticationException as e:
+ raise DefaultLoginError(e)
+ except (SSHException, socket.error) as e:
+ raise KeyPushError(e)
+ else:
+ public_key = self.keystore.key().get_base64()
+ self.client.exec_command('mkdir -p $HOME/.ssh')
+ self.client.exec_command(
+ 'echo ssh-rsa {0} mdt@localhost >>$HOME/.ssh/authorized_keys'.format(public_key))
+ finally:
+ self.client.close()
+
+ def maybeGenerateSshKeys(self):
+ if not self.keystore.key():
+ print('Looks like you don\'t have a private key yet. Generating one.')
+
+ if not self.keystore.generateKey():
+ print('Unable to generate private key.')
+ return False
+
+ return True
+
+ def openShell(self):
+ term = os.getenv("TERM", default="vt100")
+ width, height = os.get_terminal_size()
+
+ if self._shouldPushKey():
+ print("Key not present on {0} -- pushing".format(self.device))
+ self._pushKey()
+
+ self.client.connect(
+ self.address,
+ username=self.username,
+ pkey=self.keystore.key(),
+ allow_agent=False,
+ look_for_keys=False,
+ compress=True)
+ return self.client.invoke_shell(term=term, width=width, height=height)
+
+ def close(self):
+ self.client.close()
+
+
class Shell:
def __init__(self):
self.config = config.Config()
- self.keystore = keys.Keystore()
self.discoverer = discoverer.Discoverer(self)
-
- self.username = self.config.username()
- self.private_key = self.keystore.privateKey()
- self.ssh_command = self.config.sshCommand()
-
self.device = self.config.preferredDevice()
self.address = None
@@ -34,30 +126,58 @@
if len(args) > 1:
self.device = args[1]
- if not self.private_key:
- # Need to call genkey first.
- print('Looks like you don\'t have a private key yet. Generating one.')
-
- if not self.keystore.generateKey():
- print('Unable to generate private key.')
- return 1
-
if not self.address:
if self.device:
- print('Waiting for device "{0}"...'.format(self.device))
+ print('Waiting for device {0}...'.format(self.device))
else:
print('Waiting for a device...')
while not self.address:
sleep(0.1)
- print('Found "{0}" at {1}'.format(self.device, self.address))
+ print('Connecting to {0} at {1}'.format(self.device, self.address))
+ client = SshClient(self.device, self.address)
- os.execvp(
- self.ssh_command, [
- self.ssh_command,
- '-oStrictHostKeyChecking=no',
- '-i{0}'.format(self.keystore.privateKeyPath()),
- '-Ct',
- '@'.join([self.username, self.address])
- ])
+ try:
+ channel = client.openShell()
+ except KeyPushError as e:
+ print("Unable to push keys to the device: {0}".format(e))
+ return 1
+ except DefaultLoginError as e:
+ print("Can't login using default credentials: {0}".format(e))
+ return 1
+ except SSHException as e:
+ print("Couldn't establish ssh connection to device: {0}".format(e))
+ return 1
+ except socket.error as e:
+ print("Couldn't establish ssh connection to device: {0}".format(e))
+ return 1
+
+ localtty = termios.tcgetattr(sys.stdin)
+ try:
+ tty.setraw(sys.stdin.fileno())
+ tty.setcbreak(sys.stdin.fileno())
+ channel.settimeout(0)
+
+ while True:
+ read, write, exception = select.select([channel, sys.stdin], [], [])
+
+ if channel in read:
+ try:
+ data = channel.recv(256)
+ if len(data) == 0:
+ sys.stdout.write("Connection to {0} at {1} closed.\r\n".format(self.device, self.address))
+ break
+ sys.stdout.write(data.decode("utf-8", errors="ignore"))
+ sys.stdout.flush()
+ except socket.timeout:
+ sys.stdout.write("Connection to {0} at {1} closed: socket timeout\r\n".format(self.device, self.address))
+ break
+ if sys.stdin in read:
+ data = sys.stdin.read(1)
+ if len(data) == 0:
+ break
+ channel.send(data)
+ finally:
+ termios.tcsetattr(sys.stdin, termios.TCSADRAIN, localtty)
+ client.close()