|
| 1 | +package di |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + "reflect" |
| 6 | + "sync" |
| 7 | +) |
| 8 | + |
| 9 | +// DIContainer is a simple dependency injection container for singleton instances |
| 10 | +type DIContainer struct { |
| 11 | + mu sync.RWMutex |
| 12 | + instances map[reflect.Type]interface{} |
| 13 | + factories map[reflect.Type]Factory |
| 14 | +} |
| 15 | + |
| 16 | +type Factory func(container *DIContainer) (interface{}, error) |
| 17 | + |
| 18 | +func NewDIContainer() *DIContainer { |
| 19 | + return &DIContainer{ |
| 20 | + instances: make(map[reflect.Type]interface{}), |
| 21 | + factories: make(map[reflect.Type]Factory), |
| 22 | + } |
| 23 | +} |
| 24 | + |
| 25 | +// Get retrieves or creates a singleton instance of the specified type |
| 26 | +func (c *DIContainer) Get(targetType reflect.Type) (interface{}, error) { |
| 27 | + c.mu.RLock() |
| 28 | + // Check if we already have a singleton instance |
| 29 | + if instance, exists := c.instances[targetType]; exists { |
| 30 | + c.mu.RUnlock() |
| 31 | + return instance, nil |
| 32 | + } |
| 33 | + c.mu.RUnlock() |
| 34 | + |
| 35 | + // Check if we have a factory |
| 36 | + c.mu.RLock() |
| 37 | + factory, exists := c.factories[targetType] |
| 38 | + c.mu.RUnlock() |
| 39 | + |
| 40 | + if !exists { |
| 41 | + return nil, fmt.Errorf("no registration found for type %s", targetType.String()) |
| 42 | + } |
| 43 | + |
| 44 | + // Create singleton instance using factory |
| 45 | + c.mu.Lock() |
| 46 | + |
| 47 | + // Double-check if instance was created while we were waiting for the lock |
| 48 | + if instance, exists := c.instances[targetType]; exists { |
| 49 | + c.mu.Unlock() |
| 50 | + return instance, nil |
| 51 | + } |
| 52 | + |
| 53 | + // Release lock before calling factory to avoid deadlock when factory resolves dependencies |
| 54 | + c.mu.Unlock() |
| 55 | + |
| 56 | + instance, err := factory(c) |
| 57 | + if err != nil { |
| 58 | + return nil, fmt.Errorf("failed to create instance of type %s: %w", targetType.String(), err) |
| 59 | + } |
| 60 | + |
| 61 | + // Re-acquire lock to cache the instance |
| 62 | + c.mu.Lock() |
| 63 | + // Check again in case another goroutine created it |
| 64 | + if existing, exists := c.instances[targetType]; exists { |
| 65 | + c.mu.Unlock() |
| 66 | + return existing, nil |
| 67 | + } |
| 68 | + c.instances[targetType] = instance |
| 69 | + c.mu.Unlock() |
| 70 | + |
| 71 | + return instance, nil |
| 72 | +} |
| 73 | + |
| 74 | +// GetTyped is a generic helper to get a singleton instance with type safety |
| 75 | +func GetTyped[T any](c *DIContainer) (T, error) { |
| 76 | + var zero T |
| 77 | + targetType := reflect.TypeOf((*T)(nil)).Elem() |
| 78 | + |
| 79 | + instance, err := c.Get(targetType) |
| 80 | + if err != nil { |
| 81 | + return zero, err |
| 82 | + } |
| 83 | + |
| 84 | + typed, ok := instance.(T) |
| 85 | + if !ok { |
| 86 | + return zero, fmt.Errorf("instance is not of expected type %T", zero) |
| 87 | + } |
| 88 | + |
| 89 | + return typed, nil |
| 90 | +} |
| 91 | + |
| 92 | +// RegisterTyped is a generic helper to register a factory with type safety |
| 93 | +// Usage: RegisterTyped[MyInterface](container, func(c *DIContainer) (MyInterface, error) { ... }) |
| 94 | +func RegisterTyped[T any](c *DIContainer, factory func(*DIContainer) (T, error)) { |
| 95 | + c.mu.Lock() |
| 96 | + defer c.mu.Unlock() |
| 97 | + |
| 98 | + targetType := reflect.TypeOf((*T)(nil)).Elem() |
| 99 | + c.factories[targetType] = func(container *DIContainer) (interface{}, error) { |
| 100 | + return factory(container) |
| 101 | + } |
| 102 | +} |
0 commit comments