236 lines
6.6 KiB
Python
236 lines
6.6 KiB
Python
"""Minimal example of Python websocket server
|
|
handling OTA updates for ESP32 amd ESP8266
|
|
|
|
Check and upload of firmware works.
|
|
Register and state function are jus for example.
|
|
"""
|
|
# pylint: disable=W0703,E1101
|
|
import asyncio
|
|
import copy
|
|
import json
|
|
import logging
|
|
import subprocess
|
|
import threading
|
|
import time
|
|
from os import listdir
|
|
from os.path import join as join_pth
|
|
from pathlib import Path
|
|
|
|
import websockets
|
|
from packaging import version
|
|
|
|
# Logger settings
|
|
logging.basicConfig(filename="ws_server.log")
|
|
Logger = logging.getLogger('WS-OTA')
|
|
Logger.addHandler(logging.StreamHandler())
|
|
Logger.setLevel(logging.INFO)
|
|
|
|
# Path to directory with FW
|
|
fw_path = join_pth(Path().absolute(), "firmware")
|
|
|
|
|
|
def create_path(path: str) -> None:
|
|
"""Check if path exist or create it"""
|
|
Path(path).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
def shell(command):
|
|
"""Handle execution of shell commands"""
|
|
with subprocess.Popen(command, shell=True,
|
|
stdout=subprocess.PIPE,
|
|
universal_newlines=True
|
|
) as process:
|
|
for stdout_line in iter(process.stdout.readline, ""):
|
|
Logger.debug(stdout_line)
|
|
process.stdout.close()
|
|
return_code = process.wait()
|
|
Logger.debug("Shell returned: %s", return_code)
|
|
|
|
return process.returncode
|
|
return None
|
|
|
|
|
|
async def binary_send(websocket, fw_file):
|
|
"""Read firmware file, divide it to chunks and send them"""
|
|
with open(fw_file, "rb") as binaryfile:
|
|
|
|
while True:
|
|
chunk = binaryfile.read(4096)
|
|
if not chunk:
|
|
break
|
|
try:
|
|
await websocket.send(chunk)
|
|
except Exception as exception:
|
|
Logger.exception(exception)
|
|
return False
|
|
asyncio.sleep(0.2)
|
|
|
|
|
|
def version_checker(name, vdev, vapp):
|
|
"""Parse and compare FW version"""
|
|
|
|
if version.parse(vdev) < version.parse(vapp):
|
|
Logger.info("Client(%s) version %s is smaller than %s: Go for update", name, vdev, vapp)
|
|
return True
|
|
Logger.info("Client(%s) version %s is greater or equal to %s: Not updating", name, vdev, vapp)
|
|
return False
|
|
|
|
|
|
class WsOtaHandler (threading.Thread):
|
|
"""Thread handling ota update
|
|
|
|
Running ota directly from message would kill WS
|
|
as message bus would timeout.
|
|
"""
|
|
def __init__(self, name, message, websocket):
|
|
threading.Thread.__init__(self, daemon=True)
|
|
self.name = name
|
|
self.msg = message
|
|
self.websocket = websocket
|
|
|
|
def run(self, ):
|
|
try:
|
|
asyncio.run(self.start_())
|
|
except Exception as exception:
|
|
Logger.exception(exception)
|
|
finally:
|
|
pass
|
|
|
|
async def start_(self):
|
|
"""Start _ota se asyncio future"""
|
|
msg_task = asyncio.ensure_future(
|
|
self._ota())
|
|
|
|
done, pending = await asyncio.wait(
|
|
[msg_task],
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
Logger.info("WS Ota Handler done: %s", done)
|
|
for task in pending:
|
|
task.cancel()
|
|
|
|
async def _ota(self):
|
|
"""Check for new fw and update or pass"""
|
|
device_name = self.msg['name']
|
|
device_chip = self.msg['chip']
|
|
device_version = self.msg['version']
|
|
fw_version = ''
|
|
fw_name = ''
|
|
fw_device = ''
|
|
|
|
for filename in listdir(fw_path):
|
|
fw_info = filename.split("-")
|
|
fw_device = fw_info[0]
|
|
if fw_device == device_name:
|
|
fw_version = fw_info[1]
|
|
fw_name = filename
|
|
break
|
|
|
|
if not fw_version:
|
|
Logger.info("Client(%s): No fw found!", device_name)
|
|
msg = '{"type": "ota", "value":"ok"}'
|
|
await self.websocket.send(msg)
|
|
return
|
|
|
|
if not version_checker(device_name, device_version, fw_version):
|
|
return
|
|
|
|
fw_file = join_pth(fw_path, fw_name)
|
|
if device_chip == 'esp8266' and not fw_file.endswith('.gz'):
|
|
# We can compress fw to make it smaller for upload
|
|
fw_cpress = fw_file
|
|
fw_file = fw_cpress + ".gz"
|
|
cpress = f"gzip -9 {fw_cpress}"
|
|
cstate = shell(cpress)
|
|
if cstate:
|
|
Logger.error("Cannot compress firmware: %s", fw_name)
|
|
return
|
|
|
|
# Get size of fw
|
|
size = Path(fw_file).stat().st_size
|
|
|
|
# Request ota mode
|
|
msg = '{"type": "ota", "value":"go", "size":' + str(size) + '}'
|
|
await self.websocket.send(msg)
|
|
|
|
# send file by chunks trough websocket
|
|
await binary_send(self.websocket, fw_file)
|
|
|
|
|
|
async def _register(websocket, message):
|
|
mac = message.get('mac')
|
|
name = message.get('name')
|
|
Logger.info("Client(%s) mac: %s", name, mac)
|
|
# Some code
|
|
|
|
response = {'type': 'registry', 'state': 'ok'}
|
|
await websocket.send(json.dumps(response))
|
|
|
|
|
|
async def _state(websocket, message):
|
|
mac = message.get('mac')
|
|
name = message.get('name')
|
|
Logger.info("Client(%s) mac: %s", name, mac)
|
|
# Some code
|
|
|
|
response = {'type': 'state', 'state': 'ok'}
|
|
await websocket.send(json.dumps(response))
|
|
|
|
|
|
async def _unhandled(websocket, msg):
|
|
Logger.info("Unhandled message from device: %s", str(msg))
|
|
response = {'type': 'response', 'state': 'nok'}
|
|
await websocket.send(json.dumps(response))
|
|
|
|
|
|
async def _greetings(websocket, message):
|
|
WsOtaHandler('thread_ota', copy.deepcopy(message), websocket).start()
|
|
|
|
|
|
async def message_received(websocket, message) -> None:
|
|
"""Handle incoming messages
|
|
|
|
Check if message contain json and run waned function
|
|
"""
|
|
switcher = {"greetings": _greetings,
|
|
"register": _register,
|
|
"state": _state
|
|
}
|
|
|
|
if message[0:1] == "{":
|
|
try:
|
|
msg_json = json.loads(message)
|
|
except Exception as exception:
|
|
Logger.error(exception)
|
|
return
|
|
|
|
type_ = msg_json.get('type')
|
|
name = msg_json.get('name')
|
|
func = switcher.get(type_, _unhandled)
|
|
Logger.debug("Client(%s)said: %s", name, type_)
|
|
|
|
try:
|
|
await func(websocket, msg_json)
|
|
except Exception as exception:
|
|
Logger.error(exception)
|
|
|
|
|
|
# pylint: disable=W0613
|
|
async def ws_server(websocket, path) -> None:
|
|
"""Run in cycle and wait for new messages"""
|
|
async for message in websocket:
|
|
await message_received(websocket, message)
|
|
|
|
|
|
async def main():
|
|
"""Server starter
|
|
|
|
Normal user can bind only port numbers greater than 1024
|
|
"""
|
|
async with websockets.serve(ws_server, "10.0.1.5", 8081):
|
|
await asyncio.Future() # run forever
|
|
|
|
|
|
create_path(fw_path)
|
|
asyncio.run(main())
|