1、接口封装
import sys
import struct
from base64 import b64encode
from hashlib import sha1
import logging
from socket import error as SocketError
import errno
from socketserver import ThreadingMixIn, TCPServer, StreamRequestHandler
logger = logging.getLogger(__name__)
logging.basicConfig()
FIN = 0x80
OPCODE = 0x0f
MASKED = 0x80
PAYLOAD_LEN = 0x7f
PAYLOAD_LEN_EXT16 = 0x7e
PAYLOAD_LEN_EXT64 = 0x7f
OPCODE_CONTINUATION = 0x0
OPCODE_TEXT = 0x1
OPCODE_BINARY = 0x2
OPCODE_CLOSE_CONN = 0x8
OPCODE_PING = 0x9
OPCODE_PONG = 0xA
# -------------------------------- API ---------------------------------
class API():
def run_forever(self):
try:
logger.info("Listening on port %d for clients.." % self.port)
self.serve_forever()
except KeyboardInterrupt:
self.server_close()
logger.info("Server terminated.")
except Exception as e:
logger.error(str(e), exc_info=True)
exit(1)
def new_client(self, client, server):
pass
def client_left(self, client, server):
pass
def message_received(self, client, server, message):
pass
def set_fn_new_client(self, fn):
self.new_client = fn
def set_fn_client_left(self, fn):
self.client_left = fn
def set_fn_message_received(self, fn):
self.message_received = fn
def send_message(self, client, msg):
self._unicast_(client, msg)
def send_message_to_all(self, msg):
self._multicast_(msg)
# ------------------------- Implementation -----------------------------
class WebsocketServer(ThreadingMixIn, TCPServer, API):
"""
A websocket server waiting for clients to connect.
Args:
port(int): Port to bind to
host(str): Hostname or IP to listen for connections. By default 127.0.0.1
is being used. To accept connections from any client, you should use
0.0.0.0.
loglevel: Logging level from logging module to use for logging. By default
warnings and errors are being logged.
Properties:
clients(list): A list of connected clients. A client is a dictionary
like below.
{
# 'id' : id,
'handler' : handler,
'address' : (addr, port)
}
"""
allow_reuse_address = True
daemon_threads = True # comment to keep threads alive until finished
clients = []
# id_counter = 0
def __init__(self, port, host='0.0.0.0', loglevel=logging.WARNING):
logger.setLevel(loglevel)
TCPServer.__init__(self, (host, port), WebSocketHandler)
self.port = self.socket.getsockname()[1]
def _message_received_(self, handler, msg):
self.message_received(self.handler_to_client(handler), self, msg)
def _ping_received_(self, handler, msg):
handler.send_pong(msg)
def _pong_received_(self, handler, msg):
pass
def _new_client_(self, handler):
# self.id_counter += 1
client = {
# 'id': self.id_counter,
'handler': handler,
'address': handler.client_address
}
# print(client)
self.clients.append(client)
self.new_client(client, self)
def _client_left_(self, handler):
client = self.handler_to_client(handler)
self.client_left(client, self)
if client in self.clients:
self.clients.remove(client)
def _unicast_(self, to_client, msg):
# print(self.clients)
to_client['handler'].send_message(msg)
def _multicast_(self, msg):
for client in self.clients:
self._unicast_(client, str(self.clients))
self._unicast_(client, msg)
def handler_to_client(self, handler):
print(self.clients)
for client in self.clients:
if client['handler'] == handler:
return client
class WebSocketHandler(StreamRequestHandler):
def __init__(self, socket, addr, server):
self.server = server
StreamRequestHandler.__init__(self, socket, addr, server)
def setup(self):
StreamRequestHandler.setup(self)
self.keep_alive = True
self.handshake_done = False
self.valid_client = False
def handle(self):
while self.keep_alive:
if not self.handshake_done:
self.handshake()
elif self.valid_client:
self.read_next_message()
def read_bytes(self, num):
# python3 gives ordinal of byte directly
bytes = self.rfile.read(num)
if sys.version_info[0] < 3:
return map(ord, bytes)
else:
return bytes
def read_next_message(self):
try:
b1, b2 = self.read_bytes(2)
except SocketError as e: # to be replaced with ConnectionResetError for py3
if e.errno == errno.ECONNRESET:
logger.info("Client closed connection.")
print("Error: {}".format(e))
self.keep_alive = 0
return
b1, b2 = 0, 0
except ValueError as e:
b1, b2 = 0, 0
fin = b1 & FIN
opcode = b1 & OPCODE
masked = b2 & MASKED
payload_length = b2 & PAYLOAD_LEN
if opcode == OPCODE_CLOSE_CONN:
logger.info("Client asked to close connection.")
self.keep_alive = 0
return
if not masked:
logger.warn("Client must always be masked.")
self.keep_alive = 0
return
if opcode == OPCODE_CONTINUATION:
logger.warn("Continuation frames are not supported.")
return
elif opcode == OPCODE_BINARY:
logger.warn("Binary frames are not supported.")
return
elif opcode == OPCODE_TEXT:
opcode_handler = self.server._message_received_
elif opcode == OPCODE_PING:
opcode_handler = self.server._ping_received_
elif opcode == OPCODE_PONG:
opcode_handler = self.server._pong_received_
else:
logger.warn("Unknown opcode %#x." % opcode)
self.keep_alive = 0
return
if payload_length == 126:
payload_length = struct.unpack(">H", self.rfile.read(2))[0]
elif payload_length == 127:
payload_length = struct.unpack(">Q", self.rfile.read(8))[0]
masks = self.read_bytes(4)
message_bytes = bytearray()
for message_byte in self.read_bytes(payload_length):
message_byte ^= masks[len(message_bytes) % 4]
message_bytes.append(message_byte)
opcode_handler(self, message_bytes.decode('utf8'))
def send_message(self, message):
self.send_text(message)
def send_pong(self, message):
self.send_text(message, OPCODE_PONG)
def send_text(self, message, opcode=OPCODE_TEXT):
"""
Important: Fragmented(=continuation) messages are not supported since
their usage cases are limited - when we don't know the payload length.
"""
# Validate message
if isinstance(message, bytes):
message = try_decode_UTF8(message) # this is slower but ensures we have UTF-8
if not message:
logger.warning("Can\'t send message, message is not valid UTF-8")
return False
elif sys.version_info < (3, 0) and (isinstance(message, str)):
pass
elif isinstance(message, str):
pass
else:
logger.warning('Can\'t send message, message has to be a string or bytes. Given type is %s' % type(message))
return False
header = bytearray()
payload = encode_to_UTF8(message)
payload_length = len(payload)
# Normal payload
if payload_length <= 125:
header.append(FIN | opcode)
header.append(payload_length)
# Extended payload
elif payload_length >= 126 and payload_length <= 65535:
header.append(FIN | opcode)
header.append(PAYLOAD_LEN_EXT16)
header.extend(struct.pack(">H", payload_length))
# Huge extended payload
elif payload_length < 18446744073709551616:
header.append(FIN | opcode)
header.append(PAYLOAD_LEN_EXT64)
header.extend(struct.pack(">Q", payload_length))
else:
raise Exception("Message is too big. Consider breaking it into chunks.")
self.request.send(header + payload)
def read_http_headers(self):
headers = {}
# first line should be HTTP GET
http_get = self.rfile.readline().decode().strip()
assert http_get.upper().startswith('GET')
# remaining should be headers
while True:
header = self.rfile.readline().decode().strip()
if not header:
break
head, value = header.split(':', 1)
headers[head.lower().strip()] = value.strip()
return headers
def handshake(self):
headers = self.read_http_headers()
try:
assert headers['upgrade'].lower() == 'websocket'
except AssertionError:
self.keep_alive = False
return
try:
key = headers['sec-websocket-key']
except KeyError:
logger.warning("Client tried to connect but was missing a key")
self.keep_alive = False
return
response = self.make_handshake_response(key)
self.handshake_done = self.request.send(response.encode())
self.valid_client = True
self.server._new_client_(self)
@classmethod
def make_handshake_response(cls, key):
return \
'HTTP/1.1 101 Switching Protocols\r\n' \
'Upgrade: websocket\r\n' \
'Connection: Upgrade\r\n' \
'Sec-WebSocket-Accept: %s\r\n' \
'\r\n' % cls.calculate_response_key(key)
@classmethod
def calculate_response_key(cls, key):
GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
hash = sha1(key.encode() + GUID.encode())
response_key = b64encode(hash.digest()).strip()
return response_key.decode('ASCII')
def finish(self):
self.server._client_left_(self)
def encode_to_UTF8(data):
try:
return data.encode('UTF-8')
except UnicodeEncodeError as e:
logger.error("Could not encode data to UTF-8 -- %s" % e)
return False
except Exception as e:
raise (e)
def try_decode_UTF8(data):
try:
return data.decode('utf-8')
except UnicodeDecodeError:
return False
except Exception as e:
raise (e)
2、服务端
from datetime import datetime
from netlogs import Logger
from websocket_server import WebsocketServer
class WBServer(object):
def __init__(self, port):
self.PORT = port
@staticmethod
def new_client(client, server):
Logger().log("New client connected")
server.send_message_to_all('')
@staticmethod
def client_left(client, server):
Logger().log("Client disconnected")
@staticmethod
def message_received(client, server, message):
if len(message) > 200:
message = message[:200] + '..'
message = eval(message)
no = message['id']
content = message['content']
msg = '{time}-Client{id}: {msg}'.format(time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'), msg=content, id=no)
server.send_message(server.clients[no], msg)
def run(self):
server = WebsocketServer(self.PORT)
server.set_fn_new_client(self.new_client)
server.set_fn_client_left(self.client_left)
server.set_fn_message_received(self.message_received)
server.run_forever()
if __name__ == '__main__':
server = WBServer(9001)
server.run()
3、客户端
import threading
import sys
import websocket
from netlogs import Logger
def on_message(ws, message):
Logger().log('received:' + message, mode=True)
def on_error(ws, error):
Logger().log(error, mode=False)
def on_close(ws):
Logger().log("### closed ###", mode=True)
def send_message():
try:
while True:
print('请输入要发送的目标编号', end=':')
no = input()
print('请输入要发送的内容', end=':')
content = input()
message = {'id': int(no), 'content': content}
msg = str(message)
Logger().log('send:' + msg, mode=True)
ws.send(msg)
except Exception as e:
sys.exit()
def run_connect():
ws.run_forever()
ws = websocket.WebSocketApp("ws://localhost:9001/track",
on_message=on_message,
on_error=on_error,
on_close=on_close,
)
run = threading.Thread(target=run_connect)
run.start()
send_message()