diff --git a/fetcher.py b/fetcher.py index 522e76a..7779c58 100755 --- a/fetcher.py +++ b/fetcher.py @@ -25,6 +25,11 @@ parser.add_argument( dest='ca_cert', action='store', required=True) +parser.add_argument( + '--image-dir', + dest='image_dir', + action='store', + required=True) parser.add_argument( '--image-type', dest='image_type', @@ -33,14 +38,28 @@ parser.add_argument( FLAGS = parser.parse_args() +class Error(Exception): + pass + + +class InvalidHash(Error): + pass + + +class NoValidImage(Error): + pass + + class Fetcher(object): + _BUF_SIZE = 2 ** 16 _MAX_BP = 10000 - def __init__(self, base_url, image_type, ca_cert): + def __init__(self, base_url, image_type, ca_cert, image_dir): self._base_url = base_url self._image_type = image_type self._ca_cert_path = ca_cert + self._image_dir = image_dir def _VerifyChain(self, untrusted_certs, cert): tempdir = tempfile.mkdtemp() @@ -92,14 +111,62 @@ class Fetcher(object): my_bp = struct.unpack('!I', hashobj.digest()[-4:])[0] % self._MAX_BP if my_bp < image['rollout_‱']: return image + raise NoValidImage + + def _FetchImage(self, image): + filename = '%s.%d.iso' % (self._image_type, image['timestamp']) + path = os.path.join(self._image_dir, filename) + + if os.path.exists(path): + return + + url = '%s/%s' % (self._base_url, filename) + print('Fetching:', url) + resp = urllib.request.urlopen(url) + + hash_obj = hashlib.sha256() + try: + fh = tempfile.NamedTemporaryFile(dir=self._image_dir, delete=False) + while True: + data = resp.read(self._BUF_SIZE) + if not data: + break + hash_obj.update(data) + fh.write(data) + if hash_obj.hexdigest() != image['hash']: + raise InvalidHash + os.rename(fh.name, path) + except: + os.unlink(fh.name) + raise + + def _SetCurrent(self, image): + filename = '%s.%d.iso' % (self._image_type, image['timestamp']) + path = os.path.join(self._image_dir, filename) + current_path = os.path.join(self._image_dir, 'current') + + try: + link = os.readlink(current_path) + link_path = os.path.join(self._image_dir, link) + if link_path == path: + return + except FileNotFoundError: + pass + + print('Changing current link to:', path) + temp_path = tempfile.mktemp(dir=self._image_dir) + os.symlink(path, temp_path) + os.rename(temp_path, current_path) def Fetch(self): manifest = self._GetManifest() image = self._ChooseImage(manifest) + self._FetchImage(image) + self._SetCurrent(image) def main(): - fetcher = Fetcher(FLAGS.base_url, FLAGS.image_type, FLAGS.ca_cert) + fetcher = Fetcher(FLAGS.base_url, FLAGS.image_type, FLAGS.ca_cert, FLAGS.image_dir) fetcher.Fetch()