Skip to content

Commit 93195fa

Browse files
committed
refactor: streamline message extraction and validation logic in PythAccumulator
1 parent dc2cd92 commit 93195fa

File tree

1 file changed

+70
-69
lines changed

1 file changed

+70
-69
lines changed

target_chains/ethereum/contracts/contracts/pyth/PythAccumulator.sol

+70-69
Original file line numberDiff line numberDiff line change
@@ -229,44 +229,22 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
229229
uint64 prevPublishTime
230230
)
231231
{
232-
unchecked {
233-
bytes calldata encodedMessage;
234-
uint16 messageSize = UnsafeCalldataBytesLib.toUint16(
235-
encoded,
236-
offset
237-
);
238-
offset += 2;
239-
240-
encodedMessage = UnsafeCalldataBytesLib.slice(
241-
encoded,
242-
offset,
243-
messageSize
244-
);
245-
offset += messageSize;
246-
247-
bool valid;
248-
(valid, endOffset) = MerkleTree.isProofValid(
249-
encoded,
250-
offset,
251-
digest,
252-
encodedMessage
253-
);
254-
if (!valid) {
255-
revert PythErrors.InvalidUpdateData();
256-
}
257-
258-
MessageType messageType = MessageType(
259-
UnsafeCalldataBytesLib.toUint8(encodedMessage, 0)
232+
bytes calldata encodedMessage;
233+
MessageType messageType;
234+
(
235+
encodedMessage,
236+
messageType,
237+
endOffset
238+
) = extractAndValidateEncodedMessage(encoded, offset, digest);
239+
240+
if (messageType == MessageType.PriceFeed) {
241+
(priceInfo, priceId, prevPublishTime) = parsePriceFeedMessage(
242+
encodedMessage,
243+
1
260244
);
261-
if (messageType == MessageType.PriceFeed) {
262-
(priceInfo, priceId, prevPublishTime) = parsePriceFeedMessage(
263-
encodedMessage,
264-
1
265-
);
266-
} else revert PythErrors.InvalidUpdateData();
245+
} else revert PythErrors.InvalidUpdateData();
267246

268-
return (endOffset, priceInfo, priceId, prevPublishTime);
269-
}
247+
return (endOffset, priceInfo, priceId, prevPublishTime);
270248
}
271249

272250
function extractTwapPriceInfoFromMerkleProof(
@@ -281,46 +259,69 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
281259
PythInternalStructs.TwapPriceInfo memory twapPriceInfo,
282260
bytes32 priceId
283261
)
284-
// TODO: Add the logic to extract the twap price info from the merkle proof
285262
{
286-
unchecked {
287-
bytes calldata encodedMessage;
288-
uint16 messageSize = UnsafeCalldataBytesLib.toUint16(
289-
encoded,
290-
offset
263+
bytes calldata encodedMessage;
264+
MessageType messageType;
265+
(
266+
encodedMessage,
267+
messageType,
268+
endOffset
269+
) = extractAndValidateEncodedMessage(encoded, offset, digest);
270+
271+
if (messageType == MessageType.TwapPriceFeed) {
272+
(twapPriceInfo, priceId) = parseTwapPriceFeedMessage(
273+
encodedMessage,
274+
1
291275
);
292-
offset += 2;
276+
} else revert PythErrors.InvalidUpdateData();
293277

294-
encodedMessage = UnsafeCalldataBytesLib.slice(
295-
encoded,
296-
offset,
297-
messageSize
298-
);
299-
offset += messageSize;
278+
return (endOffset, twapPriceInfo, priceId);
279+
}
300280

301-
bool valid;
302-
(valid, endOffset) = MerkleTree.isProofValid(
303-
encoded,
304-
offset,
305-
digest,
306-
encodedMessage
307-
);
308-
if (!valid) {
309-
revert PythErrors.InvalidUpdateData();
310-
}
281+
function extractAndValidateEncodedMessage(
282+
bytes calldata encoded,
283+
uint offset,
284+
bytes20 digest
285+
)
286+
private
287+
pure
288+
returns (
289+
bytes calldata encodedMessage,
290+
MessageType messageType,
291+
uint endOffset
292+
)
293+
{
294+
uint16 messageSize = UnsafeCalldataBytesLib.toUint16(encoded, offset);
295+
offset += 2;
311296

312-
MessageType messageType = MessageType(
313-
UnsafeCalldataBytesLib.toUint8(encodedMessage, 0)
314-
);
315-
if (messageType == MessageType.TwapPriceFeed) {
316-
(twapPriceInfo, priceId) = parseTwapPriceFeedMessage(
317-
encodedMessage,
318-
1
319-
);
320-
} else revert PythErrors.InvalidUpdateData();
297+
encodedMessage = UnsafeCalldataBytesLib.slice(
298+
encoded,
299+
offset,
300+
messageSize
301+
);
302+
offset += messageSize;
303+
304+
bool valid;
305+
(valid, endOffset) = MerkleTree.isProofValid(
306+
encoded,
307+
offset,
308+
digest,
309+
encodedMessage
310+
);
311+
if (!valid) {
312+
revert PythErrors.InvalidUpdateData();
313+
}
321314

322-
return (endOffset, twapPriceInfo, priceId);
315+
messageType = MessageType(
316+
UnsafeCalldataBytesLib.toUint8(encodedMessage, 0)
317+
);
318+
if (
319+
messageType != MessageType.PriceFeed &&
320+
messageType != MessageType.TwapPriceFeed
321+
) {
322+
revert PythErrors.InvalidUpdateData();
323323
}
324+
return (encodedMessage, messageType, endOffset);
324325
}
325326

326327
function parsePriceFeedMessage(

0 commit comments

Comments
 (0)