#!/usr/bin/env python

#
# Copyright 2017 Mellanox Technologies. All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright notice,
#   this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of Mellanox nor the names of its contributors
#   may be used to endorse or promote products derived from this software
#   without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
# THE POSSIBILITY OF SUCH DAMAGE.
#

"""
Create or dump a BlueField boot stream.
"""

import binascii
import os
import StringIO
import struct
import sys

from optparse import OptionParser, OptionGroup

# Table of all possible image types.  Order is significant, since we
# will emit images in this order unless overridden by the --expert
# switch.  ATF image IDs match those used internally by ATF; see
# atf/include/common/tbbr/tbbr_img_def.h.  UEFI image IDs are defined by
# us; see edk2/MlxPlatformPkg/Filesystem/BfbFs/BfbFs.c.
#
# Contents: (command-line option, description, image ID for header)
image_list = (
    # Images for ATF
    ("bl2-cert", "Trusted Boot Firmware BL2 certificate", 6),
    ("bl2", "Trusted Boot Firmware BL2", 1),
    ("sys", "Running system's part number or PSID", 29),
    ("bl30-key-cert", "SCP Firmware BL3-0 key certificate", 8),
    ("bl30-cert", "SCP Firmware BL3-0 certificate", 12),
    ("bl30", "SCP Firmware BL3-0", 2),
    ("trusted-key-cert", "Trusted key certificate", 7),
    ("bl31-key-cert", "EL3 Runtime Firmware BL3-1 key certificate", 9),
    ("bl31-cert", "EL3 Runtime Firmware BL3-1 certificate", 13),
    ("bl31", "EL3 Runtime Firmware BL3-1", 3),
    ("bl32-key-cert", "Secure Payload BL3-2 (Trusted OS) key certificate", 10),
    ("bl32-cert", "Secure Payload BL3-2 (Trusted OS) certificate", 14),
    ("bl32", "Secure Payload BL3-2 (Trusted OS)", 4),
    ("bl33-key-cert", "Non-Trusted Firmware BL3-3 key certificate", 11),
    ("bl33-cert", "Non-Trusted Firmware BL3-3 certificate", 15),
    ("bl33", "Non-Trusted Firmware BL3-3", 5),
    # Images for UEFI
    ("boot-acpi", "Name of the ACPI table", 55),
    ("boot-dtb", "Name of the dtb file", 56),
    ("boot-desc", "Default boot menu item description", 57),
    ("boot-path", "Boot image path", 58),
    ("boot-args", "Arguments for boot image", 59),
    ("boot-timeout", "Boot menu timeout", 60),
    ("uefi-tests", "Specify what UEFI tests to run", 61),
    ("image", "Boot image", 62),
    ("initramfs", "In-memory filesystem", 63),
)

# Map command-line option to (description, image ID)
image_table = dict([(x, (y, z)) for (x, y, z) in image_list])

# Map image ID to (command-line option, description)
rev_image_table = dict([(z, (x, y)) for (x, y, z) in image_list])

# Did we see the -v option?
verbose = False


#
# We pad images out to an even multiple of 8 bytes.
#
def num_padding_bytes(length):
    """Return the number of padding bytes needed for the given length."""

    return ((length + 7) & ~7) - length


class NoHdr(Exception):
    """
    Exception raised when we try to create an ImageHdr from a stream at
    end-of-file.
    """
    pass


class BadHdr(Exception):
    """
    Exception raised when we try to create an ImageHdr from a stream
    which contains a bad image header.
    """

    def __init__(self, msg):
        self.msg = msg

    def __str__(self):
        return self.msg


#
# Currently we produce/consume a little-endian boot stream.  If we
# ever need to change that, you'll need to modify the code below that
# packs/unpacks the header, and possibly the code that handles padding
# bytes at the end of images.
#
class ImageHdr:
    """Represent a BlueField boot stream image header."""

    magic = 0x13026642  # Magic number, "Bf^B^S"
    default_major = 1   # Major version number
    default_minor = 1   # Minor version number
    length = 24         # Header length; must be multiple of 8

    def __init__(self, infile=None, major=default_major, minor=default_minor,
                 image_id=0, image_len=0, image_crc=0, following_images=0):

        if infile:
            self.__set_from_file(infile)
        else:
            self.major = major
            self.minor = minor
            self.image_id = image_id
            self.image_len = image_len
            self.image_crc = image_crc
            self.following_images = following_images

    def get_header_length(self):
        """Return length of header in bytes."""
        #
        # Right now this is constant; if later versions increase the header
        # size, or include optional fields, this might vary based on the
        # specific contents of this header.
        #
        return length

    def get_image_id(self):
        """Return the image's ID."""
        return self.image_id

    def get_image_length(self):
        """
        Return the image's length in bytes.  Note that the image following
        this header in the boot stream will be padded with the minimum
        necessary number of trailing zeroes to make it an even multiple
        of 8 bytes long; this padding is not included in the length value.
        """
        return self.image_len

    def get_image_crc(self):
        """
        Return the image's CRC.  Note that this covers both the bytes of
        the image, and any added padding bytes.
        """
        return self.image_crc

    def get_following_images(self):
        """Return bitmap of IDs of subsequent images."""
        return self.following_images

    def set_following_images(self, follow_map):
        """Set bitmap of IDs of subsequent images."""
        self.following_images = follow_map

    def get_bits(self):
        """Return a string containing the header."""

        w0 = (((self.magic & 0xFFFFFFFF) << 0) |
              ((self.major & 0xF) << 32) |
              ((self.minor & 0xF) << 36) |
              (((self.length / 8) & 0xF) << 52) |
              ((self.image_id & 0xFF) << 56))

        w1 = (((self.image_len & 0xFFFFFFFF) << 0) |
              ((self.image_crc & 0xFFFFFFFF) << 32))

        w2 = self.following_images

        return struct.pack("<QQQ", w0, w1, w2)

    def __set_from_file(self, infile):
        """
        Parse the header at the current position of the provided file and
        set our state appropriately, consuming only the header bytes;
        if the current file position does not contain a valid header,
        raise an exception.  Note that if an exception is raised, the
        amount of data consumed from the file is indeterminate.
        """

        #
        # This will need to get more complicated if we ever have
        # variable-length headers, but for now, we can just read a fixed
        # amount.
        #
        str = infile.read(self.length)
        if not str:
            raise NoHdr()

        if len(str) < self.length:
            raise BadHdr("ImageHdr too short")

        (w0, w1, w2) = struct.unpack("<QQQ", str)

        magic = w0 & 0xFFFFFFFF
        self.major = (w0 >> 32) & 0xF
        self.minor = (w0 >> 36) & 0xF
        length = (w0 >> 52) & 0xF
        self.image_id = (w0 >> 56) & 0xFF
        self.image_len = w1 & 0xFFFFFFFF
        self.image_crc = (w1 >> 32) & 0xFFFFFFFF
        self.following_images = w2

        if magic != self.magic:
            raise BadHdr("Bad ImageHdr magic number 0x%x" % magic)

        #
        # Obviously the stuff below will change once we support more than
        # one version.
        #
        if self.major != self.default_major:
            raise BadHdr("Bad ImageHdr major version %d" % self.major)

        if self.minor != self.default_minor:
            raise BadHdr("Bad ImageHdr minor version %d" % self.major)

        if length != self.length / 8:
            raise BadHdr("Bad ImageHdr header length %d" % length)

    def dump(self, outfile=sys.stdout):
        """Dump out all internal state."""
        print >> outfile, "Ver: %d.%d Len: %d ID: %d " \
                          "ImLen: %d ImCRC: 0x%x FolIm: 0x%lx" % (
                            self.major, self.minor, self.length,
                            self.image_id, self.image_len,
                            self.image_crc, self.following_images)


class Image:
    """Represent a BlueField boot stream image."""

    def __init__(self, instream=None, infile=None, idstr=None):
        """
        Create an Image object.

        instream: File object from which to read the next header and image
          data.  There may be additional headers, images, or other random
          data following that, which will not be consumed.
        infile: Name of file from which to read image data.  Unlike when
          instream is specified, the file does not contain a header, and
          the entire file is consumed; a header will be constructed which
          describes the data we read.  Mutually exclusive with instream.
        idstr: String ID (first column of image_list) for image.  Required,
          and only useful, with infile.
        """

        if instream:
            self.__set_from_stream(instream)
        elif infile:
            self.__set_from_file(infile, idstr)
        else:
            self.header = ImageHdr()
            self.bits = None

    def get_bits(self):
        """Return image data."""
        return self.bits

    def get_padding(self):
        """Return pad bytes."""
        length = self.header.get_image_length()
        return "\0" * num_padding_bytes(length)

    def get_image_id(self):
        """Return the numeric image ID for this image."""
        return self.header.get_image_id()

    def get_image_length(self):
        """Return the image's length in bytes."""
        return self.header.get_image_length()

    def set_following_images(self, follow_map):
        """Set the bitmap of images following this one."""
        self.header.set_following_images(follow_map)

    def write(self, outfile):
        """Write image contents to an output stream."""

        outfile.write(self.header.get_bits())
        outfile.write(self.bits)
        outfile.write(self.get_padding())

    def dump(self, outfile=sys.stdout):
        self.header.dump(outfile)

    def __set_from_stream(self, instream):
        """
        Initialize ourselves from the given stream, which contains a header
        and then image data; only the data described by the header will be
        consumed, so you can call this repeatedly on a stream to pick off
        all of the images it contains.

        instream: stream to read.
        """

        self.header = ImageHdr(infile=instream)

        length = self.header.get_image_length()

        self.bits = instream.read(length)
        if len(self.bits) != length:
            raise Exception("Image ID %d: stream corrupted, image too short" %
                            self.header.get_image_id())

        crc = binascii.crc32(self.bits, 0)
        padding = instream.read(num_padding_bytes(length))
        crc = binascii.crc32(padding, crc)
        crc &= 0xFFFFFFFF

        if crc != self.header.get_image_crc():
            raise Exception("Image ID %d: header CRC of 0x%08x does "
                            "not match calculated CRC of 0x%08x" %
                            (self.header.get_image_id(),
                             self.header.get_image_crc(), crc))

    def __set_from_file(self, infile, idstr):
        """
        Initialize ourselves from the given filename/string.  The entire
        file or string makes up the image, and we then construct a header
        which describes that image data.

        stream: file name, or string preceded by "=".
        idstr: string ID of image.
        """

        if infile.startswith("="):
            self.bits = infile[1:]
        else:
            self.bits = open(infile, "rb").read()

        length = len(self.bits)

        crc = binascii.crc32(self.bits, 0)
        crc = binascii.crc32("\0" * num_padding_bytes(length), crc)
        crc &= 0xFFFFFFFF

        self.header = ImageHdr(image_id=image_table[idstr][1],
                               image_len=length, image_crc=crc)


def make_stream(infn, outfn, image_tuples, expert):
    """
    Write a boot stream to an output file.

    infn: input file name.
    outfn: output file name.
    image_tuples: an iterable containing tuples, each of which contains the
      string image identifier (the first column of image_list) and the filename
      (or literal string) from the command line.
    """

    #
    # We'll construct the following_images bitmap while reading in the
    # images, and then will reduce it with every output image.
    #
    fim_map = 0

    #
    # If there's an input file, open it, and read in its contents.
    #
    infile_images = {}
    if infn:
        infile = open(infn, "rb")
        while True:
            try:
                img = Image(instream=infile)
            except NoHdr:
                break
            infile_images[rev_image_table[img.get_image_id()][0]] = img
            fim_map |= 1 << img.get_image_id()
        infile.close()

    #
    # If there are files from the command line, open them and read in
    # their contents.
    #
    cmdline_images = {}
    for (id, fn) in image_tuples:
        img = Image(infile=fn, idstr=id)
        cmdline_images[id] = img
        fim_map |= 1 << img.get_image_id()

    #
    # In expert mode, we emit images in the order we found them on the command
    # line, but otherwise, we use the order they're found in image_list.
    #
    if expert:
        image_order = [x[0] for x in image_tuples]
    else:
        image_order = [x[0] for x in image_list]

    outfile = open(outfn, "wb")

    #
    # Emit each image in order, preferring the version from the command line.
    #
    for i in image_order:
        if i in cmdline_images:
            img = cmdline_images[i]
        elif i in infile_images:
            img = infile_images[i]
        else:
            continue

        fim_map &= ~(1 << img.get_image_id())
        img.set_following_images(fim_map)
        img.write(outfile)

    outfile.close()


def dump_stream(infn, do_dump, do_extract, prefix):
    """
    Dump a description of a boot stream.

    infn: input filename.
    do_dump: if true, print table of contents.
    do_extract: if true, extract images to files.
    prefix: prefix to use with extracted image filenames.
    """

    inf = open(infn, "rb")

    while True:
        try:
            img = Image(instream=inf)
        except NoHdr:
            break

        id = img.get_image_id()
        if id in rev_image_table:
            fn = rev_image_table[id][0]
            desc = rev_image_table[id][1]
        else:
            fn = "image_id_%d" % image_id
            desc = "Unknown image type, ID %d" % image_id

        if do_dump:
            if verbose:
                img.dump()
            else:
                print "%10d %s" % (img.get_image_length(), desc)

        if do_extract:
            xfile = open(prefix + fn, "wb")
            xfile.write(img.get_bits())
            xfile.close()

    inf.close()


def parse_image_opt(option, opt_str, value, parser):
    """Handle one of the --<image-type> options."""
    parser.values.images.append((opt_str.lstrip("-"), value))


def main(argv):
    #
    # Set up our option parser.
    #
    parser = OptionParser(
            usage="\n  %prog [ options ] [ INFILE ] OUTFILE\n"
                  "  %prog -[dx] [ options ] INFILE\n"
                  "  %prog [ -h | --help ]",
            description="%prog creates or dumps a BlueField boot stream.")
    parser.add_option(
            "-d", "--dump", dest="d", action="store_true",
            help="print listing of images within INFILE")
    parser.add_option(
            "-x", "--extract", dest="x", action="store_true",
            help="extract images from INFILE and write to files")
    parser.add_option(
            "-p", "--prefix", dest="p", action="store", type="string",
            metavar="PFX", default="dump-",
            help="prefix for files containing extracted images; "
                 "default is '%default'")
    parser.add_option(
            "-v", "--verbose", dest="v", action="store_true",
            help="provide more verbose output")
    parser.add_option(
            "-e", "--expert", dest="e", action="store_true",
            help="expert mode: emit images in order specified on command line")

    parser.set_defaults(images=[])

    iog = OptionGroup(parser, "Image Options",
                      'Images specified on the command line replace images '
                      'of the same type within the input file.  Note: an '
                      'image specification may either be the name of a file '
                      'which contains the image data (e.g., --image=vmlinux), '
                      'or a literal string prefixed with an equals sign '
                      '(e.g., --boot-args="=debug isolcpus=10").')
    for fn in image_table:
        iog.add_option(
                "--" + fn, type="string",
                action="callback", callback=parse_image_opt, metavar="IMAGE",
                help="take data for " + image_table[fn][0] + " from IMAGE")
    parser.add_option_group(iog)

    #
    # Parse the command line and execute.
    #
    (opt, args) = parser.parse_args()

    if opt.v:
        global verbose
        verbose = True

    if (opt.d or opt.x) and len(args) > 1:
        parser.error("no output file allowed with -d/-x")

    if len(args) > 2:
        parser.error("only one input and one output file allowed")

    if opt.d or opt.x:
        if opt.e:
            parser.error("-e not applicable with -d/-x")
        if opt.images:
            parser.error("cannot specify images with -d/-x")
        dump_stream(args[0], opt.d, opt.x, opt.p)
    else:
        if len(args) <= 0:
            parser.error("must specify a file")
        elif len(args) == 1:
            infn = None
            outfn = args[0]
            if not opt.images:
                parser.error("must specify input .bfb or at least one image "
                             "option with output file")
        else:
            infn = args[0]
            outfn = args[1]
            if opt.e:
                parser.error("input file not permitted with -e")

        make_stream(infn, outfn, opt.images, opt.e)


if __name__ == "__main__":
    main(sys.argv)
