DeDRM_tools/DeDRM_plugin/ion.py
2021-01-21 23:41:22 +09:00

1042 lines
32 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ion.py
# Copyright © 2013-2020 Apprentice Harper et al.
__license__ = 'GPL v3'
__version__ = '3.0'
# Revision history:
# Pascal implementation by lulzkabulz.
# BinaryIon.pas + DrmIon.pas + IonSymbols.pas
# 1.0 - Python translation by apprenticenaomi.
# 1.1 - DeDRM integration by anon.
# 1.2 - Added pylzma import fallback
# 1.3 - Fixed lzma support for calibre 4.6+
# 2.0 - VoucherEnvelope v2/v3 support by apprenticesakuya.
# 3.0 - Added Python 3 compatibility for calibre 5.0
"""
Decrypt Kindle KFX files.
"""
import collections
import hashlib
import hmac
import os
import os.path
import struct
from io import BytesIO
from Crypto.Cipher import AES
from Crypto.Util.py3compat import bchr
try:
# lzma library from calibre 4.6.0 or later
import calibre_lzma.lzma1 as calibre_lzma
except ImportError:
calibre_lzma = None
# lzma library from calibre 2.35.0 or later
try:
import lzma.lzma1 as calibre_lzma
except ImportError:
calibre_lzma = None
try:
import lzma
except ImportError:
# Need pip backports.lzma on Python <3.3
try:
from backports import lzma
except ImportError:
# Windows-friendly choice: pylzma wheels
import pylzma as lzma
TID_NULL = 0
TID_BOOLEAN = 1
TID_POSINT = 2
TID_NEGINT = 3
TID_FLOAT = 4
TID_DECIMAL = 5
TID_TIMESTAMP = 6
TID_SYMBOL = 7
TID_STRING = 8
TID_CLOB = 9
TID_BLOB = 0xA
TID_LIST = 0xB
TID_SEXP = 0xC
TID_STRUCT = 0xD
TID_TYPEDECL = 0xE
TID_UNUSED = 0xF
SID_UNKNOWN = -1
SID_ION = 1
SID_ION_1_0 = 2
SID_ION_SYMBOL_TABLE = 3
SID_NAME = 4
SID_VERSION = 5
SID_IMPORTS = 6
SID_SYMBOLS = 7
SID_MAX_ID = 8
SID_ION_SHARED_SYMBOL_TABLE = 9
SID_ION_1_0_MAX = 10
LEN_IS_VAR_LEN = 0xE
LEN_IS_NULL = 0xF
VERSION_MARKER = [b"\x01", b"\x00", b"\xEA"]
# asserts must always raise exceptions for proper functioning
def _assert(test, msg="Exception"):
if not test:
raise Exception(msg)
class SystemSymbols(object):
ION = '$ion'
ION_1_0 = '$ion_1_0'
ION_SYMBOL_TABLE = '$ion_symbol_table'
NAME = 'name'
VERSION = 'version'
IMPORTS = 'imports'
SYMBOLS = 'symbols'
MAX_ID = 'max_id'
ION_SHARED_SYMBOL_TABLE = '$ion_shared_symbol_table'
class IonCatalogItem(object):
name = ""
version = 0
symnames = []
def __init__(self, name, version, symnames):
self.name = name
self.version = version
self.symnames = symnames
class SymbolToken(object):
text = ""
sid = 0
def __init__(self, text, sid):
if text == "" and sid == 0:
raise ValueError("Symbol token must have Text or SID")
self.text = text
self.sid = sid
class SymbolTable(object):
table = None
def __init__(self):
self.table = [None] * SID_ION_1_0_MAX
self.table[SID_ION] = SystemSymbols.ION
self.table[SID_ION_1_0] = SystemSymbols.ION_1_0
self.table[SID_ION_SYMBOL_TABLE] = SystemSymbols.ION_SYMBOL_TABLE
self.table[SID_NAME] = SystemSymbols.NAME
self.table[SID_VERSION] = SystemSymbols.VERSION
self.table[SID_IMPORTS] = SystemSymbols.IMPORTS
self.table[SID_SYMBOLS] = SystemSymbols.SYMBOLS
self.table[SID_MAX_ID] = SystemSymbols.MAX_ID
self.table[SID_ION_SHARED_SYMBOL_TABLE] = SystemSymbols.ION_SHARED_SYMBOL_TABLE
def findbyid(self, sid):
if sid < 1:
raise ValueError("Invalid symbol id")
if sid < len(self.table):
return self.table[sid]
else:
return ""
def import_(self, table, maxid):
for i in range(maxid):
self.table.append(table.symnames[i])
def importunknown(self, name, maxid):
for i in range(maxid):
self.table.append("%s#%d" % (name, i + 1))
class ParserState:
Invalid,BeforeField,BeforeTID,BeforeValue,AfterValue,EOF = 1,2,3,4,5,6
ContainerRec = collections.namedtuple("ContainerRec", "nextpos, tid, remaining")
class BinaryIonParser(object):
eof = False
state = None
localremaining = 0
needhasnext = False
isinstruct = False
valuetid = 0
valuefieldid = 0
parenttid = 0
valuelen = 0
valueisnull = False
valueistrue = False
value = None
didimports = False
def __init__(self, stream):
self.annotations = []
self.catalog = []
self.stream = stream
self.initpos = stream.tell()
self.reset()
self.symbols = SymbolTable()
def reset(self):
self.state = ParserState.BeforeTID
self.needhasnext = True
self.localremaining = -1
self.eof = False
self.isinstruct = False
self.containerstack = []
self.stream.seek(self.initpos)
def addtocatalog(self, name, version, symbols):
self.catalog.append(IonCatalogItem(name, version, symbols))
def hasnext(self):
while self.needhasnext and not self.eof:
self.hasnextraw()
if len(self.containerstack) == 0 and not self.valueisnull:
if self.valuetid == TID_SYMBOL:
if self.value == SID_ION_1_0:
self.needhasnext = True
elif self.valuetid == TID_STRUCT:
for a in self.annotations:
if a == SID_ION_SYMBOL_TABLE:
self.parsesymboltable()
self.needhasnext = True
break
return not self.eof
def hasnextraw(self):
self.clearvalue()
while self.valuetid == -1 and not self.eof:
self.needhasnext = False
if self.state == ParserState.BeforeField:
_assert(self.valuefieldid == SID_UNKNOWN)
self.valuefieldid = self.readfieldid()
if self.valuefieldid != SID_UNKNOWN:
self.state = ParserState.BeforeTID
else:
self.eof = True
elif self.state == ParserState.BeforeTID:
self.state = ParserState.BeforeValue
self.valuetid = self.readtypeid()
if self.valuetid == -1:
self.state = ParserState.EOF
self.eof = True
break
if self.valuetid == TID_TYPEDECL:
if self.valuelen == 0:
self.checkversionmarker()
else:
self.loadannotations()
elif self.state == ParserState.BeforeValue:
self.skip(self.valuelen)
self.state = ParserState.AfterValue
elif self.state == ParserState.AfterValue:
if self.isinstruct:
self.state = ParserState.BeforeField
else:
self.state = ParserState.BeforeTID
else:
_assert(self.state == ParserState.EOF)
def next(self):
if self.hasnext():
self.needhasnext = True
return self.valuetid
else:
return -1
def push(self, typeid, nextposition, nextremaining):
self.containerstack.append(ContainerRec(nextpos=nextposition, tid=typeid, remaining=nextremaining))
def stepin(self):
_assert(self.valuetid in [TID_STRUCT, TID_LIST, TID_SEXP] and not self.eof,
"valuetid=%s eof=%s" % (self.valuetid, self.eof))
_assert((not self.valueisnull or self.state == ParserState.AfterValue) and
(self.valueisnull or self.state == ParserState.BeforeValue))
nextrem = self.localremaining
if nextrem != -1:
nextrem -= self.valuelen
if nextrem < 0:
nextrem = 0
self.push(self.parenttid, self.stream.tell() + self.valuelen, nextrem)
self.isinstruct = (self.valuetid == TID_STRUCT)
if self.isinstruct:
self.state = ParserState.BeforeField
else:
self.state = ParserState.BeforeTID
self.localremaining = self.valuelen
self.parenttid = self.valuetid
self.clearvalue()
self.needhasnext = True
def stepout(self):
rec = self.containerstack.pop()
self.eof = False
self.parenttid = rec.tid
if self.parenttid == TID_STRUCT:
self.isinstruct = True
self.state = ParserState.BeforeField
else:
self.isinstruct = False
self.state = ParserState.BeforeTID
self.needhasnext = True
self.clearvalue()
curpos = self.stream.tell()
if rec.nextpos > curpos:
self.skip(rec.nextpos - curpos)
else:
_assert(rec.nextpos == curpos)
self.localremaining = rec.remaining
def read(self, count=1):
if self.localremaining != -1:
self.localremaining -= count
_assert(self.localremaining >= 0)
result = self.stream.read(count)
if len(result) == 0:
raise EOFError()
return result
def readfieldid(self):
if self.localremaining != -1 and self.localremaining < 1:
return -1
try:
return self.readvaruint()
except EOFError:
return -1
def readtypeid(self):
if self.localremaining != -1:
if self.localremaining < 1:
return -1
self.localremaining -= 1
b = self.stream.read(1)
if len(b) < 1:
return -1
b = ord(b)
result = b >> 4
ln = b & 0xF
if ln == LEN_IS_VAR_LEN:
ln = self.readvaruint()
elif ln == LEN_IS_NULL:
ln = 0
self.state = ParserState.AfterValue
elif result == TID_NULL:
# Must have LEN_IS_NULL
_assert(False)
elif result == TID_BOOLEAN:
_assert(ln <= 1)
self.valueistrue = (ln == 1)
ln = 0
self.state = ParserState.AfterValue
elif result == TID_STRUCT:
if ln == 1:
ln = self.readvaruint()
self.valuelen = ln
return result
def readvarint(self):
b = ord(self.read())
negative = ((b & 0x40) != 0)
result = (b & 0x3F)
i = 0
while (b & 0x80) == 0 and i < 4:
b = ord(self.read())
result = (result << 7) | (b & 0x7F)
i += 1
_assert(i < 4 or (b & 0x80) != 0, "int overflow")
if negative:
return -result
return result
def readvaruint(self):
b = ord(self.read())
result = (b & 0x7F)
i = 0
while (b & 0x80) == 0 and i < 4:
b = ord(self.read())
result = (result << 7) | (b & 0x7F)
i += 1
_assert(i < 4 or (b & 0x80) != 0, "int overflow")
return result
def readdecimal(self):
if self.valuelen == 0:
return 0.
rem = self.localremaining - self.valuelen
self.localremaining = self.valuelen
exponent = self.readvarint()
_assert(self.localremaining > 0, "Only exponent in ReadDecimal")
_assert(self.localremaining <= 8, "Decimal overflow")
signed = False
b = [ord(x) for x in self.read(self.localremaining)]
if (b[0] & 0x80) != 0:
b[0] = b[0] & 0x7F
signed = True
# Convert variably sized network order integer into 64-bit little endian
j = 0
vb = [0] * 8
for i in range(len(b), -1, -1):
vb[i] = b[j]
j += 1
v = struct.unpack("<Q", b"".join(bchr(x) for x in vb))[0]
result = v * (10 ** exponent)
if signed:
result = -result
self.localremaining = rem
return result
def skip(self, count):
if self.localremaining != -1:
self.localremaining -= count
if self.localremaining < 0:
raise EOFError()
self.stream.seek(count, os.SEEK_CUR)
def parsesymboltable(self):
self.next() # shouldn't do anything?
_assert(self.valuetid == TID_STRUCT)
if self.didimports:
return
self.stepin()
fieldtype = self.next()
while fieldtype != -1:
if not self.valueisnull:
_assert(self.valuefieldid == SID_IMPORTS, "Unsupported symbol table field id")
if fieldtype == TID_LIST:
self.gatherimports()
fieldtype = self.next()
self.stepout()
self.didimports = True
def gatherimports(self):
self.stepin()
t = self.next()
while t != -1:
if not self.valueisnull and t == TID_STRUCT:
self.readimport()
t = self.next()
self.stepout()
def readimport(self):
version = -1
maxid = -1
name = ""
self.stepin()
t = self.next()
while t != -1:
if not self.valueisnull and self.valuefieldid != SID_UNKNOWN:
if self.valuefieldid == SID_NAME:
name = self.stringvalue()
elif self.valuefieldid == SID_VERSION:
version = self.intvalue()
elif self.valuefieldid == SID_MAX_ID:
maxid = self.intvalue()
t = self.next()
self.stepout()
if name == "" or name == SystemSymbols.ION:
return
if version < 1:
version = 1
table = self.findcatalogitem(name)
if maxid < 0:
_assert(table is not None and version == table.version, "Import %s lacks maxid" % name)
maxid = len(table.symnames)
if table is not None:
self.symbols.import_(table, min(maxid, len(table.symnames)))
if len(table.symnames) < maxid:
self.symbols.importunknown(name + "-unknown", maxid - len(table.symnames))
else:
self.symbols.importunknown(name, maxid)
def intvalue(self):
_assert(self.valuetid in [TID_POSINT, TID_NEGINT], "Not an int")
self.preparevalue()
return self.value
def stringvalue(self):
_assert(self.valuetid == TID_STRING, "Not a string")
if self.valueisnull:
return ""
self.preparevalue()
return self.value
def symbolvalue(self):
_assert(self.valuetid == TID_SYMBOL, "Not a symbol")
self.preparevalue()
result = self.symbols.findbyid(self.value)
if result == "":
result = "SYMBOL#%d" % self.value
return result
def lobvalue(self):
_assert(self.valuetid in [TID_CLOB, TID_BLOB], "Not a LOB type: %s" % self.getfieldname())
if self.valueisnull:
return None
result = self.read(self.valuelen)
self.state = ParserState.AfterValue
return result
def decimalvalue(self):
_assert(self.valuetid == TID_DECIMAL, "Not a decimal")
self.preparevalue()
return self.value
def preparevalue(self):
if self.value is None:
self.loadscalarvalue()
def loadscalarvalue(self):
if self.valuetid not in [TID_NULL, TID_BOOLEAN, TID_POSINT, TID_NEGINT,
TID_FLOAT, TID_DECIMAL, TID_TIMESTAMP,
TID_SYMBOL, TID_STRING]:
return
if self.valueisnull:
self.value = None
return
if self.valuetid == TID_STRING:
self.value = self.read(self.valuelen).decode("UTF-8")
elif self.valuetid in (TID_POSINT, TID_NEGINT, TID_SYMBOL):
if self.valuelen == 0:
self.value = 0
else:
_assert(self.valuelen <= 4, "int too long: %d" % self.valuelen)
v = 0
for i in range(self.valuelen - 1, -1, -1):
v = (v | (ord(self.read()) << (i * 8)))
if self.valuetid == TID_NEGINT:
self.value = -v
else:
self.value = v
elif self.valuetid == TID_DECIMAL:
self.value = self.readdecimal()
#else:
# _assert(False, "Unhandled scalar type %d" % self.valuetid)
self.state = ParserState.AfterValue
def clearvalue(self):
self.valuetid = -1
self.value = None
self.valueisnull = False
self.valuefieldid = SID_UNKNOWN
self.annotations = []
def loadannotations(self):
ln = self.readvaruint()
maxpos = self.stream.tell() + ln
while self.stream.tell() < maxpos:
self.annotations.append(self.readvaruint())
self.valuetid = self.readtypeid()
def checkversionmarker(self):
for i in VERSION_MARKER:
_assert(self.read() == i, "Unknown version marker")
self.valuelen = 0
self.valuetid = TID_SYMBOL
self.value = SID_ION_1_0
self.valueisnull = False
self.valuefieldid = SID_UNKNOWN
self.state = ParserState.AfterValue
def findcatalogitem(self, name):
for result in self.catalog:
if result.name == name:
return result
def forceimport(self, symbols):
item = IonCatalogItem("Forced", 1, symbols)
self.symbols.import_(item, len(symbols))
def getfieldname(self):
if self.valuefieldid == SID_UNKNOWN:
return ""
return self.symbols.findbyid(self.valuefieldid)
def getfieldnamesymbol(self):
return SymbolToken(self.getfieldname(), self.valuefieldid)
def gettypename(self):
if len(self.annotations) == 0:
return ""
return self.symbols.findbyid(self.annotations[0])
@staticmethod
def printlob(b):
if b is None:
return "null"
result = ""
for i in b:
result += ("%02x " % ord(i))
if len(result) > 0:
result = result[:-1]
return result
def ionwalk(self, supert, indent, lst):
while self.hasnext():
if supert == TID_STRUCT:
L = self.getfieldname() + ":"
else:
L = ""
t = self.next()
if t in [TID_STRUCT, TID_LIST]:
if L != "":
lst.append(indent + L)
L = self.gettypename()
if L != "":
lst.append(indent + L + "::")
if t == TID_STRUCT:
lst.append(indent + "{")
else:
lst.append(indent + "[")
self.stepin()
self.ionwalk(t, indent + " ", lst)
self.stepout()
if t == TID_STRUCT:
lst.append(indent + "}")
else:
lst.append(indent + "]")
else:
if t == TID_STRING:
L += ('"%s"' % self.stringvalue())
elif t in [TID_CLOB, TID_BLOB]:
L += ("{%s}" % self.printlob(self.lobvalue()))
elif t == TID_POSINT:
L += str(self.intvalue())
elif t == TID_SYMBOL:
tn = self.gettypename()
if tn != "":
tn += "::"
L += tn + self.symbolvalue()
elif t == TID_DECIMAL:
L += str(self.decimalvalue())
else:
L += ("TID %d" % t)
lst.append(indent + L)
def print_(self, lst):
self.reset()
self.ionwalk(-1, "", lst)
SYM_NAMES = [ 'com.amazon.drm.Envelope@1.0',
'com.amazon.drm.EnvelopeMetadata@1.0', 'size', 'page_size',
'encryption_key', 'encryption_transformation',
'encryption_voucher', 'signing_key', 'signing_algorithm',
'signing_voucher', 'com.amazon.drm.EncryptedPage@1.0',
'cipher_text', 'cipher_iv', 'com.amazon.drm.Signature@1.0',
'data', 'com.amazon.drm.EnvelopeIndexTable@1.0', 'length',
'offset', 'algorithm', 'encoded', 'encryption_algorithm',
'hashing_algorithm', 'expires', 'format', 'id',
'lock_parameters', 'strategy', 'com.amazon.drm.Key@1.0',
'com.amazon.drm.KeySet@1.0', 'com.amazon.drm.PIDv3@1.0',
'com.amazon.drm.PlainTextPage@1.0',
'com.amazon.drm.PlainText@1.0', 'com.amazon.drm.PrivateKey@1.0',
'com.amazon.drm.PublicKey@1.0', 'com.amazon.drm.SecretKey@1.0',
'com.amazon.drm.Voucher@1.0', 'public_key', 'private_key',
'com.amazon.drm.KeyPair@1.0', 'com.amazon.drm.ProtectedData@1.0',
'doctype', 'com.amazon.drm.EnvelopeIndexTableOffset@1.0',
'enddoc', 'license_type', 'license', 'watermark', 'key', 'value',
'com.amazon.drm.License@1.0', 'category', 'metadata',
'categorized_metadata', 'com.amazon.drm.CategorizedMetadata@1.0',
'com.amazon.drm.VoucherEnvelope@1.0', 'mac', 'voucher',
'com.amazon.drm.ProtectedData@2.0',
'com.amazon.drm.Envelope@2.0',
'com.amazon.drm.EnvelopeMetadata@2.0',
'com.amazon.drm.EncryptedPage@2.0',
'com.amazon.drm.PlainText@2.0', 'compression_algorithm',
'com.amazon.drm.Compressed@1.0', 'page_index_table',
'com.amazon.drm.VoucherEnvelope@2.0', 'com.amazon.drm.VoucherEnvelope@3.0' ]
def addprottable(ion):
ion.addtocatalog("ProtectedData", 1, SYM_NAMES)
def pkcs7pad(msg, blocklen):
paddinglen = blocklen - len(msg) % blocklen
padding = bchr(paddinglen) * paddinglen
return msg + padding
def pkcs7unpad(msg, blocklen):
_assert(len(msg) % blocklen == 0)
paddinglen = msg[-1]
_assert(paddinglen > 0 and paddinglen <= blocklen, "Incorrect padding - Wrong key")
_assert(msg[-paddinglen:] == bchr(paddinglen) * paddinglen, "Incorrect padding - Wrong key")
return msg[:-paddinglen]
# every VoucherEnvelope version has a corresponding "word" and magic number, used in obfuscating the shared secret
VOUCHER_VERSION_INFOS = {
2: [b'Antidisestablishmentarianism', 5],
3: [b'Floccinaucinihilipilification', 8]
}
# obfuscate shared secret according to the VoucherEnvelope version
def obfuscate(secret, version):
if version == 1: # v1 does not use obfuscation
return secret
params = VOUCHER_VERSION_INFOS[version]
word = params[0]
magic = params[1]
# extend secret so that its length is divisible by the magic number
if len(secret) % magic != 0:
secret = secret + b'\x00' * (magic - len(secret) % magic)
secret = bytearray(secret)
obfuscated = bytearray(len(secret))
wordhash = bytearray(hashlib.sha256(word).digest())
# shuffle secret and xor it with the first half of the word hash
for i in range(0, len(secret)):
index = i // (len(secret) // magic) + magic * (i % (len(secret) // magic))
obfuscated[index] = secret[i] ^ wordhash[index % 16]
return obfuscated
class DrmIonVoucher(object):
envelope = None
version = None
voucher = None
drmkey = None
license_type = "Unknown"
encalgorithm = ""
enctransformation = ""
hashalgorithm = ""
lockparams = None
ciphertext = b""
cipheriv = b""
secretkey = b""
def __init__(self, voucherenv, dsn, secret):
self.dsn, self.secret = dsn, secret
self.lockparams = []
self.envelope = BinaryIonParser(voucherenv)
addprottable(self.envelope)
def decryptvoucher(self):
shared = "PIDv3" + self.encalgorithm + self.enctransformation + self.hashalgorithm
self.lockparams.sort()
for param in self.lockparams:
if param == "ACCOUNT_SECRET":
shared += param + self.secret
elif param == "CLIENT_ID":
shared += param + self.dsn
else:
_assert(False, "Unknown lock parameter: %s" % param)
sharedsecret = obfuscate(shared.encode('ASCII'), self.version)
key = hmac.new(sharedsecret, b"PIDv3", digestmod=hashlib.sha256).digest()
aes = AES.new(key[:32], AES.MODE_CBC, self.cipheriv[:16])
b = aes.decrypt(self.ciphertext)
b = pkcs7unpad(b, 16)
self.drmkey = BinaryIonParser(BytesIO(b))
addprottable(self.drmkey)
_assert(self.drmkey.hasnext() and self.drmkey.next() == TID_LIST and self.drmkey.gettypename() == "com.amazon.drm.KeySet@1.0",
"Expected KeySet, got %s" % self.drmkey.gettypename())
self.drmkey.stepin()
while self.drmkey.hasnext():
self.drmkey.next()
if self.drmkey.gettypename() != "com.amazon.drm.SecretKey@1.0":
continue
self.drmkey.stepin()
while self.drmkey.hasnext():
self.drmkey.next()
if self.drmkey.getfieldname() == "algorithm":
_assert(self.drmkey.stringvalue() == "AES", "Unknown cipher algorithm: %s" % self.drmkey.stringvalue())
elif self.drmkey.getfieldname() == "format":
_assert(self.drmkey.stringvalue() == "RAW", "Unknown key format: %s" % self.drmkey.stringvalue())
elif self.drmkey.getfieldname() == "encoded":
self.secretkey = self.drmkey.lobvalue()
self.drmkey.stepout()
break
self.drmkey.stepout()
def parse(self):
self.envelope.reset()
_assert(self.envelope.hasnext(), "Envelope is empty")
_assert(self.envelope.next() == TID_STRUCT and str.startswith(self.envelope.gettypename(), "com.amazon.drm.VoucherEnvelope@"),
"Unknown type encountered in envelope, expected VoucherEnvelope")
self.version = int(self.envelope.gettypename().split('@')[1][:-2])
self.envelope.stepin()
while self.envelope.hasnext():
self.envelope.next()
field = self.envelope.getfieldname()
if field == "voucher":
self.voucher = BinaryIonParser(BytesIO(self.envelope.lobvalue()))
addprottable(self.voucher)
continue
elif field != "strategy":
continue
_assert(self.envelope.gettypename() == "com.amazon.drm.PIDv3@1.0", "Unknown strategy: %s" % self.envelope.gettypename())
self.envelope.stepin()
while self.envelope.hasnext():
self.envelope.next()
field = self.envelope.getfieldname()
if field == "encryption_algorithm":
self.encalgorithm = self.envelope.stringvalue()
elif field == "encryption_transformation":
self.enctransformation = self.envelope.stringvalue()
elif field == "hashing_algorithm":
self.hashalgorithm = self.envelope.stringvalue()
elif field == "lock_parameters":
self.envelope.stepin()
while self.envelope.hasnext():
_assert(self.envelope.next() == TID_STRING, "Expected string list for lock_parameters")
self.lockparams.append(self.envelope.stringvalue())
self.envelope.stepout()
self.envelope.stepout()
self.parsevoucher()
def parsevoucher(self):
_assert(self.voucher.hasnext(), "Voucher is empty")
_assert(self.voucher.next() == TID_STRUCT and self.voucher.gettypename() == "com.amazon.drm.Voucher@1.0",
"Unknown type, expected Voucher")
self.voucher.stepin()
while self.voucher.hasnext():
self.voucher.next()
if self.voucher.getfieldname() == "cipher_iv":
self.cipheriv = self.voucher.lobvalue()
elif self.voucher.getfieldname() == "cipher_text":
self.ciphertext = self.voucher.lobvalue()
elif self.voucher.getfieldname() == "license":
_assert(self.voucher.gettypename() == "com.amazon.drm.License@1.0",
"Unknown license: %s" % self.voucher.gettypename())
self.voucher.stepin()
while self.voucher.hasnext():
self.voucher.next()
if self.voucher.getfieldname() == "license_type":
self.license_type = self.voucher.stringvalue()
self.voucher.stepout()
def printenvelope(self, lst):
self.envelope.print_(lst)
def printkey(self, lst):
if self.voucher is None:
self.parse()
if self.drmkey is None:
self.decryptvoucher()
self.drmkey.print_(lst)
def printvoucher(self, lst):
if self.voucher is None:
self.parse()
self.voucher.print_(lst)
def getlicensetype(self):
return self.license_type
class DrmIon(object):
ion = None
voucher = None
vouchername = ""
key = b""
onvoucherrequired = None
def __init__(self, ionstream, onvoucherrequired):
self.ion = BinaryIonParser(ionstream)
addprottable(self.ion)
self.onvoucherrequired = onvoucherrequired
def parse(self, outpages):
self.ion.reset()
_assert(self.ion.hasnext(), "DRMION envelope is empty")
_assert(self.ion.next() == TID_SYMBOL and self.ion.gettypename() == "doctype", "Expected doctype symbol")
_assert(self.ion.next() == TID_LIST and self.ion.gettypename() in ["com.amazon.drm.Envelope@1.0", "com.amazon.drm.Envelope@2.0"],
"Unknown type encountered in DRMION envelope, expected Envelope, got %s" % self.ion.gettypename())
while True:
if self.ion.gettypename() == "enddoc":
break
self.ion.stepin()
while self.ion.hasnext():
self.ion.next()
if self.ion.gettypename() in ["com.amazon.drm.EnvelopeMetadata@1.0", "com.amazon.drm.EnvelopeMetadata@2.0"]:
self.ion.stepin()
while self.ion.hasnext():
self.ion.next()
if self.ion.getfieldname() != "encryption_voucher":
continue
if self.vouchername == "":
self.vouchername = self.ion.stringvalue()
self.voucher = self.onvoucherrequired(self.vouchername)
self.key = self.voucher.secretkey
_assert(self.key is not None, "Unable to obtain secret key from voucher")
else:
_assert(self.vouchername == self.ion.stringvalue(),
"Unexpected: Different vouchers required for same file?")
self.ion.stepout()
elif self.ion.gettypename() in ["com.amazon.drm.EncryptedPage@1.0", "com.amazon.drm.EncryptedPage@2.0"]:
decompress = False
ct = None
civ = None
self.ion.stepin()
while self.ion.hasnext():
self.ion.next()
if self.ion.gettypename() == "com.amazon.drm.Compressed@1.0":
decompress = True
if self.ion.getfieldname() == "cipher_text":
ct = self.ion.lobvalue()
elif self.ion.getfieldname() == "cipher_iv":
civ = self.ion.lobvalue()
if ct is not None and civ is not None:
self.processpage(ct, civ, outpages, decompress)
self.ion.stepout()
self.ion.stepout()
if not self.ion.hasnext():
break
self.ion.next()
def print_(self, lst):
self.ion.print_(lst)
def processpage(self, ct, civ, outpages, decompress):
aes = AES.new(self.key[:16], AES.MODE_CBC, civ[:16])
msg = pkcs7unpad(aes.decrypt(ct), 16)
if not decompress:
outpages.write(msg)
return
_assert(msg[0] == 0, "LZMA UseFilter not supported")
if calibre_lzma is not None:
with calibre_lzma.decompress(msg[1:], bufsize=0x1000000) as f:
f.seek(0)
outpages.write(f.read())
return
decomp = lzma.LZMADecompressor(format=lzma.FORMAT_ALONE)
while not decomp.eof:
segment = decomp.decompress(msg[1:])
msg = b"" # Contents were internally buffered after the first call
outpages.write(segment)