Skip to content

Commit ed8ad2f

Browse files
committed
Enabled automatic ATA creation in CW
1 parent aa71d84 commit ed8ad2f

File tree

4 files changed

+126
-238
lines changed

4 files changed

+126
-238
lines changed

pkg/solana/chainwriter/chain_writer.go

+87-3
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ package chainwriter
33
import (
44
"context"
55
"encoding/json"
6+
"errors"
67
"fmt"
78
"math/big"
89

910
"github.com/gagliardetto/solana-go"
1011
addresslookuptable "github.com/gagliardetto/solana-go/programs/address-lookup-table"
1112
"github.com/gagliardetto/solana-go/rpc"
1213

14+
"github.com/smartcontractkit/chainlink-ccip/chains/solana/utils/tokens"
1315
commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec"
1416
"github.com/smartcontractkit/chainlink-common/pkg/logger"
1517
"github.com/smartcontractkit/chainlink-common/pkg/services"
@@ -55,6 +57,7 @@ type MethodConfig struct {
5557
FromAddress string
5658
InputModifications commoncodec.ModifiersConfig
5759
ChainSpecificName string
60+
ATAs []ATALookup
5861
LookupTables LookupTables
5962
Accounts []Lookup
6063
// Location in the args where the debug ID is stored
@@ -214,6 +217,79 @@ func (s *SolanaChainWriterService) FilterLookupTableAddresses(
214217
return filteredLookupTables
215218
}
216219

220+
// CreateATAs first checks if a specified location exists, then checks if the accounts derived from the
221+
// ATALookups in the ChainWriter's configuration exist on-chain and creates them if they do not.
222+
func CreateATAs(ctx context.Context, args any, lookups []ATALookup, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.Reader, idl string, feePayer solana.PublicKey) ([]solana.Instruction, error) {
223+
createATAInstructions := []solana.Instruction{}
224+
for _, lookup := range lookups {
225+
// Check if location exists
226+
if lookup.Location != "" {
227+
// TODO refactor GetValuesAtLocation to not return an error if the field doesn't exist
228+
_, err := GetValuesAtLocation(args, lookup.Location)
229+
if err != nil {
230+
// field doesn't exist, so ignore ATA creation
231+
if errors.Is(err, errFieldNotFound) {
232+
continue
233+
}
234+
return nil, fmt.Errorf("error getting values at location: %w", err)
235+
}
236+
}
237+
walletAddresses, err := GetAddresses(ctx, args, []Lookup{lookup.WalletAddress}, derivedTableMap, reader, idl)
238+
if err != nil {
239+
return nil, fmt.Errorf("error resolving wallet address: %w", err)
240+
}
241+
if len(walletAddresses) != 1 {
242+
return nil, fmt.Errorf("expected exactly one wallet address, got %d", len(walletAddresses))
243+
}
244+
wallet := walletAddresses[0].PublicKey
245+
246+
tokenPrograms, err := GetAddresses(ctx, args, []Lookup{lookup.TokenProgram}, derivedTableMap, reader, idl)
247+
if err != nil {
248+
return nil, fmt.Errorf("error resolving token program address: %w", err)
249+
}
250+
251+
mints, err := GetAddresses(ctx, args, []Lookup{lookup.MintAddress}, derivedTableMap, reader, idl)
252+
if err != nil {
253+
return nil, fmt.Errorf("error resolving mint address: %w", err)
254+
}
255+
256+
if len(tokenPrograms) != len(mints) {
257+
return nil, fmt.Errorf("expected equal number of token programs and mints, got %d tokenPrograms and %d mints", len(tokenPrograms), len(mints))
258+
}
259+
260+
for i := range tokenPrograms {
261+
tokenProgram := tokenPrograms[i].PublicKey
262+
mint := mints[i].PublicKey
263+
264+
ataAddress, _, err := tokens.FindAssociatedTokenAddress(tokenProgram, mint, wallet)
265+
if err != nil {
266+
return nil, fmt.Errorf("error deriving ATA: %w", err)
267+
}
268+
269+
accountInfo, err := reader.GetAccountInfoWithOpts(ctx, ataAddress, &rpc.GetAccountInfoOpts{
270+
Encoding: "base64",
271+
Commitment: rpc.CommitmentFinalized,
272+
})
273+
if err != nil {
274+
return nil, fmt.Errorf("error checking ATA %s on-chain: %w", ataAddress, err)
275+
}
276+
277+
// Check if account exists on-chain
278+
if accountInfo.Value != nil {
279+
continue
280+
}
281+
282+
ins, _, err := tokens.CreateAssociatedTokenAccount(tokenProgram, mint, wallet, feePayer)
283+
if err != nil {
284+
return nil, fmt.Errorf("error creating associated token account: %w", err)
285+
}
286+
createATAInstructions = append(createATAInstructions, ins)
287+
}
288+
}
289+
290+
return createATAInstructions, nil
291+
}
292+
217293
// SubmitTransaction builds, encodes, and enqueues a transaction using the provided program
218294
// configuration and method details. It relies on the configured IDL, account lookups, and
219295
// lookup tables to gather the necessary accounts and data. The function retrieves the latest
@@ -274,6 +350,11 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra
274350
return errorWithDebugID(fmt.Errorf("error parsing fee payer address: %w", err), debugID)
275351
}
276352

353+
createATAinstructions, err := CreateATAs(ctx, args, methodConfig.ATAs, derivedTableMap, s.reader, programConfig.IDL, feePayer)
354+
if err != nil {
355+
return errorWithDebugID(fmt.Errorf("error resolving account addresses: %w", err), debugID)
356+
}
357+
277358
// Filter the lookup table addresses based on which accounts are actually used
278359
filteredLookupTableMap := s.FilterLookupTableAddresses(accounts, derivedTableMap, staticTableMap)
279360

@@ -310,10 +391,13 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra
310391
discriminator := GetDiscriminator(methodConfig.ChainSpecificName)
311392
encodedPayload = append(discriminator[:], encodedPayload...)
312393

394+
// Combine the two sets of instructions into one slice
395+
var instructions []solana.Instruction
396+
instructions = append(instructions, createATAinstructions...)
397+
instructions = append(instructions, solana.NewInstruction(programID, accounts, encodedPayload))
398+
313399
tx, err := solana.NewTransaction(
314-
[]solana.Instruction{
315-
solana.NewInstruction(programID, accounts, encodedPayload),
316-
},
400+
instructions,
317401
blockhash.Value.Blockhash,
318402
solana.TransactionPayer(feePayer),
319403
solana.TransactionAddressTables(filteredLookupTableMap),

pkg/solana/chainwriter/chain_writer_test.go

-215
Original file line numberDiff line numberDiff line change
@@ -799,221 +799,6 @@ func TestChainWriter_CCIPRouter(t *testing.T) {
799799
})
800800
}
801801

802-
func TestChainWriter_CCIPRouter(t *testing.T) {
803-
t.Parallel()
804-
805-
// setup admin key
806-
adminPk, err := solana.NewRandomPrivateKey()
807-
require.NoError(t, err)
808-
admin := adminPk.PublicKey()
809-
810-
routerAddr := chainwriter.GetRandomPubKey(t)
811-
destTokenAddr := chainwriter.GetRandomPubKey(t)
812-
813-
poolKeys := []solana.PublicKey{destTokenAddr}
814-
poolKeys = append(poolKeys, chainwriter.CreateTestPubKeys(t, 3)...)
815-
816-
// simplified CCIP Config - does not contain full account list
817-
ccipCWConfig := chainwriter.ChainWriterConfig{
818-
Programs: map[string]chainwriter.ProgramConfig{
819-
"ccip_router": {
820-
Methods: map[string]chainwriter.MethodConfig{
821-
"execute": {
822-
FromAddress: admin.String(),
823-
InputModifications: []codec.ModifierConfig{
824-
&codec.RenameModifierConfig{
825-
Fields: map[string]string{"ReportContextByteWords": "ReportContext"},
826-
},
827-
&codec.RenameModifierConfig{
828-
Fields: map[string]string{"RawExecutionReport": "Report"},
829-
},
830-
},
831-
ChainSpecificName: "execute",
832-
ArgsTransform: "CCIP",
833-
LookupTables: chainwriter.LookupTables{},
834-
Accounts: []chainwriter.Lookup{
835-
chainwriter.AccountConstant{
836-
Name: "testAcc1",
837-
Address: chainwriter.GetRandomPubKey(t).String(),
838-
},
839-
chainwriter.AccountConstant{
840-
Name: "testAcc2",
841-
Address: chainwriter.GetRandomPubKey(t).String(),
842-
},
843-
chainwriter.AccountConstant{
844-
Name: "testAcc3",
845-
Address: chainwriter.GetRandomPubKey(t).String(),
846-
},
847-
chainwriter.AccountConstant{
848-
Name: "poolAddr1",
849-
Address: poolKeys[0].String(),
850-
},
851-
chainwriter.AccountConstant{
852-
Name: "poolAddr2",
853-
Address: poolKeys[1].String(),
854-
},
855-
chainwriter.AccountConstant{
856-
Name: "poolAddr3",
857-
Address: poolKeys[2].String(),
858-
},
859-
chainwriter.AccountConstant{
860-
Name: "poolAddr4",
861-
Address: poolKeys[3].String(),
862-
},
863-
},
864-
},
865-
"commit": {
866-
FromAddress: admin.String(),
867-
InputModifications: []codec.ModifierConfig{
868-
&codec.RenameModifierConfig{
869-
Fields: map[string]string{"ReportContextByteWords": "ReportContext"},
870-
},
871-
&codec.RenameModifierConfig{
872-
Fields: map[string]string{"RawReport": "Report"},
873-
},
874-
},
875-
ChainSpecificName: "commit",
876-
ArgsTransform: "",
877-
LookupTables: chainwriter.LookupTables{},
878-
Accounts: []chainwriter.Lookup{
879-
chainwriter.AccountConstant{
880-
Name: "testAcc1",
881-
Address: chainwriter.GetRandomPubKey(t).String(),
882-
},
883-
chainwriter.AccountConstant{
884-
Name: "testAcc2",
885-
Address: chainwriter.GetRandomPubKey(t).String(),
886-
},
887-
chainwriter.AccountConstant{
888-
Name: "testAcc3",
889-
Address: chainwriter.GetRandomPubKey(t).String(),
890-
},
891-
},
892-
},
893-
},
894-
IDL: ccipRouterIDL,
895-
},
896-
},
897-
}
898-
899-
ctx := tests.Context(t)
900-
// mock client
901-
rw := clientmocks.NewReaderWriter(t)
902-
// mock estimator
903-
ge := feemocks.NewEstimator(t)
904-
905-
t.Run("CCIP execute is encoded successfully and ArgsTransform is applied correctly.", func(t *testing.T) {
906-
// mock txm
907-
txm := txmMocks.NewTxManager(t)
908-
// initialize chain writer
909-
cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), rw, txm, ge, ccipCWConfig)
910-
require.NoError(t, err)
911-
912-
recentBlockHash := solana.Hash{}
913-
rw.On("LatestBlockhash", mock.Anything).Return(&rpc.GetLatestBlockhashResult{Value: &rpc.LatestBlockhashResult{Blockhash: recentBlockHash, LastValidBlockHeight: uint64(100)}}, nil).Once()
914-
915-
pda, _, err := solana.FindProgramAddress([][]byte{[]byte("token_admin_registry"), destTokenAddr.Bytes()}, routerAddr)
916-
require.NoError(t, err)
917-
918-
lookupTable := mockTokenAdminRegistryLookupTable(t, rw, pda)
919-
920-
mockFetchLookupTableAddresses(t, rw, lookupTable, poolKeys)
921-
922-
txID := uuid.NewString()
923-
txm.On("Enqueue", mock.Anything, admin.String(), mock.MatchedBy(func(tx *solana.Transaction) bool {
924-
txData := tx.Message.Instructions[0].Data
925-
payload := txData[8:]
926-
var decoded ccip_router.Execute
927-
dec := ag_binary.NewBorshDecoder(payload)
928-
err = dec.Decode(&decoded)
929-
require.NoError(t, err)
930-
931-
tokenIndexes := *decoded.TokenIndexes
932-
933-
require.Len(t, tokenIndexes, 1)
934-
require.Equal(t, uint8(3), tokenIndexes[0])
935-
return true
936-
}), &txID, mock.Anything).Return(nil).Once()
937-
938-
// stripped back report just for purposes of example
939-
abstractReport := ccipocr3.ExecutePluginReportSingleChain{
940-
Messages: []ccipocr3.Message{
941-
{
942-
TokenAmounts: []ccipocr3.RampTokenAmount{
943-
{
944-
DestTokenAddress: destTokenAddr.Bytes(),
945-
},
946-
},
947-
},
948-
},
949-
}
950-
951-
// Marshal the abstract report to json just for testing purposes.
952-
encodedReport, err := json.Marshal(abstractReport)
953-
require.NoError(t, err)
954-
955-
args := chainwriter.ReportPreTransform{
956-
ReportContext: [2][32]byte{{0x01}, {0x02}},
957-
Report: encodedReport,
958-
Info: ccipocr3.ExecuteReportInfo{
959-
MerkleRoots: []ccipocr3.MerkleRootChain{},
960-
AbstractReports: []ccipocr3.ExecutePluginReportSingleChain{abstractReport},
961-
},
962-
}
963-
964-
submitErr := cw.SubmitTransaction(ctx, "ccip_router", "execute", args, txID, routerAddr.String(), nil, nil)
965-
require.NoError(t, submitErr)
966-
})
967-
968-
t.Run("CCIP commit is encoded successfully", func(t *testing.T) {
969-
// mock txm
970-
txm := txmMocks.NewTxManager(t)
971-
// initialize chain writer
972-
cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), rw, txm, ge, ccipCWConfig)
973-
require.NoError(t, err)
974-
975-
recentBlockHash := solana.Hash{}
976-
rw.On("LatestBlockhash", mock.Anything).Return(&rpc.GetLatestBlockhashResult{Value: &rpc.LatestBlockhashResult{Blockhash: recentBlockHash, LastValidBlockHeight: uint64(100)}}, nil).Once()
977-
978-
type CommitArgs struct {
979-
ReportContext [2][32]byte
980-
Report []byte
981-
Rs [][32]byte
982-
Ss [][32]byte
983-
RawVs [32]byte
984-
Info ccipocr3.CommitReportInfo
985-
}
986-
987-
txID := uuid.NewString()
988-
989-
// TODO: Replace with actual type from ccipocr3
990-
args := CommitArgs{
991-
ReportContext: [2][32]byte{{0x01}, {0x02}},
992-
Report: []byte{0x01, 0x02},
993-
Rs: [][32]byte{{0x01, 0x02}},
994-
Ss: [][32]byte{{0x01, 0x02}},
995-
RawVs: [32]byte{0x01, 0x02},
996-
Info: ccipocr3.CommitReportInfo{
997-
RemoteF: 1,
998-
MerkleRoots: []ccipocr3.MerkleRootChain{},
999-
},
1000-
}
1001-
1002-
txm.On("Enqueue", mock.Anything, admin.String(), mock.MatchedBy(func(tx *solana.Transaction) bool {
1003-
txData := tx.Message.Instructions[0].Data
1004-
payload := txData[8:]
1005-
var decoded ccip_router.Commit
1006-
dec := ag_binary.NewBorshDecoder(payload)
1007-
err := dec.Decode(&decoded)
1008-
require.NoError(t, err)
1009-
return true
1010-
}), &txID, mock.Anything).Return(nil).Once()
1011-
1012-
submitErr := cw.SubmitTransaction(ctx, "ccip_router", "commit", args, txID, routerAddr.String(), nil, nil)
1013-
require.NoError(t, submitErr)
1014-
})
1015-
}
1016-
1017802
func TestChainWriter_GetTransactionStatus(t *testing.T) {
1018803
t.Parallel()
1019804

0 commit comments

Comments
 (0)