Add fetcher and imager support for HTTPS client/server auth.

This commit is contained in:
Ian Gulliver
2016-04-05 17:39:16 -07:00
parent 1086d58569
commit a49872f808
2 changed files with 66 additions and 16 deletions

View File

@@ -6,12 +6,12 @@ import json
import hashlib import hashlib
import os import os
import re import re
import requests
import shutil import shutil
import socket import socket
import struct import struct
import subprocess import subprocess
import tempfile import tempfile
import urllib.request
from OpenSSL import crypto from OpenSSL import crypto
@@ -26,6 +26,18 @@ parser.add_argument(
dest='ca_cert', dest='ca_cert',
action='store', action='store',
required=True) required=True)
parser.add_argument(
'--https-ca-cert',
dest='https_ca_cert',
action='store')
parser.add_argument(
'--https-client-cert',
dest='https_client_cert',
action='store')
parser.add_argument(
'--https-client-key',
dest='https_client_key',
action='store')
parser.add_argument( parser.add_argument(
'--image-dir', '--image-dir',
dest='image_dir', dest='image_dir',
@@ -62,10 +74,15 @@ class Fetcher(object):
_MAX_BP = 10000 _MAX_BP = 10000
_FILE_REGEX = re.compile('^(?P<timestamp>\d+)\.iso$') _FILE_REGEX = re.compile('^(?P<timestamp>\d+)\.iso$')
def __init__(self, base_url, ca_cert, image_dir): def __init__(self, base_url, ca_cert, image_dir, https_ca_cert, https_client_cert, https_client_key):
self._base_url = base_url self._base_url = base_url
self._ca_cert_path = ca_cert self._ca_cert_path = ca_cert
self._image_dir = image_dir self._image_dir = image_dir
self._session = requests.Session()
if https_ca_cert:
self._session.verify = https_ca_cert
if https_client_cert and https_client_key:
self._session.cert = (https_client_cert, https_client_key)
def _VerifyChain(self, untrusted_certs, cert): def _VerifyChain(self, untrusted_certs, cert):
tempdir = tempfile.mkdtemp() tempdir = tempfile.mkdtemp()
@@ -105,8 +122,8 @@ class Fetcher(object):
def _GetManifest(self): def _GetManifest(self):
url = '%s/manifest.json' % (self._base_url) url = '%s/manifest.json' % (self._base_url)
resp = urllib.request.urlopen(url).read().decode('utf8') resp = self._session.get(url)
unwrapped = self._Unwrap(json.loads(resp)) unwrapped = self._Unwrap(resp.json())
self._ValidateManifest(unwrapped) self._ValidateManifest(unwrapped)
return unwrapped return unwrapped
@@ -148,15 +165,12 @@ class Fetcher(object):
url = '%s/%s' % (self._base_url, filename) url = '%s/%s' % (self._base_url, filename)
print('Fetching:', url) print('Fetching:', url)
resp = urllib.request.urlopen(url) resp = self._session.get(url, stream=True)
hash_obj = hashlib.sha256() hash_obj = hashlib.sha256()
try: try:
fh = tempfile.NamedTemporaryFile(dir=self._image_dir, delete=False) fh = tempfile.NamedTemporaryFile(dir=self._image_dir, delete=False)
while True: for data in resp.iter_content(self._BUF_SIZE):
data = resp.read(self._BUF_SIZE)
if not data:
break
hash_obj.update(data) hash_obj.update(data)
fh.write(data) fh.write(data)
if hash_obj.hexdigest() != image['hash']: if hash_obj.hexdigest() != image['hash']:
@@ -207,7 +221,13 @@ class Fetcher(object):
def main(): def main():
fetcher = Fetcher(FLAGS.base_url, FLAGS.ca_cert, FLAGS.image_dir) fetcher = Fetcher(
FLAGS.base_url,
FLAGS.ca_cert,
FLAGS.image_dir,
FLAGS.https_ca_cert,
FLAGS.https_client_cert,
FLAGS.https_client_key)
fetcher.Fetch() fetcher.Fetch()
fetcher.DeleteOldImages(FLAGS.max_images) fetcher.DeleteOldImages(FLAGS.max_images)

View File

@@ -20,6 +20,18 @@ parser.add_argument(
dest='ca_cert', dest='ca_cert',
action='store', action='store',
required=True) required=True)
parser.add_argument(
'--https-ca-cert',
dest='https_ca_cert',
action='store')
parser.add_argument(
'--https-client-cert',
dest='https_client_cert',
action='store')
parser.add_argument(
'--https-client-key',
dest='https_client_key',
action='store')
parser.add_argument( parser.add_argument(
'--device', '--device',
dest='device', dest='device',
@@ -36,11 +48,23 @@ FLAGS = parser.parse_args()
class Imager(object): class Imager(object):
def __init__(self, device, persistent_percent, base_url, ca_cert): def __init__(self, device, persistent_percent, base_url, ca_cert, https_ca_cert, https_client_cert, https_client_key):
self._device = device self._device = device
self._persistent_percent = persistent_percent self._persistent_percent = persistent_percent
self._base_url = base_url
self._ca_cert = ca_cert self._fetcher_args = [
'--base-url', base_url,
'--ca-cert', ca_cert,
]
if https_ca_cert:
self._fetcher_args.extend([
'--https-ca-cert', https_ca_cert,
])
if https_client_cert and https_client_key:
self._fetcher_args.extend([
'--https-client-cert', https_client_cert,
'--https-client-key', https_client_key,
])
self._icon_path = os.path.dirname(sys.argv[0]) self._icon_path = os.path.dirname(sys.argv[0])
@@ -123,8 +147,7 @@ class Imager(object):
self._Exec( self._Exec(
fetcher, fetcher,
'--image-dir', image_path, '--image-dir', image_path,
'--base-url', self._base_url, *self._fetcher_args)
'--ca-cert', self._ca_cert)
return image_path return image_path
@@ -160,7 +183,14 @@ class Imager(object):
def main(): def main():
imager = Imager(FLAGS.device, FLAGS.persistent_percent, FLAGS.base_url, FLAGS.ca_cert) imager = Imager(
FLAGS.device,
FLAGS.persistent_percent,
FLAGS.base_url,
FLAGS.ca_cert,
FLAGS.https_ca_cert,
FLAGS.https_client_cert,
FLAGS.https_client_key)
imager.Image() imager.Image()