passwordauth: Disable password authentication once a key is pushed

Change-Id: I4714fa655d1f2c36bc97d6b7eb867062c202c34e
diff --git a/mdt/config.py b/mdt/config.py
index 62cc85d..8b79ec1 100644
--- a/mdt/config.py
+++ b/mdt/config.py
@@ -24,7 +24,7 @@
 
 DEFAULT_USERNAME = "mendel"
 DEFAULT_PASSWORD = "mendel"
-DEFAULT_SSH_COMMAND = "ssh"
+DEFAULT_DISABLE_PASSWD_AUTH = "true"
 
 
 class Config:
@@ -70,10 +70,11 @@
             return self.getAttribute("password", DEFAULT_PASSWORD)
         self.setAttribute("password", password)
 
-    def sshCommand(self, command=None):
-        if not command:
-            return self.getAttribute("ssh-command", DEFAULT_SSH_COMMAND)
-        self.setAttribute("ssh-command", command)
+    def shouldDisablePasswordAuth(self, disablePasswdAuth=None):
+        if disablePasswdAuth == None:
+            return self.getAttribute("disable-password-auth",
+                                     DEFAULT_DISABLE_PASSWD_AUTH)
+        self.setAttribute("disable-password-auth", disablePasswdAuth)
 
 
 class GetCommand:
diff --git a/mdt/sshclient.py b/mdt/sshclient.py
index 1001131..e018559 100644
--- a/mdt/sshclient.py
+++ b/mdt/sshclient.py
@@ -35,6 +35,10 @@
     pass
 
 
+class PasswordAuthDisableError(Exception):
+    pass
+
+
 class SshClient:
     def __init__(self, device, address):
         self.config = config.Config()
@@ -45,7 +49,6 @@
 
         self.username = self.config.username()
         self.password = self.config.password()
-        self.ssh_command = self.config.sshCommand()
 
         if not self.maybeGenerateSshKeys():
             return False
@@ -91,6 +94,29 @@
         finally:
             self.client.close()
 
+        # Ensure the key we just pushed allows us to login
+        try:
+            self.client.connect(
+                self.address,
+                username=self.username,
+                pkey=self.keystore.key(),
+                allow_agent=False,
+                look_for_keys=False,
+                compress=True)
+
+            if self.config.shouldDisablePasswordAuth() == "true":
+                print("Key push successful: disabling password authentication")
+                self.client.exec_command(
+                    "sudo sed -e's/^#?PasswordAuthentication.*/PasswordAuthentication no/g' -i "
+                    "/etc/ssh/sshd_config")
+                self.client.exec_command('sudo systemctl restart ssh')
+        except AuthenticationException as e:
+            raise KeyPushError(e)
+        except (SSHException, socket.error) as e:
+            raise KeyPushError(e)
+        finally:
+            self.client.close()
+
     def maybeGenerateSshKeys(self):
         if not self.keystore.key():
             print('Looks like you don\'t have a private key yet. '