296 lines
10 KiB
Python
296 lines
10 KiB
Python
from collections import namedtuple
|
||
|
||
import network
|
||
import uerrno
|
||
import uio
|
||
import uselect as select
|
||
import usocket as socket
|
||
from config import config
|
||
from server_base import BaseServer
|
||
from wifi_manager import wifi_manager
|
||
|
||
WriteConn = namedtuple("WriteConn", ["body", "buff", "buffmv", "write_range"])
|
||
ReqInfo = namedtuple("ReqInfo", ["type", "path", "params", "host"])
|
||
|
||
import gc
|
||
|
||
|
||
def unquote(string):
|
||
"""stripped down implementation of urllib.parse unquote_to_bytes"""
|
||
|
||
if not string:
|
||
return b""
|
||
|
||
if isinstance(string, str):
|
||
string = string.encode("utf-8")
|
||
string = string.replace(b"+", b" ")
|
||
|
||
# split into substrings on each escape character
|
||
bits = string.split(b"%")
|
||
if len(bits) == 1:
|
||
return string # there was no escape character
|
||
|
||
res = [bits[0]] # everything before the first escape character
|
||
|
||
# for each escape character, get the next two digits and convert to
|
||
for item in bits[1:]:
|
||
code = item[:2]
|
||
char = bytes([int(code, 16)]) # convert to utf-8-encoded byte
|
||
res.append(char) # append the converted character
|
||
res.append(
|
||
item[2:]
|
||
) # append anything else that occurred before the next escape character
|
||
|
||
return b"".join(res)
|
||
|
||
|
||
class HTTPServer(BaseServer):
|
||
def __init__(self, poller, local_ip):
|
||
super().__init__(poller, 80, socket.SOCK_STREAM, "HTTP Server")
|
||
if type(local_ip) is bytes:
|
||
self.local_ip = local_ip
|
||
else:
|
||
self.local_ip = local_ip.encode()
|
||
self.request = dict()
|
||
self.conns = dict()
|
||
self.routes = {
|
||
b"/": b"/rom/www/iwconfig.html",
|
||
b"/login": self.login,
|
||
b"/scan": self.scan_networks,
|
||
}
|
||
|
||
self.ssid = None
|
||
|
||
# queue up to 2 connection requests before refusing (ESP8266 memory optimization)
|
||
self.sock.listen(2)
|
||
self.sock.setblocking(False)
|
||
|
||
# @micropython.native
|
||
def handle(self, sock, event, others):
|
||
if sock is self.sock:
|
||
# client connecting on port 80, so spawn off a new
|
||
# socket to handle this connection
|
||
# print("- Accepting new HTTP connection")
|
||
self.accept(sock)
|
||
elif event & select.POLLIN:
|
||
# socket has data to read in
|
||
# print("- Reading incoming HTTP data")
|
||
return self.read(sock)
|
||
elif event & select.POLLOUT:
|
||
# existing connection has space to send more data
|
||
# print("- Sending outgoing HTTP data")
|
||
self.write_to(sock)
|
||
|
||
def accept(self, server_sock):
|
||
"""accept a new client request socket and register it for polling"""
|
||
|
||
try:
|
||
client_sock, addr = server_sock.accept()
|
||
except OSError as e:
|
||
if e.args[0] == uerrno.EAGAIN:
|
||
return
|
||
|
||
client_sock.setblocking(False)
|
||
client_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||
self.poller.register(client_sock, select.POLLIN)
|
||
|
||
def parse_request(self, req):
|
||
"""parse a raw HTTP request to get items of interest"""
|
||
|
||
req_lines = req.split(b"\r\n")
|
||
req_type, full_path, http_ver = req_lines[0].split(b" ")
|
||
path = full_path.split(b"?")
|
||
base_path = path[0]
|
||
query = path[1] if len(path) > 1 else None
|
||
query_params = (
|
||
{
|
||
key: val
|
||
for key, val in [param.split(b"=") for param in query.split(b"&")]
|
||
}
|
||
if query
|
||
else {}
|
||
)
|
||
host = [line.split(b": ")[1] for line in req_lines if b"Host:" in line][0]
|
||
|
||
return ReqInfo(req_type, base_path, query_params, host)
|
||
|
||
def login(self, params):
|
||
# 从URL参数中提取表单数据
|
||
ssid = unquote(params.get(b"ssid", None))
|
||
password = unquote(params.get(b"password", ""))
|
||
city = unquote(params.get(b"city", None))
|
||
cityid = params.get(b"city", None)
|
||
|
||
# 使用全局Config实例保存配置
|
||
config.set("ssid", ssid)
|
||
config.set("password", password)
|
||
config.set("city", city)
|
||
config.set("cityid", cityid)
|
||
if config.write():
|
||
print("Configuration saved successfully")
|
||
else:
|
||
print("Failed to save configuration, invalid data")
|
||
|
||
# 重定向local_ip
|
||
headers = (
|
||
b"HTTP/1.1 307 Temporary Redirect\r\nLocation: http://{:s}/\r\n".format(
|
||
self.local_ip
|
||
)
|
||
)
|
||
|
||
return b"", headers
|
||
|
||
def scan_networks(self, params):
|
||
"""扫描WiFi网络并返回JSON数据"""
|
||
try:
|
||
# 使用wifi_manager扫描网络
|
||
networks = wifi_manager.scan_networks()
|
||
|
||
import ujson
|
||
|
||
json_data = ujson.dumps({"networks": networks})
|
||
except Exception as e:
|
||
print(f"Error scanning networks: {e}")
|
||
json_data = ujson.dumps({"networks": []})
|
||
|
||
headers = b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nAccess-Control-Allow-Origin: *\r\n"
|
||
return json_data.encode(), headers
|
||
|
||
def get_response(self, req):
|
||
"""generate a response body and headers, given a route"""
|
||
|
||
headers = b"HTTP/1.1 200 OK\r\n"
|
||
route = self.routes.get(req.path, None)
|
||
|
||
if type(route) is bytes:
|
||
# expect a filename, so return contents of file
|
||
return open(route, "rb"), headers
|
||
|
||
if callable(route):
|
||
# call a function, which may or may not return a response
|
||
response = route(req.params)
|
||
body = response[0] or b""
|
||
headers = response[1] or headers
|
||
return uio.BytesIO(body), headers
|
||
|
||
headers = b"HTTP/1.1 404 Not Found\r\n"
|
||
return uio.BytesIO(b""), headers
|
||
|
||
def is_valid_req(self, req):
|
||
if req.host != self.local_ip:
|
||
# force a redirect to the MCU's IP address
|
||
return False
|
||
# redirect if we don't have a route for the requested path
|
||
return req.path in self.routes
|
||
|
||
def read(self, s):
|
||
"""read in client request from socket"""
|
||
|
||
data = s.read()
|
||
if not data:
|
||
# no data in the TCP stream, so close the socket
|
||
self.close(s)
|
||
return
|
||
|
||
# add new data to the full request
|
||
sid = id(s)
|
||
self.request[sid] = self.request.get(sid, b"") + data
|
||
|
||
# check if additional data expected
|
||
if data[-4:] != b"\r\n\r\n":
|
||
# HTTP request is not finished if no blank line at the end
|
||
# wait for next read event on this socket instead
|
||
return
|
||
|
||
# get the completed request
|
||
req = self.parse_request(self.request.pop(sid))
|
||
|
||
if not self.is_valid_req(req):
|
||
headers = (
|
||
b"HTTP/1.1 307 Temporary Redirect\r\nLocation: http://{:s}/\r\n".format(
|
||
self.local_ip
|
||
)
|
||
)
|
||
body = uio.BytesIO(b"")
|
||
self.prepare_write(s, body, headers)
|
||
return
|
||
|
||
# by this point, we know the request has the correct
|
||
# host and a valid route
|
||
body, headers = self.get_response(req)
|
||
self.prepare_write(s, body, headers)
|
||
return req.path
|
||
|
||
def prepare_write(self, s, body, headers):
|
||
# add newline to headers to signify transition to body
|
||
headers += "\r\n"
|
||
# TCP/IP MSS is 536 bytes, so create buffer of this size and
|
||
# initially populate with header data
|
||
buff = bytearray(headers + "\x00" * (536 - len(headers)))
|
||
# use memoryview to read directly into the buffer without copying
|
||
buffmv = memoryview(buff)
|
||
# start reading body data into the memoryview starting after
|
||
# the headers, and writing at most the remaining space of the buffer
|
||
# return the number of bytes written into the memoryview from the body
|
||
bw = body.readinto(buffmv[len(headers) :], 536 - len(headers))
|
||
# save place for next write event
|
||
c = WriteConn(body, buff, buffmv, [0, len(headers) + bw])
|
||
self.conns[id(s)] = c
|
||
# let the poller know we want to know when it's OK to write
|
||
self.poller.modify(s, select.POLLOUT)
|
||
|
||
def write_to(self, sock):
|
||
"""write the next message to an open socket"""
|
||
|
||
# get the data that needs to be written to this socket
|
||
c = self.conns[id(sock)]
|
||
if c:
|
||
# write next 536 bytes (max) into the socket
|
||
try:
|
||
bytes_written = sock.write(
|
||
c.buffmv[c.write_range[0] : c.write_range[1]]
|
||
)
|
||
except OSError:
|
||
print("cannot write to a closed socket")
|
||
self.close(sock)
|
||
return
|
||
if not bytes_written or c.write_range[1] < 536:
|
||
# either we wrote no bytes, or we wrote < TCP MSS of bytes
|
||
# so we're done with this connection
|
||
self.close(sock)
|
||
else:
|
||
# more to write, so read the next portion of the data into
|
||
# the memoryview for the next send event
|
||
self.buff_advance(c, bytes_written)
|
||
|
||
def buff_advance(self, c, bytes_written):
|
||
"""advance the writer buffer for this connection to next outgoing bytes"""
|
||
|
||
if bytes_written == c.write_range[1] - c.write_range[0]:
|
||
# wrote all the bytes we had buffered into the memoryview
|
||
# set next write start on the memoryview to the beginning
|
||
c.write_range[0] = 0
|
||
# set next write end on the memoryview to length of bytes
|
||
# read in from remainder of the body, up to TCP MSS
|
||
c.write_range[1] = c.body.readinto(c.buff, 536)
|
||
else:
|
||
# didn't read in all the bytes that were in the memoryview
|
||
# so just set next write start to where we ended the write
|
||
c.write_range[0] += bytes_written
|
||
|
||
def close(self, s):
|
||
"""close the socket, unregister from poller, and delete connection"""
|
||
|
||
s.close()
|
||
self.poller.unregister(s)
|
||
sid = id(s)
|
||
if sid in self.request:
|
||
del self.request[sid]
|
||
if sid in self.conns:
|
||
c = self.conns[sid]
|
||
# 检查body是文件对象(而不是BytesIO),则关闭文件
|
||
if hasattr(c.body, "close"):
|
||
c.body.close()
|
||
del self.conns[sid]
|
||
gc.collect()
|