Files
iconograph/client/fetcher.py

195 lines
5.4 KiB
Python
Raw Normal View History

2016-03-28 22:06:04 -07:00
#!/usr/bin/python3
2016-03-29 17:43:21 -07:00
import codecs
2016-03-28 22:06:04 -07:00
import json
import hashlib
2016-03-29 17:43:21 -07:00
import os
2016-04-02 13:22:24 -07:00
import re
import requests
2016-03-29 17:43:21 -07:00
import shutil
2016-03-28 22:06:04 -07:00
import socket
import struct
2016-03-29 17:43:21 -07:00
import subprocess
import tempfile
from OpenSSL import crypto
2016-03-28 22:06:04 -07:00
class Error(Exception):
pass
class InvalidHash(Error):
pass
class NoValidImage(Error):
pass
class ManifestTimeRegressed(Error):
pass
2016-03-28 22:06:04 -07:00
class Fetcher(object):
_BUF_SIZE = 2 ** 16
2016-03-28 22:06:04 -07:00
_MAX_BP = 10000
2016-04-02 13:22:24 -07:00
_FILE_REGEX = re.compile('^(?P<timestamp>\d+)\.iso$')
2016-03-28 22:06:04 -07:00
def __init__(self, base_url, ca_cert, image_dir, https_ca_cert, https_client_cert, https_client_key):
2016-03-28 22:06:04 -07:00
self._base_url = base_url
2016-03-29 17:43:21 -07:00
self._ca_cert_path = ca_cert
self._image_dir = image_dir
self._session = requests.Session()
if https_ca_cert:
self._session.verify = https_ca_cert
if https_client_cert and https_client_key:
self._session.cert = (https_client_cert, https_client_key)
2016-03-29 17:43:21 -07:00
def _VerifyChain(self, untrusted_certs, cert):
tempdir = tempfile.mkdtemp()
try:
untrusted_path = os.path.join(tempdir, 'untrusted.pem')
with open(untrusted_path, 'w') as fh:
for cert_str in untrusted_certs:
fh.write(cert_str)
cert_path = os.path.join(tempdir, 'cert.pem')
with open(cert_path, 'w') as fh:
fh.write(cert)
# Rely on pipe buffering to eat the stdout junk
subprocess.check_call([
'openssl', 'verify',
'-CAfile', self._ca_cert_path,
'-untrusted', untrusted_path,
cert_path,
], stdout=subprocess.PIPE)
finally:
shutil.rmtree(tempdir)
2016-03-29 11:27:10 -07:00
def _Unwrap(self, wrapped):
2016-03-29 17:43:21 -07:00
self._VerifyChain(wrapped.get('other_certs', []), wrapped['cert'])
2016-03-29 11:27:10 -07:00
cert = crypto.load_certificate(crypto.FILETYPE_PEM, wrapped['cert'])
2016-03-29 17:43:21 -07:00
sig = codecs.decode(wrapped['sig'], 'hex')
2016-03-29 11:27:10 -07:00
crypto.verify(
cert,
2016-03-29 17:43:21 -07:00
sig,
wrapped['inner'].encode('utf8'),
2016-03-29 11:27:10 -07:00
'sha256')
2016-03-28 22:06:04 -07:00
2016-03-29 17:43:21 -07:00
return json.loads(wrapped['inner'])
2016-03-28 22:06:04 -07:00
def _GetManifest(self):
2016-03-31 17:28:25 -07:00
url = '%s/manifest.json' % (self._base_url)
resp = self._session.get(url)
unwrapped = self._Unwrap(resp.json())
self._ValidateManifest(unwrapped)
return unwrapped
def _ValidateManifest(self, new_manifest):
path = os.path.join(self._image_dir, 'manifest.json')
try:
with open(path, 'r') as fh:
old_manifest = json.load(fh)
# This checks for replay of an old manifest. Injecting an older manifest
# could allow an attacker to cause us to revert to an older image with
# security issues. Manifest timestamps are therefor required to always
# increase.
if old_manifest['timestamp'] > new_manifest['timestamp']:
raise ManifestTimeRegressed
except FileNotFoundError:
pass
with open(path, 'w') as fh:
json.dump(new_manifest, fh, indent=4)
2016-03-28 22:06:04 -07:00
def _FindImage(self, manifest, timestamp):
for image in manifest['images']:
if image['timestamp'] == timestamp:
return image
raise NoValidImage
2016-03-28 22:06:04 -07:00
def _ChooseImage(self, manifest):
hostname = socket.gethostname()
hash_base = hashlib.sha256(hostname.encode('ascii'))
2016-03-29 20:11:37 -07:00
for image in manifest['images']:
2016-03-28 22:06:04 -07:00
hashobj = hash_base.copy()
hashobj.update(struct.pack('!L', image['timestamp']))
my_bp = struct.unpack('!I', hashobj.digest()[-4:])[0] % self._MAX_BP
if my_bp < image['rollout_‱']:
return image
raise NoValidImage
def _FetchImage(self, image):
2016-03-31 17:28:25 -07:00
filename = '%d.iso' % (image['timestamp'])
path = os.path.join(self._image_dir, filename)
if os.path.exists(path):
return
url = '%s/%s' % (self._base_url, filename)
2016-05-10 00:23:04 +00:00
print('Fetching:', url, flush=True)
resp = self._session.get(url, stream=True)
hash_obj = hashlib.sha256()
2016-05-10 20:55:50 +00:00
with tempfile.NamedTemporaryFile(dir=self._image_dir, delete=False) as fh:
try:
for data in resp.iter_content(self._BUF_SIZE):
hash_obj.update(data)
fh.write(data)
if hash_obj.hexdigest() != image['hash']:
raise InvalidHash
os.rename(fh.name, path)
except:
os.unlink(fh.name)
raise
def _SetCurrent(self, image):
2016-03-31 17:28:25 -07:00
filename = '%d.iso' % (image['timestamp'])
path = os.path.join(self._image_dir, filename)
current_path = os.path.join(self._image_dir, 'current')
try:
link = os.readlink(current_path)
link_path = os.path.join(self._image_dir, link)
if link_path == path:
return
except FileNotFoundError:
pass
2016-05-10 00:23:04 +00:00
print('Changing current link to:', filename, flush=True)
temp_path = tempfile.mktemp(dir=self._image_dir)
2016-03-31 14:46:29 -07:00
os.symlink(filename, temp_path)
os.rename(temp_path, current_path)
2016-03-28 22:06:04 -07:00
def Fetch(self, force_timestamp=None):
2016-03-28 22:06:04 -07:00
manifest = self._GetManifest()
if force_timestamp:
image = self._FindImage(manifest, timestamp)
else:
image = self._ChooseImage(manifest)
self._FetchImage(image)
self._SetCurrent(image)
2016-03-28 22:06:04 -07:00
def DeleteOldImages(self, max_images=5, skip=None):
2016-04-02 13:22:24 -07:00
if not max_images:
return
skip = skip or set()
2016-04-02 13:22:24 -07:00
images = []
for filename in os.listdir(self._image_dir):
match = self._FILE_REGEX.match(filename)
if not match:
continue
images.append((int(match.group('timestamp')), filename))
images.sort(reverse=True)
for timestamp, filename in images[max_images:]:
if filename in skip:
continue
2016-05-10 00:23:04 +00:00
print('Deleting old image:', filename, flush=True)
2016-04-02 13:22:24 -07:00
path = os.path.join(self._image_dir, filename)
os.unlink(path)