diff --git a/nl80211.py b/nl80211.py index cbe27f3..9d36878 100755 --- a/nl80211.py +++ b/nl80211.py @@ -2,9 +2,11 @@ import fcntl import os -import random +import Queue +import threading import socket import struct +import weakref class Iterator(object): @@ -197,35 +199,62 @@ class Netlink(object): _nlmsghdr = StructParser('LHHLL', ('length', 'type', 'flags', 'seq', 'pid')) + _seq = 0 + def __init__(self): self._sock = socket.socket(socket.AF_NETLINK, socket.SOCK_DGRAM, 16) self._sock.bind((0, 0)) + self._seq_lock = threading.Lock() + self._response_queues = {} + thread = threading.Thread( + target=self._Receiver, + args=(weakref.proxy(self),)) + thread.daemon = True + thread.start() + + @staticmethod + def _Receiver(self): + while True: + data = self._sock.recv(4096) + iterator = Iterator(data) + hdr = self._nlmsghdr.Unpack(iterator) + self._response_queues[hdr['seq']].put(data) + + def _NextSeq(self): + with self._seq_lock: + self._seq += 1 + return self._seq def Send(self, msgtype, flags, msg): flagint = self._NLMSG_F_REQUEST for flag in flags: flagint |= self.flags[flag] accumulator = Accumulator() + seq = self._NextSeq() self._nlmsghdr.Pack( accumulator, length=len(msg) + self._nlmsghdr.size, type=msgtype, flags=flagint, - seq=random.randint(0, 2 ** 32 - 1), + seq=seq, pid=os.getpid()) accumulator.Append(msg) + self._response_queues[seq] = Queue.Queue() self._sock.send(str(accumulator)) + return seq - def Recv(self): + def Recv(self, seq): while True: - data = self._sock.recv(4096) + data = self._response_queues[seq].get() iterator = Iterator(data) while not iterator.AtEnd(): myhdr = self._nlmsghdr.Unpack(iterator) if myhdr['type'] == self._NLMSG_DONE: + del self._response_queues[seq] return yield (myhdr['type'], iterator.ExtractIterator(myhdr['length'] - self._nlmsghdr.size)) if not myhdr['flags'] & self._NLMSG_F_MULTI: + del self._response_queues[seq] return @@ -267,8 +296,7 @@ class GenericNetlink(object): self._netlink = Netlink() self._UpdateMsgTypes() - self.Send('nlctrl', ['dump'], 'getfamily', 1) - for msg in self.Recv(): + for msg in self.Query('nlctrl', ['dump'], 'getfamily', 1): msgtype, attrs = msg assert msgtype == self._msgtypes_by_name['nlctrl']['commands']['newfamily'], msgtype family_name = attrs['family_name'].rstrip('\0') @@ -305,17 +333,17 @@ class GenericNetlink(object): accumulator, **attrs) - self._netlink.Send(msgtype['id'], flags, str(accumulator)) + return self._netlink.Send(msgtype['id'], flags, str(accumulator)) - def Recv(self): - for msgtype, iterator in self._netlink.Recv(): + def Recv(self, seq): + for msgtype, iterator in self._netlink.Recv(seq): genlhdr = self._genlmsghdr.Unpack(iterator) parser = self._msgtypes_by_id[msgtype]['parser'] yield (genlhdr['cmd'], parser.Unpack(iterator)) def Query(self, msgtype, flags, cmd, version, **attrs): - self.Send(msgtype, flags, cmd, version, **attrs) - return self.Recv() + seq = self.Send(msgtype, flags, cmd, version, **attrs) + return self.Recv(seq) def RegisterNL80211(gnl):