diff --git a/oauthproxy.py b/oauthproxy.py index 2301d2d..f13f55a 100755 --- a/oauthproxy.py +++ b/oauthproxy.py @@ -2,14 +2,22 @@ import argparse from oauth2client import client +import os from urllib import parse import requests from http import server import socket import ssl +import subprocess +import tempfile parser = argparse.ArgumentParser(description='oauthproxy') +parser.add_argument( + '--allowed-domain', + dest='allowed_domain', + action='store', + required=True) parser.add_argument( '--api-key', dest='api_key', @@ -36,6 +44,11 @@ parser.add_argument( dest='server_cert', action='store', required=True) +parser.add_argument( + '--subject', + dest='subject', + action='store', + required=True) FLAGS = parser.parse_args() @@ -45,8 +58,10 @@ class HTTPServer6(server.HTTPServer): class OAuthProxy(object): - def __init__(self, listen_host, listen_port, server_key, server_cert, api_key): + def __init__(self, listen_host, listen_port, server_key, server_cert, api_key, allowed_domain, subject): self._api_key = api_key + self._allowed_domain = allowed_domain + self._subject = subject HANDLERS = { '/': self._ServeRedirect, @@ -81,9 +96,27 @@ class OAuthProxy(object): ]) return client.flow_from_clientsecrets( 'client_secrets.json', + login_hint=self._allowed_domain, scope='https://www.googleapis.com/auth/userinfo.email', redirect_uri=return_url) + def _GetCert(self, email): + with tempfile.TemporaryDirectory() as td: + key_path = os.path.join(td, 'key.pem') + subprocess.check_call([ + 'openssl', 'ecparam', '-genkey', + '-name', 'secp384r1', + '-out', key_path, + ]) + csr_path = os.path.join(td, 'csr.pem') + subprocess.check_call([ + 'openssl', 'req', '-new', + '-key', key_path, + '-out', csr_path, + '-subj', self._subject.replace('EMAIL', email), + ]) + return open(csr_path, 'rb').read() + def _ServeRedirect(self, req): req.send_response(302) req.send_header('Location', self._GetFlow(req).step1_get_authorize_url()) @@ -102,9 +135,12 @@ class OAuthProxy(object): for x in result.json()['emails'] if x['type'] == 'account' ] + email = emails[0] + assert email.endswith('@%s' % self._allowed_domain) + result = self._GetCert(email) req.send_response(200) req.end_headers() - req.wfile.write(emails[0].encode('utf8')) + req.wfile.write(result) def main(): @@ -113,7 +149,9 @@ def main(): FLAGS.listen_port, FLAGS.server_key, FLAGS.server_cert, - FLAGS.api_key) + FLAGS.api_key, + FLAGS.allowed_domain, + FLAGS.subject) server.Serve()