Skip to content

Commit 3d26810

Browse files
committed
Merge pull request #1056 from me-no-dev/async-ota
Async ota
2 parents b5ca4fe + fe9dc91 commit 3d26810

File tree

2 files changed

+159
-99
lines changed

2 files changed

+159
-99
lines changed

libraries/ArduinoOTA/ArduinoOTA.cpp

+151-97
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,23 @@
1-
#include <ESP8266WiFi.h>
2-
#include <ESP8266mDNS.h>
1+
#define LWIP_OPEN_SRC
2+
#include <functional>
33
#include <WiFiUdp.h>
44
#include "ArduinoOTA.h"
55
#include "MD5Builder.h"
66

7+
extern "C" {
8+
#include "osapi.h"
9+
#include "ets_sys.h"
10+
#include "user_interface.h"
11+
}
12+
13+
#include "lwip/opt.h"
14+
#include "lwip/udp.h"
15+
#include "lwip/inet.h"
16+
#include "lwip/igmp.h"
17+
#include "lwip/mem.h"
18+
#include "include/UdpContext.h"
19+
#include <ESP8266mDNS.h>
20+
721
//#define OTA_DEBUG 1
822

923
ArduinoOTAClass::ArduinoOTAClass()
@@ -16,9 +30,17 @@ ArduinoOTAClass::ArduinoOTAClass()
1630
, _end_callback(NULL)
1731
, _progress_callback(NULL)
1832
, _error_callback(NULL)
33+
, _udp_ota(0)
1934
{
2035
}
2136

37+
ArduinoOTAClass::~ArduinoOTAClass(){
38+
if(_udp_ota){
39+
_udp_ota->unref();
40+
_udp_ota = 0;
41+
}
42+
}
43+
2244
void ArduinoOTAClass::onStart(OTA_CALLBACK(fn)) {
2345
_start_callback = fn;
2446
}
@@ -35,9 +57,6 @@ void ArduinoOTAClass::onError(OTA_CALLBACK_ERROR(fn)) {
3557
_error_callback = fn;
3658
}
3759

38-
ArduinoOTAClass::~ArduinoOTAClass() {
39-
}
40-
4160
void ArduinoOTAClass::setPort(uint16_t port) {
4261
if (!_initialized && !_port && port) {
4362
_port = port;
@@ -59,7 +78,6 @@ void ArduinoOTAClass::setPassword(const char * password) {
5978
void ArduinoOTAClass::begin() {
6079
if (_initialized)
6180
return;
62-
_initialized = true;
6381

6482
if (!_hostname.length()) {
6583
char tmp[15];
@@ -70,20 +88,141 @@ void ArduinoOTAClass::begin() {
7088
_port = 8266;
7189
}
7290

73-
_udp_ota.begin(_port);
91+
if(_udp_ota){
92+
_udp_ota->unref();
93+
_udp_ota = 0;
94+
}
95+
96+
_udp_ota = new UdpContext;
97+
_udp_ota->ref();
98+
99+
if(!_udp_ota->listen(*IP_ADDR_ANY, _port))
100+
return;
101+
_udp_ota->onRx(std::bind(&ArduinoOTAClass::_onRx, this));
74102
MDNS.begin(_hostname.c_str());
75103

76104
if (_password.length()) {
77105
MDNS.enableArduino(_port, true);
78106
} else {
79107
MDNS.enableArduino(_port);
80108
}
109+
_initialized = true;
81110
_state = OTA_IDLE;
82111
#if OTA_DEBUG
83112
Serial.printf("OTA server at: %s.local:%u\n", _hostname.c_str(), _port);
84113
#endif
85114
}
86115

116+
int ArduinoOTAClass::parseInt(){
117+
char data[16];
118+
uint8_t index = 0;
119+
char value;
120+
while(_udp_ota->peek() == ' ') _udp_ota->read();
121+
while(true){
122+
value = _udp_ota->peek();
123+
if(value < '0' || value > '9'){
124+
data[index++] = '\0';
125+
return atoi(data);
126+
}
127+
data[index++] = _udp_ota->read();
128+
}
129+
return 0;
130+
}
131+
132+
String ArduinoOTAClass::readStringUntil(char end){
133+
String res = "";
134+
char value;
135+
while(true){
136+
value = _udp_ota->read();
137+
if(value == '\0' || value == end){
138+
return res;
139+
}
140+
res += value;
141+
}
142+
return res;
143+
}
144+
145+
void ArduinoOTAClass::_onRx(){
146+
if(!_udp_ota->next()) return;
147+
ip_addr_t ota_ip;
148+
149+
if (_state == OTA_IDLE) {
150+
int cmd = parseInt();
151+
if (cmd != U_FLASH && cmd != U_SPIFFS)
152+
return;
153+
_ota_ip = _udp_ota->getRemoteAddress();
154+
_cmd = cmd;
155+
_ota_port = parseInt();
156+
_size = parseInt();
157+
_udp_ota->read();
158+
_md5 = readStringUntil('\n');
159+
_md5.trim();
160+
if(_md5.length() != 32)
161+
return;
162+
163+
ota_ip.addr = (uint32_t)_ota_ip;
164+
165+
if (_password.length()){
166+
MD5Builder nonce_md5;
167+
nonce_md5.begin();
168+
nonce_md5.add(String(micros()));
169+
nonce_md5.calculate();
170+
_nonce = nonce_md5.toString();
171+
172+
char auth_req[38];
173+
sprintf(auth_req, "AUTH %s", _nonce.c_str());
174+
_udp_ota->append((const char *)auth_req, strlen(auth_req));
175+
_udp_ota->send(&ota_ip, _udp_ota->getRemotePort());
176+
_state = OTA_WAITAUTH;
177+
return;
178+
} else {
179+
_udp_ota->append("OK", 2);
180+
_udp_ota->send(&ota_ip, _udp_ota->getRemotePort());
181+
_state = OTA_RUNUPDATE;
182+
}
183+
} else if (_state == OTA_WAITAUTH) {
184+
int cmd = parseInt();
185+
if (cmd != U_AUTH) {
186+
_state = OTA_IDLE;
187+
return;
188+
}
189+
_udp_ota->read();
190+
String cnonce = readStringUntil(' ');
191+
String response = readStringUntil('\n');
192+
if (cnonce.length() != 32 || response.length() != 32) {
193+
_state = OTA_IDLE;
194+
return;
195+
}
196+
197+
MD5Builder _passmd5;
198+
_passmd5.begin();
199+
_passmd5.add(_password);
200+
_passmd5.calculate();
201+
String passmd5 = _passmd5.toString();
202+
203+
String challenge = passmd5 + ":" + String(_nonce) + ":" + cnonce;
204+
MD5Builder _challengemd5;
205+
_challengemd5.begin();
206+
_challengemd5.add(challenge);
207+
_challengemd5.calculate();
208+
String result = _challengemd5.toString();
209+
210+
ota_ip.addr = (uint32_t)_ota_ip;
211+
if(result.equals(response)){
212+
_udp_ota->append("OK", 2);
213+
_udp_ota->send(&ota_ip, _udp_ota->getRemotePort());
214+
_state = OTA_RUNUPDATE;
215+
} else {
216+
_udp_ota->append("Authentication Failed", 21);
217+
_udp_ota->send(&ota_ip, _udp_ota->getRemotePort());
218+
if (_error_callback) _error_callback(OTA_AUTH_ERROR);
219+
_state = OTA_IDLE;
220+
}
221+
}
222+
223+
while(_udp_ota->next()) _udp_ota->flush();
224+
}
225+
87226
void ArduinoOTAClass::_runUpdate() {
88227
if (!Update.begin(_size, _cmd)) {
89228
#if OTA_DEBUG
@@ -92,7 +231,7 @@ void ArduinoOTAClass::_runUpdate() {
92231
if (_error_callback) {
93232
_error_callback(OTA_BEGIN_ERROR);
94233
}
95-
_udp_ota.begin(_port);
234+
_udp_ota->listen(*IP_ADDR_ANY, _port);
96235
_state = OTA_IDLE;
97236
return;
98237
}
@@ -112,7 +251,7 @@ void ArduinoOTAClass::_runUpdate() {
112251
#if OTA_DEBUG
113252
Serial.printf("Connect Failed\n");
114253
#endif
115-
_udp_ota.begin(_port);
254+
_udp_ota->listen(*IP_ADDR_ANY, _port);
116255
if (_error_callback) {
117256
_error_callback(OTA_CONNECT_ERROR);
118257
}
@@ -128,7 +267,7 @@ void ArduinoOTAClass::_runUpdate() {
128267
#if OTA_DEBUG
129268
Serial.printf("Recieve Failed\n");
130269
#endif
131-
_udp_ota.begin(_port);
270+
_udp_ota->listen(*IP_ADDR_ANY, _port);
132271
if (_error_callback) {
133272
_error_callback(OTA_RECIEVE_ERROR);
134273
}
@@ -156,7 +295,7 @@ void ArduinoOTAClass::_runUpdate() {
156295
}
157296
ESP.restart();
158297
} else {
159-
_udp_ota.begin(_port);
298+
_udp_ota->listen(*IP_ADDR_ANY, _port);
160299
if (_error_callback) {
161300
_error_callback(OTA_END_ERROR);
162301
}
@@ -169,94 +308,9 @@ void ArduinoOTAClass::_runUpdate() {
169308
}
170309

171310
void ArduinoOTAClass::handle() {
172-
if (!_udp_ota) {
173-
_udp_ota.begin(_port);
174-
#if OTA_DEBUG
175-
Serial.println("OTA restarted");
176-
#endif
177-
}
178-
179-
if (!_udp_ota.parsePacket()) return;
180-
181-
if (_state == OTA_IDLE) {
182-
int cmd = _udp_ota.parseInt();
183-
if (cmd != U_FLASH && cmd != U_SPIFFS)
184-
return;
185-
_ota_ip = _udp_ota.remoteIP();
186-
_cmd = cmd;
187-
_ota_port = _udp_ota.parseInt();
188-
_size = _udp_ota.parseInt();
189-
_udp_ota.read();
190-
_md5 = _udp_ota.readStringUntil('\n');
191-
_md5.trim();
192-
if(_md5.length() != 32)
193-
return;
194-
195-
#if OTA_DEBUG
196-
Serial.print("Update Start: ip:");
197-
Serial.print(_ota_ip);
198-
Serial.printf(", port:%d, size:%d, md5:%s\n", _ota_port, _size, _md5.c_str());
199-
#endif
200-
201-
_udp_ota.beginPacket(_ota_ip, _udp_ota.remotePort());
202-
if (_password.length()){
203-
MD5Builder nonce_md5;
204-
nonce_md5.begin();
205-
nonce_md5.add(String(micros()));
206-
nonce_md5.calculate();
207-
_nonce = nonce_md5.toString();
208-
_udp_ota.printf("AUTH %s", _nonce.c_str());
209-
_udp_ota.endPacket();
210-
_state = OTA_WAITAUTH;
211-
return;
212-
} else {
213-
_udp_ota.print("OK");
214-
_udp_ota.endPacket();
215-
_state = OTA_RUNUPDATE;
216-
}
217-
} else if (_state == OTA_WAITAUTH) {
218-
int cmd = _udp_ota.parseInt();
219-
if (cmd != U_AUTH) {
220-
_state = OTA_IDLE;
221-
return;
222-
}
223-
_udp_ota.read();
224-
String cnonce = _udp_ota.readStringUntil(' ');
225-
String response = _udp_ota.readStringUntil('\n');
226-
if (cnonce.length() != 32 || response.length() != 32) {
227-
_state = OTA_IDLE;
228-
return;
229-
}
230-
231-
MD5Builder _passmd5;
232-
_passmd5.begin();
233-
_passmd5.add(_password);
234-
_passmd5.calculate();
235-
String passmd5 = _passmd5.toString();
236-
237-
String challenge = passmd5 + ":" + String(_nonce) + ":" + cnonce;
238-
MD5Builder _challengemd5;
239-
_challengemd5.begin();
240-
_challengemd5.add(challenge);
241-
_challengemd5.calculate();
242-
String result = _challengemd5.toString();
243-
244-
if(result.equals(response)){
245-
_udp_ota.beginPacket(_ota_ip, _udp_ota.remotePort());
246-
_udp_ota.print("OK");
247-
_udp_ota.endPacket();
248-
_state = OTA_RUNUPDATE;
249-
} else {
250-
_udp_ota.beginPacket(_ota_ip, _udp_ota.remotePort());
251-
_udp_ota.print("Authentication Failed");
252-
_udp_ota.endPacket();
253-
if (_error_callback) _error_callback(OTA_AUTH_ERROR);
254-
_state = OTA_IDLE;
255-
}
256-
}
257-
258311
if (_state == OTA_RUNUPDATE) {
259312
_runUpdate();
313+
_state = OTA_IDLE;
260314
}
261315
}
262316

libraries/ArduinoOTA/ArduinoOTA.h

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
#ifndef __ARDUINO_OTA_H
22
#define __ARDUINO_OTA_H
33

4-
class WiFiUDP;
4+
#include <ESP8266WiFi.h>
5+
#include <WiFiUdp.h>
6+
7+
class UdpContext;
58

69
#define OTA_CALLBACK(callback) void (*callback)()
710
#define OTA_CALLBACK_PROGRESS(callback) void (*callback)(unsigned int, unsigned int)
@@ -41,7 +44,7 @@ class ArduinoOTAClass
4144
String _password;
4245
String _hostname;
4346
String _nonce;
44-
WiFiUDP _udp_ota;
47+
UdpContext *_udp_ota;
4548
bool _initialized;
4649
ota_state_t _state;
4750
int _size;
@@ -56,6 +59,9 @@ class ArduinoOTAClass
5659
OTA_CALLBACK_PROGRESS(_progress_callback);
5760

5861
void _runUpdate(void);
62+
void _onRx(void);
63+
int parseInt(void);
64+
String readStringUntil(char end);
5965
};
6066

6167
extern ArduinoOTAClass ArduinoOTA;

0 commit comments

Comments
 (0)