Create ion.py
This commit is contained in:
parent
ffd7d41bcd
commit
20e0850001
|
@ -0,0 +1,981 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Pascal implementation by lulzkabulz. Python translation by apprenticenaomi. DeDRM integration by anon.
|
||||
# BinaryIon.pas + DrmIon.pas + IonSymbols.pas
|
||||
|
||||
from __future__ import with_statement
|
||||
|
||||
import collections
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import os.path
|
||||
import struct
|
||||
|
||||
try:
|
||||
from cStringIO import StringIO
|
||||
except ImportError:
|
||||
from StringIO import StringIO
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Util.py3compat import bchr, bord
|
||||
|
||||
try:
|
||||
# lzma library from calibre 2.35.0 or later
|
||||
import lzma.lzma1 as calibre_lzma
|
||||
except:
|
||||
calibre_lzma = None
|
||||
try:
|
||||
import lzma
|
||||
except:
|
||||
# Need pip backports.lzma on Python <3.3
|
||||
from backports import lzma
|
||||
|
||||
|
||||
TID_NULL = 0
|
||||
TID_BOOLEAN = 1
|
||||
TID_POSINT = 2
|
||||
TID_NEGINT = 3
|
||||
TID_FLOAT = 4
|
||||
TID_DECIMAL = 5
|
||||
TID_TIMESTAMP = 6
|
||||
TID_SYMBOL = 7
|
||||
TID_STRING = 8
|
||||
TID_CLOB = 9
|
||||
TID_BLOB = 0xA
|
||||
TID_LIST = 0xB
|
||||
TID_SEXP = 0xC
|
||||
TID_STRUCT = 0xD
|
||||
TID_TYPEDECL = 0xE
|
||||
TID_UNUSED = 0xF
|
||||
|
||||
|
||||
SID_UNKNOWN = -1
|
||||
SID_ION = 1
|
||||
SID_ION_1_0 = 2
|
||||
SID_ION_SYMBOL_TABLE = 3
|
||||
SID_NAME = 4
|
||||
SID_VERSION = 5
|
||||
SID_IMPORTS = 6
|
||||
SID_SYMBOLS = 7
|
||||
SID_MAX_ID = 8
|
||||
SID_ION_SHARED_SYMBOL_TABLE = 9
|
||||
SID_ION_1_0_MAX = 10
|
||||
|
||||
|
||||
LEN_IS_VAR_LEN = 0xE
|
||||
LEN_IS_NULL = 0xF
|
||||
|
||||
|
||||
VERSION_MARKER = b"\x01\x00\xEA"
|
||||
|
||||
|
||||
# asserts must always raise exceptions for proper functioning
|
||||
def _assert(test, msg="Exception"):
|
||||
if not test:
|
||||
raise Exception(msg)
|
||||
|
||||
|
||||
class SystemSymbols(object):
|
||||
ION = '$ion'
|
||||
ION_1_0 = '$ion_1_0'
|
||||
ION_SYMBOL_TABLE = '$ion_symbol_table'
|
||||
NAME = 'name'
|
||||
VERSION = 'version'
|
||||
IMPORTS = 'imports'
|
||||
SYMBOLS = 'symbols'
|
||||
MAX_ID = 'max_id'
|
||||
ION_SHARED_SYMBOL_TABLE = '$ion_shared_symbol_table'
|
||||
|
||||
|
||||
class IonCatalogItem(object):
|
||||
name = ""
|
||||
version = 0
|
||||
symnames = []
|
||||
|
||||
def __init__(self, name, version, symnames):
|
||||
self.name = name
|
||||
self.version = version
|
||||
self.symnames = symnames
|
||||
|
||||
|
||||
class SymbolToken(object):
|
||||
text = ""
|
||||
sid = 0
|
||||
|
||||
def __init__(self, text, sid):
|
||||
if text == "" and sid == 0:
|
||||
raise ValueError("Symbol token must have Text or SID")
|
||||
|
||||
self.text = text
|
||||
self.sid = sid
|
||||
|
||||
|
||||
class SymbolTable(object):
|
||||
table = None
|
||||
|
||||
def __init__(self):
|
||||
self.table = [None] * SID_ION_1_0_MAX
|
||||
self.table[SID_ION] = SystemSymbols.ION
|
||||
self.table[SID_ION_1_0] = SystemSymbols.ION_1_0
|
||||
self.table[SID_ION_SYMBOL_TABLE] = SystemSymbols.ION_SYMBOL_TABLE
|
||||
self.table[SID_NAME] = SystemSymbols.NAME
|
||||
self.table[SID_VERSION] = SystemSymbols.VERSION
|
||||
self.table[SID_IMPORTS] = SystemSymbols.IMPORTS
|
||||
self.table[SID_SYMBOLS] = SystemSymbols.SYMBOLS
|
||||
self.table[SID_MAX_ID] = SystemSymbols.MAX_ID
|
||||
self.table[SID_ION_SHARED_SYMBOL_TABLE] = SystemSymbols.ION_SHARED_SYMBOL_TABLE
|
||||
|
||||
def findbyid(self, sid):
|
||||
if sid < 1:
|
||||
raise ValueError("Invalid symbol id")
|
||||
|
||||
if sid < len(self.table):
|
||||
return self.table[sid]
|
||||
else:
|
||||
return ""
|
||||
|
||||
def import_(self, table, maxid):
|
||||
for i in range(maxid):
|
||||
self.table.append(table.symnames[i])
|
||||
|
||||
def importunknown(self, name, maxid):
|
||||
for i in range(maxid):
|
||||
self.table.append("%s#%d" % (name, i + 1))
|
||||
|
||||
|
||||
class ParserState:
|
||||
Invalid,BeforeField,BeforeTID,BeforeValue,AfterValue,EOF = 1,2,3,4,5,6
|
||||
|
||||
ContainerRec = collections.namedtuple("ContainerRec", "nextpos, tid, remaining")
|
||||
|
||||
|
||||
class BinaryIonParser(object):
|
||||
eof = False
|
||||
state = None
|
||||
localremaining = 0
|
||||
needhasnext = False
|
||||
isinstruct = False
|
||||
valuetid = 0
|
||||
valuefieldid = 0
|
||||
parenttid = 0
|
||||
valuelen = 0
|
||||
valueisnull = False
|
||||
valueistrue = False
|
||||
value = None
|
||||
didimports = False
|
||||
|
||||
def __init__(self, stream):
|
||||
self.annotations = []
|
||||
self.catalog = []
|
||||
|
||||
self.stream = stream
|
||||
self.initpos = stream.tell()
|
||||
self.reset()
|
||||
self.symbols = SymbolTable()
|
||||
|
||||
def reset(self):
|
||||
self.state = ParserState.BeforeTID
|
||||
self.needhasnext = True
|
||||
self.localremaining = -1
|
||||
self.eof = False
|
||||
self.isinstruct = False
|
||||
self.containerstack = []
|
||||
self.stream.seek(self.initpos)
|
||||
|
||||
def addtocatalog(self, name, version, symbols):
|
||||
self.catalog.append(IonCatalogItem(name, version, symbols))
|
||||
|
||||
def hasnext(self):
|
||||
while self.needhasnext and not self.eof:
|
||||
self.hasnextraw()
|
||||
if len(self.containerstack) == 0 and not self.valueisnull:
|
||||
if self.valuetid == TID_SYMBOL:
|
||||
if self.value == SID_ION_1_0:
|
||||
self.needhasnext = True
|
||||
elif self.valuetid == TID_STRUCT:
|
||||
for a in self.annotations:
|
||||
if a == SID_ION_SYMBOL_TABLE:
|
||||
self.parsesymboltable()
|
||||
self.needhasnext = True
|
||||
break
|
||||
return not self.eof
|
||||
|
||||
def hasnextraw(self):
|
||||
self.clearvalue()
|
||||
while self.valuetid == -1 and not self.eof:
|
||||
self.needhasnext = False
|
||||
if self.state == ParserState.BeforeField:
|
||||
_assert(self.valuefieldid == SID_UNKNOWN)
|
||||
|
||||
self.valuefieldid = self.readfieldid()
|
||||
if self.valuefieldid != SID_UNKNOWN:
|
||||
self.state = ParserState.BeforeTID
|
||||
else:
|
||||
self.eof = True
|
||||
|
||||
elif self.state == ParserState.BeforeTID:
|
||||
self.state = ParserState.BeforeValue
|
||||
self.valuetid = self.readtypeid()
|
||||
if self.valuetid == -1:
|
||||
self.state = ParserState.EOF
|
||||
self.eof = True
|
||||
break
|
||||
|
||||
if self.valuetid == TID_TYPEDECL:
|
||||
if self.valuelen == 0:
|
||||
self.checkversionmarker()
|
||||
else:
|
||||
self.loadannotations()
|
||||
|
||||
elif self.state == ParserState.BeforeValue:
|
||||
self.skip(self.valuelen)
|
||||
self.state = ParserState.AfterValue
|
||||
|
||||
elif self.state == ParserState.AfterValue:
|
||||
if self.isinstruct:
|
||||
self.state = ParserState.BeforeField
|
||||
else:
|
||||
self.state = ParserState.BeforeTID
|
||||
|
||||
else:
|
||||
_assert(self.state == ParserState.EOF)
|
||||
|
||||
def next(self):
|
||||
if self.hasnext():
|
||||
self.needhasnext = True
|
||||
return self.valuetid
|
||||
else:
|
||||
return -1
|
||||
|
||||
def push(self, typeid, nextposition, nextremaining):
|
||||
self.containerstack.append(ContainerRec(nextpos=nextposition, tid=typeid, remaining=nextremaining))
|
||||
|
||||
def stepin(self):
|
||||
_assert(self.valuetid in [TID_STRUCT, TID_LIST, TID_SEXP] and not self.eof,
|
||||
"valuetid=%s eof=%s" % (self.valuetid, self.eof))
|
||||
_assert((not self.valueisnull or self.state == ParserState.AfterValue) and
|
||||
(self.valueisnull or self.state == ParserState.BeforeValue))
|
||||
|
||||
nextrem = self.localremaining
|
||||
if nextrem != -1:
|
||||
nextrem -= self.valuelen
|
||||
if nextrem < 0:
|
||||
nextrem = 0
|
||||
self.push(self.parenttid, self.stream.tell() + self.valuelen, nextrem)
|
||||
|
||||
self.isinstruct = (self.valuetid == TID_STRUCT)
|
||||
if self.isinstruct:
|
||||
self.state = ParserState.BeforeField
|
||||
else:
|
||||
self.state = ParserState.BeforeTID
|
||||
|
||||
self.localremaining = self.valuelen
|
||||
self.parenttid = self.valuetid
|
||||
self.clearvalue()
|
||||
self.needhasnext = True
|
||||
|
||||
def stepout(self):
|
||||
rec = self.containerstack.pop()
|
||||
|
||||
self.eof = False
|
||||
self.parenttid = rec.tid
|
||||
if self.parenttid == TID_STRUCT:
|
||||
self.isinstruct = True
|
||||
self.state = ParserState.BeforeField
|
||||
else:
|
||||
self.isinstruct = False
|
||||
self.state = ParserState.BeforeTID
|
||||
self.needhasnext = True
|
||||
|
||||
self.clearvalue()
|
||||
curpos = self.stream.tell()
|
||||
if rec.nextpos > curpos:
|
||||
self.skip(rec.nextpos - curpos)
|
||||
else:
|
||||
_assert(rec.nextpos == curpos)
|
||||
|
||||
self.localremaining = rec.remaining
|
||||
|
||||
def read(self, count=1):
|
||||
if self.localremaining != -1:
|
||||
self.localremaining -= count
|
||||
_assert(self.localremaining >= 0)
|
||||
|
||||
result = self.stream.read(count)
|
||||
if len(result) == 0:
|
||||
raise EOFError()
|
||||
return result
|
||||
|
||||
def readfieldid(self):
|
||||
if self.localremaining != -1 and self.localremaining < 1:
|
||||
return -1
|
||||
|
||||
try:
|
||||
return self.readvaruint()
|
||||
except EOFError:
|
||||
return -1
|
||||
|
||||
def readtypeid(self):
|
||||
if self.localremaining != -1:
|
||||
if self.localremaining < 1:
|
||||
return -1
|
||||
self.localremaining -= 1
|
||||
|
||||
b = self.stream.read(1)
|
||||
if len(b) < 1:
|
||||
return -1
|
||||
b = bord(b)
|
||||
result = b >> 4
|
||||
ln = b & 0xF
|
||||
|
||||
if ln == LEN_IS_VAR_LEN:
|
||||
ln = self.readvaruint()
|
||||
elif ln == LEN_IS_NULL:
|
||||
ln = 0
|
||||
self.state = ParserState.AfterValue
|
||||
elif result == TID_NULL:
|
||||
# Must have LEN_IS_NULL
|
||||
_assert(False)
|
||||
elif result == TID_BOOLEAN:
|
||||
_assert(ln <= 1)
|
||||
self.valueistrue = (ln == 1)
|
||||
ln = 0
|
||||
self.state = ParserState.AfterValue
|
||||
elif result == TID_STRUCT:
|
||||
if ln == 1:
|
||||
ln = self.readvaruint()
|
||||
|
||||
self.valuelen = ln
|
||||
return result
|
||||
|
||||
def readvarint(self):
|
||||
b = bord(self.read())
|
||||
negative = ((b & 0x40) != 0)
|
||||
result = (b & 0x3F)
|
||||
|
||||
i = 0
|
||||
while (b & 0x80) == 0 and i < 4:
|
||||
b = bord(self.read())
|
||||
result = (result << 7) | (b & 0x7F)
|
||||
i += 1
|
||||
|
||||
_assert(i < 4 or (b & 0x80) != 0, "int overflow")
|
||||
|
||||
if negative:
|
||||
return -result
|
||||
return result
|
||||
|
||||
def readvaruint(self):
|
||||
b = bord(self.read())
|
||||
result = (b & 0x7F)
|
||||
|
||||
i = 0
|
||||
while (b & 0x80) == 0 and i < 4:
|
||||
b = bord(self.read())
|
||||
result = (result << 7) | (b & 0x7F)
|
||||
i += 1
|
||||
|
||||
_assert(i < 4 or (b & 0x80) != 0, "int overflow")
|
||||
|
||||
return result
|
||||
|
||||
def readdecimal(self):
|
||||
if self.valuelen == 0:
|
||||
return 0.
|
||||
|
||||
rem = self.localremaining - self.valuelen
|
||||
self.localremaining = self.valuelen
|
||||
exponent = self.readvarint()
|
||||
|
||||
_assert(self.localremaining > 0, "Only exponent in ReadDecimal")
|
||||
_assert(self.localremaining <= 8, "Decimal overflow")
|
||||
|
||||
signed = False
|
||||
b = [bord(x) for x in self.read(self.localremaining)]
|
||||
if (b[0] & 0x80) != 0:
|
||||
b[0] = b[0] & 0x7F
|
||||
signed = True
|
||||
|
||||
# Convert variably sized network order integer into 64-bit little endian
|
||||
j = 0
|
||||
vb = [0] * 8
|
||||
for i in range(len(b), -1, -1):
|
||||
vb[i] = b[j]
|
||||
j += 1
|
||||
|
||||
v = struct.unpack("<Q", b"".join(bchr(x) for x in vb))[0]
|
||||
|
||||
result = v * (10 ** exponent)
|
||||
if signed:
|
||||
result = -result
|
||||
|
||||
self.localremaining = rem
|
||||
return result
|
||||
|
||||
def skip(self, count):
|
||||
if self.localremaining != -1:
|
||||
self.localremaining -= count
|
||||
if self.localremaining < 0:
|
||||
raise EOFError()
|
||||
|
||||
self.stream.seek(count, os.SEEK_CUR)
|
||||
|
||||
def parsesymboltable(self):
|
||||
self.next() # shouldn't do anything?
|
||||
|
||||
_assert(self.valuetid == TID_STRUCT)
|
||||
|
||||
if self.didimports:
|
||||
return
|
||||
|
||||
self.stepin()
|
||||
|
||||
fieldtype = self.next()
|
||||
while fieldtype != -1:
|
||||
if not self.valueisnull:
|
||||
_assert(self.valuefieldid == SID_IMPORTS, "Unsupported symbol table field id")
|
||||
|
||||
if fieldtype == TID_LIST:
|
||||
self.gatherimports()
|
||||
|
||||
fieldtype = self.next()
|
||||
|
||||
self.stepout()
|
||||
self.didimports = True
|
||||
|
||||
def gatherimports(self):
|
||||
self.stepin()
|
||||
|
||||
t = self.next()
|
||||
while t != -1:
|
||||
if not self.valueisnull and t == TID_STRUCT:
|
||||
self.readimport()
|
||||
|
||||
t = self.next()
|
||||
|
||||
self.stepout()
|
||||
|
||||
def readimport(self):
|
||||
version = -1
|
||||
maxid = -1
|
||||
name = ""
|
||||
|
||||
self.stepin()
|
||||
|
||||
t = self.next()
|
||||
while t != -1:
|
||||
if not self.valueisnull and self.valuefieldid != SID_UNKNOWN:
|
||||
if self.valuefieldid == SID_NAME:
|
||||
name = self.stringvalue()
|
||||
elif self.valuefieldid == SID_VERSION:
|
||||
version = self.intvalue()
|
||||
elif self.valuefieldid == SID_MAX_ID:
|
||||
maxid = self.intvalue()
|
||||
|
||||
t = self.next()
|
||||
|
||||
self.stepout()
|
||||
|
||||
if name == "" or name == SystemSymbols.ION:
|
||||
return
|
||||
|
||||
if version < 1:
|
||||
version = 1
|
||||
|
||||
table = self.findcatalogitem(name)
|
||||
if maxid < 0:
|
||||
_assert(table is not None and version == table.version, "Import %s lacks maxid" % name)
|
||||
maxid = len(table.symnames)
|
||||
|
||||
if table is not None:
|
||||
self.symbols.import_(table, min(maxid, len(table.symnames)))
|
||||
else:
|
||||
self.symbols.importunknown(name, maxid)
|
||||
|
||||
def intvalue(self):
|
||||
_assert(self.valuetid in [TID_POSINT, TID_NEGINT], "Not an int")
|
||||
|
||||
self.preparevalue()
|
||||
return self.value
|
||||
|
||||
def stringvalue(self):
|
||||
_assert(self.valuetid == TID_STRING, "Not a string")
|
||||
|
||||
if self.valueisnull:
|
||||
return ""
|
||||
|
||||
self.preparevalue()
|
||||
return self.value
|
||||
|
||||
def symbolvalue(self):
|
||||
_assert(self.valuetid == TID_SYMBOL, "Not a symbol")
|
||||
|
||||
self.preparevalue()
|
||||
result = self.symbols.findbyid(self.value)
|
||||
if result == "":
|
||||
result = "SYMBOL#%d" % self.value
|
||||
return result
|
||||
|
||||
def lobvalue(self):
|
||||
_assert(self.valuetid in [TID_CLOB, TID_BLOB], "Not a LOB type: %s" % self.getfieldname())
|
||||
|
||||
if self.valueisnull:
|
||||
return None
|
||||
|
||||
result = self.read(self.valuelen)
|
||||
self.state = ParserState.AfterValue
|
||||
return result
|
||||
|
||||
def decimalvalue(self):
|
||||
_assert(self.valuetid == TID_DECIMAL, "Not a decimal")
|
||||
|
||||
self.preparevalue()
|
||||
return self.value
|
||||
|
||||
def preparevalue(self):
|
||||
if self.value is None:
|
||||
self.loadscalarvalue()
|
||||
|
||||
def loadscalarvalue(self):
|
||||
if self.valuetid not in [TID_NULL, TID_BOOLEAN, TID_POSINT, TID_NEGINT,
|
||||
TID_FLOAT, TID_DECIMAL, TID_TIMESTAMP,
|
||||
TID_SYMBOL, TID_STRING]:
|
||||
return
|
||||
|
||||
if self.valueisnull:
|
||||
self.value = None
|
||||
return
|
||||
|
||||
if self.valuetid == TID_STRING:
|
||||
self.value = self.read(self.valuelen).decode("UTF-8")
|
||||
|
||||
elif self.valuetid in (TID_POSINT, TID_NEGINT, TID_SYMBOL):
|
||||
if self.valuelen == 0:
|
||||
self.value = 0
|
||||
else:
|
||||
_assert(self.valuelen <= 4, "int too long: %d" % self.valuelen)
|
||||
v = 0
|
||||
for i in range(self.valuelen - 1, -1, -1):
|
||||
v = (v | (bord(self.read()) << (i * 8)))
|
||||
|
||||
if self.valuetid == TID_NEGINT:
|
||||
self.value = -v
|
||||
else:
|
||||
self.value = v
|
||||
|
||||
elif self.valuetid == TID_DECIMAL:
|
||||
self.value = self.readdecimal()
|
||||
|
||||
#else:
|
||||
# _assert(False, "Unhandled scalar type %d" % self.valuetid)
|
||||
|
||||
self.state = ParserState.AfterValue
|
||||
|
||||
def clearvalue(self):
|
||||
self.valuetid = -1
|
||||
self.value = None
|
||||
self.valueisnull = False
|
||||
self.valuefieldid = SID_UNKNOWN
|
||||
self.annotations = []
|
||||
|
||||
def loadannotations(self):
|
||||
ln = self.readvaruint()
|
||||
maxpos = self.stream.tell() + ln
|
||||
while self.stream.tell() < maxpos:
|
||||
self.annotations.append(self.readvaruint())
|
||||
self.valuetid = self.readtypeid()
|
||||
|
||||
def checkversionmarker(self):
|
||||
for i in VERSION_MARKER:
|
||||
_assert(self.read() == i, "Unknown version marker")
|
||||
|
||||
self.valuelen = 0
|
||||
self.valuetid = TID_SYMBOL
|
||||
self.value = SID_ION_1_0
|
||||
self.valueisnull = False
|
||||
self.valuefieldid = SID_UNKNOWN
|
||||
self.state = ParserState.AfterValue
|
||||
|
||||
def findcatalogitem(self, name):
|
||||
for result in self.catalog:
|
||||
if result.name == name:
|
||||
return result
|
||||
|
||||
def forceimport(self, symbols):
|
||||
item = IonCatalogItem("Forced", 1, symbols)
|
||||
self.symbols.import_(item, len(symbols))
|
||||
|
||||
def getfieldname(self):
|
||||
if self.valuefieldid == SID_UNKNOWN:
|
||||
return ""
|
||||
return self.symbols.findbyid(self.valuefieldid)
|
||||
|
||||
def getfieldnamesymbol(self):
|
||||
return SymbolToken(self.getfieldname(), self.valuefieldid)
|
||||
|
||||
def gettypename(self):
|
||||
if len(self.annotations) == 0:
|
||||
return ""
|
||||
|
||||
return self.symbols.findbyid(self.annotations[0])
|
||||
|
||||
@staticmethod
|
||||
def printlob(b):
|
||||
if b is None:
|
||||
return "null"
|
||||
|
||||
result = ""
|
||||
for i in b:
|
||||
result += ("%02x " % bord(i))
|
||||
|
||||
if len(result) > 0:
|
||||
result = result[:-1]
|
||||
return result
|
||||
|
||||
def ionwalk(self, supert, indent, lst):
|
||||
while self.hasnext():
|
||||
if supert == TID_STRUCT:
|
||||
L = self.getfieldname() + ":"
|
||||
else:
|
||||
L = ""
|
||||
|
||||
t = self.next()
|
||||
if t in [TID_STRUCT, TID_LIST]:
|
||||
if L != "":
|
||||
lst.append(indent + L)
|
||||
L = self.gettypename()
|
||||
if L != "":
|
||||
lst.append(indent + L + "::")
|
||||
if t == TID_STRUCT:
|
||||
lst.append(indent + "{")
|
||||
else:
|
||||
lst.append(indent + "[")
|
||||
|
||||
self.stepin()
|
||||
self.ionwalk(t, indent + " ", lst)
|
||||
self.stepout()
|
||||
|
||||
if t == TID_STRUCT:
|
||||
lst.append(indent + "}")
|
||||
else:
|
||||
lst.append(indent + "]")
|
||||
|
||||
else:
|
||||
if t == TID_STRING:
|
||||
L += ('"%s"' % self.stringvalue())
|
||||
elif t in [TID_CLOB, TID_BLOB]:
|
||||
L += ("{%s}" % self.printlob(self.lobvalue()))
|
||||
elif t == TID_POSINT:
|
||||
L += str(self.intvalue())
|
||||
elif t == TID_SYMBOL:
|
||||
tn = self.gettypename()
|
||||
if tn != "":
|
||||
tn += "::"
|
||||
L += tn + self.symbolvalue()
|
||||
elif t == TID_DECIMAL:
|
||||
L += str(self.decimalvalue())
|
||||
else:
|
||||
L += ("TID %d" % t)
|
||||
lst.append(indent + L)
|
||||
|
||||
def print_(self, lst):
|
||||
self.reset()
|
||||
self.ionwalk(-1, "", lst)
|
||||
|
||||
|
||||
SYM_NAMES = [ 'com.amazon.drm.Envelope@1.0',
|
||||
'com.amazon.drm.EnvelopeMetadata@1.0', 'size', 'page_size',
|
||||
'encryption_key', 'encryption_transformation',
|
||||
'encryption_voucher', 'signing_key', 'signing_algorithm',
|
||||
'signing_voucher', 'com.amazon.drm.EncryptedPage@1.0',
|
||||
'cipher_text', 'cipher_iv', 'com.amazon.drm.Signature@1.0',
|
||||
'data', 'com.amazon.drm.EnvelopeIndexTable@1.0', 'length',
|
||||
'offset', 'algorithm', 'encoded', 'encryption_algorithm',
|
||||
'hashing_algorithm', 'expires', 'format', 'id',
|
||||
'lock_parameters', 'strategy', 'com.amazon.drm.Key@1.0',
|
||||
'com.amazon.drm.KeySet@1.0', 'com.amazon.drm.PIDv3@1.0',
|
||||
'com.amazon.drm.PlainTextPage@1.0',
|
||||
'com.amazon.drm.PlainText@1.0', 'com.amazon.drm.PrivateKey@1.0',
|
||||
'com.amazon.drm.PublicKey@1.0', 'com.amazon.drm.SecretKey@1.0',
|
||||
'com.amazon.drm.Voucher@1.0', 'public_key', 'private_key',
|
||||
'com.amazon.drm.KeyPair@1.0', 'com.amazon.drm.ProtectedData@1.0',
|
||||
'doctype', 'com.amazon.drm.EnvelopeIndexTableOffset@1.0',
|
||||
'enddoc', 'license_type', 'license', 'watermark', 'key', 'value',
|
||||
'com.amazon.drm.License@1.0', 'category', 'metadata',
|
||||
'categorized_metadata', 'com.amazon.drm.CategorizedMetadata@1.0',
|
||||
'com.amazon.drm.VoucherEnvelope@1.0', 'mac', 'voucher',
|
||||
'com.amazon.drm.ProtectedData@2.0',
|
||||
'com.amazon.drm.Envelope@2.0',
|
||||
'com.amazon.drm.EnvelopeMetadata@2.0',
|
||||
'com.amazon.drm.EncryptedPage@2.0',
|
||||
'com.amazon.drm.PlainText@2.0', 'compression_algorithm',
|
||||
'com.amazon.drm.Compressed@1.0', 'priority', 'refines']
|
||||
|
||||
def addprottable(ion):
|
||||
ion.addtocatalog("ProtectedData", 1, SYM_NAMES)
|
||||
|
||||
|
||||
def pkcs7pad(msg, blocklen):
|
||||
paddinglen = blocklen - len(msg) % blocklen
|
||||
padding = bchr(paddinglen) * paddinglen
|
||||
return msg + padding
|
||||
|
||||
|
||||
def pkcs7unpad(msg, blocklen):
|
||||
_assert(len(msg) % blocklen == 0)
|
||||
|
||||
paddinglen = bord(msg[-1])
|
||||
_assert(paddinglen > 0 and paddinglen <= blocklen, "Incorrect padding - Wrong key")
|
||||
_assert(msg[-paddinglen:] == bchr(paddinglen) * paddinglen, "Incorrect padding - Wrong key")
|
||||
|
||||
return msg[:-paddinglen]
|
||||
|
||||
|
||||
class DrmIonVoucher(object):
|
||||
envelope = None
|
||||
voucher = None
|
||||
drmkey = None
|
||||
license_type = "Unknown"
|
||||
|
||||
encalgorithm = ""
|
||||
enctransformation = ""
|
||||
hashalgorithm = ""
|
||||
|
||||
lockparams = None
|
||||
|
||||
ciphertext = b""
|
||||
cipheriv = b""
|
||||
secretkey = b""
|
||||
|
||||
def __init__(self, voucherenv, dsn, secret):
|
||||
self.dsn,self.secret = dsn,secret
|
||||
|
||||
self.lockparams = []
|
||||
|
||||
self.envelope = BinaryIonParser(voucherenv)
|
||||
addprottable(self.envelope)
|
||||
|
||||
def decryptvoucher(self):
|
||||
shared = "PIDv3" + self.encalgorithm + self.enctransformation + self.hashalgorithm
|
||||
|
||||
self.lockparams.sort()
|
||||
for param in self.lockparams:
|
||||
if param == "ACCOUNT_SECRET":
|
||||
shared += param + self.secret
|
||||
elif param == "CLIENT_ID":
|
||||
shared += param + self.dsn
|
||||
else:
|
||||
_assert(False, "Unknown lock parameter: %s" % param)
|
||||
|
||||
sharedsecret = shared.encode("UTF-8")
|
||||
|
||||
key = hmac.new(sharedsecret, sharedsecret[:5], digestmod=hashlib.sha256).digest()
|
||||
aes = AES.new(key[:32], AES.MODE_CBC, self.cipheriv[:16])
|
||||
b = aes.decrypt(self.ciphertext)
|
||||
b = pkcs7unpad(b, 16)
|
||||
|
||||
self.drmkey = BinaryIonParser(StringIO(b))
|
||||
addprottable(self.drmkey)
|
||||
|
||||
_assert(self.drmkey.hasnext() and self.drmkey.next() == TID_LIST and self.drmkey.gettypename() == "com.amazon.drm.KeySet@1.0",
|
||||
"Expected KeySet, got %s" % self.drmkey.gettypename())
|
||||
|
||||
self.drmkey.stepin()
|
||||
while self.drmkey.hasnext():
|
||||
self.drmkey.next()
|
||||
if self.drmkey.gettypename() != "com.amazon.drm.SecretKey@1.0":
|
||||
continue
|
||||
|
||||
self.drmkey.stepin()
|
||||
while self.drmkey.hasnext():
|
||||
self.drmkey.next()
|
||||
if self.drmkey.getfieldname() == "algorithm":
|
||||
_assert(self.drmkey.stringvalue() == "AES", "Unknown cipher algorithm: %s" % self.drmkey.stringvalue())
|
||||
elif self.drmkey.getfieldname() == "format":
|
||||
_assert(self.drmkey.stringvalue() == "RAW", "Unknown key format: %s" % self.drmkey.stringvalue())
|
||||
elif self.drmkey.getfieldname() == "encoded":
|
||||
self.secretkey = self.drmkey.lobvalue()
|
||||
|
||||
self.drmkey.stepout()
|
||||
break
|
||||
|
||||
self.drmkey.stepout()
|
||||
|
||||
def parse(self):
|
||||
self.envelope.reset()
|
||||
_assert(self.envelope.hasnext(), "Envelope is empty")
|
||||
_assert(self.envelope.next() == TID_STRUCT and self.envelope.gettypename() == "com.amazon.drm.VoucherEnvelope@1.0",
|
||||
"Unknown type encountered in envelope, expected VoucherEnvelope")
|
||||
|
||||
self.envelope.stepin()
|
||||
while self.envelope.hasnext():
|
||||
self.envelope.next()
|
||||
field = self.envelope.getfieldname()
|
||||
if field == "voucher":
|
||||
self.voucher = BinaryIonParser(StringIO(self.envelope.lobvalue()))
|
||||
addprottable(self.voucher)
|
||||
continue
|
||||
elif field != "strategy":
|
||||
continue
|
||||
|
||||
_assert(self.envelope.gettypename() == "com.amazon.drm.PIDv3@1.0", "Unknown strategy: %s" % self.envelope.gettypename())
|
||||
|
||||
self.envelope.stepin()
|
||||
while self.envelope.hasnext():
|
||||
self.envelope.next()
|
||||
field = self.envelope.getfieldname()
|
||||
if field == "encryption_algorithm":
|
||||
self.encalgorithm = self.envelope.stringvalue()
|
||||
elif field == "encryption_transformation":
|
||||
self.enctransformation = self.envelope.stringvalue()
|
||||
elif field == "hashing_algorithm":
|
||||
self.hashalgorithm = self.envelope.stringvalue()
|
||||
elif field == "lock_parameters":
|
||||
self.envelope.stepin()
|
||||
while self.envelope.hasnext():
|
||||
_assert(self.envelope.next() == TID_STRING, "Expected string list for lock_parameters")
|
||||
self.lockparams.append(self.envelope.stringvalue())
|
||||
self.envelope.stepout()
|
||||
|
||||
self.envelope.stepout()
|
||||
|
||||
self.parsevoucher()
|
||||
|
||||
def parsevoucher(self):
|
||||
_assert(self.voucher.hasnext(), "Voucher is empty")
|
||||
_assert(self.voucher.next() == TID_STRUCT and self.voucher.gettypename() == "com.amazon.drm.Voucher@1.0",
|
||||
"Unknown type, expected Voucher")
|
||||
|
||||
self.voucher.stepin()
|
||||
while self.voucher.hasnext():
|
||||
self.voucher.next()
|
||||
|
||||
if self.voucher.getfieldname() == "cipher_iv":
|
||||
self.cipheriv = self.voucher.lobvalue()
|
||||
elif self.voucher.getfieldname() == "cipher_text":
|
||||
self.ciphertext = self.voucher.lobvalue()
|
||||
elif self.voucher.getfieldname() == "license":
|
||||
_assert(self.voucher.gettypename() == "com.amazon.drm.License@1.0",
|
||||
"Unknown license: %s" % self.voucher.gettypename())
|
||||
self.voucher.stepin()
|
||||
while self.voucher.hasnext():
|
||||
self.voucher.next()
|
||||
if self.voucher.getfieldname() == "license_type":
|
||||
self.license_type = self.voucher.stringvalue()
|
||||
self.voucher.stepout()
|
||||
|
||||
def printenvelope(self, lst):
|
||||
self.envelope.print_(lst)
|
||||
|
||||
def printkey(self, lst):
|
||||
if self.voucher is None:
|
||||
self.parse()
|
||||
if self.drmkey is None:
|
||||
self.decryptvoucher()
|
||||
|
||||
self.drmkey.print_(lst)
|
||||
|
||||
def printvoucher(self, lst):
|
||||
if self.voucher is None:
|
||||
self.parse()
|
||||
|
||||
self.voucher.print_(lst)
|
||||
|
||||
def getlicensetype(self):
|
||||
return self.license_type
|
||||
|
||||
|
||||
class DrmIon(object):
|
||||
ion = None
|
||||
voucher = None
|
||||
vouchername = ""
|
||||
key = b""
|
||||
onvoucherrequired = None
|
||||
|
||||
def __init__(self, ionstream, onvoucherrequired):
|
||||
self.ion = BinaryIonParser(ionstream)
|
||||
addprottable(self.ion)
|
||||
self.onvoucherrequired = onvoucherrequired
|
||||
|
||||
def parse(self, outpages):
|
||||
self.ion.reset()
|
||||
|
||||
_assert(self.ion.hasnext(), "DRMION envelope is empty")
|
||||
_assert(self.ion.next() == TID_SYMBOL and self.ion.gettypename() == "doctype", "Expected doctype symbol")
|
||||
_assert(self.ion.next() == TID_LIST and self.ion.gettypename() in ["com.amazon.drm.Envelope@1.0", "com.amazon.drm.Envelope@2.0"],
|
||||
"Unknown type encountered in DRMION envelope, expected Envelope, got %s" % self.ion.gettypename())
|
||||
|
||||
while True:
|
||||
if self.ion.gettypename() == "enddoc":
|
||||
break
|
||||
|
||||
self.ion.stepin()
|
||||
while self.ion.hasnext():
|
||||
self.ion.next()
|
||||
|
||||
if self.ion.gettypename() in ["com.amazon.drm.EnvelopeMetadata@1.0", "com.amazon.drm.EnvelopeMetadata@2.0"]:
|
||||
self.ion.stepin()
|
||||
while self.ion.hasnext():
|
||||
self.ion.next()
|
||||
if self.ion.getfieldname() != "encryption_voucher":
|
||||
continue
|
||||
|
||||
if self.vouchername == "":
|
||||
self.vouchername = self.ion.stringvalue()
|
||||
self.voucher = self.onvoucherrequired(self.vouchername)
|
||||
self.key = self.voucher.secretkey
|
||||
_assert(self.key is not None, "Unable to obtain secret key from voucher")
|
||||
else:
|
||||
_assert(self.vouchername == self.ion.stringvalue(),
|
||||
"Unexpected: Different vouchers required for same file?")
|
||||
|
||||
self.ion.stepout()
|
||||
|
||||
elif self.ion.gettypename() in ["com.amazon.drm.EncryptedPage@1.0", "com.amazon.drm.EncryptedPage@2.0"]:
|
||||
decompress = False
|
||||
ct = None
|
||||
civ = None
|
||||
self.ion.stepin()
|
||||
while self.ion.hasnext():
|
||||
self.ion.next()
|
||||
if self.ion.gettypename() == "com.amazon.drm.Compressed@1.0":
|
||||
decompress = True
|
||||
if self.ion.getfieldname() == "cipher_text":
|
||||
ct = self.ion.lobvalue()
|
||||
elif self.ion.getfieldname() == "cipher_iv":
|
||||
civ = self.ion.lobvalue()
|
||||
|
||||
if ct is not None and civ is not None:
|
||||
self.processpage(ct, civ, outpages, decompress)
|
||||
self.ion.stepout()
|
||||
|
||||
self.ion.stepout()
|
||||
if not self.ion.hasnext():
|
||||
break
|
||||
self.ion.next()
|
||||
|
||||
def print_(self, lst):
|
||||
self.ion.print_(lst)
|
||||
|
||||
def processpage(self, ct, civ, outpages, decompress):
|
||||
aes = AES.new(self.key[:16], AES.MODE_CBC, civ[:16])
|
||||
msg = pkcs7unpad(aes.decrypt(ct), 16)
|
||||
|
||||
if not decompress:
|
||||
outpages.write(msg)
|
||||
return
|
||||
|
||||
_assert(msg[0] == b"\x00", "LZMA UseFilter not supported")
|
||||
|
||||
if calibre_lzma is not None:
|
||||
with calibre_lzma.decompress(msg[1:], bufsize=0x1000000) as f:
|
||||
f.seek(0)
|
||||
outpages.write(f.read())
|
||||
return
|
||||
|
||||
decomp = lzma.LZMADecompressor(format=lzma.FORMAT_ALONE)
|
||||
while not decomp.eof:
|
||||
segment = decomp.decompress(msg[1:])
|
||||
msg = b"" # Contents were internally buffered after the first call
|
||||
outpages.write(segment)
|
Loading…
Reference in New Issue