Multiplexing thread for the netlink socket to make the class thread safe.
This commit is contained in:
50
nl80211.py
50
nl80211.py
@@ -2,9 +2,11 @@
|
|||||||
|
|
||||||
import fcntl
|
import fcntl
|
||||||
import os
|
import os
|
||||||
import random
|
import Queue
|
||||||
|
import threading
|
||||||
import socket
|
import socket
|
||||||
import struct
|
import struct
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
|
||||||
class Iterator(object):
|
class Iterator(object):
|
||||||
@@ -197,35 +199,62 @@ class Netlink(object):
|
|||||||
|
|
||||||
_nlmsghdr = StructParser('LHHLL', ('length', 'type', 'flags', 'seq', 'pid'))
|
_nlmsghdr = StructParser('LHHLL', ('length', 'type', 'flags', 'seq', 'pid'))
|
||||||
|
|
||||||
|
_seq = 0
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._sock = socket.socket(socket.AF_NETLINK, socket.SOCK_DGRAM, 16)
|
self._sock = socket.socket(socket.AF_NETLINK, socket.SOCK_DGRAM, 16)
|
||||||
self._sock.bind((0, 0))
|
self._sock.bind((0, 0))
|
||||||
|
self._seq_lock = threading.Lock()
|
||||||
|
self._response_queues = {}
|
||||||
|
thread = threading.Thread(
|
||||||
|
target=self._Receiver,
|
||||||
|
args=(weakref.proxy(self),))
|
||||||
|
thread.daemon = True
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _Receiver(self):
|
||||||
|
while True:
|
||||||
|
data = self._sock.recv(4096)
|
||||||
|
iterator = Iterator(data)
|
||||||
|
hdr = self._nlmsghdr.Unpack(iterator)
|
||||||
|
self._response_queues[hdr['seq']].put(data)
|
||||||
|
|
||||||
|
def _NextSeq(self):
|
||||||
|
with self._seq_lock:
|
||||||
|
self._seq += 1
|
||||||
|
return self._seq
|
||||||
|
|
||||||
def Send(self, msgtype, flags, msg):
|
def Send(self, msgtype, flags, msg):
|
||||||
flagint = self._NLMSG_F_REQUEST
|
flagint = self._NLMSG_F_REQUEST
|
||||||
for flag in flags:
|
for flag in flags:
|
||||||
flagint |= self.flags[flag]
|
flagint |= self.flags[flag]
|
||||||
accumulator = Accumulator()
|
accumulator = Accumulator()
|
||||||
|
seq = self._NextSeq()
|
||||||
self._nlmsghdr.Pack(
|
self._nlmsghdr.Pack(
|
||||||
accumulator,
|
accumulator,
|
||||||
length=len(msg) + self._nlmsghdr.size,
|
length=len(msg) + self._nlmsghdr.size,
|
||||||
type=msgtype,
|
type=msgtype,
|
||||||
flags=flagint,
|
flags=flagint,
|
||||||
seq=random.randint(0, 2 ** 32 - 1),
|
seq=seq,
|
||||||
pid=os.getpid())
|
pid=os.getpid())
|
||||||
accumulator.Append(msg)
|
accumulator.Append(msg)
|
||||||
|
self._response_queues[seq] = Queue.Queue()
|
||||||
self._sock.send(str(accumulator))
|
self._sock.send(str(accumulator))
|
||||||
|
return seq
|
||||||
|
|
||||||
def Recv(self):
|
def Recv(self, seq):
|
||||||
while True:
|
while True:
|
||||||
data = self._sock.recv(4096)
|
data = self._response_queues[seq].get()
|
||||||
iterator = Iterator(data)
|
iterator = Iterator(data)
|
||||||
while not iterator.AtEnd():
|
while not iterator.AtEnd():
|
||||||
myhdr = self._nlmsghdr.Unpack(iterator)
|
myhdr = self._nlmsghdr.Unpack(iterator)
|
||||||
if myhdr['type'] == self._NLMSG_DONE:
|
if myhdr['type'] == self._NLMSG_DONE:
|
||||||
|
del self._response_queues[seq]
|
||||||
return
|
return
|
||||||
yield (myhdr['type'], iterator.ExtractIterator(myhdr['length'] - self._nlmsghdr.size))
|
yield (myhdr['type'], iterator.ExtractIterator(myhdr['length'] - self._nlmsghdr.size))
|
||||||
if not myhdr['flags'] & self._NLMSG_F_MULTI:
|
if not myhdr['flags'] & self._NLMSG_F_MULTI:
|
||||||
|
del self._response_queues[seq]
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@@ -267,8 +296,7 @@ class GenericNetlink(object):
|
|||||||
|
|
||||||
self._netlink = Netlink()
|
self._netlink = Netlink()
|
||||||
self._UpdateMsgTypes()
|
self._UpdateMsgTypes()
|
||||||
self.Send('nlctrl', ['dump'], 'getfamily', 1)
|
for msg in self.Query('nlctrl', ['dump'], 'getfamily', 1):
|
||||||
for msg in self.Recv():
|
|
||||||
msgtype, attrs = msg
|
msgtype, attrs = msg
|
||||||
assert msgtype == self._msgtypes_by_name['nlctrl']['commands']['newfamily'], msgtype
|
assert msgtype == self._msgtypes_by_name['nlctrl']['commands']['newfamily'], msgtype
|
||||||
family_name = attrs['family_name'].rstrip('\0')
|
family_name = attrs['family_name'].rstrip('\0')
|
||||||
@@ -305,17 +333,17 @@ class GenericNetlink(object):
|
|||||||
accumulator,
|
accumulator,
|
||||||
**attrs)
|
**attrs)
|
||||||
|
|
||||||
self._netlink.Send(msgtype['id'], flags, str(accumulator))
|
return self._netlink.Send(msgtype['id'], flags, str(accumulator))
|
||||||
|
|
||||||
def Recv(self):
|
def Recv(self, seq):
|
||||||
for msgtype, iterator in self._netlink.Recv():
|
for msgtype, iterator in self._netlink.Recv(seq):
|
||||||
genlhdr = self._genlmsghdr.Unpack(iterator)
|
genlhdr = self._genlmsghdr.Unpack(iterator)
|
||||||
parser = self._msgtypes_by_id[msgtype]['parser']
|
parser = self._msgtypes_by_id[msgtype]['parser']
|
||||||
yield (genlhdr['cmd'], parser.Unpack(iterator))
|
yield (genlhdr['cmd'], parser.Unpack(iterator))
|
||||||
|
|
||||||
def Query(self, msgtype, flags, cmd, version, **attrs):
|
def Query(self, msgtype, flags, cmd, version, **attrs):
|
||||||
self.Send(msgtype, flags, cmd, version, **attrs)
|
seq = self.Send(msgtype, flags, cmd, version, **attrs)
|
||||||
return self.Recv()
|
return self.Recv(seq)
|
||||||
|
|
||||||
|
|
||||||
def RegisterNL80211(gnl):
|
def RegisterNL80211(gnl):
|
||||||
|
|||||||
Reference in New Issue
Block a user