Get up to generating a CSR

This commit is contained in:
Ian Gulliver
2016-04-10 22:28:30 +00:00
parent d38a1fd25c
commit 88a58e60b8

View File

@@ -2,14 +2,22 @@
import argparse
from oauth2client import client
import os
from urllib import parse
import requests
from http import server
import socket
import ssl
import subprocess
import tempfile
parser = argparse.ArgumentParser(description='oauthproxy')
parser.add_argument(
'--allowed-domain',
dest='allowed_domain',
action='store',
required=True)
parser.add_argument(
'--api-key',
dest='api_key',
@@ -36,6 +44,11 @@ parser.add_argument(
dest='server_cert',
action='store',
required=True)
parser.add_argument(
'--subject',
dest='subject',
action='store',
required=True)
FLAGS = parser.parse_args()
@@ -45,8 +58,10 @@ class HTTPServer6(server.HTTPServer):
class OAuthProxy(object):
def __init__(self, listen_host, listen_port, server_key, server_cert, api_key):
def __init__(self, listen_host, listen_port, server_key, server_cert, api_key, allowed_domain, subject):
self._api_key = api_key
self._allowed_domain = allowed_domain
self._subject = subject
HANDLERS = {
'/': self._ServeRedirect,
@@ -81,9 +96,27 @@ class OAuthProxy(object):
])
return client.flow_from_clientsecrets(
'client_secrets.json',
login_hint=self._allowed_domain,
scope='https://www.googleapis.com/auth/userinfo.email',
redirect_uri=return_url)
def _GetCert(self, email):
with tempfile.TemporaryDirectory() as td:
key_path = os.path.join(td, 'key.pem')
subprocess.check_call([
'openssl', 'ecparam', '-genkey',
'-name', 'secp384r1',
'-out', key_path,
])
csr_path = os.path.join(td, 'csr.pem')
subprocess.check_call([
'openssl', 'req', '-new',
'-key', key_path,
'-out', csr_path,
'-subj', self._subject.replace('EMAIL', email),
])
return open(csr_path, 'rb').read()
def _ServeRedirect(self, req):
req.send_response(302)
req.send_header('Location', self._GetFlow(req).step1_get_authorize_url())
@@ -102,9 +135,12 @@ class OAuthProxy(object):
for x in result.json()['emails']
if x['type'] == 'account'
]
email = emails[0]
assert email.endswith('@%s' % self._allowed_domain)
result = self._GetCert(email)
req.send_response(200)
req.end_headers()
req.wfile.write(emails[0].encode('utf8'))
req.wfile.write(result)
def main():
@@ -113,7 +149,9 @@ def main():
FLAGS.listen_port,
FLAGS.server_key,
FLAGS.server_cert,
FLAGS.api_key)
FLAGS.api_key,
FLAGS.allowed_domain,
FLAGS.subject)
server.Serve()