shell: Split console handling to a separate class

This splits out the console handling routines into their own classes, uses
exceptions to great effect, and also provides preliminary Windows console
support via multithreading.

Change-Id: I46d4b227cc62208d9698339957c9f348ed01a8b7
diff --git a/mdt/console.py b/mdt/console.py
new file mode 100644
index 0000000..3b952bd
--- /dev/null
+++ b/mdt/console.py
@@ -0,0 +1,124 @@
+import platform
+import threading
+import queue
+import os
+import socket
+import select
+import sys
+
+
+class ConnectionClosedError(Exception):
+    pass
+
+
+class SocketTimeoutError(Exception):
+    pass
+
+
+class PosixConsole:
+    def __init__(self, channel, inputfile):
+        self.channel = channel
+        self.inputfile = inputfile
+
+    def run(self):
+        import termios
+        import tty
+
+        localtty = termios.tcgetattr(sys.stdin)
+        try:
+            tty.setraw(sys.stdin.fileno())
+            tty.setcbreak(sys.stdin.fileno())
+            self.channel.settimeout(0)
+
+            while True:
+                read, write, exception = select.select([self.channel, sys.stdin], [], [])
+
+                if self.channel in read:
+                    try:
+                        data = self.channel.recv(256)
+                        if len(data) == 0:
+                            raise ConnectionClosedError()
+                        sys.stdout.write(data.decode("utf-8", errors="ignore"))
+                        sys.stdout.flush()
+                    except socket.timeout as e:
+                        raise SocketTimeoutError(e)
+
+                if sys.stdin in read:
+                    data = sys.stdin.read(1)
+                    if len(data) == 0:
+                        raise ConnectionClosedError()
+                    self.channel.send(data)
+        finally:
+            termios.tcsetattr(sys.stdin, termios.TCSADRAIN, localtty)
+
+
+class WindowsConsole:
+    TYPE_KEYBOARD_INPUT = 0
+    TYPE_TERMINAL_OUTPUT = 1
+    TYPE_REMOTE_CLOSED = 2
+    TYPE_SOCKET_TIMEOUT = 3
+
+    class KeyboardInputThread(threading.Thread):
+        def __init__(self, queue):
+            super(KeyboardInputThread, self).__init__()
+            self.daemon = True
+            self.queue = queue
+
+        def run(self):
+            while True:
+                ch = sys.stdin.read(1)
+                self.queue.put((KEYBOARD_INPUT_DATA, ch))
+
+    class TerminalOutputThread(threading.Thread):
+        def __init__(self, queue, channel):
+            super(TerminalOutputThread, self).__init__()
+            self.daemon = True
+            self.queue = queue
+            self.channel = channel
+
+        def run(self):
+            while True:
+                try:
+                    data = self.channel.recv(256)
+                    if len(data) == 0:
+                        self.queue.put((TYPE_REMOTE_CLOSED, None))
+                        break
+                    data = data.decode("utf-8", errors="ignore")
+                    self.queue.put((TYPE_TERMINAL_OUTPUT, data))
+                except socket.timeout:
+                    self.queue.put((TYPE_SOCKET_TIMEOUT, None))
+                    break
+
+    def __init__(self, channel, inputfile):
+        self.channel = channel
+        self.dataQueue = queue.Queue()
+        self.inputThread = KeyboardInputThread(self.dataQueue)
+        self.outputThread = TerminalOutputThread(self.dataQueue, self.channel)
+
+    def run(self):
+        self.inputThread.start()
+        self.outputThread.start()
+
+        while True:
+            dataType, data = self.queue.get()
+
+            if dataType == TYPE_KEYBOARD_INPUT:
+                channel.send(data)
+            if dataType == TYPE_TERMINAL_OUTPUT:
+                sys.stdout.write(data)
+                sys.stdout.flush()
+            if dataType == TYPE_REMOTE_CLOSED:
+                raise ConnectionClosedError()
+            if dataType == TYPE_SOCKET_TIMEOUT:
+                raise SocketTimeoutError()
+
+
+class Console:
+    def __init__(self, channel, inputfile):
+        if os.name == 'nt':
+            self._console = WindowsConsole(channel, inputfile)
+        else:
+            self._console = PosixConsole(channel, inputfile)
+
+    def run(self):
+        return self._console.run()
diff --git a/mdt/shell.py b/mdt/shell.py
index 7e726c3..f10651a 100644
--- a/mdt/shell.py
+++ b/mdt/shell.py
@@ -14,6 +14,7 @@
 
 import discoverer
 import config
+import console
 import keys
 
 class KeyPushError(Exception):
@@ -136,10 +137,12 @@
                 sleep(0.1)
 
         print('Connecting to {0} at {1}'.format(self.device, self.address))
-        client = SshClient(self.device, self.address)
 
         try:
+            client = SshClient(self.device, self.address)
             channel = client.openShell()
+            cons = console.Console(channel, sys.stdin)
+            cons.run()
         except KeyPushError as e:
             print("Unable to push keys to the device: {0}".format(e))
             return 1
@@ -152,32 +155,11 @@
         except socket.error as e:
             print("Couldn't establish ssh connection to device: {0}".format(e))
             return 1
-
-        localtty = termios.tcgetattr(sys.stdin)
-        try:
-            tty.setraw(sys.stdin.fileno())
-            tty.setcbreak(sys.stdin.fileno())
-            channel.settimeout(0)
-
-            while True:
-                read, write, exception = select.select([channel, sys.stdin], [], [])
-
-                if channel in read:
-                    try:
-                        data = channel.recv(256)
-                        if len(data) == 0:
-                            sys.stdout.write("Connection to {0} at {1} closed.\r\n".format(self.device, self.address))
-                            break
-                        sys.stdout.write(data.decode("utf-8", errors="ignore"))
-                        sys.stdout.flush()
-                    except socket.timeout:
-                        sys.stdout.write("Connection to {0} at {1} closed: socket timeout\r\n".format(self.device, self.address))
-                        break
-                if sys.stdin in read:
-                    data = sys.stdin.read(1)
-                    if len(data) == 0:
-                        break
-                    channel.send(data)
+        except console.SocketTimeoutError as e:
+            print("Connection to {0} at {1} closed: socket timeout".format(self.device, self.address))
+            return 1
+        except console.ConnectionClosedError as e:
+            print("Connection to {0} at {1} closed".format(self.device, self.address))
+            return 0
         finally:
-            termios.tcsetattr(sys.stdin, termios.TCSADRAIN, localtty)
             client.close()