Multiplexing thread for the netlink socket to make the class thread safe.
This commit is contained in:
50
nl80211.py
50
nl80211.py
@@ -2,9 +2,11 @@
|
||||
|
||||
import fcntl
|
||||
import os
|
||||
import random
|
||||
import Queue
|
||||
import threading
|
||||
import socket
|
||||
import struct
|
||||
import weakref
|
||||
|
||||
|
||||
class Iterator(object):
|
||||
@@ -197,35 +199,62 @@ class Netlink(object):
|
||||
|
||||
_nlmsghdr = StructParser('LHHLL', ('length', 'type', 'flags', 'seq', 'pid'))
|
||||
|
||||
_seq = 0
|
||||
|
||||
def __init__(self):
|
||||
self._sock = socket.socket(socket.AF_NETLINK, socket.SOCK_DGRAM, 16)
|
||||
self._sock.bind((0, 0))
|
||||
self._seq_lock = threading.Lock()
|
||||
self._response_queues = {}
|
||||
thread = threading.Thread(
|
||||
target=self._Receiver,
|
||||
args=(weakref.proxy(self),))
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
@staticmethod
|
||||
def _Receiver(self):
|
||||
while True:
|
||||
data = self._sock.recv(4096)
|
||||
iterator = Iterator(data)
|
||||
hdr = self._nlmsghdr.Unpack(iterator)
|
||||
self._response_queues[hdr['seq']].put(data)
|
||||
|
||||
def _NextSeq(self):
|
||||
with self._seq_lock:
|
||||
self._seq += 1
|
||||
return self._seq
|
||||
|
||||
def Send(self, msgtype, flags, msg):
|
||||
flagint = self._NLMSG_F_REQUEST
|
||||
for flag in flags:
|
||||
flagint |= self.flags[flag]
|
||||
accumulator = Accumulator()
|
||||
seq = self._NextSeq()
|
||||
self._nlmsghdr.Pack(
|
||||
accumulator,
|
||||
length=len(msg) + self._nlmsghdr.size,
|
||||
type=msgtype,
|
||||
flags=flagint,
|
||||
seq=random.randint(0, 2 ** 32 - 1),
|
||||
seq=seq,
|
||||
pid=os.getpid())
|
||||
accumulator.Append(msg)
|
||||
self._response_queues[seq] = Queue.Queue()
|
||||
self._sock.send(str(accumulator))
|
||||
return seq
|
||||
|
||||
def Recv(self):
|
||||
def Recv(self, seq):
|
||||
while True:
|
||||
data = self._sock.recv(4096)
|
||||
data = self._response_queues[seq].get()
|
||||
iterator = Iterator(data)
|
||||
while not iterator.AtEnd():
|
||||
myhdr = self._nlmsghdr.Unpack(iterator)
|
||||
if myhdr['type'] == self._NLMSG_DONE:
|
||||
del self._response_queues[seq]
|
||||
return
|
||||
yield (myhdr['type'], iterator.ExtractIterator(myhdr['length'] - self._nlmsghdr.size))
|
||||
if not myhdr['flags'] & self._NLMSG_F_MULTI:
|
||||
del self._response_queues[seq]
|
||||
return
|
||||
|
||||
|
||||
@@ -267,8 +296,7 @@ class GenericNetlink(object):
|
||||
|
||||
self._netlink = Netlink()
|
||||
self._UpdateMsgTypes()
|
||||
self.Send('nlctrl', ['dump'], 'getfamily', 1)
|
||||
for msg in self.Recv():
|
||||
for msg in self.Query('nlctrl', ['dump'], 'getfamily', 1):
|
||||
msgtype, attrs = msg
|
||||
assert msgtype == self._msgtypes_by_name['nlctrl']['commands']['newfamily'], msgtype
|
||||
family_name = attrs['family_name'].rstrip('\0')
|
||||
@@ -305,17 +333,17 @@ class GenericNetlink(object):
|
||||
accumulator,
|
||||
**attrs)
|
||||
|
||||
self._netlink.Send(msgtype['id'], flags, str(accumulator))
|
||||
return self._netlink.Send(msgtype['id'], flags, str(accumulator))
|
||||
|
||||
def Recv(self):
|
||||
for msgtype, iterator in self._netlink.Recv():
|
||||
def Recv(self, seq):
|
||||
for msgtype, iterator in self._netlink.Recv(seq):
|
||||
genlhdr = self._genlmsghdr.Unpack(iterator)
|
||||
parser = self._msgtypes_by_id[msgtype]['parser']
|
||||
yield (genlhdr['cmd'], parser.Unpack(iterator))
|
||||
|
||||
def Query(self, msgtype, flags, cmd, version, **attrs):
|
||||
self.Send(msgtype, flags, cmd, version, **attrs)
|
||||
return self.Recv()
|
||||
seq = self.Send(msgtype, flags, cmd, version, **attrs)
|
||||
return self.Recv(seq)
|
||||
|
||||
|
||||
def RegisterNL80211(gnl):
|
||||
|
||||
Reference in New Issue
Block a user