blob: 9c616b19768c7346425c4b51e4dd1e76792862b9 [file] [log] [blame]
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import ctypes
import jwt
import logging
import subprocess
import tempfile
from asn1crypto.core import Sequence
from coral.cloudiot.utils import ascii_hex_string
from cryptography.hazmat.primitives import hashes
logger = logging.getLogger(__name__)
library = None
kA71ChOk = 0x9000
def a71ch_serial():
uid_len = 18
ret_len = ctypes.c_uint16(uid_len)
uid = ctypes.create_string_buffer(uid_len)
get_unique_id = library.A71_GetUniqueID
get_unique_id.argtypes = [ctypes.c_char_p, ctypes.POINTER(ctypes.c_uint16)]
get_unique_id.restype = ctypes.c_uint16
assert get_unique_id(uid, ctypes.byref(ret_len)) == kA71ChOk
return ascii_hex_string(uid.raw, l=uid_len)
def a71ch_public_key():
with tempfile.NamedTemporaryFile(mode='w+') as tempkey:
subprocess.check_call(['A71CHConfigTool', 'get', 'pub', '-c', '10', '-x',
'0', '-k', tempkey.name],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
public_key = '\n'.join([x.strip() for x in tempkey.readlines()])
return public_key
def a71ch_hw_sign(msg, key_id=0):
get_sha256 = library.A71_GetSha256
get_sha256.argtypes = [ctypes.c_char_p, ctypes.c_uint16,
ctypes.c_char_p, ctypes.POINTER(ctypes.c_uint16)]
get_sha256.restype = ctypes.c_uint16
hash = ctypes.create_string_buffer(32)
hash_len = ctypes.c_uint16(32)
assert get_sha256(msg, ctypes.c_uint16(len(msg)), hash,
ctypes.byref(hash_len)) == kA71ChOk
ecc_sign = library.A71_EccSign
ecc_sign.argtypes = [ctypes.c_uint8, ctypes.c_char_p,
ctypes.c_uint16, ctypes.c_char_p, ctypes.POINTER(ctypes.c_uint16)]
ecc_sign.restype = ctypes.c_uint16
sig = ctypes.create_string_buffer(256)
sig_len = ctypes.c_uint16(256)
assert ecc_sign(key_id, hash, hash_len, sig,
ctypes.byref(sig_len)) == kA71ChOk
asn1 = Sequence.load(sig.raw)
signature = asn1[0].native.to_bytes(
32, 'big') + asn1[1].native.to_bytes(32, 'big')
return signature
class HwEcAlgorithm(jwt.algorithms.Algorithm):
def __init__(self):
self.hash_alg = hashes.SHA256
def prepare_key(self, key):
return key
def sign(self, msg, key):
return a71ch_hw_sign(msg)
def verify(self, msg, key, sig):
try:
der_sig = jwt.utils.raw_to_der_signature(sig, key.curve)
except ValueError:
return False
try:
key.verify(der_sig, msg, ec.ECDSA(self.hash_alg()))
return True
except InvalidSignature:
return False
class SmCommState_t(ctypes.Structure):
pass
SmCommState_t.__slots__ = [
'connType',
'param1',
'param2',
'hostLibVersion',
'appletVersion',
'sbVersion',
'skip_select_applet',
]
SmCommState_t.__fields__ = [
('connType', ctypes.c_uint16),
('param1', ctypes.c_uint16),
('param2', ctypes.c_uint16),
('hostLibVersion', ctypes.c_uint16),
('appletVersion', ctypes.c_uint32),
('sbVersion', ctypes.c_uint16),
('skip_select_applet', ctypes.c_uint8),
]
try:
a71ch_jwt_with_hw_alg = None
library = ctypes.cdll.LoadLibrary('libsss_engine.so')
sm_connect = library.SM_Connect
sm_connect.argtypes = [ctypes.POINTER(None),
ctypes.POINTER(SmCommState_t), ctypes.c_char_p, ctypes.POINTER(ctypes.c_uint16)]
sm_connect.restype = ctypes.c_uint16
comm_state = SmCommState_t()
atr_len = ctypes.c_uint16(64)
atr = ctypes.create_string_buffer(64)
assert sm_connect(None, ctypes.byref(comm_state), atr,
ctypes.byref(atr_len)) == kA71ChOk
a71ch_jwt_with_hw_alg = jwt.PyJWT(algorithms=[])
a71ch_jwt_with_hw_alg.register_algorithm('ES256', HwEcAlgorithm())
except Exception as e:
logger.debug('Unable to load A71CH')