Skip to content

Commit 3fbfb75

Browse files
authored
Use functional options for creating and joining rooms (#767)
* Use functional options for creating and joining rooms This is much more ergonomic and clear from call-sites what is happening. * Fix chicken/egg problem with ServerRoom and ServerRoomImpl To make the impl you need a room. To make a room you need an impl. Instead, pass in the room to the impl so you do not need a room to make an impl. This makes the way you call the impl a bit less elegant as you refer to the room twice e.g `room.EventCreator(room, ...)` but allows custom impls to use functional options. * Assign into the rooms map with the actual room ID
1 parent 81c0d9a commit 3fbfb75

File tree

4 files changed

+108
-65
lines changed

4 files changed

+108
-65
lines changed

federation/handle.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func MakeJoinRequestsHandler(s *Server, w http.ResponseWriter, req *http.Request
5959
// or dealing with HTTP responses itself.
6060
func MakeRespMakeJoin(s *Server, room *ServerRoom, userID string) (resp fclient.RespMakeJoin, err error) {
6161
// Generate a join event
62-
proto, err := room.ProtoEventCreator(Event{
62+
proto, err := room.ProtoEventCreator(room, Event{
6363
Type: "m.room.member",
6464
StateKey: &userID,
6565
Content: map[string]interface{}{
@@ -84,7 +84,7 @@ func MakeRespMakeJoin(s *Server, room *ServerRoom, userID string) (resp fclient.
8484
// or dealing with HTTP responses itself.
8585
func MakeRespMakeKnock(s *Server, room *ServerRoom, userID string) (resp fclient.RespMakeKnock, err error) {
8686
// Generate a knock event
87-
proto, err := room.ProtoEventCreator(Event{
87+
proto, err := room.ProtoEventCreator(room, Event{
8888
Type: "m.room.member",
8989
StateKey: &userID,
9090
Content: map[string]interface{}{
@@ -159,7 +159,7 @@ func SendJoinRequestsHandler(s *Server, w http.ResponseWriter, req *http.Request
159159
return
160160
}
161161

162-
resp := room.GenerateSendJoinResponse(s, event, expectPartialState, omitServersInRoom)
162+
resp := room.GenerateSendJoinResponse(room, s, event, expectPartialState, omitServersInRoom)
163163
b, err := json.Marshal(resp)
164164
if err != nil {
165165
w.WriteHeader(500)

federation/server.go

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func (s *Server) MakeAliasMapping(aliasLocalpart, roomID string) string {
172172

173173
// MustMakeRoom will add a room to this server so it is accessible to other servers when prompted via federation.
174174
// The `events` will be added to this room. Returns the created room.
175-
func (s *Server) MustMakeRoom(t ct.TestLike, roomVer gomatrixserverlib.RoomVersion, events []Event) *ServerRoom {
175+
func (s *Server) MustMakeRoom(t ct.TestLike, roomVer gomatrixserverlib.RoomVersion, events []Event, opts ...ServerRoomOpt) *ServerRoom {
176176
if !s.listening {
177177
ct.Fatalf(s.t, "MustMakeRoom() called before Listen() - this is not supported because Listen() chooses a high-numbered port and thus changes the server name and thus changes the room ID. Ensure you Listen() first!")
178178
}
@@ -184,13 +184,16 @@ func (s *Server) MustMakeRoom(t ct.TestLike, roomVer gomatrixserverlib.RoomVersi
184184
roomID := fmt.Sprintf("!%d-%s:%s", len(s.rooms), util.RandomString(18), s.serverName)
185185
t.Logf("Creating room %s with version %s", roomID, roomVer)
186186
room := NewServerRoom(roomVer, roomID)
187+
for _, opt := range opts {
188+
opt(room)
189+
}
187190

188191
// sign all these events
189192
for _, ev := range events {
190193
signedEvent := s.MustCreateEvent(t, room, ev)
191194
room.AddEvent(signedEvent)
192195
}
193-
s.rooms[roomID] = room
196+
s.rooms[room.RoomID] = room
194197
return room
195198
}
196199

@@ -303,11 +306,11 @@ func (s *Server) DoFederationRequest(
303306
// It does not insert this event into the room however. See ServerRoom.AddEvent for that.
304307
func (s *Server) MustCreateEvent(t ct.TestLike, room *ServerRoom, ev Event) gomatrixserverlib.PDU {
305308
t.Helper()
306-
proto, err := room.ProtoEventCreator(ev)
309+
proto, err := room.ProtoEventCreator(room, ev)
307310
if err != nil {
308311
ct.Fatalf(t, "MustCreateEvent: failed to create proto event: %v", err)
309312
}
310-
pdu, err := room.EventCreator(s, proto)
313+
pdu, err := room.EventCreator(room, s, proto)
311314
if err != nil {
312315
ct.Fatalf(t, "MustCreateEvent: failed to create PDU: %v", err)
313316
}
@@ -316,8 +319,12 @@ func (s *Server) MustCreateEvent(t ct.TestLike, room *ServerRoom, ev Event) goma
316319

317320
// MustJoinRoom will make the server send a make_join and a send_join to join a room
318321
// It returns the resultant room.
319-
func (s *Server) MustJoinRoom(t ct.TestLike, deployment FederationDeployment, remoteServer spec.ServerName, roomID string, userID string, partialState ...bool) *ServerRoom {
322+
func (s *Server) MustJoinRoom(t ct.TestLike, deployment FederationDeployment, remoteServer spec.ServerName, roomID string, userID string, opts ...JoinRoomOpt) *ServerRoom {
320323
t.Helper()
324+
var jr joinRoom
325+
for _, opt := range opts {
326+
opt(&jr)
327+
}
321328
origin := spec.ServerName(s.serverName)
322329
fedClient := s.FederationClient(deployment)
323330
makeJoinResp, err := fedClient.MakeJoin(context.Background(), origin, remoteServer, roomID, userID)
@@ -372,7 +379,7 @@ func (s *Server) MustJoinRoom(t ct.TestLike, deployment FederationDeployment, re
372379
ct.Fatalf(t, "MustJoinRoom: failed to sign event: %v", err)
373380
}
374381
var sendJoinResp fclient.RespSendJoin
375-
if len(partialState) == 0 || !partialState[0] {
382+
if !jr.partialState {
376383
// Default to doing a regular join.
377384
sendJoinResp, err = fedClient.SendJoin(context.Background(), origOrigin, remoteServer, joinEvent)
378385
} else {
@@ -382,10 +389,13 @@ func (s *Server) MustJoinRoom(t ct.TestLike, deployment FederationDeployment, re
382389
ct.Fatalf(t, "MustJoinRoom: send_join failed: %v", err)
383390
}
384391
room := NewServerRoom(roomVer, roomID)
385-
room.PopulateFromSendJoinResponse(joinEvent, sendJoinResp)
386-
s.rooms[roomID] = room
392+
for _, opt := range jr.roomOpts {
393+
opt(room)
394+
}
395+
room.PopulateFromSendJoinResponse(room, joinEvent, sendJoinResp)
396+
s.rooms[room.RoomID] = room
387397

388-
t.Logf("Server.MustJoinRoom joined room ID %s", roomID)
398+
t.Logf("Server.MustJoinRoom joined room ID %s", room.RoomID)
389399

390400
return room
391401
}
@@ -433,11 +443,6 @@ func (s *Server) MustLeaveRoom(t ct.TestLike, deployment FederationDeployment, r
433443
t.Logf("Server.MustLeaveRoom left room ID %s", roomID)
434444
}
435445

436-
// AddRoom is a low-level function to add a custom room to the server. Useful to mix custom logic with helper functions.
437-
func (s *Server) AddRoom(room *ServerRoom) {
438-
s.rooms[room.RoomID] = room
439-
}
440-
441446
// ValidFederationRequest is a wrapper around http.HandlerFunc which automatically validates the incoming
442447
// federation request and supports sending back JSON. Fails the test if the request is not valid.
443448
func (s *Server) ValidFederationRequest(t ct.TestLike, handler func(fr *fclient.FederationRequest, pathParams map[string]string) util.JSONResponse) http.HandlerFunc {
@@ -513,6 +518,28 @@ func (s *Server) Listen() (cancel func()) {
513518
}
514519
}
515520

521+
type joinRoom struct {
522+
partialState bool
523+
roomOpts []ServerRoomOpt
524+
}
525+
526+
// JoinRoomOpt is an option for configuring how the server should join the room
527+
type JoinRoomOpt func(jr *joinRoom)
528+
529+
// WithPartialState tells the server to join the room with partial state
530+
func WithPartialState() JoinRoomOpt {
531+
return func(jr *joinRoom) {
532+
jr.partialState = true
533+
}
534+
}
535+
536+
// WithRoomOpts controls how the newly joined room is created
537+
func WithRoomOpts(opts ...ServerRoomOpt) JoinRoomOpt {
538+
return func(jr *joinRoom) {
539+
jr.roomOpts = opts
540+
}
541+
}
542+
516543
// federationServer creates a federation server with the given handler
517544
func federationServer(cfg *config.Complement, h http.Handler) (*http.Server, string, string, error) {
518545
var derBytes []byte

federation/server_room.go

Lines changed: 58 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,24 @@ type Event struct {
3434
Redacts string
3535
}
3636

37+
// ServerRoomOpt are options that can configure ServerRooms
38+
type ServerRoomOpt func(r *ServerRoom)
39+
40+
// WithRoomID configures the room to have the given room ID
41+
func WithRoomID(roomID string) ServerRoomOpt {
42+
return func(r *ServerRoom) {
43+
r.RoomID = roomID
44+
}
45+
}
46+
47+
// WithImpl configures the room to have the given ServerRoomImpl.
48+
// Useful for custom rooms.
49+
func WithImpl(impl ServerRoomImpl) ServerRoomOpt {
50+
return func(r *ServerRoom) {
51+
r.ServerRoomImpl = impl
52+
}
53+
}
54+
3755
// EXPERIMENTAL
3856
// ServerRoom represents a room on this test federation server
3957
type ServerRoom struct {
@@ -67,7 +85,7 @@ func NewServerRoom(roomVer gomatrixserverlib.RoomVersion, roomId string) *Server
6785
waiters: make(map[string][]*helpers.Waiter),
6886
waitersMu: &sync.Mutex{},
6987
}
70-
room.ServerRoomImpl = &ServerRoomImplDefault{Room: room}
88+
room.ServerRoomImpl = &ServerRoomImplDefault{}
7189
return room
7290
}
7391

@@ -354,58 +372,56 @@ type ServerRoomImpl interface {
354372
// ProtoEventCreator converts a Complement Event into a gomatrixserverlib proto event, ready to be signed.
355373
// This function is used in /make_x endpoints to create proto events to return to other servers.
356374
// This function is one of two used when creating events, the other being EventCreator.
357-
ProtoEventCreator(ev Event) (*gomatrixserverlib.ProtoEvent, error)
375+
ProtoEventCreator(room *ServerRoom, ev Event) (*gomatrixserverlib.ProtoEvent, error)
358376
// EventCreator converts a proto event into a signed PDU.
359-
EventCreator(s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error)
377+
EventCreator(room *ServerRoom, s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error)
360378
// PopulateFromSendJoinResponse should replace the state of this ServerRoom with the information contained
361379
// in RespSendJoin and the join event.
362-
PopulateFromSendJoinResponse(joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin)
380+
PopulateFromSendJoinResponse(room *ServerRoom, joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin)
363381
// GenerateSendJoinResponse generates a /send_join response to send back to a server.
364-
GenerateSendJoinResponse(s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin
382+
GenerateSendJoinResponse(room *ServerRoom, s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin
365383
}
366384

367385
type ServerRoomImplCustom struct {
368386
ServerRoomImplDefault
369-
ProtoEventCreatorFn func(def ServerRoomImpl, ev Event) (*gomatrixserverlib.ProtoEvent, error)
370-
EventCreatorFn func(def ServerRoomImpl, s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error)
371-
PopulateFromSendJoinResponseFn func(def ServerRoomImpl, joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin)
372-
GenerateSendJoinResponseFn func(def ServerRoomImpl, s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin
387+
ProtoEventCreatorFn func(def ServerRoomImpl, room *ServerRoom, ev Event) (*gomatrixserverlib.ProtoEvent, error)
388+
EventCreatorFn func(def ServerRoomImpl, room *ServerRoom, s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error)
389+
PopulateFromSendJoinResponseFn func(def ServerRoomImpl, room *ServerRoom, joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin)
390+
GenerateSendJoinResponseFn func(def ServerRoomImpl, room *ServerRoom, s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin
373391
}
374392

375-
func (i *ServerRoomImplCustom) ProtoEventCreator(ev Event) (*gomatrixserverlib.ProtoEvent, error) {
393+
func (i *ServerRoomImplCustom) ProtoEventCreator(room *ServerRoom, ev Event) (*gomatrixserverlib.ProtoEvent, error) {
376394
if i.ProtoEventCreatorFn != nil {
377-
return i.ProtoEventCreatorFn(&i.ServerRoomImplDefault, ev)
395+
return i.ProtoEventCreatorFn(&i.ServerRoomImplDefault, room, ev)
378396
}
379-
return i.ServerRoomImplDefault.ProtoEventCreator(ev)
397+
return i.ServerRoomImplDefault.ProtoEventCreator(room, ev)
380398
}
381399

382-
func (i *ServerRoomImplCustom) EventCreator(s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error) {
400+
func (i *ServerRoomImplCustom) EventCreator(room *ServerRoom, s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error) {
383401
if i.EventCreatorFn != nil {
384-
return i.EventCreatorFn(&i.ServerRoomImplDefault, s, proto)
402+
return i.EventCreatorFn(&i.ServerRoomImplDefault, room, s, proto)
385403
}
386-
return i.ServerRoomImplDefault.EventCreator(s, proto)
404+
return i.ServerRoomImplDefault.EventCreator(room, s, proto)
387405
}
388406

389-
func (i *ServerRoomImplCustom) PopulateFromSendJoinResponse(joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin) {
407+
func (i *ServerRoomImplCustom) PopulateFromSendJoinResponse(room *ServerRoom, joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin) {
390408
if i.PopulateFromSendJoinResponseFn != nil {
391-
i.PopulateFromSendJoinResponseFn(&i.ServerRoomImplDefault, joinEvent, resp)
409+
i.PopulateFromSendJoinResponseFn(&i.ServerRoomImplDefault, room, joinEvent, resp)
392410
return
393411
}
394-
i.ServerRoomImplDefault.PopulateFromSendJoinResponse(joinEvent, resp)
412+
i.ServerRoomImplDefault.PopulateFromSendJoinResponse(room, joinEvent, resp)
395413
}
396414

397-
func (i *ServerRoomImplCustom) GenerateSendJoinResponse(s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin {
415+
func (i *ServerRoomImplCustom) GenerateSendJoinResponse(room *ServerRoom, s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin {
398416
if i.GenerateSendJoinResponseFn != nil {
399-
return i.GenerateSendJoinResponseFn(&i.ServerRoomImplDefault, s, joinEvent, expectPartialState, omitServersInRoom)
417+
return i.GenerateSendJoinResponseFn(&i.ServerRoomImplDefault, room, s, joinEvent, expectPartialState, omitServersInRoom)
400418
}
401-
return i.ServerRoomImplDefault.GenerateSendJoinResponse(s, joinEvent, expectPartialState, omitServersInRoom)
419+
return i.ServerRoomImplDefault.GenerateSendJoinResponse(room, s, joinEvent, expectPartialState, omitServersInRoom)
402420
}
403421

404-
type ServerRoomImplDefault struct {
405-
Room *ServerRoom
406-
}
422+
type ServerRoomImplDefault struct{}
407423

408-
func (i *ServerRoomImplDefault) ProtoEventCreator(ev Event) (*gomatrixserverlib.ProtoEvent, error) {
424+
func (i *ServerRoomImplDefault) ProtoEventCreator(room *ServerRoom, ev Event) (*gomatrixserverlib.ProtoEvent, error) {
409425
var prevEvents interface{}
410426
if ev.PrevEvents != nil {
411427
// We deliberately want to set the prev events.
@@ -414,14 +430,14 @@ func (i *ServerRoomImplDefault) ProtoEventCreator(ev Event) (*gomatrixserverlib.
414430
// No other prev events were supplied so we'll just
415431
// use the forward extremities of the room, which is
416432
// the usual behaviour.
417-
prevEvents = i.Room.ForwardExtremities
433+
prevEvents = room.ForwardExtremities
418434
}
419435
proto := gomatrixserverlib.ProtoEvent{
420436
SenderID: ev.Sender,
421-
Depth: int64(i.Room.Depth + 1), // depth starts at 1
437+
Depth: int64(room.Depth + 1), // depth starts at 1
422438
Type: ev.Type,
423439
StateKey: ev.StateKey,
424-
RoomID: i.Room.RoomID,
440+
RoomID: room.RoomID,
425441
PrevEvents: prevEvents,
426442
AuthEvents: ev.AuthEvents,
427443
Redacts: ev.Redacts,
@@ -438,13 +454,13 @@ func (i *ServerRoomImplDefault) ProtoEventCreator(ev Event) (*gomatrixserverlib.
438454
if err != nil {
439455
return nil, fmt.Errorf("EventCreator: failed to work out auth_events : %s", err)
440456
}
441-
proto.AuthEvents = i.Room.AuthEvents(stateNeeded)
457+
proto.AuthEvents = room.AuthEvents(stateNeeded)
442458
}
443459
return &proto, nil
444460
}
445461

446-
func (i *ServerRoomImplDefault) EventCreator(s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error) {
447-
verImpl, err := gomatrixserverlib.GetRoomVersion(i.Room.Version)
462+
func (i *ServerRoomImplDefault) EventCreator(room *ServerRoom, s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error) {
463+
verImpl, err := gomatrixserverlib.GetRoomVersion(room.Version)
448464
if err != nil {
449465
return nil, fmt.Errorf("EventCreator: invalid room version: %s", err)
450466
}
@@ -456,19 +472,19 @@ func (i *ServerRoomImplDefault) EventCreator(s *Server, proto *gomatrixserverlib
456472
return signedEvent, nil
457473
}
458474

459-
func (i *ServerRoomImplDefault) PopulateFromSendJoinResponse(joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin) {
460-
stateEvents := resp.StateEvents.UntrustedEvents(i.Room.Version)
475+
func (i *ServerRoomImplDefault) PopulateFromSendJoinResponse(room *ServerRoom, joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin) {
476+
stateEvents := resp.StateEvents.UntrustedEvents(room.Version)
461477
for _, ev := range stateEvents {
462-
i.Room.ReplaceCurrentState(ev)
478+
room.ReplaceCurrentState(ev)
463479
}
464-
i.Room.AddEvent(joinEvent)
480+
room.AddEvent(joinEvent)
465481
}
466482

467-
func (i *ServerRoomImplDefault) GenerateSendJoinResponse(s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin {
483+
func (i *ServerRoomImplDefault) GenerateSendJoinResponse(room *ServerRoom, s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin {
468484
// build the state list *before* we insert the new event
469485
var stateEvents []gomatrixserverlib.PDU
470-
i.Room.StateMutex.RLock()
471-
for _, ev := range i.Room.State {
486+
room.StateMutex.RLock()
487+
for _, ev := range room.State {
472488
// filter out non-critical memberships if this is a partial-state join
473489
if expectPartialState {
474490
if ev.Type() == "m.room.member" && ev.StateKey() != joinEvent.StateKey() {
@@ -477,18 +493,18 @@ func (i *ServerRoomImplDefault) GenerateSendJoinResponse(s *Server, joinEvent go
477493
}
478494
stateEvents = append(stateEvents, ev)
479495
}
480-
i.Room.StateMutex.RUnlock()
496+
room.StateMutex.RUnlock()
481497

482-
authEvents := i.Room.AuthChainForEvents(stateEvents)
498+
authEvents := room.AuthChainForEvents(stateEvents)
483499

484500
// get servers in room *before* the join event
485501
serversInRoom := []string{s.serverName}
486502
if !omitServersInRoom {
487-
serversInRoom = i.Room.ServersInRoom()
503+
serversInRoom = room.ServersInRoom()
488504
}
489505

490506
// insert the join event into the room state
491-
i.Room.AddEvent(joinEvent)
507+
room.AddEvent(joinEvent)
492508
log.Printf("Received send-join of event %s", joinEvent.EventID())
493509

494510
// return state and auth chain

0 commit comments

Comments
 (0)