diff --git a/nl80211.py b/nl80211.py index 6a32d74..3487421 100755 --- a/nl80211.py +++ b/nl80211.py @@ -7,26 +7,37 @@ import struct class Iterator(object): - def __init__(self, data): + def __init__(self, data, offset=0, length=None): self.data = data - self.offset = 0 + self.offset = offset + self.length = len(self.data) if length is None else length + assert self.length <= len(self.data) def __str__(self): - data = self.data[self.offset:] + data = self.data[self.offset:self.length] return '(%d bytes): %r' % (len(data), data) def Advance(self, offset_incr): - assert self.offset + offset_incr <= len(self.data) + assert offset_incr <= self.Remaining(), 'Want %d bytes, have %d' % (offset_incr, self.Remaining()) self.offset += offset_incr def Extract(self, length): - assert self.offset + length <= len(self.data), 'Want %d bytes, have %d' % (length, len(self.data) - self.offset) + assert length <= self.Remaining(), 'Want %d bytes, have %d' % (length, self.Remaining()) ret = self.data[self.offset:self.offset + length] self.Advance(length) return ret + def ExtractIterator(self, length): + assert length <= self.Remaining(), 'Want %d bytes, have %d' % (length, self.Remaining()) + ret = Iterator(self.data, self.offset, self.offset + length) + self.Advance(length) + return ret + + def Remaining(self): + return self.length - self.offset + def AtEnd(self): - return self.offset == len(self.data) + return not self.Remaining() class Accumulator(object): @@ -44,9 +55,7 @@ class Accumulator(object): class SingleStructParser(struct.Struct): - def Unpack(self, iterator, targetlen=None): - if targetlen is not None: - assert self.size == targetlen, 'Actual bytes: %d, expected bytes: %d' % (targetlen, self.size) + def Unpack(self, iterator): values = self.unpack_from(iterator.data, iterator.offset) iterator.Advance(self.size) assert len(values) == 1 @@ -61,9 +70,7 @@ class StructParser(struct.Struct): super(StructParser, self).__init__(format) self._fields = fields - def Unpack(self, iterator, targetlen=None): - if targetlen is not None: - assert self.size == targetlen, 'Actual bytes: %d, expected bytes: %d' % (targetlen, self.size) + def Unpack(self, iterator): values = self.unpack_from(iterator.data, iterator.offset) iterator.Advance(self.size) return dict(zip(self._fields, values)) @@ -76,16 +83,15 @@ class StructParser(struct.Struct): class StringParser(object): - def Unpack(self, iterator, targetlen): - return iterator.Extract(targetlen) + def Unpack(self, iterator): + return iterator.Extract(iterator.Remaining()) def Pack(self, accumulator, value): accumulator.Append(value) class EmptyParser(object): - def Unpack(self, iterator, targetlen=None): - assert not targetlen + def Unpack(self, iterator): return True def Pack(self, accumulator, value=None): @@ -103,17 +109,17 @@ class Attribute(object): super(Attribute, self).__init__() self._attributes = attributes - def Unpack(self, iterator, targetlen=None): + def Unpack(self, iterator): nlattr = self._nlattr.Unpack(iterator) - if targetlen is not None: - assert nlattr['len'] == targetlen value = iterator.data[iterator.offset:iterator.offset + nlattr['len'] - self._nlattr.size] name, sub_parser = self._attributes.get(nlattr['type'], (None, None)) assert sub_parser, 'Unknown attribute type %d, len %d' % (nlattr['type'], nlattr['len']) sub_len = nlattr['len'] - self._nlattr.size + sub_iterator = iterator.ExtractIterator(sub_len) ret = { - name: sub_parser.Unpack(iterator, sub_len) + name: sub_parser.Unpack(sub_iterator) } + assert sub_iterator.AtEnd(), '%d bytes remaining' % sub_iterator.Remaining() padding = ((nlattr['len'] + 4 - 1) & ~3) - nlattr['len'] iterator.Advance(padding) @@ -134,9 +140,7 @@ class Attributes(object): self._attribute_idx = dict((v[0], k) for k, v in attributes.iteritems()) self._attribute = Attribute(attributes) - def Unpack(self, iterator, targetlen=None): - if targetlen is not None: - iterator = Iterator(iterator.Extract(targetlen)) + def Unpack(self, iterator): ret = {} while not iterator.AtEnd(): ret.update(self._attribute.Unpack(iterator)) @@ -154,13 +158,14 @@ class Array(object): super(Array, self).__init__() self._child = child - def Unpack(self, iterator, targetlen=None): - if targetlen is not None: - iterator = Iterator(iterator.Extract(targetlen)) + def Unpack(self, iterator): ret = [] while not iterator.AtEnd(): hdr = self._arrayhdr.Unpack(iterator) - ret.append(self._child.Unpack(iterator, hdr['len'] - self._arrayhdr.size)) + sub_len = hdr['len'] - self._arrayhdr.size + sub_iterator = iterator.ExtractIterator(sub_len) + ret.append(self._child.Unpack(sub_iterator)) + assert sub_iterator.AtEnd(), '%d bytes remaining' % sub_iterator.Remaining() return ret @@ -258,42 +263,52 @@ STA_FLAG_TDLS_PEER = 1 << 5 STA_FLAG_ASSOCIATED = 1 << 6 -int_genquery = Accumulator() -genlmsghdr.Pack( - int_genquery, - cmd=CMD_GET_STATION, - version=0, - reserved=0) -nl80211_attr.Pack( - int_genquery, - ifindex=6) -genquery = Accumulator() -nlmsghdr.Pack( - genquery, - length=nlmsghdr.size + len(int_genquery), - type=20, # XXX - flags=F_REQUEST | F_ACK | F_DUMP, - seq=random.randint(0, 2 ** 32 - 1), - pid=os.getpid()) -genquery.Append(str(int_genquery)) +class Connection(object): + def __init__(self): + self._sock = socket.socket(socket.AF_NETLINK, socket.SOCK_DGRAM, 16) + self._sock.bind((0, 0)) -sock = socket.socket(socket.AF_NETLINK, socket.SOCK_DGRAM, 16) -sock.bind((0, 0)) + def Send(self, msg): + self._sock.send(msg) + + def RecvAndUnpack(self): + data = self._sock.recv(4096) + iterator = Iterator(data) + myhdr = nlmsghdr.Unpack(iterator) + print 'nlmsghdr: %s' % myhdr + int_iterator = iterator.ExtractIterator(myhdr['length'] - nlmsghdr.size) + print 'genlmsghdr: %s' % genlmsghdr.Unpack(int_iterator) + print 'ctrl_attr: %s' % ctrl_attr.Unpack(int_iterator) + + +#int_genquery = Accumulator() +#genlmsghdr.Pack( +# int_genquery, +# cmd=CMD_GET_STATION, +# version=0, +# reserved=0) +#nl80211_attr.Pack( +# int_genquery, +# ifindex=6) +#genquery = Accumulator() +#nlmsghdr.Pack( +# genquery, +# length=nlmsghdr.size + len(int_genquery), +# type=20, # XXX +# flags=F_REQUEST | F_ACK | F_DUMP, +# seq=random.randint(0, 2 ** 32 - 1), +# pid=os.getpid()) +#genquery.Append(str(int_genquery)) #sock.send(str(genquery)) #data = sock.recv(4096) -# + #iterator = Iterator(data) #print 'nlmsghdr: %s' % nlmsghdr.Unpack(iterator) #print 'genlmsghdr: %s' % genlmsghdr.Unpack(iterator) #print 'nl80211_attr: %s' % nl80211_attr.Unpack(iterator) query = '\24\0\0\0\20\0\5\3a\6\256T\v\17\0\0\3\1\0\0' -sock.send(query) -data = sock.recv(4096) -iterator = Iterator(data) -myhdr = nlmsghdr.Unpack(iterator) -print 'nlmsghdr: %s' % myhdr -int_iterator = Iterator(iterator.Extract(myhdr['length'] - nlmsghdr.size)) -print 'genlmsghdr: %s' % genlmsghdr.Unpack(int_iterator) -print 'ctrl_attr: %s' % ctrl_attr.Unpack(int_iterator) +conn = Connection() +conn.Send(query) +conn.RecvAndUnpack()