More cleanup, merged pymumble into the project so I don't have to maintain 2 separate things, tested working version.
This commit is contained in:
383
pymumble_py3/crypto.py
Normal file
383
pymumble_py3/crypto.py
Normal file
@@ -0,0 +1,383 @@
|
||||
'''
|
||||
OCB2 crypto, broadly following the implementation from Mumble
|
||||
'''
|
||||
from typing import Tuple
|
||||
import struct
|
||||
import time
|
||||
from math import ceil
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Random import get_random_bytes
|
||||
|
||||
|
||||
AES_BLOCK_SIZE = 128 // 8 # Number of bytes in a block
|
||||
AES_KEY_SIZE_BITS = 128
|
||||
AES_KEY_SIZE_BYTES = AES_KEY_SIZE_BITS // 8
|
||||
SHIFTBITS = 63 # Shift size for S2 operation
|
||||
MAX64 = (1 << 64) - 1 # Maximum value of uint64
|
||||
|
||||
|
||||
class EncryptFailedException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DecryptFailedException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CryptStateOCB2:
|
||||
"""
|
||||
State tracker for AES-OCB2 crypto.
|
||||
All encryption/decryption should be done through this class
|
||||
and not the `ocb_*` functions.
|
||||
|
||||
A random key and IVs are chosen upon initialization; these can be
|
||||
replaced using `set_key`.
|
||||
|
||||
Attributes intended for external access:
|
||||
raw_key
|
||||
encrypt_iv
|
||||
decrypt_iv
|
||||
decrypt_history
|
||||
|
||||
uiGood
|
||||
uiLate
|
||||
uiLost
|
||||
tLastGood
|
||||
"""
|
||||
_raw_key: bytes # AES key; access through `raw_key` property
|
||||
_aes: object # pycrypto AES cipher object, replaced when `raw_key` is changed
|
||||
_encrypt_iv: bytearray # IV for encryption, access through `encrypt_iv` property
|
||||
_decrypt_iv: bytearray # IV for decryption, access through `decrypt_iv` property
|
||||
decrypt_history: bytearray # History of previous decrypt_iv values
|
||||
|
||||
# Statistics:
|
||||
uiGood: int # Number of packets successfully decrypted
|
||||
uiLate: int # Number of packets which arrived out of order
|
||||
uiLost: int # Number of packets which did not arrive in order (may arrive late)
|
||||
tLastGood: float # time.perf_counter() value for latest good packet
|
||||
|
||||
def __init__(self):
|
||||
self.uiGood = 0
|
||||
self.uiLate = 0
|
||||
self.uiLost = 0
|
||||
self.tLastGood = 0
|
||||
|
||||
self._raw_key = get_random_bytes(AES_KEY_SIZE_BYTES)
|
||||
self._encrypt_iv = get_random_bytes(AES_BLOCK_SIZE)
|
||||
self._decrypt_iv = get_random_bytes(AES_BLOCK_SIZE)
|
||||
self._aes = None
|
||||
self.decrypt_history = bytearray(0x100)
|
||||
|
||||
@property
|
||||
def raw_key(self) -> bytes:
|
||||
return self._raw_key
|
||||
|
||||
@raw_key.setter
|
||||
def raw_key(self, rkey: bytes):
|
||||
if len(rkey) != AES_KEY_SIZE_BYTES:
|
||||
raise Exception('raw_key has wrong length')
|
||||
self._raw_key = bytes(rkey)
|
||||
self._aes = AES.new(key=self.raw_key, mode=AES.MODE_ECB)
|
||||
|
||||
@property
|
||||
def encrypt_iv(self) -> bytearray:
|
||||
return self._encrypt_iv
|
||||
|
||||
@encrypt_iv.setter
|
||||
def encrypt_iv(self, eiv: bytearray):
|
||||
if len(eiv) != AES_BLOCK_SIZE:
|
||||
raise Exception('encrypt_iv wrong length')
|
||||
self._encrypt_iv = bytearray(eiv)
|
||||
|
||||
@property
|
||||
def decrypt_iv(self) -> bytearray:
|
||||
return self._decrypt_iv
|
||||
|
||||
@decrypt_iv.setter
|
||||
def decrypt_iv(self, div: bytearray):
|
||||
if len(div) != AES_BLOCK_SIZE:
|
||||
raise Exception('decrypt_iv has wrong length')
|
||||
self._decrypt_iv = bytearray(div)
|
||||
|
||||
def gen_key(self):
|
||||
"""
|
||||
Randomly generate new keys
|
||||
"""
|
||||
self.raw_key = get_random_bytes(AES_KEY_SIZE_BYTES)
|
||||
self.encrypt_iv = get_random_bytes(AES_BLOCK_SIZE)
|
||||
self.decrypt_iv = get_random_bytes(AES_BLOCK_SIZE)
|
||||
|
||||
def set_key(self, raw_key: bytes, encrypt_iv: bytearray, decrypt_iv: bytearray):
|
||||
"""
|
||||
Set new keys
|
||||
|
||||
Args:
|
||||
raw_key: AES key
|
||||
encrypt_iv: IV for encryption
|
||||
decrypt_iv: IV for decrpytion
|
||||
"""
|
||||
self.raw_key = raw_key
|
||||
self.encrypt_iv = encrypt_iv
|
||||
self.decrypt_iv = decrypt_iv
|
||||
|
||||
def encrypt(self, source: bytes) -> bytes:
|
||||
"""
|
||||
Encrypt a message
|
||||
|
||||
Args:
|
||||
source: The plaintext bytes to be encrypted
|
||||
|
||||
Returns:
|
||||
Encrypted (ciphertext) bytes
|
||||
|
||||
Raises:
|
||||
EncryptFailedException if `source` would result in a vulnerable packet
|
||||
"""
|
||||
eiv = increment_iv(self.encrypt_iv)
|
||||
self.encrypt_iv = eiv
|
||||
|
||||
dst, tag = ocb_encrypt(self._aes, source, bytes(eiv))
|
||||
|
||||
head = bytes((eiv[0], *tag[:3]))
|
||||
return head + dst
|
||||
|
||||
def decrypt(self, source: bytes, len_plain: int) -> bytes:
|
||||
"""
|
||||
Decrypt a message
|
||||
|
||||
Args:
|
||||
source: The ciphertext bytes to be decrypted
|
||||
len_plain: The length of the plaintext
|
||||
|
||||
Returns:
|
||||
Decrypted (plaintext) bytes
|
||||
|
||||
Raises:
|
||||
DecryptFailedException:
|
||||
- if `source` is too short
|
||||
- packet is out of order or duplicate
|
||||
- packet was could have been tampered with
|
||||
"""
|
||||
if len(source) < 4:
|
||||
raise DecryptFailedException('Source <4 bytes long!')
|
||||
|
||||
div = self.decrypt_iv.copy()
|
||||
ivbyte = source[0]
|
||||
late = False
|
||||
lost = 0
|
||||
|
||||
if (div[0] + 1) & 0xFF == ivbyte:
|
||||
# In order as expected.
|
||||
if ivbyte > div[0]:
|
||||
div[0] = ivbyte
|
||||
elif ivbyte < div[0]:
|
||||
div[0] = ivbyte
|
||||
div = increment_iv(div, 1)
|
||||
else:
|
||||
raise DecryptFailedException('ivbyte == decrypt_iv[0]')
|
||||
else:
|
||||
# This is either out of order or a repeat.
|
||||
diff = ivbyte - div[0]
|
||||
if diff > 128:
|
||||
diff -= 256
|
||||
elif diff < -128:
|
||||
diff += 256
|
||||
|
||||
if ivbyte < div[0] and -30 < diff < 0:
|
||||
# Late packet, but no wraparound.
|
||||
late = True
|
||||
lost = -1
|
||||
div[0] = ivbyte
|
||||
elif ivbyte > div[0] and -30 < diff < 0:
|
||||
# Last was 0x02, here comes 0xff from last round
|
||||
late = True
|
||||
lost = -1
|
||||
div[0] = ivbyte
|
||||
div = decrement_iv(div, 1)
|
||||
elif ivbyte > div[0] and diff > 0:
|
||||
# Lost a few packets, but beyond that we're good.
|
||||
lost = ivbyte - div[0] - 1
|
||||
div[0] = ivbyte
|
||||
elif ivbyte < div[0] and diff > 0:
|
||||
# Lost a few packets, and wrapped around
|
||||
lost = 0x100 - div[0] + ivbyte - 1
|
||||
div[0] = ivbyte
|
||||
div = increment_iv(div, 1)
|
||||
else:
|
||||
raise DecryptFailedException('Lost too many packets?')
|
||||
|
||||
if self.decrypt_history[div[0]] == div[1]:
|
||||
raise DecryptFailedException('decrypt_iv in history')
|
||||
|
||||
dst, tag = ocb_decrypt(self._aes, source[4:], bytes(div), len_plain)
|
||||
|
||||
if tag[:3] != source[1:4]:
|
||||
raise DecryptFailedException('Tag did not match!')
|
||||
|
||||
self.decrypt_history[div[0]] = div[1]
|
||||
|
||||
if not late:
|
||||
self.decrypt_iv = div
|
||||
else:
|
||||
self.uiLate += 1
|
||||
|
||||
self.uiGood += 1
|
||||
self.uiLost += lost
|
||||
|
||||
self.tLastGood = time.perf_counter()
|
||||
|
||||
return dst
|
||||
|
||||
|
||||
def ocb_encrypt(aes: object,
|
||||
plain: bytes,
|
||||
nonce: bytes,
|
||||
*,
|
||||
insecure=False,
|
||||
) -> Tuple[bytes, bytes]:
|
||||
"""
|
||||
Encrypt a message.
|
||||
This should be called from CryptStateOCB2.encrypt() and not independently.
|
||||
|
||||
Args:
|
||||
aes: AES-ECB cipher object
|
||||
plain: The plaintext bytes to be encrypted
|
||||
nonce: The encryption IV
|
||||
|
||||
Returns:
|
||||
Encrypted (ciphertext) bytes and tag
|
||||
|
||||
Raises:
|
||||
EncryptFailedException if `source` would result in a vulnerable packet
|
||||
"""
|
||||
delta = aes.encrypt(nonce)
|
||||
checksum = bytes(AES_BLOCK_SIZE)
|
||||
plain_block = b''
|
||||
|
||||
pos = 0
|
||||
encrypted = bytearray(ceil(len(plain) / AES_BLOCK_SIZE) * AES_BLOCK_SIZE)
|
||||
while len(plain) - pos > AES_BLOCK_SIZE:
|
||||
plain_block = plain[pos:pos + AES_BLOCK_SIZE]
|
||||
delta = S2(delta)
|
||||
encrypted_block = xor(delta, aes.encrypt(xor(delta, plain_block)))
|
||||
checksum = xor(checksum, plain_block)
|
||||
|
||||
encrypted[pos:pos + AES_BLOCK_SIZE] = encrypted_block
|
||||
pos += AES_BLOCK_SIZE
|
||||
|
||||
# Counter-cryptanalysis described in section 9 of https://eprint.iacr.org/2019/311
|
||||
# For an attack, the second to last block (i.e. the last iteration of this loop)
|
||||
# must be all 0 except for the last byte (which may be 0 - 128).
|
||||
if not insecure and bytes(plain_block[:-1]) == bytes(AES_BLOCK_SIZE - 1):
|
||||
raise EncryptFailedException('Insecure input block: ' +
|
||||
'see section 9 of https://eprint.iacr.org/2019/311')
|
||||
|
||||
len_remaining = len(plain) - pos
|
||||
delta = S2(delta)
|
||||
pad_in = struct.pack('>QQ', 0, len_remaining * 8)
|
||||
pad = aes.encrypt(xor(pad_in, delta))
|
||||
plain_block = plain[pos:] + pad[len_remaining - AES_BLOCK_SIZE:]
|
||||
|
||||
checksum = xor(checksum, plain_block)
|
||||
encrypted_block = xor(pad, plain_block)
|
||||
encrypted[pos:] = encrypted_block
|
||||
|
||||
delta = xor(delta, S2(delta))
|
||||
tag = aes.encrypt(xor(delta, checksum))
|
||||
|
||||
return encrypted, tag
|
||||
|
||||
|
||||
def ocb_decrypt(aes: object,
|
||||
encrypted: bytes,
|
||||
nonce: bytes,
|
||||
len_plain: int,
|
||||
*,
|
||||
insecure=False,
|
||||
) -> Tuple[bytes, bytes]:
|
||||
"""
|
||||
Decrypt a message.
|
||||
This should be called from CryptStateOCB2.decrypt() and not independently.
|
||||
|
||||
Args:
|
||||
aes: AES-ECB cipher object
|
||||
encrypted: The ciphertext bytes to be decrypted
|
||||
nonce: The decryption IV
|
||||
len_plain: The length of the desired plaintext
|
||||
|
||||
Returns:
|
||||
Decrypted (plaintext) bytes and tag
|
||||
|
||||
Raises:
|
||||
DecryptFailedException:
|
||||
- if `source` is too short
|
||||
- packet is out of order or duplicate
|
||||
- packet was could have been tampered with
|
||||
"""
|
||||
delta = aes.encrypt(nonce)
|
||||
checksum = bytes(AES_BLOCK_SIZE)
|
||||
plain = bytearray(len_plain)
|
||||
|
||||
pos = 0
|
||||
while len_plain - pos > AES_BLOCK_SIZE:
|
||||
encrypted_block = encrypted[pos:pos + AES_BLOCK_SIZE]
|
||||
delta = S2(delta)
|
||||
tmp = aes.decrypt(xor(delta, encrypted_block))
|
||||
plain_block = xor(delta, tmp)
|
||||
checksum = xor(checksum, plain_block)
|
||||
|
||||
plain[pos:pos + AES_BLOCK_SIZE] = plain_block
|
||||
pos += AES_BLOCK_SIZE
|
||||
|
||||
len_remaining = len_plain - pos
|
||||
delta = S2(delta)
|
||||
pad_in = struct.pack('>QQ', 0, len_remaining * 8)
|
||||
pad = aes.encrypt(xor(pad_in, delta))
|
||||
encrypted_zeropad = encrypted[pos:] + bytes(AES_BLOCK_SIZE - len_remaining)
|
||||
plain_block = xor(encrypted_zeropad, pad)
|
||||
|
||||
checksum = xor(checksum, plain_block)
|
||||
plain[pos:] = plain_block[:len_remaining]
|
||||
|
||||
# Counter-cryptanalysis described in section 9 of https://eprint.iacr.org/2019/311
|
||||
# In an attack, the decrypted last block would need to equal `delta ^ len(128)`.
|
||||
# With a bit of luck (or many packets), smaller values than 128 (i.e. non-full blocks) are also
|
||||
# feasible, so we check `plain_block` instead of `plain`.
|
||||
# Since our `len` only ever modifies the last byte, we simply check all remaining ones.
|
||||
if not insecure and plain_block[:-1] == delta[:-1]:
|
||||
raise DecryptFailedException('Possibly tampered/able block, discarding.')
|
||||
|
||||
delta = xor(delta, S2(delta))
|
||||
tag = aes.encrypt(xor(delta, checksum))
|
||||
return plain, tag
|
||||
|
||||
|
||||
def increment_iv(iv: bytearray, start: int = 0) -> bytearray:
|
||||
for i in range(start, AES_BLOCK_SIZE):
|
||||
iv[i] = (iv[i] + 1) % 0x100
|
||||
if iv[i] != 0:
|
||||
break
|
||||
return iv
|
||||
|
||||
|
||||
def decrement_iv(iv: bytearray, start: int = 0) -> bytearray:
|
||||
for i in range(start, AES_BLOCK_SIZE):
|
||||
iv[i] = (iv[i] - 1) % 0x100
|
||||
if iv[i] != 0xFF:
|
||||
break
|
||||
return iv
|
||||
|
||||
|
||||
def xor(a: bytes, b: bytes) -> bytes:
|
||||
return bytes(aa ^ bb for aa, bb in zip(a, b))
|
||||
|
||||
|
||||
def S2(block: bytes) -> bytes:
|
||||
ll, uu = struct.unpack('>QQ', block)
|
||||
carry = ll >> 63
|
||||
block = struct.pack('>QQ',
|
||||
((ll << 1) | (uu >> 63)) & MAX64,
|
||||
((uu << 1) ^ (carry * 0x87)) & MAX64)
|
||||
return block
|
||||
|
Reference in New Issue
Block a user