diff --git a/cmds/fittool/commands/init/command.go b/cmds/fittool/commands/init/command.go index f883c32a..8e20ee3f 100644 --- a/cmds/fittool/commands/init/command.go +++ b/cmds/fittool/commands/init/command.go @@ -11,14 +11,14 @@ import ( "github.com/linuxboot/fiano/cmds/fittool/commands" "github.com/linuxboot/fiano/pkg/intel/metadata/fit" - "github.com/linuxboot/fiano/pkg/intel/metadata/fit/consts" ) var _ commands.Command = (*Command)(nil) type Command struct { - UEFIPath string `short:"f" long:"uefi" description:"path to UEFI image" required:"true"` - Pointer uint64 `short:"p" long:"pointer" description:"the FIT pointer value" required:"true"` + UEFIPath string `short:"f" long:"uefi" description:"path to UEFI image" required:"true"` + Pointer *uint64 `short:"p" long:"pointer" description:"the FIT pointer value"` + PointerFromOffset *uint64 `long:"pointer-from-offset" description:"the FIT pointer value defined by an offset from the beginning of the image"` } // ShortDescription explains what this command does in one line @@ -40,6 +40,13 @@ func (cmd *Command) Execute(args []string) error { return commands.ErrArgs{Err: fmt.Errorf("there are extra arguments")} } + if cmd.PointerFromOffset == nil && cmd.Pointer == nil { + return commands.ErrArgs{Err: fmt.Errorf("either '--pointer' or '--pointer-from-offset' is required")} + } + if cmd.PointerFromOffset != nil && cmd.Pointer != nil { + return commands.ErrArgs{Err: fmt.Errorf("it does not make sense to use '--pointer' and '--pointer-from-offset' together")} + } + file, err := os.OpenFile(cmd.UEFIPath, os.O_RDWR, 0) if err != nil { return fmt.Errorf("unable to open the firmware image file '%s': %w", cmd.UEFIPath, err) @@ -50,6 +57,14 @@ func (cmd *Command) Execute(args []string) error { return fmt.Errorf("unable to detect file size (through seek): %w", err) } + var fitOffset uint64 + if cmd.Pointer != nil { + fitOffset = fit.Address64(*cmd.Pointer).Offset(uint64(fileSize)) + } + if cmd.PointerFromOffset != nil { + fitOffset = *cmd.PointerFromOffset + } + entries := fit.Entries{ &fit.EntryFITHeaderEntry{}, } @@ -57,7 +72,7 @@ func (cmd *Command) Execute(args []string) error { return fmt.Errorf("unable to recalculate headers: %w", err) } - if err := entries.InjectTo(file, fit.CalculateOffsetFromPhysAddr(consts.BasePhysAddr-cmd.Pointer, uint64(fileSize))); err != nil { + if err := entries.InjectTo(file, fitOffset); err != nil { return fmt.Errorf("unable to inject entries to the firmware: %w", err) } diff --git a/cmds/utk/utk.go b/cmds/utk/utk.go index f4b2ee1e..9750d017 100644 --- a/cmds/utk/utk.go +++ b/cmds/utk/utk.go @@ -56,23 +56,53 @@ package main import ( "flag" "fmt" + "strconv" "github.com/linuxboot/fiano/pkg/log" + "github.com/linuxboot/fiano/pkg/uefi" "github.com/linuxboot/fiano/pkg/utk" "github.com/linuxboot/fiano/pkg/visitors" ) -func init() { +type config struct { + ErasePolarity *byte +} + +func parseArguments() (config, []string, error) { flag.Usage = func() { fmt.Fprintf(flag.CommandLine.Output(), "Usage: utk [flags] [0 or more operations]\n") flag.PrintDefaults() fmt.Fprintf(flag.CommandLine.Output(), "\nOperations:\n%s", visitors.ListCLI()) } + erasePolarityFlag := flag.String("erase-polarity", "", "set erase polarity; possible values: '', '0x00', '0xFF'") + flag.Parse() + + var cfg config + + if *erasePolarityFlag != "" { + erasePolarity, err := strconv.ParseUint(*erasePolarityFlag, 0, 8) + if err != nil { + return config{}, nil, fmt.Errorf("unable to parse erase polarity '%s': %w", *erasePolarityFlag, err) + } + cfg.ErasePolarity = &[]uint8{uint8(erasePolarity)}[0] + } + + return cfg, flag.Args(), nil } func main() { - flag.Parse() - if err := utk.Run(flag.Args()...); err != nil { + cfg, args, err := parseArguments() + if err != nil { + panic(err) + } + + if cfg.ErasePolarity != nil { + if err := uefi.SetErasePolarity(*cfg.ErasePolarity); err != nil { + panic(fmt.Errorf("unable to set erase polarity 0x%X: %w", *cfg.ErasePolarity, err)) + } + } + + if err := utk.Run(args...); err != nil { log.Fatalf("%v", err) } } diff --git a/pkg/intel/metadata/fit/ent_startup_ac_module_entry.go b/pkg/intel/metadata/fit/ent_startup_ac_module_entry.go index 219e6e30..58238bf7 100644 --- a/pkg/intel/metadata/fit/ent_startup_ac_module_entry.go +++ b/pkg/intel/metadata/fit/ent_startup_ac_module_entry.go @@ -5,15 +5,16 @@ package fit import ( - "bytes" "crypto/rsa" "encoding/binary" "encoding/json" "fmt" "io" "math/big" + "reflect" "github.com/linuxboot/fiano/pkg/intel/metadata/fit/check" + "github.com/xaionaro-go/bytesextra" ) // EntrySACM represents a FIT entry of type "Startup AC Module Entry" (0x02) @@ -52,6 +53,9 @@ func (entry *EntrySACM) CustomRecalculateHeaders() error { // EntrySACMDataInterface is the interface of a startup AC module // data (of any version) type EntrySACMDataInterface interface { + io.ReadWriter + io.ReaderFrom + io.WriterTo // Field getters: @@ -124,8 +128,9 @@ type BCDDate uint32 type SizeM4 uint32 // Size return the size in bytes -func (size SizeM4) Size() uint64 { return uint64(size) << 2 } -func (size SizeM4) String() string { return fmt.Sprintf("%d*4", uint32(size)) } +func (size SizeM4) Size() uint64 { return uint64(size) << 2 } +func (size SizeM4) String() string { return fmt.Sprintf("%d*4", uint32(size)) } +func (size *SizeM4) SetSize(v uint64) { *size = SizeM4(v >> 2) } // TXTSVN is the TXT Security Version Number type TXTSVN uint16 @@ -184,6 +189,38 @@ type EntrySACMDataCommon struct { ScratchSize SizeM4 } +var entrySACMDataCommonSize = uint(binary.Size(EntrySACMDataCommon{})) + +// Read parses the ACM common headers +func (entryData *EntrySACMDataCommon) Read(b []byte) (int, error) { + n, err := entryData.ReadFrom(bytesextra.NewReadWriteSeeker(b)) + return int(n), err +} + +// ReadFrom parses the ACM common headers +func (entryData *EntrySACMDataCommon) ReadFrom(r io.Reader) (int64, error) { + err := binary.Read(r, binary.LittleEndian, entryData) + if err != nil { + return -1, err + } + return int64(entrySACMDataCommonSize), nil +} + +// Write compiles the SACM common headers into a binary representation +func (entryData *EntrySACMDataCommon) Write(b []byte) (int, error) { + n, err := entryData.WriteTo(bytesextra.NewReadWriteSeeker(b)) + return int(n), err +} + +// WriteTo compiles the SACM common headers into a binary representation +func (entryData *EntrySACMDataCommon) WriteTo(w io.Writer) (int64, error) { + err := binary.Write(w, binary.LittleEndian, entryData) + if err != nil { + return -1, err + } + return int64(entrySACMDataCommonSize), nil +} + // GetModuleType returns the type of AC module func (entryData *EntrySACMDataCommon) GetModuleType() ACModuleType { return entryData.ModuleType } @@ -314,6 +351,36 @@ type EntrySACMData0 struct { var entrySACMData0Size = uint(binary.Size(EntrySACMData0{})) +// Read parses the ACM v0 headers +func (entryData *EntrySACMData0) Read(b []byte) (int, error) { + n, err := entryData.ReadFrom(bytesextra.NewReadWriteSeeker(b)) + return int(n), err +} + +// ReadFrom parses the ACM v0 headers +func (entryData *EntrySACMData0) ReadFrom(r io.Reader) (int64, error) { + err := binary.Read(r, binary.LittleEndian, entryData) + if err != nil { + return -1, err + } + return int64(entrySACMData0Size), nil +} + +// Write compiles the SACM v0 headers into a binary representation +func (entryData *EntrySACMData0) Write(b []byte) (int, error) { + n, err := entryData.WriteTo(bytesextra.NewReadWriteSeeker(b)) + return int(n), err +} + +// WriteTo compiles the SACM v0 headers into a binary representation +func (entryData *EntrySACMData0) WriteTo(w io.Writer) (int64, error) { + err := binary.Write(w, binary.LittleEndian, entryData) + if err != nil { + return -1, err + } + return int64(entrySACMData0Size), nil +} + // GetRSAPubKey returns the RSA public key func (entryData *EntrySACMData0) GetRSAPubKey() rsa.PublicKey { pubKey := rsa.PublicKey{ @@ -353,6 +420,36 @@ type EntrySACMData3 struct { var entrySACMData3Size = uint(binary.Size(EntrySACMData3{})) +// Read parses the ACM v3 headers +func (entryData *EntrySACMData3) Read(b []byte) (int, error) { + n, err := entryData.ReadFrom(bytesextra.NewReadWriteSeeker(b)) + return int(n), err +} + +// ReadFrom parses the ACM v3 headers +func (entryData *EntrySACMData3) ReadFrom(r io.Reader) (int64, error) { + err := binary.Read(r, binary.LittleEndian, entryData) + if err != nil { + return -1, err + } + return int64(entrySACMData3Size), nil +} + +// Write compiles the SACM v3 headers into a binary representation +func (entryData *EntrySACMData3) Write(b []byte) (int, error) { + n, err := entryData.WriteTo(bytesextra.NewReadWriteSeeker(b)) + return int(n), err +} + +// WriteTo compiles the SACM v3 headers into a binary representation +func (entryData *EntrySACMData3) WriteTo(w io.Writer) (int64, error) { + err := binary.Write(w, binary.LittleEndian, entryData) + if err != nil { + return -1, err + } + return int64(entrySACMData3Size), nil +} + // GetRSAPubKey returns the RSA public key func (entryData *EntrySACMData3) GetRSAPubKey() rsa.PublicKey { pubKey := rsa.PublicKey{ @@ -382,6 +479,47 @@ type EntrySACMData struct { UserArea []byte } +// Read parses the ACM +func (entryData *EntrySACMData) Read(b []byte) (int, error) { + n, err := entryData.ReadFrom(bytesextra.NewReadWriteSeeker(b)) + return int(n), err +} + +// ReadFrom parses the ACM +func (entryData *EntrySACMData) ReadFrom(r io.Reader) (int64, error) { + parsedEntryData, err := ParseSACMData(r) + if err != nil { + return -1, err + } + *entryData = *parsedEntryData + return int64(binary.Size(entryData.EntrySACMDataInterface) + len(entryData.UserArea)), nil +} + +// Write compiles the SACM into a binary representation +func (entryData *EntrySACMData) Write(b []byte) (int, error) { + n, err := entryData.WriteTo(bytesextra.NewReadWriteSeeker(b)) + return int(n), err +} + +// WriteTo compiles the SACM into a binary representation +func (entryData *EntrySACMData) WriteTo(w io.Writer) (int64, error) { + totalN, err := entryData.EntrySACMDataInterface.WriteTo(w) + if err != nil { + return -1, err + } + n, err := w.Write(entryData.UserArea) + if n >= 0 { + totalN += int64(n) + } + if err != nil { + return totalN, fmt.Errorf("unable to write UserArea: %w", err) + } + if n != len(entryData.UserArea) { + return totalN, fmt.Errorf("unable to complete writing UserArea: %d != %d: %w", n, len(entryData.UserArea), err) + } + return totalN, nil +} + // GetCommon returns the common part of the structures for different ACM versions. func (entryData *EntrySACMData) GetCommon() *EntrySACMDataCommon { if entryData == nil { @@ -425,8 +563,21 @@ func EntrySACMParseSize(b []byte) (uint32, error) { // ParseData parses SACM entry and returns EntrySACMData. func (entry *EntrySACM) ParseData() (*EntrySACMData, error) { + entryData := EntrySACMData{} + _, err := entryData.Read(entry.DataSegmentBytes) + if err != nil { + return nil, err + } + return &entryData, nil +} + +// ParseSACMData parses SACM entry and returns EntrySACMData. +func ParseSACMData(r io.Reader) (*EntrySACMData, error) { + + // Read common headers + common := EntrySACMDataCommon{} - if err := binary.Read(bytes.NewReader(entry.DataSegmentBytes), binary.LittleEndian, &common); err != nil { + if _, err := common.ReadFrom(r); err != nil { return nil, fmt.Errorf("unable to parse startup AC module entry: %w", err) } result := &EntrySACMData{EntrySACMDataInterface: &common, UserArea: nil} @@ -434,10 +585,10 @@ func (entry *EntrySACM) ParseData() (*EntrySACMData, error) { var requiredKeySize uint64 switch common.HeaderVersion { case ACHeaderVersion0: - result.EntrySACMDataInterface = &EntrySACMData0{} + result.EntrySACMDataInterface = &EntrySACMData0{EntrySACMDataCommon: common} requiredKeySize = uint64(len(EntrySACMData0{}.RSAPubKey)) case ACHeaderVersion3: - result.EntrySACMDataInterface = &EntrySACMData3{} + result.EntrySACMDataInterface = &EntrySACMData3{EntrySACMDataCommon: common} requiredKeySize = uint64(len(EntrySACMData3{}.RSAPubKey)) default: return result, &ErrUnknownACMHeaderVersion{ACHeaderVersion: common.HeaderVersion} @@ -447,17 +598,45 @@ func (entry *EntrySACM) ParseData() (*EntrySACMData, error) { return result, &ErrACMInvalidKeySize{ExpectedKeySize: requiredKeySize, RealKeySize: common.KeySize.Size()} } - if err := binary.Read(bytes.NewReader(entry.DataSegmentBytes), binary.LittleEndian, result.EntrySACMDataInterface); err != nil { - return result, fmt.Errorf("cannot parse AC header of version %v: %w", common.HeaderVersion, err) + // Read version-specific headers + // + // Here we need to continue reading from the reader, + // but in the resulting struct we need to skip the first field (because it contains + // already read common headers). + + // Creating a substruct without the first field (which is already read) + t := reflect.TypeOf(result.EntrySACMDataInterface).Elem() + var fieldsToBeFilled []reflect.StructField + for fieldNum := 1; fieldNum < t.NumField(); fieldNum++ { + fieldsToBeFilled = append(fieldsToBeFilled, t.Field(fieldNum)) + } + subStructToBeFilled := reflect.New(reflect.StructOf(fieldsToBeFilled)) + // Reading the substruct + if err := binary.Read(r, binary.LittleEndian, subStructToBeFilled.Interface()); err != nil { + return result, fmt.Errorf("cannot parse version-specific headers (version 0x%04X): %w", common.HeaderVersion, err) } + // Copying values from the substruct to the headers struct + subStructToBeFilled = subStructToBeFilled.Elem() + v := reflect.ValueOf(result.EntrySACMDataInterface).Elem() + for fieldNum := 1; fieldNum < v.NumField(); fieldNum++ { + v.Field(fieldNum).Set(subStructToBeFilled.Field(fieldNum - 1)) + } + + // Read UserArea // `UserArea` has variable length and therefore was not included into // `EntrySACMData0` and `EntrySACMData3`, but it is in the tail, // so we just calculate the startIndex as the end of // EntrySACMData0/EntrySACMData3. - userAreaStartIdx := binary.Size(result.EntrySACMDataInterface) + userAreaStartIdx := uint64(binary.Size(result.EntrySACMDataInterface)) userAreaEndIdx := result.EntrySACMDataInterface.GetSize().Size() - result.UserArea = entry.DataSegmentBytes[userAreaStartIdx:userAreaEndIdx] + if userAreaEndIdx > userAreaStartIdx { + var err error + result.UserArea, err = readBytesFromReader(r, userAreaEndIdx-userAreaStartIdx) + if err != nil { + return result, fmt.Errorf("unable to read user area: %w", err) + } + } return result, nil } diff --git a/pkg/intel/metadata/fit/ent_startup_ac_module_entry_test.go b/pkg/intel/metadata/fit/ent_startup_ac_module_entry_test.go index 1e40b16f..938a118d 100644 --- a/pkg/intel/metadata/fit/ent_startup_ac_module_entry_test.go +++ b/pkg/intel/metadata/fit/ent_startup_ac_module_entry_test.go @@ -5,6 +5,7 @@ package fit import ( + "bytes" "encoding/binary" "math/rand" "testing" @@ -36,6 +37,25 @@ func TestEntrySACM_ParseData(t *testing.T) { }, } + testPositive := func(t *testing.T, headersSize int) { + data, err := entry.ParseData() + require.NoError(t, err) + + _ = data.GetRSAPubKey() + require.Zero(t, len(data.UserArea)) + require.Zero(t, len(entry.DataSegmentBytes)-headersSize) + require.NotZero(t, data.GetKeySize()) + + var buf bytes.Buffer + _, err = data.WriteTo(&buf) + require.NoError(t, err) + + dataCopy, err := ParseSACMData(&buf) + require.NoError(t, err) + + require.Equal(t, data, dataCopy) + } + t.Run("SACMv0", func(t *testing.T) { binary.LittleEndian.PutUint32(entry.DataSegmentBytes[versionOffset:versionEndOffset], uint32(ACHeaderVersion0)) binary.LittleEndian.PutUint32(entry.DataSegmentBytes[sizeOffset:sizeEndOffset], uint32(entrySACMData0Size)>>2) @@ -44,13 +64,7 @@ func TestEntrySACM_ParseData(t *testing.T) { entry.DataSegmentBytes = entry.DataSegmentBytes[:dataSize] t.Run("positive", func(t *testing.T) { binary.LittleEndian.PutUint32(entry.DataSegmentBytes[keySizeOffset:keySizeEndOffset], 256>>2) - - data, err := entry.ParseData() - require.NoError(t, err) - - _ = data.GetRSAPubKey() - require.Zero(t, len(data.UserArea)) - require.Zero(t, len(entry.DataSegmentBytes)-int(entrySACMData0Size)) + testPositive(t, int(entrySACMData0Size)) }) }) @@ -62,13 +76,7 @@ func TestEntrySACM_ParseData(t *testing.T) { entry.DataSegmentBytes = entry.DataSegmentBytes[:dataSize] t.Run("positive", func(t *testing.T) { binary.LittleEndian.PutUint32(entry.DataSegmentBytes[keySizeOffset:keySizeEndOffset], 384>>2) - - data, err := entry.ParseData() - require.NoError(t, err) - - _ = data.GetRSAPubKey() - require.Zero(t, len(data.UserArea)) - require.Zero(t, len(entry.DataSegmentBytes)-int(entrySACMData3Size)) + testPositive(t, int(entrySACMData3Size)) }) t.Run("negative_keySize", func(t *testing.T) { binary.LittleEndian.PutUint32(entry.DataSegmentBytes[keySizeOffset:keySizeEndOffset], 256>>2) diff --git a/pkg/intel/metadata/fit/entry.go b/pkg/intel/metadata/fit/entry.go index a1af3067..793f289a 100644 --- a/pkg/intel/metadata/fit/entry.go +++ b/pkg/intel/metadata/fit/entry.go @@ -211,7 +211,10 @@ func copyBytesFrom(r io.ReadSeeker, startIdx, endIdx uint64) ([]byte, error) { return nil, fmt.Errorf("endIdx < startIdx: %d < %d", endIdx, startIdx) } - size := endIdx - startIdx + return readBytesFromReader(r, endIdx-startIdx) +} + +func readBytesFromReader(r io.Reader, size uint64) ([]byte, error) { result := make([]byte, size) written, err := io.CopyN(bytesextra.NewReadWriteSeeker(result), r, int64(size)) if err != nil { diff --git a/pkg/visitors/insert.go b/pkg/visitors/insert.go index 3b48b9d1..ad6719b8 100644 --- a/pkg/visitors/insert.go +++ b/pkg/visitors/insert.go @@ -8,6 +8,8 @@ import ( "errors" "fmt" "io/ioutil" + "strconv" + "strings" "github.com/linuxboot/fiano/pkg/uefi" ) @@ -17,39 +19,61 @@ type InsertType int // Insert Types const ( + + // == Deprectated == + // These first two specify a firmware volume. - // InsertFront inserts a file at the beginning of the firmware volume, + // InsertTypeFront inserts a file at the beginning of the firmware volume, // which is specified by 1) FVname GUID, or (File GUID/File name) of a file // inside that FV. - InsertFront InsertType = iota - // InsertEnd inserts a file at the end of the specified firmware volume. - InsertEnd + InsertTypeFront InsertType = iota + // InsertTypeEnd inserts a file at the end of the specified firmware volume. + InsertTypeEnd // These two specify a File to insert before or after - // InsertAfter inserts after the specified file, + // InsertTypeAfter inserts after the specified file, // which is specified by a File GUID or File name. - InsertAfter - // InsertBefore inserts before the specified file. - InsertBefore - // InsertDXE inserts into the Dxe Firmware Volume. This works by searching + InsertTypeAfter + // InsertTypeBefore inserts before the specified file. + InsertTypeBefore + // InsertTypeDXE inserts into the Dxe Firmware Volume. This works by searching // for the DxeCore first to identify the Dxe Firmware Volume. - InsertDXE + InsertTypeDXE + + // == Not deprecated == - // ReplaceFFS replaces the found file with the new FFS. This is used + // InsertTypeReplaceFFS replaces the found file with the new FFS. This is used // as a shortcut for remove and insert combined, but also when we want to make // sure that the starting offset of the new file is the same as the old. - ReplaceFFS + InsertTypeReplaceFFS // TODO: Add InsertIn + + // InsertTypeInsert is generalization of all InsertTypeInsert* above. Arguments: + // * The first argument specifies the type of what to insert (possible values: "file" or "pad_file") + // * The second argument specifies the content of what to insert: + // - If the first argument is "file" then a path to the file content is expected. + // - If the first argument is "pad_file" then the size is expected. + // * The third argument specifies the preposition of where to insert to (possible values: "front", "end", "after", "before"). + // * The forth argument specifies the preposition object of where to insert to. It could be FV_or_File GUID_or_name. + // For example combination "end 5C60F367-A505-419A-859E-2A4FF6CA6FE5" means to insert to the end of volume + // "5C60F367-A505-419A-859E-2A4FF6CA6FE5". + // + // A complete example: "pad_file 256 after FC510EE7-FFDC-11D4-BD41-0080C73C8881" means to insert a pad file + // of size 256 bytes after file with GUID "FC510EE7-FFDC-11D4-BD41-0080C73C8881". + InsertTypeInsert ) var insertTypeNames = map[InsertType]string{ - InsertFront: "insert_front", - InsertEnd: "insert_end", - InsertAfter: "insert_after", - InsertBefore: "insert_before", - InsertDXE: "insert_dxe", - ReplaceFFS: "replace_ffs", + InsertTypeInsert: "insert", + InsertTypeReplaceFFS: "replace_ffs", + + // Deprecated: + InsertTypeFront: "insert_front", + InsertTypeEnd: "insert_end", + InsertTypeAfter: "insert_after", + InsertTypeBefore: "insert_before", + InsertTypeDXE: "insert_dxe", } // String creates a string representation for the insert type. @@ -60,8 +84,90 @@ func (i InsertType) String() string { return "UNKNOWN" } +// InsertWhatType defines the type of inserting object +type InsertWhatType int + +const ( + InsertWhatTypeUndefined = InsertWhatType(iota) + InsertWhatTypeFile + InsertWhatTypePadFile + + EndOfInsertWhatType +) + +// String implements fmt.Stringer. +func (t InsertWhatType) String() string { + switch t { + case InsertWhatTypeUndefined: + return "undefined" + case InsertWhatTypeFile: + return "file" + case InsertWhatTypePadFile: + return "pad_file" + } + return fmt.Sprintf("unknown_%d", t) +} + +// ParseInsertWhatType converts a string to InsertWhatType +func ParseInsertWhatType(s string) InsertWhatType { + // TODO: it is currently O(n), optimize + + s = strings.Trim(strings.ToLower(s), " \t") + for t := InsertWhatTypeUndefined; t < EndOfInsertWhatType; t++ { + if t.String() == s { + return t + } + } + return InsertWhatTypeUndefined +} + +// InsertWherePreposition defines the type of inserting object +type InsertWherePreposition int + +const ( + InsertWherePrepositionUndefined = InsertWherePreposition(iota) + InsertWherePrepositionFront + InsertWherePrepositionEnd + InsertWherePrepositionAfter + InsertWherePrepositionBefore + + EndOfInsertWherePreposition +) + +// String implements fmt.Stringer. +func (p InsertWherePreposition) String() string { + switch p { + case InsertWherePrepositionUndefined: + return "undefined" + case InsertWherePrepositionFront: + return "front" + case InsertWherePrepositionEnd: + return "end" + case InsertWherePrepositionAfter: + return "after" + case InsertWherePrepositionBefore: + return "before" + } + return fmt.Sprintf("unknown_%d", p) +} + +// ParseInsertWherePreposition converts a string to InsertWherePreposition +func ParseInsertWherePreposition(s string) InsertWherePreposition { + // TODO: it is currently O(n), optimize + + s = strings.Trim(strings.ToLower(s), " \t") + for t := InsertWherePrepositionUndefined; t < EndOfInsertWherePreposition; t++ { + if t.String() == s { + return t + } + } + return InsertWherePrepositionUndefined +} + // Insert inserts a firmware file into an FV type Insert struct { + // TODO: use InsertWherePreposition to define the location, instead of InsertType + // Input Predicate func(f uefi.Firmware) bool NewFile *uefi.File @@ -91,9 +197,9 @@ func (v *Insert) Run(f uefi.Firmware) error { // edit the FV directly. if fvMatch, ok := find.Matches[0].(*uefi.FirmwareVolume); ok { switch v.InsertType { - case InsertFront: + case InsertTypeFront: fvMatch.Files = append([]*uefi.File{v.NewFile}, fvMatch.Files...) - case InsertEnd: + case InsertTypeEnd: fvMatch.Files = append(fvMatch.Files, v.NewFile) default: return fmt.Errorf("matched FV but insert operation was %s, which only matches Files", @@ -115,18 +221,19 @@ func (v *Insert) Visit(f uefi.Firmware) error { case *uefi.FirmwareVolume: for i := 0; i < len(f.Files); i++ { if f.Files[i] == v.FileMatch { + // TODO: use InsertWherePreposition to define the location, instead of InsertType switch v.InsertType { - case InsertFront: + case InsertTypeFront: f.Files = append([]*uefi.File{v.NewFile}, f.Files...) - case InsertDXE: + case InsertTypeDXE: fallthrough - case InsertEnd: + case InsertTypeEnd: f.Files = append(f.Files, v.NewFile) - case InsertAfter: + case InsertTypeAfter: f.Files = append(f.Files[:i+1], append([]*uefi.File{v.NewFile}, f.Files[i+1:]...)...) - case InsertBefore: + case InsertTypeBefore: f.Files = append(f.Files[:i], append([]*uefi.File{v.NewFile}, f.Files[i:]...)...) - case ReplaceFFS: + case InsertTypeReplaceFFS: f.Files = append(f.Files[:i], append([]*uefi.File{v.NewFile}, f.Files[i+1:]...)...) } return nil @@ -137,13 +244,27 @@ func (v *Insert) Visit(f uefi.Firmware) error { return f.ApplyChildren(v) } -func genInsertCLI(iType InsertType) func(args []string) (uefi.Visitor, error) { +func parseFile(filePath string) (*uefi.File, error) { + fileBytes, err := ioutil.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("unable to read file '%s': %w", filePath, err) + } + + file, err := uefi.NewFile(fileBytes) + if err != nil { + return nil, fmt.Errorf("unable to parse file '%s': %w", filePath, err) + } + + return file, nil +} + +func genInsertRegularFileCLI(iType InsertType) func(args []string) (uefi.Visitor, error) { return func(args []string) (uefi.Visitor, error) { var pred FindPredicate var err error var filename string - if iType == InsertDXE { + if iType == InsertTypeDXE { pred = FindFileTypePredicate(uefi.FVFileTypeDXECore) filename = args[0] } else { @@ -154,14 +275,9 @@ func genInsertCLI(iType InsertType) func(args []string) (uefi.Visitor, error) { filename = args[1] } - newFileBuf, err := ioutil.ReadFile(filename) + file, err := parseFile(filename) if err != nil { - return nil, err - } - // Parse the file. - file, err := uefi.NewFile(newFileBuf) - if err != nil { - return nil, err + return nil, fmt.Errorf("unable to parse file '%s': %w", args[1], err) } // Insert File. @@ -173,17 +289,82 @@ func genInsertCLI(iType InsertType) func(args []string) (uefi.Visitor, error) { } } +func genInsertFileCLI() func(args []string) (uefi.Visitor, error) { + return func(args []string) (uefi.Visitor, error) { + whatType := ParseInsertWhatType(args[0]) + if whatType == InsertWhatTypeUndefined { + return nil, fmt.Errorf("unknown what-type: '%s'", args[0]) + } + + var file *uefi.File + switch whatType { + case InsertWhatTypeFile: + var err error + file, err = parseFile(args[1]) + if err != nil { + return nil, fmt.Errorf("unable to parse file '%s': %w", args[1], err) + } + case InsertWhatTypePadFile: + padSize, err := strconv.ParseUint(args[1], 0, 64) + if err != nil { + return nil, fmt.Errorf("unable to parse pad file size '%s': %w", args[1], err) + } + file, err = uefi.CreatePadFile(padSize) + if err != nil { + return nil, fmt.Errorf("unable to create a pad file of size %d: %w", padSize, err) + } + default: + return nil, fmt.Errorf("what-type '%s' is not supported, yet", whatType) + } + + wherePreposition := ParseInsertWherePreposition(args[2]) + if wherePreposition == InsertWherePrepositionUndefined { + return nil, fmt.Errorf("unknown where-preposition: '%s'", args[2]) + } + + pred, err := FindFileFVPredicate(args[3]) + if err != nil { + return nil, fmt.Errorf("unable to parse the predicate parameters '%s': %w", args[0], err) + } + + // TODO: use InsertWherePreposition to define the location, instead of InsertType + var insertType InsertType + switch wherePreposition { + case InsertWherePrepositionFront: + insertType = InsertTypeFront + case InsertWherePrepositionEnd: + insertType = InsertTypeEnd + case InsertWherePrepositionAfter: + insertType = InsertTypeAfter + case InsertWherePrepositionBefore: + insertType = InsertTypeBefore + default: + return nil, fmt.Errorf("where-preposition '%s' is not supported, yet", wherePreposition) + } + + // Insert File. + return &Insert{ + Predicate: pred, + NewFile: file, + // TODO: use InsertWherePreposition to define the location, instead of InsertType + InsertType: insertType, + }, nil + } +} + func init() { - RegisterCLI(insertTypeNames[InsertFront], - "insert a file at the beginning of a firmware volume", 2, genInsertCLI(InsertFront)) - RegisterCLI(insertTypeNames[InsertEnd], - "insert a file at the end of a firmware volume", 2, genInsertCLI(InsertEnd)) - RegisterCLI(insertTypeNames[InsertDXE], - "insert a file at the end of the DXE firmware volume", 1, genInsertCLI(InsertDXE)) - RegisterCLI(insertTypeNames[InsertAfter], - "insert a file after another file", 2, genInsertCLI(InsertAfter)) - RegisterCLI(insertTypeNames[InsertBefore], - "insert a file before another file", 2, genInsertCLI(InsertBefore)) - RegisterCLI(insertTypeNames[ReplaceFFS], - "replace a file with another file", 2, genInsertCLI(ReplaceFFS)) + RegisterCLI(insertTypeNames[InsertTypeInsert], + "insert a file", 4, genInsertFileCLI()) + RegisterCLI(insertTypeNames[InsertTypeReplaceFFS], + "replace a file with another file", 2, genInsertRegularFileCLI(InsertTypeReplaceFFS)) + RegisterCLI(insertTypeNames[InsertTypeFront], + "(deprecated) insert a file at the beginning of a firmware volume", 2, genInsertRegularFileCLI(InsertTypeFront)) + RegisterCLI(insertTypeNames[InsertTypeEnd], + "(deprecated) insert a file at the end of a firmware volume", 2, genInsertRegularFileCLI(InsertTypeEnd)) + RegisterCLI(insertTypeNames[InsertTypeDXE], + "(deprecated) insert a file at the end of the DXE firmware volume", 1, genInsertRegularFileCLI(InsertTypeDXE)) + RegisterCLI(insertTypeNames[InsertTypeAfter], + "(deprecated) insert a file after another file", 2, genInsertRegularFileCLI(InsertTypeAfter)) + RegisterCLI(insertTypeNames[InsertTypeBefore], + "(deprecated) insert a file before another file", 2, genInsertRegularFileCLI(InsertTypeBefore)) } diff --git a/pkg/visitors/insert_test.go b/pkg/visitors/insert_test.go index 1b7f6e2f..02af188d 100644 --- a/pkg/visitors/insert_test.go +++ b/pkg/visitors/insert_test.go @@ -13,8 +13,12 @@ import ( "github.com/linuxboot/fiano/pkg/uefi" ) -func testRunInsert(t *testing.T, f uefi.Firmware, insertType InsertType, testGUID guid.GUID) (*Insert, error) { - file, err := ioutil.ReadFile("../../integration/roms/testfile.ffs") +const ( + insertTestFile = "../../integration/roms/testfile.ffs" +) + +func testRunObsoleteInsert(t *testing.T, f uefi.Firmware, insertType InsertType, testGUID guid.GUID) (*Insert, error) { + file, err := ioutil.ReadFile(insertTestFile) if err != nil { t.Fatal(err) } @@ -25,7 +29,7 @@ func testRunInsert(t *testing.T, f uefi.Firmware, insertType InsertType, testGUI } // Apply the visitor. var pred FindPredicate - if insertType == InsertDXE { + if insertType == InsertTypeDXE { pred = FindFileTypePredicate(uefi.FVFileTypeDXECore) } else { pred = FindFileGUIDPredicate(testGUID) @@ -39,22 +43,22 @@ func testRunInsert(t *testing.T, f uefi.Firmware, insertType InsertType, testGUI return insert, insert.Run(f) } -func TestInsert(t *testing.T) { +func TestObsoleteInsert(t *testing.T) { var tests = []struct { name string InsertType }{ - {InsertFront.String(), InsertFront}, - {InsertEnd.String(), InsertEnd}, - {InsertAfter.String(), InsertAfter}, - {InsertBefore.String(), InsertBefore}, - {InsertDXE.String(), InsertDXE}, + {InsertTypeFront.String(), InsertTypeFront}, + {InsertTypeEnd.String(), InsertTypeEnd}, + {InsertTypeAfter.String(), InsertTypeAfter}, + {InsertTypeBefore.String(), InsertTypeBefore}, + {InsertTypeDXE.String(), InsertTypeDXE}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { f := parseImage(t) - _, err := testRunInsert(t, f, test.InsertType, *testGUID) + _, err := testRunObsoleteInsert(t, f, test.InsertType, *testGUID) if err != nil { t.Fatal(err) } @@ -74,18 +78,94 @@ func TestInsert(t *testing.T) { } } +func testInsertCLI(t *testing.T, whatType InsertWhatType, wherePreposition InsertWherePreposition) { + f := parseImage(t) + + args := []string{ + whatType.String(), + } + switch whatType { + case InsertWhatTypeFile: + args = append(args, insertTestFile) + case InsertWhatTypePadFile: + args = append(args, "256") + default: + t.Fatalf("unknown what-type '%s'", whatType) + } + + args = append(args, wherePreposition.String(), testGUID.String()) + + visitor, err := genInsertFileCLI()(args) + if err != nil { + t.Fatal(err) + } + + if err := visitor.Run(f); err != nil { + t.Fatal(err) + } + + switch whatType { + case InsertWhatTypeFile: + find := &Find{ + Predicate: FindFileGUIDPredicate(*testGUID), + } + if err = find.Run(f); err != nil { + t.Fatal(err) + } + if len(find.Matches) != 2 { + t.Errorf("incorrect number of matches after insertion! expected 2, got %v", len(find.Matches)) + } + case InsertWhatTypePadFile: + find := &Find{ + Predicate: func(f uefi.Firmware) bool { + file, ok := f.(*uefi.File) + if !ok { + return false + } + if file.Header.GUID.String() != "FFFFFFFF-FFFF-FFFF-FFFF-FFFFFFFFFFFF" { + return false + } + if len(file.Buf()) != 256 { + return false + } + return true + }, + } + if err = find.Run(f); err != nil { + t.Fatal(err) + } + if len(find.Matches) != 1 { + t.Errorf("incorrect number of matches after insertion! expected 1, got %v", len(find.Matches)) + } + default: + t.Fatalf("unknown what-type '%s'", whatType) + } +} + +func TestInsert(t *testing.T) { + for whatType := InsertWhatTypeUndefined + 1; whatType < EndOfInsertWhatType; whatType++ { + t.Run(whatType.String(), func(t *testing.T) { + for wherePreposition := InsertWherePrepositionUndefined + 1; wherePreposition < EndOfInsertWherePreposition; wherePreposition++ { + t.Run(wherePreposition.String(), func(t *testing.T) { + testInsertCLI(t, whatType, wherePreposition) + }) + } + }) + } +} + func TestDoubleFindInsert(t *testing.T) { var tests = []struct { name string InsertType }{ - {"insert_after double result", InsertAfter}, + {"insert_after double result", InsertTypeAfter}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { f := parseImage(t) - insert, err := testRunInsert(t, f, test.InsertType, *testGUID) + insert, err := testRunObsoleteInsert(t, f, test.InsertType, *testGUID) if err != nil { t.Fatal(err) } @@ -107,13 +187,13 @@ func TestNoFindInsert(t *testing.T) { name string InsertType }{ - {"insert_after no file", InsertAfter}, + {"insert_after no file", InsertTypeAfter}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { f := parseImage(t) - _, err := testRunInsert(t, f, test.InsertType, + _, err := testRunObsoleteInsert(t, f, test.InsertType, *guid.MustParse("DECAFBAD-0000-0000-0000-000000000000")) // It should fail due to no such file if err == nil {