{
int i;
int ret = 0;
- int hLen, maskedDBLen, emBits, emLen;
+ int hLen, maskedDBLen, MSBits, emLen;
const unsigned char *H;
unsigned char *DB = NULL;
EVP_MD_CTX ctx;
unsigned char H_[EVP_MAX_MD_SIZE];
- emBits = BN_num_bits(rsa->n) - 1;
- emLen = (emBits + 7) >> 3;
+ MSBits = (BN_num_bits(rsa->n) - 1) & 0x7;
+ emLen = RSA_size(rsa);
hLen = EVP_MD_size(Hash);
if (emLen < (hLen + sLen + 2))
{
RSAerr(RSA_F_RSA_VERIFY_PKCS1_PSS, RSA_R_LAST_OCTET_INVALID);
goto err;
}
- if (EM[0] & (0xFF << (emBits & 0x7)))
+ if (EM[0] & (0xFF << MSBits))
{
RSAerr(RSA_F_RSA_VERIFY_PKCS1_PSS, RSA_R_FIRST_OCTET_INVALID);
goto err;
}
+ if (!MSBits)
+ {
+ EM++;
+ emLen--;
+ }
maskedDBLen = emLen - hLen - 1;
H = EM + maskedDBLen;
DB = OPENSSL_malloc(maskedDBLen);
PKCS1_MGF1(DB, maskedDBLen, H, hLen, Hash);
for (i = 0; i < maskedDBLen; i++)
DB[i] ^= EM[i];
- DB[0] &= 0xFF >> (8 - (emBits & 0x7));
+ if (MSBits)
+ DB[0] &= 0xFF >> (8 - MSBits);
for (i = 0; i < (emLen - hLen - sLen - 2); i++)
{
if (DB[i] != 0)
{
int i;
int ret = 0;
- int hLen, maskedDBLen, emBits, emLen;
+ int hLen, maskedDBLen, MSBits, emLen;
unsigned char *H, *salt = NULL, *p;
EVP_MD_CTX ctx;
- emBits = BN_num_bits(rsa->n) - 1;
- emLen = (emBits + 7) >> 3;
+ MSBits = (BN_num_bits(rsa->n) - 1) & 0x7;
+ emLen = RSA_size(rsa);
hLen = EVP_MD_size(Hash);
if (sLen < 0)
sLen = 0;
RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE);
goto err;
}
+ if (MSBits == 0)
+ {
+ *EM++ = 0;
+ emLen--;
+ }
if (sLen > 0)
{
salt = OPENSSL_malloc(sLen);
for (i = 0; i < sLen; i++)
*p++ ^= salt[i];
}
- EM[0] &= 0xFF >> (8 - (emBits & 0x7));
+ if (MSBits)
+ EM[0] &= 0xFF >> (8 - MSBits);
/* H is already in place so just set final 0xbc */