Multiplexing thread for the netlink socket to make the class thread safe.

This commit is contained in:
Ian Gulliver
2015-01-11 06:46:28 +00:00
parent d0709b885c
commit 841abaa612

View File

@@ -2,9 +2,11 @@
import fcntl
import os
import random
import Queue
import threading
import socket
import struct
import weakref
class Iterator(object):
@@ -197,35 +199,62 @@ class Netlink(object):
_nlmsghdr = StructParser('LHHLL', ('length', 'type', 'flags', 'seq', 'pid'))
_seq = 0
def __init__(self):
self._sock = socket.socket(socket.AF_NETLINK, socket.SOCK_DGRAM, 16)
self._sock.bind((0, 0))
self._seq_lock = threading.Lock()
self._response_queues = {}
thread = threading.Thread(
target=self._Receiver,
args=(weakref.proxy(self),))
thread.daemon = True
thread.start()
@staticmethod
def _Receiver(self):
while True:
data = self._sock.recv(4096)
iterator = Iterator(data)
hdr = self._nlmsghdr.Unpack(iterator)
self._response_queues[hdr['seq']].put(data)
def _NextSeq(self):
with self._seq_lock:
self._seq += 1
return self._seq
def Send(self, msgtype, flags, msg):
flagint = self._NLMSG_F_REQUEST
for flag in flags:
flagint |= self.flags[flag]
accumulator = Accumulator()
seq = self._NextSeq()
self._nlmsghdr.Pack(
accumulator,
length=len(msg) + self._nlmsghdr.size,
type=msgtype,
flags=flagint,
seq=random.randint(0, 2 ** 32 - 1),
seq=seq,
pid=os.getpid())
accumulator.Append(msg)
self._response_queues[seq] = Queue.Queue()
self._sock.send(str(accumulator))
return seq
def Recv(self):
def Recv(self, seq):
while True:
data = self._sock.recv(4096)
data = self._response_queues[seq].get()
iterator = Iterator(data)
while not iterator.AtEnd():
myhdr = self._nlmsghdr.Unpack(iterator)
if myhdr['type'] == self._NLMSG_DONE:
del self._response_queues[seq]
return
yield (myhdr['type'], iterator.ExtractIterator(myhdr['length'] - self._nlmsghdr.size))
if not myhdr['flags'] & self._NLMSG_F_MULTI:
del self._response_queues[seq]
return
@@ -267,8 +296,7 @@ class GenericNetlink(object):
self._netlink = Netlink()
self._UpdateMsgTypes()
self.Send('nlctrl', ['dump'], 'getfamily', 1)
for msg in self.Recv():
for msg in self.Query('nlctrl', ['dump'], 'getfamily', 1):
msgtype, attrs = msg
assert msgtype == self._msgtypes_by_name['nlctrl']['commands']['newfamily'], msgtype
family_name = attrs['family_name'].rstrip('\0')
@@ -305,17 +333,17 @@ class GenericNetlink(object):
accumulator,
**attrs)
self._netlink.Send(msgtype['id'], flags, str(accumulator))
return self._netlink.Send(msgtype['id'], flags, str(accumulator))
def Recv(self):
for msgtype, iterator in self._netlink.Recv():
def Recv(self, seq):
for msgtype, iterator in self._netlink.Recv(seq):
genlhdr = self._genlmsghdr.Unpack(iterator)
parser = self._msgtypes_by_id[msgtype]['parser']
yield (genlhdr['cmd'], parser.Unpack(iterator))
def Query(self, msgtype, flags, cmd, version, **attrs):
self.Send(msgtype, flags, cmd, version, **attrs)
return self.Recv()
seq = self.Send(msgtype, flags, cmd, version, **attrs)
return self.Recv(seq)
def RegisterNL80211(gnl):