Black Lives Matter. Support the Equal Justice Initiative.

Source file src/crypto/rsa/pss.go

Documentation: crypto/rsa

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package rsa
     6  
     7  // This file implements the RSASSA-PSS signature scheme according to RFC 8017.
     8  
     9  import (
    10  	"bytes"
    11  	"crypto"
    12  	"errors"
    13  	"hash"
    14  	"io"
    15  	"math/big"
    16  )
    17  
    18  // Per RFC 8017, Section 9.1
    19  //
    20  //     EM = MGF1 xor DB || H( 8*0x00 || mHash || salt ) || 0xbc
    21  //
    22  // where
    23  //
    24  //     DB = PS || 0x01 || salt
    25  //
    26  // and PS can be empty so
    27  //
    28  //     emLen = dbLen + hLen + 1 = psLen + sLen + hLen + 2
    29  //
    30  
    31  func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) {
    32  	// See RFC 8017, Section 9.1.1.
    33  
    34  	hLen := hash.Size()
    35  	sLen := len(salt)
    36  	emLen := (emBits + 7) / 8
    37  
    38  	// 1.  If the length of M is greater than the input limitation for the
    39  	//     hash function (2^61 - 1 octets for SHA-1), output "message too
    40  	//     long" and stop.
    41  	//
    42  	// 2.  Let mHash = Hash(M), an octet string of length hLen.
    43  
    44  	if len(mHash) != hLen {
    45  		return nil, errors.New("crypto/rsa: input must be hashed with given hash")
    46  	}
    47  
    48  	// 3.  If emLen < hLen + sLen + 2, output "encoding error" and stop.
    49  
    50  	if emLen < hLen+sLen+2 {
    51  		return nil, errors.New("crypto/rsa: key size too small for PSS signature")
    52  	}
    53  
    54  	em := make([]byte, emLen)
    55  	psLen := emLen - sLen - hLen - 2
    56  	db := em[:psLen+1+sLen]
    57  	h := em[psLen+1+sLen : emLen-1]
    58  
    59  	// 4.  Generate a random octet string salt of length sLen; if sLen = 0,
    60  	//     then salt is the empty string.
    61  	//
    62  	// 5.  Let
    63  	//       M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt;
    64  	//
    65  	//     M' is an octet string of length 8 + hLen + sLen with eight
    66  	//     initial zero octets.
    67  	//
    68  	// 6.  Let H = Hash(M'), an octet string of length hLen.
    69  
    70  	var prefix [8]byte
    71  
    72  	hash.Write(prefix[:])
    73  	hash.Write(mHash)
    74  	hash.Write(salt)
    75  
    76  	h = hash.Sum(h[:0])
    77  	hash.Reset()
    78  
    79  	// 7.  Generate an octet string PS consisting of emLen - sLen - hLen - 2
    80  	//     zero octets. The length of PS may be 0.
    81  	//
    82  	// 8.  Let DB = PS || 0x01 || salt; DB is an octet string of length
    83  	//     emLen - hLen - 1.
    84  
    85  	db[psLen] = 0x01
    86  	copy(db[psLen+1:], salt)
    87  
    88  	// 9.  Let dbMask = MGF(H, emLen - hLen - 1).
    89  	//
    90  	// 10. Let maskedDB = DB \xor dbMask.
    91  
    92  	mgf1XOR(db, hash, h)
    93  
    94  	// 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in
    95  	//     maskedDB to zero.
    96  
    97  	db[0] &= 0xff >> (8*emLen - emBits)
    98  
    99  	// 12. Let EM = maskedDB || H || 0xbc.
   100  	em[emLen-1] = 0xbc
   101  
   102  	// 13. Output EM.
   103  	return em, nil
   104  }
   105  
   106  func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
   107  	// See RFC 8017, Section 9.1.2.
   108  
   109  	hLen := hash.Size()
   110  	if sLen == PSSSaltLengthEqualsHash {
   111  		sLen = hLen
   112  	}
   113  	emLen := (emBits + 7) / 8
   114  	if emLen != len(em) {
   115  		return errors.New("rsa: internal error: inconsistent length")
   116  	}
   117  
   118  	// 1.  If the length of M is greater than the input limitation for the
   119  	//     hash function (2^61 - 1 octets for SHA-1), output "inconsistent"
   120  	//     and stop.
   121  	//
   122  	// 2.  Let mHash = Hash(M), an octet string of length hLen.
   123  	if hLen != len(mHash) {
   124  		return ErrVerification
   125  	}
   126  
   127  	// 3.  If emLen < hLen + sLen + 2, output "inconsistent" and stop.
   128  	if emLen < hLen+sLen+2 {
   129  		return ErrVerification
   130  	}
   131  
   132  	// 4.  If the rightmost octet of EM does not have hexadecimal value
   133  	//     0xbc, output "inconsistent" and stop.
   134  	if em[emLen-1] != 0xbc {
   135  		return ErrVerification
   136  	}
   137  
   138  	// 5.  Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and
   139  	//     let H be the next hLen octets.
   140  	db := em[:emLen-hLen-1]
   141  	h := em[emLen-hLen-1 : emLen-1]
   142  
   143  	// 6.  If the leftmost 8 * emLen - emBits bits of the leftmost octet in
   144  	//     maskedDB are not all equal to zero, output "inconsistent" and
   145  	//     stop.
   146  	var bitMask byte = 0xff >> (8*emLen - emBits)
   147  	if em[0] & ^bitMask != 0 {
   148  		return ErrVerification
   149  	}
   150  
   151  	// 7.  Let dbMask = MGF(H, emLen - hLen - 1).
   152  	//
   153  	// 8.  Let DB = maskedDB \xor dbMask.
   154  	mgf1XOR(db, hash, h)
   155  
   156  	// 9.  Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
   157  	//     to zero.
   158  	db[0] &= bitMask
   159  
   160  	// If we don't know the salt length, look for the 0x01 delimiter.
   161  	if sLen == PSSSaltLengthAuto {
   162  		psLen := bytes.IndexByte(db, 0x01)
   163  		if psLen < 0 {
   164  			return ErrVerification
   165  		}
   166  		sLen = len(db) - psLen - 1
   167  	}
   168  
   169  	// 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
   170  	//     or if the octet at position emLen - hLen - sLen - 1 (the leftmost
   171  	//     position is "position 1") does not have hexadecimal value 0x01,
   172  	//     output "inconsistent" and stop.
   173  	psLen := emLen - hLen - sLen - 2
   174  	for _, e := range db[:psLen] {
   175  		if e != 0x00 {
   176  			return ErrVerification
   177  		}
   178  	}
   179  	if db[psLen] != 0x01 {
   180  		return ErrVerification
   181  	}
   182  
   183  	// 11.  Let salt be the last sLen octets of DB.
   184  	salt := db[len(db)-sLen:]
   185  
   186  	// 12.  Let
   187  	//          M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
   188  	//     M' is an octet string of length 8 + hLen + sLen with eight
   189  	//     initial zero octets.
   190  	//
   191  	// 13. Let H' = Hash(M'), an octet string of length hLen.
   192  	var prefix [8]byte
   193  	hash.Write(prefix[:])
   194  	hash.Write(mHash)
   195  	hash.Write(salt)
   196  
   197  	h0 := hash.Sum(nil)
   198  
   199  	// 14. If H = H', output "consistent." Otherwise, output "inconsistent."
   200  	if !bytes.Equal(h0, h) { // TODO: constant time?
   201  		return ErrVerification
   202  	}
   203  	return nil
   204  }
   205  
   206  // signPSSWithSalt calculates the signature of hashed using PSS with specified salt.
   207  // Note that hashed must be the result of hashing the input message using the
   208  // given hash function. salt is a random sequence of bytes whose length will be
   209  // later used to verify the signature.
   210  func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) {
   211  	emBits := priv.N.BitLen() - 1
   212  	em, err := emsaPSSEncode(hashed, emBits, salt, hash.New())
   213  	if err != nil {
   214  		return nil, err
   215  	}
   216  	m := new(big.Int).SetBytes(em)
   217  	c, err := decryptAndCheck(rand, priv, m)
   218  	if err != nil {
   219  		return nil, err
   220  	}
   221  	s := make([]byte, priv.Size())
   222  	return c.FillBytes(s), nil
   223  }
   224  
   225  const (
   226  	// PSSSaltLengthAuto causes the salt in a PSS signature to be as large
   227  	// as possible when signing, and to be auto-detected when verifying.
   228  	PSSSaltLengthAuto = 0
   229  	// PSSSaltLengthEqualsHash causes the salt length to equal the length
   230  	// of the hash used in the signature.
   231  	PSSSaltLengthEqualsHash = -1
   232  )
   233  
   234  // PSSOptions contains options for creating and verifying PSS signatures.
   235  type PSSOptions struct {
   236  	// SaltLength controls the length of the salt used in the PSS
   237  	// signature. It can either be a number of bytes, or one of the special
   238  	// PSSSaltLength constants.
   239  	SaltLength int
   240  
   241  	// Hash is the hash function used to generate the message digest. If not
   242  	// zero, it overrides the hash function passed to SignPSS. It's required
   243  	// when using PrivateKey.Sign.
   244  	Hash crypto.Hash
   245  }
   246  
   247  // HashFunc returns opts.Hash so that PSSOptions implements crypto.SignerOpts.
   248  func (opts *PSSOptions) HashFunc() crypto.Hash {
   249  	return opts.Hash
   250  }
   251  
   252  func (opts *PSSOptions) saltLength() int {
   253  	if opts == nil {
   254  		return PSSSaltLengthAuto
   255  	}
   256  	return opts.SaltLength
   257  }
   258  
   259  // SignPSS calculates the signature of digest using PSS.
   260  //
   261  // digest must be the result of hashing the input message using the given hash
   262  // function. The opts argument may be nil, in which case sensible defaults are
   263  // used. If opts.Hash is set, it overrides hash.
   264  func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, digest []byte, opts *PSSOptions) ([]byte, error) {
   265  	if opts != nil && opts.Hash != 0 {
   266  		hash = opts.Hash
   267  	}
   268  
   269  	saltLength := opts.saltLength()
   270  	switch saltLength {
   271  	case PSSSaltLengthAuto:
   272  		saltLength = (priv.N.BitLen()-1+7)/8 - 2 - hash.Size()
   273  	case PSSSaltLengthEqualsHash:
   274  		saltLength = hash.Size()
   275  	}
   276  
   277  	salt := make([]byte, saltLength)
   278  	if _, err := io.ReadFull(rand, salt); err != nil {
   279  		return nil, err
   280  	}
   281  	return signPSSWithSalt(rand, priv, hash, digest, salt)
   282  }
   283  
   284  // VerifyPSS verifies a PSS signature.
   285  //
   286  // A valid signature is indicated by returning a nil error. digest must be the
   287  // result of hashing the input message using the given hash function. The opts
   288  // argument may be nil, in which case sensible defaults are used. opts.Hash is
   289  // ignored.
   290  func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts *PSSOptions) error {
   291  	if len(sig) != pub.Size() {
   292  		return ErrVerification
   293  	}
   294  	s := new(big.Int).SetBytes(sig)
   295  	m := encrypt(new(big.Int), pub, s)
   296  	emBits := pub.N.BitLen() - 1
   297  	emLen := (emBits + 7) / 8
   298  	if m.BitLen() > emLen*8 {
   299  		return ErrVerification
   300  	}
   301  	em := m.FillBytes(make([]byte, emLen))
   302  	return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New())
   303  }
   304  

View as plain text