diff --git a/impl/graphsync_test.go b/impl/graphsync_test.go index 9a50fb4f..921f4215 100644 --- a/impl/graphsync_test.go +++ b/impl/graphsync_test.go @@ -63,9 +63,9 @@ var protocolsForTest = map[string]struct { host1Protocols []protocol.ID host2Protocols []protocol.ID }{ - "(v1.1 -> v1.1)": {nil, nil}, - "(v1.0 -> v1.1)": {[]protocol.ID{gsnet.ProtocolGraphsync_1_0_0}, nil}, - "(v1.1 -> v1.0)": {nil, []protocol.ID{gsnet.ProtocolGraphsync_1_0_0}}, + "(v2.0 -> v2.0)": {nil, nil}, + "(v1.0 -> v2.0)": {[]protocol.ID{gsnet.ProtocolGraphsync_1_0_0}, nil}, + "(v2.0 -> v1.0)": {nil, []protocol.ID{gsnet.ProtocolGraphsync_1_0_0}}, "(v1.0 -> v1.0)": {[]protocol.ID{gsnet.ProtocolGraphsync_1_0_0}, []protocol.ID{gsnet.ProtocolGraphsync_1_0_0}}, } diff --git a/message/v1/message.go b/message/v1/message.go index c554e281..e8b458a4 100644 --- a/message/v1/message.go +++ b/message/v1/message.go @@ -8,16 +8,17 @@ import ( blocks "github.com/ipfs/go-block-format" "github.com/ipfs/go-cid" - "github.com/ipfs/go-graphsync" - "github.com/ipfs/go-graphsync/ipldutil" - "github.com/ipfs/go-graphsync/message" - pb "github.com/ipfs/go-graphsync/message/pb" "github.com/ipld/go-ipld-prime/datamodel" pool "github.com/libp2p/go-buffer-pool" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-msgio" "google.golang.org/protobuf/proto" + + "github.com/ipfs/go-graphsync" + "github.com/ipfs/go-graphsync/ipldutil" + "github.com/ipfs/go-graphsync/message" + pb "github.com/ipfs/go-graphsync/message/pb" ) type MessagePartWithExtensions interface { diff --git a/message/v2/message.go b/message/v2/message.go index 0e41f503..7c63a3a7 100644 --- a/message/v2/message.go +++ b/message/v2/message.go @@ -7,15 +7,16 @@ import ( blocks "github.com/ipfs/go-block-format" "github.com/ipfs/go-cid" - "github.com/ipfs/go-graphsync" - "github.com/ipfs/go-graphsync/message" - "github.com/ipfs/go-graphsync/message/ipldbind" "github.com/ipld/go-ipld-prime/codec/dagcbor" "github.com/ipld/go-ipld-prime/datamodel" "github.com/ipld/go-ipld-prime/node/bindnode" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-msgio" + + "github.com/ipfs/go-graphsync" + "github.com/ipfs/go-graphsync/message" + "github.com/ipfs/go-graphsync/message/ipldbind" ) type MessageHandler struct{} diff --git a/network/libp2p_impl.go b/network/libp2p_impl.go index bb04456e..a423ab54 100644 --- a/network/libp2p_impl.go +++ b/network/libp2p_impl.go @@ -16,6 +16,7 @@ import ( gsmsg "github.com/ipfs/go-graphsync/message" gsmsgv1 "github.com/ipfs/go-graphsync/message/v1" + gsmsgv2 "github.com/ipfs/go-graphsync/message/v2" ) var log = logging.Logger("graphsync_network") @@ -35,10 +36,14 @@ func GraphsyncProtocols(protocols []protocol.ID) Option { // NewFromLibp2pHost returns a GraphSyncNetwork supported by underlying Libp2p host. func NewFromLibp2pHost(host host.Host, options ...Option) GraphSyncNetwork { + messageHandlerSelector := messageHandlerSelector{ + v1MessageHandler: gsmsgv1.NewMessageHandler(), + v2MessageHandler: gsmsgv2.NewMessageHandler(), + } graphSyncNetwork := libp2pGraphSyncNetwork{ - host: host, - messageHandler: gsmsgv1.NewMessageHandler(), - protocols: []protocol.ID{ProtocolGraphsync_1_0_0, ProtocolGraphsync_2_0_0}, + host: host, + messageHandlerSelector: &messageHandlerSelector, + protocols: []protocol.ID{ProtocolGraphsync_2_0_0, ProtocolGraphsync_1_0_0}, } for _, option := range options { @@ -48,20 +53,53 @@ func NewFromLibp2pHost(host host.Host, options ...Option) GraphSyncNetwork { return &graphSyncNetwork } +// a message.MessageHandler that simply returns an error for any of the calls, allows +// us to simplify erroring on bad protocol within the messageHandlerSelector#Select() +// call so we only have one place to be strict about allowed versions +type messageHandlerErrorer struct { + err error +} + +func (mhe messageHandlerErrorer) FromNet(peer.ID, io.Reader) (gsmsg.GraphSyncMessage, error) { + return gsmsg.GraphSyncMessage{}, mhe.err +} +func (mhe messageHandlerErrorer) FromMsgReader(peer.ID, msgio.Reader) (gsmsg.GraphSyncMessage, error) { + return gsmsg.GraphSyncMessage{}, mhe.err +} +func (mhe messageHandlerErrorer) ToNet(peer.ID, gsmsg.GraphSyncMessage, io.Writer) error { + return mhe.err +} + +type messageHandlerSelector struct { + v1MessageHandler gsmsg.MessageHandler + v2MessageHandler gsmsg.MessageHandler +} + +func (smh messageHandlerSelector) Select(protocol protocol.ID) gsmsg.MessageHandler { + switch protocol { + case ProtocolGraphsync_1_0_0: + return smh.v1MessageHandler + case ProtocolGraphsync_2_0_0: + return smh.v2MessageHandler + default: + return messageHandlerErrorer{fmt.Errorf("unrecognized protocol version: %s", protocol)} + } +} + // libp2pGraphSyncNetwork transforms the libp2p host interface, which sends and receives // NetMessage objects, into the graphsync network interface. type libp2pGraphSyncNetwork struct { host host.Host // inbound messages from the network are forwarded to the receiver - receiver Receiver - messageHandler gsmsg.MessageHandler - protocols []protocol.ID + receiver Receiver + protocols []protocol.ID + messageHandlerSelector *messageHandlerSelector } type streamMessageSender struct { - s network.Stream - opts MessageSenderOpts - messageHandler gsmsg.MessageHandler + s network.Stream + opts MessageSenderOpts + messageHandlerSelector *messageHandlerSelector } func (s *streamMessageSender) Close() error { @@ -73,10 +111,10 @@ func (s *streamMessageSender) Reset() error { } func (s *streamMessageSender) SendMsg(ctx context.Context, msg gsmsg.GraphSyncMessage) error { - return msgToStream(ctx, s.s, s.messageHandler, msg, s.opts.SendTimeout) + return msgToStream(ctx, s.s, s.messageHandlerSelector, msg, s.opts.SendTimeout) } -func msgToStream(ctx context.Context, s network.Stream, mh gsmsg.MessageHandler, msg gsmsg.GraphSyncMessage, timeout time.Duration) error { +func msgToStream(ctx context.Context, s network.Stream, mh *messageHandlerSelector, msg gsmsg.GraphSyncMessage, timeout time.Duration) error { log.Debugf("Outgoing message with %d requests, %d responses, and %d blocks", len(msg.Requests()), len(msg.Responses()), len(msg.Blocks())) @@ -88,19 +126,9 @@ func msgToStream(ctx context.Context, s network.Stream, mh gsmsg.MessageHandler, log.Warnf("error setting deadline: %s", err) } - switch s.Protocol() { - case ProtocolGraphsync_1_0_0: - if err := mh.ToNet(s.Conn().RemotePeer(), msg, s); err != nil { - log.Debugf("error: %s", err) - return err - } - case ProtocolGraphsync_2_0_0: - if err := mh.ToNet(s.Conn().RemotePeer(), msg, s); err != nil { - log.Debugf("error: %s", err) - return err - } - default: - return fmt.Errorf("unrecognized protocol on remote: %s", s.Protocol()) + if err := mh.Select(s.Protocol()).ToNet(s.Conn().RemotePeer(), msg, s); err != nil { + log.Debugf("error: %s", err) + return err } if err := s.SetWriteDeadline(time.Time{}); err != nil { @@ -116,9 +144,9 @@ func (gsnet *libp2pGraphSyncNetwork) NewMessageSender(ctx context.Context, p pee } return &streamMessageSender{ - s: s, - opts: setDefaults(opts), - messageHandler: gsnet.messageHandler, + s: s, + opts: setDefaults(opts), + messageHandlerSelector: gsnet.messageHandlerSelector, }, nil } @@ -136,7 +164,7 @@ func (gsnet *libp2pGraphSyncNetwork) SendMessage( return err } - if err = msgToStream(ctx, s, gsnet.messageHandler, outgoing, sendMessageTimeout); err != nil { + if err = msgToStream(ctx, s, gsnet.messageHandlerSelector, outgoing, sendMessageTimeout); err != nil { _ = s.Reset() return err } @@ -167,16 +195,7 @@ func (gsnet *libp2pGraphSyncNetwork) handleNewStream(s network.Stream) { reader := msgio.NewVarintReaderSize(s, network.MessageSizeMax) for { - var received gsmsg.GraphSyncMessage - var err error - switch s.Protocol() { - case ProtocolGraphsync_1_0_0: - received, err = gsnet.messageHandler.FromMsgReader(s.Conn().RemotePeer(), reader) - case ProtocolGraphsync_2_0_0: - received, err = gsnet.messageHandler.FromMsgReader(s.Conn().RemotePeer(), reader) - default: - err = fmt.Errorf("unexpected protocol version %s", s.Protocol()) - } + received, err := gsnet.messageHandlerSelector.Select(s.Protocol()).FromMsgReader(s.Conn().RemotePeer(), reader) p := s.Conn().RemotePeer() if err != nil { diff --git a/network/libp2p_impl_test.go b/network/libp2p_impl_test.go index b8811e80..a7cdebc3 100644 --- a/network/libp2p_impl_test.go +++ b/network/libp2p_impl_test.go @@ -106,8 +106,7 @@ func TestMessageSendAndReceive(t *testing.T) { receivedRequests := received.Requests() require.Len(t, receivedRequests, 1, "did not add request to received message") receivedRequest := receivedRequests[0] - // TODO: for protocol v1 this shouldn't match, but for v2 it should - // require.Equal(t, sentRequest.ID(), receivedRequest.ID()) + require.Equal(t, sentRequest.ID(), receivedRequest.ID()) require.Equal(t, sentRequest.IsCancel(), receivedRequest.IsCancel()) require.Equal(t, sentRequest.Priority(), receivedRequest.Priority()) require.Equal(t, sentRequest.Root().String(), receivedRequest.Root().String()) @@ -120,8 +119,7 @@ func TestMessageSendAndReceive(t *testing.T) { require.Len(t, receivedResponses, 1, "did not add response to received message") receivedResponse := receivedResponses[0] extensionData, found := receivedResponse.Extension(extensionName) - // TODO: for protocol v1 this shouldn't match, but for v2 it should - // require.Equal(t, sentResponse.RequestID(), receivedResponse.RequestID()) + require.Equal(t, sentResponse.RequestID(), receivedResponse.RequestID()) require.Equal(t, sentResponse.Status(), receivedResponse.Status()) require.True(t, found) require.Equal(t, extension.Data, extensionData)