diff --git a/src/useRTGTransitionProps.ts b/src/useRTGTransitionProps.ts index 9428196..839357f 100644 --- a/src/useRTGTransitionProps.ts +++ b/src/useRTGTransitionProps.ts @@ -4,7 +4,7 @@ import { TransitionProps as RTGTransitionProps, TransitionStatus, } from 'react-transition-group/Transition'; -import { getReactVersion } from './utils'; +import { getChildRef } from './utils'; export type TransitionProps = RTGTransitionProps & { children: @@ -33,15 +33,8 @@ export default function useRTGTransitionProps({ children, ...props }: TransitionProps) { - const { major } = getReactVersion(); - const childRef = - major >= 19 ? (children as any).props.ref : (children as any).ref; - const nodeRef = useRef(null); - const mergedRef = useMergedRefs( - nodeRef, - typeof children === 'function' ? null : childRef, - ); + const mergedRef = useMergedRefs(nodeRef, getChildRef(children)); const normalize = (callback?: (node: HTMLElement, param: any) => void) => (param: any) => { diff --git a/src/utils.ts b/src/utils.ts index 1eda92f..612f93b 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -12,3 +12,14 @@ export function getReactVersion() { patch: +parts[2], }; } + +export function getChildRef( + element?: React.ReactElement | ((...args: any[]) => React.ReactNode) | null, +) { + if (!element || typeof element === 'function') { + return null; + } + const { major } = getReactVersion(); + const childRef = major >= 19 ? element.props.ref : (element as any).ref; + return childRef; +} diff --git a/test/utilsSpec.tsx b/test/utilsSpec.tsx new file mode 100644 index 0000000..f918016 --- /dev/null +++ b/test/utilsSpec.tsx @@ -0,0 +1,18 @@ +import { describe, expect, it } from 'vitest'; +import { getChildRef } from '../src/utils'; + +describe('utils', () => { + describe('getChildRef', () => { + it('should return null if ref is null', () => { + expect(getChildRef(null)).to.equal(null); + }); + + it('should return null if ref is undefined', () => { + expect(getChildRef(undefined)).to.equal(null); + }); + + it('should return null if ref is a function', () => { + expect(getChildRef(() => null)).to.equal(null); + }); + }); +});