blob: 3dced3082317e7f5cfc70ec06cf20e808382bdc5 [file] [log] [blame]
#!/usr/bin/python3 -u
import argparse
import os
import re
import requests
import sys
import socket
import struct
import time
from multiprocessing import Process
from multiprocessing import Event
from http.server import HTTPServer
from http.server import BaseHTTPRequestHandler
import netifaces
USERNAME = 'mendel'
GROUP = 'mendel'
SSH_BASE_PATH = '/home/mendel/.ssh'
AUTHORIZED_KEYS_PATH = os.path.join(SSH_BASE_PATH, 'authorized_keys')
SERVER_PORT = 41337
MDT_KEY_REGEX = re.compile(r'^ssh-rsa .* mdt$')
MDT_EXIT_REGEX = re.compile(r'^exit$')
def get_iface_address(iface):
""" Retreive ip address from a given interface. """
addresses = netifaces.ifaddresses(iface)
inet4_addresses = addresses[socket.AF_INET]
ip_address = inet4_addresses[0]['addr']
return ip_address
def key_exists():
""" Check for existence of key. """
if os.path.exists(AUTHORIZED_KEYS_PATH):
print('authorized_keys file already exists.\n', flush=True)
with open(AUTHORIZED_KEYS_PATH) as fp:
for line in fp.readlines():
if re.match(MDT_KEY_REGEX, line):
print('authorized_keys file contains MDT key already.\n', flush=True)
return True
print('authorized_keys file does not contain an MDT key.\n', flush=True)
return False
class KeyPushHandler(BaseHTTPRequestHandler):
""" Handle keys sent from client and install. """
key_handled = Event()
def do_PUT(self):
""" Verify and install key. """
if not KeyPushHandler.key_handled.is_set():
content_length = int(self.headers.get('Content-Length', 0))
public_key = self.rfile.read(content_length)
if not re.match(MDT_KEY_REGEX, public_key.decode("utf-8").strip()):
if re.match(MDT_EXIT_REGEX, public_key.decode("utf-8").strip()):
KeyPushHandler.key_handled.set()
print("Received request to exit.", flush=True)
self.send_response(200, "Ok")
self.end_headers()
return
KeyPushHandler.key_handled.clear()
print("Received invalid key.", flush=True)
self.send_response(400, "Invalid key")
self.end_headers()
return
else:
KeyPushHandler.key_handled.set()
print("Received valid key.", flush=True)
if not os.path.exists(SSH_BASE_PATH):
os.mkdir(SSH_BASE_PATH, 0o700)
with open(AUTHORIZED_KEYS_PATH, 'ab') as fp:
fp.write(public_key)
sys.stdout.write('authorized_keys file written -- exiting.\n')
sys.stdout.flush()
self.send_response(200, "Ok")
self.end_headers()
class HTTPServerProcess(Process):
""" Process for starting a HttpServer """
def __init__(self, bind_address, *args, **kwargs):
super(HTTPServerProcess, self).__init__(*args, **kwargs)
self.bind_address = bind_address
def run(self):
""" Bind and start server on address."""
server_address = (self.bind_address, SERVER_PORT)
httpd = HTTPServer(server_address, KeyPushHandler)
print('Waiting for incoming PUT on {0}:{1}\n'.format(self.bind_address,
SERVER_PORT), flush=True)
while not KeyPushHandler.key_handled.is_set():
httpd.handle_request()
def send_exit_command(self):
""" Send command to server to exit. """
print("Request exit for {0}".format(self.bind_address, SERVER_PORT))
try:
requests.put("http://{0}:{1}".format(self.bind_address, SERVER_PORT), data="exit")
except requests.exceptions.ConnectionError:
pass
def main():
parser = argparse.ArgumentParser(description="MDT Keymaster Server.")
parser.add_argument("--interface", help="Listen on selected interfaces.",
action="append", default=[])
args = parser.parse_args()
interfaces = list(set(args.interface))
# If key is installed exit.
if key_exists():
print('exiting.\n', flush=True)
sys.exit(0)
ifaces = args.interface
bind_addresses = list()
delay = 1
while len(bind_addresses) != len(ifaces):
for iface in ifaces:
try:
bind_addresses.append(get_iface_address(iface))
except Exception as e:
print('Unable to determine {0} IPv4 address: {1}.\n'.format(iface, e), flush=True)
print('Waiting {0} seconds before retrying.'.format(delay), flush=True)
time.sleep(delay)
delay = min(10, delay * 2)
server_processes = []
for bind_address in bind_addresses:
server_processes.append((bind_address,
HTTPServerProcess(bind_address)))
for address, process in server_processes:
process.start()
while not KeyPushHandler.key_handled.is_set():
time.sleep(1.0)
for address, process in server_processes:
process.send_exit_command()
process.join()
print('Exiting.\n', flush=True)
sys.exit(0)
if __name__ == '__main__':
main()