From 846f256804f90ad51f6b954ec27d6f6764d7c984 Mon Sep 17 00:00:00 2001 From: Denys Smirnov Date: Tue, 2 Dec 2025 13:50:47 +0200 Subject: [PATCH] Improve SIP parsers. --- sip/header_params.go | 3 + sip/parse_address.go | 14 +- sip/parse_header.go | 60 +++---- sip/parse_via.go | 2 +- sip/parser.go | 365 ++++++++++++++++++++++---------------- sip/parser_stream.go | 234 +++++++++++++----------- sip/parser_stream_test.go | 125 +++++++++---- sip/parser_test.go | 36 ++-- sip/utils.go | 98 ++++++++++ sip/utils_test.go | 4 +- 10 files changed, 589 insertions(+), 352 deletions(-) diff --git a/sip/header_params.go b/sip/header_params.go index 88d20cd..e72953e 100644 --- a/sip/header_params.go +++ b/sip/header_params.go @@ -68,6 +68,9 @@ func (hp HeaderParams) Clone() HeaderParams { } func (hp HeaderParams) clone() HeaderParams { + if hp == nil { + return nil + } dup := make(HeaderParams, len(hp)) for k, v := range hp { diff --git a/sip/parse_address.go b/sip/parse_address.go index 3e19f9e..1158a76 100644 --- a/sip/parse_address.go +++ b/sip/parse_address.go @@ -188,7 +188,7 @@ func addressStateHeaderParams(a *nameAddress, s string) (addressFSM, string, err } // headerParserTo generates ToHeader -func headerParserTo(headerName string, headerText string) (header Header, err error) { +func headerParserTo(headerName []byte, headerText string) (header Header, err error) { h := &ToHeader{} return h, parseToHeader(headerText, h) } @@ -214,7 +214,7 @@ func parseToHeader(headerText string, h *ToHeader) error { } // headerParserFrom generates FromHeader -func headerParserFrom(headerName string, headerText string) (header Header, err error) { +func headerParserFrom(headerName []byte, headerText string) (header Header, err error) { h := &FromHeader{} return h, parseFromHeader(headerText, h) } @@ -240,7 +240,7 @@ func parseFromHeader(headerText string, h *FromHeader) error { return nil } -func headerParserContact(headerName string, headerText string) (header Header, err error) { +func headerParserContact(headerName []byte, headerText string) (header Header, err error) { h := ContactHeader{} return &h, parseContactHeader(headerText, &h) } @@ -285,7 +285,7 @@ func parseContactHeader(headerText string, h *ContactHeader) error { return err } -func headerParserRoute(headerName string, headerText string) (header Header, err error) { +func headerParserRoute(headerName []byte, headerText string) (header Header, err error) { // Append a comma to simplify the parsing code; we split address sections // on commas, so use a comma to signify the end of the final address section. h := RouteHeader{} @@ -298,7 +298,7 @@ func parseRouteHeader(headerText string, h *RouteHeader) error { } // parseRouteHeader generates RecordRouteHeader -func headerParserRecordRoute(headerName string, headerText string) (header Header, err error) { +func headerParserRecordRoute(headerName []byte, headerText string) (header Header, err error) { // Append a comma to simplify the parsing code; we split address sections // on commas, so use a comma to signify the end of the final address section. h := RecordRouteHeader{} @@ -309,7 +309,7 @@ func parseRecordRouteHeader(headerText string, h *RecordRouteHeader) error { return parseRouteAddress(headerText, &h.Address) } -func headerParserReferTo(headerName string, headerText string) (header Header, err error) { +func headerParserReferTo(headerName []byte, headerText string) (header Header, err error) { h := ReferToHeader{} return &h, parseReferToHeader(headerText, &h) } @@ -318,7 +318,7 @@ func parseReferToHeader(headerText string, h *ReferToHeader) error { return parseRouteAddress(headerText, &h.Address) // calling parseRouteAddress because the structure is same } -func headerParserReferredBy(headerName string, headerText string) (header Header, err error) { +func headerParserReferredBy(headerName []byte, headerText string) (header Header, err error) { h := &ReferredByHeader{} return h, parseReferredByHeader(headerText, h) } diff --git a/sip/parse_header.go b/sip/parse_header.go index a15b812..0459324 100644 --- a/sip/parse_header.go +++ b/sip/parse_header.go @@ -1,6 +1,7 @@ package sip import ( + "bytes" "fmt" "strconv" "strings" @@ -10,9 +11,9 @@ import ( // Some of headers parsing are moved to different files for better maintance // A HeaderParser is any function that turns raw header data into one or more Header objects. -type HeaderParser func(headerName string, headerData string) (Header, error) +type HeaderParser func(headerName []byte, headerData string) (Header, error) -type mapHeadersParser map[string]HeaderParser +type HeadersParser map[string]HeaderParser type errComaDetected int @@ -37,7 +38,7 @@ func (e errComaDetected) Error() string { // t To RFC 3261 // u Allow-Events -events- "understand" // v Via RFC 3261 -var headersParsers = mapHeadersParser{ +var headersParsers = HeadersParser{ "c": headerParserContentType, "content-type": headerParserContentType, "f": headerParserFrom, @@ -66,56 +67,49 @@ func DefaultHeadersParser() map[string]HeaderParser { return headersParsers } -// parseMsgHeader will append any parsed header -// In case comma seperated values it will add them as new in case comma is detected -func (headersParser mapHeadersParser) parseMsgHeader(msg Message, headerText string) (err error) { - // p.log.Tracef("parsing header \"%s\"", headerText) - - colonIdx := strings.Index(headerText, ":") +// ParseHeader parses a SIP header from the line and appends it to out. +func (headersParser HeadersParser) ParseHeader(out []Header, line []byte) ([]Header, error) { + colonIdx := bytes.IndexByte(line, ':') if colonIdx == -1 { - err = fmt.Errorf("field name with no value in header: %s", headerText) - return + return out, fmt.Errorf("field name with no value in header: %q", line) } - fieldName := strings.TrimSpace(headerText[:colonIdx]) - lowerFieldName := HeaderToLower(fieldName) - fieldText := strings.TrimSpace(headerText[colonIdx+1:]) + fieldName := bytes.TrimSpace(line[:colonIdx]) + lowerFieldName := headerToLower(fieldName) + fieldValue := bytes.TrimSpace(line[colonIdx+1:]) - headerParser, ok := headersParser[lowerFieldName] + headerParser, ok := headersParser[string(lowerFieldName)] if !ok { // We have no registered parser for this header type, // so we encapsulate the header data in a GenericHeader struct. // We do only forwarding on this with trimmed space. Validation and parsing is required by user - - header := NewHeader(fieldName, fieldText) - msg.AppendHeader(header) - return nil + h := NewHeader(string(fieldName), string(fieldValue)) + out = append(out, h) + return out, nil } - // Support comma seperated value + fieldText := string(fieldValue) + // Support comma separated values for { // We have a registered parser for this header type - use it. // headerParser should detect comma (,) and return as error - header, err := headerParser(lowerFieldName, fieldText) - - // Mostly we will run with no error + h, err := headerParser(lowerFieldName, fieldText) if err == nil { - msg.AppendHeader(header) - return nil + out = append(out, h) + return out, nil } commaErr, ok := err.(errComaDetected) if !ok { - return err + return out, err } - // Ok we detected we have comma in header value - msg.AppendHeader(header) + out = append(out, h) fieldText = fieldText[commaErr+1:] } } -func headerParserCallId(headerName string, headerText string) (header Header, err error) { +func headerParserCallId(headerName []byte, headerText string) (header Header, err error) { var callId CallIDHeader return &callId, parseCallIdHeader(headerText, &callId) } @@ -131,7 +125,7 @@ func parseCallIdHeader(headerText string, callId *CallIDHeader) error { return nil } -func headerParserMaxForwards(headerName string, headerText string) (header Header, err error) { +func headerParserMaxForwards(headerName []byte, headerText string) (header Header, err error) { var maxfwd MaxForwardsHeader return &maxfwd, parseMaxForwardsHeader(headerText, &maxfwd) } @@ -143,7 +137,7 @@ func parseMaxForwardsHeader(headerText string, maxfwd *MaxForwardsHeader) error return err } -func headerParserCSeq(headerName string, headerText string) (headers Header, err error) { +func headerParserCSeq(headerName []byte, headerText string) (headers Header, err error) { var cseq CSeqHeader return &cseq, parseCSeqHeader(headerText, &cseq) } @@ -169,7 +163,7 @@ func parseCSeqHeader(headerText string, cseq *CSeqHeader) error { return nil } -func headerParserContentLength(headerName string, headerText string) (header Header, err error) { +func headerParserContentLength(headerName []byte, headerText string) (header Header, err error) { var contentLength ContentLengthHeader return &contentLength, parseContentLengthHeader(headerText, &contentLength) } @@ -182,7 +176,7 @@ func parseContentLengthHeader(headerText string, contentLength *ContentLengthHea } // headerParserContentType parses ContentType header -func headerParserContentType(headerName string, headerText string) (headers Header, err error) { +func headerParserContentType(headerName []byte, headerText string) (headers Header, err error) { var contentType ContentTypeHeader return &contentType, parseContentTypeHeader(headerText, &contentType) } diff --git a/sip/parse_via.go b/sip/parse_via.go index afe4d61..885e560 100644 --- a/sip/parse_via.go +++ b/sip/parse_via.go @@ -6,7 +6,7 @@ import ( "strings" ) -func headerParserVia(headerName string, headerText string) ( +func headerParserVia(headerName []byte, headerText string) ( header Header, err error) { // sections := strings.Split(headerText, ",") h := ViaHeader{ diff --git a/sip/parser.go b/sip/parser.go index 71fcbb3..5da7023 100644 --- a/sip/parser.go +++ b/sip/parser.go @@ -24,7 +24,6 @@ var ( // Stream parse errors ErrParseSipPartial = errors.New("SIP partial data") ErrParseReadBodyIncomplete = errors.New("reading body incomplete") - ErrParseMoreMessages = errors.New("Stream has more message") ParseMaxMessageLength = 65535 ) @@ -38,7 +37,9 @@ func ParseMessage(msgData []byte) (Message, error) { // It is optimized with faster header parsing type Parser struct { // HeadersParsers uses default list of headers to be parsed. Smaller list parser will be faster - headersParsers mapHeadersParser + headersParsers HeadersParser + + MaxMessageLength int } // ParserOption are addition option for NewParser. Check WithParser... @@ -47,7 +48,8 @@ type ParserOption func(p *Parser) // Create a new Parser. func NewParser(options ...ParserOption) *Parser { p := &Parser{ - headersParsers: headersParsers, + headersParsers: DefaultHeadersParser(), + MaxMessageLength: ParseMaxMessageLength, } for _, o := range options { @@ -68,223 +70,288 @@ func WithHeadersParsers(m map[string]HeaderParser) ParserOption { } } -// ParseSIP converts data to sip message. Buffer must contain full sip message -func (p *Parser) ParseSIP(data []byte) (msg Message, err error) { - reader := bytes.NewBuffer(data) +// ParseHeaders parses all headers of a SIP message. It returns the number of bytes read. +// Data must contain a full SIP message header section, including double CRLF (\r\n). +// +// If the message is cut in the middle of a header or the first line, io.ErrUnexpectedEOF is returned. +// It may return an error wrapping ErrParseLineNoCRLF if one of the header lines is malformed, +// or if there's no CRLF (\r\n) delimiter after headers. +func (p *Parser) ParseHeaders(data []byte, stream bool) (Message, int, error) { + msg, _, n, err := p.parseHeaders(data, stream) + return msg, n, err +} + +func (p *Parser) parseHeaders(data []byte, stream bool) (Message, *ContentLengthHeader, int, error) { + msg, total, err := p.parseStartLine(data, stream) + if err != nil { + return msg, nil, total, err + } + data = data[total:] + + contentLength, n, err := p.parseHeadersOnly(msg, data) + total += n + return msg, contentLength, total, err +} - startLine, err := nextLine(reader) +func (p *Parser) parseStartLine(data []byte, stream bool) (Message, int, error) { + var ( + total int + skipped bool + ) + + if stream { + // RFC 3261 - 7.5. + // Implementations processing SIP messages over stream-oriented + // transports MUST ignore any CRLF appearing before the start-line. + for len(data) >= 2 && data[0] == '\r' && data[1] == '\n' { + data = data[2:] + total += 2 + skipped = true + } + } + + startLine, n, err := NextLine(data) if err != nil { - return nil, err + if err == io.EOF && skipped { + return nil, total, io.ErrUnexpectedEOF + } + return nil, total, err } + data = data[n:] + total += n - msg, err = parseLine(startLine) + msg, err := parseLine(string(startLine)) if err != nil { - return nil, err + return nil, total, err } + return msg, total, nil +} - for { - line, err := nextLine(reader) +var errParseNoMoreHeaders = errors.New("no more headers") - if err != nil { - if err == io.EOF { - return nil, ErrParseEOF +func (p *Parser) parseNextHeader(out []Header, data []byte) ([]Header, int, error) { + var total int + line, n, err := NextLine(data) + if err == io.EOF { + return out, total, io.ErrUnexpectedEOF + } else if err != nil { + return out, total, err + } + // Advance only after a successful parse. + data = data[n:] + total += n + if len(line) == 0 { + // We've hit the end of the header section. + return out, total, errParseNoMoreHeaders + } + out, err = p.headersParsers.ParseHeader(out, line) + if err != nil { + return out, total, err + } + return out, total, nil +} + +func (p *Parser) parseHeadersOnly(msg Message, data []byte) (*ContentLengthHeader, int, error) { + var ( + total, n int + headerBuf []Header + contentLength *ContentLengthHeader + err error + ) + for { + headerBuf, n, err = p.parseNextHeader(headerBuf[:0], data) + data = data[n:] + total += n + for _, h := range headerBuf { + switch h := h.(type) { + case *ContentLengthHeader: + contentLength = h } - return nil, err + msg.AppendHeader(h) } - size := len(line) - if size == 0 { - // We've hit the end of the header section. - break + if err == errParseNoMoreHeaders { + return contentLength, total, nil } - - err = p.headersParsers.parseMsgHeader(msg, line) if err != nil { - err := fmt.Errorf("parsing header failed line=%q: %w", line, err) - return nil, err + return contentLength, total, err } } +} - var contentLength int - if ct := msg.ContentLength(); ct != nil { - contentLength = int(*ct) - } else { - contentLength = getBodyLength(data) - } - - if contentLength <= 0 { - return msg, nil +// Parse data to a SIP message. It returns the number of bytes read. Data must contain a full SIP message. +// +// If the message is cut in the middle of a header or a first line, io.ErrUnexpectedEOF is returned. +// It may return an error wrapping ErrParseLineNoCRLF if one of the header lines is malformed, +// or if there's no CRLF (\r\n) delimiter after headers. +// +// In case the end of the body cannot be determined, or the body is incomplete, +// an ErrParseReadBodyIncomplete is returned. +func (p *Parser) Parse(data []byte, stream bool) (Message, int, error) { + if len(data) > p.MaxMessageLength { + return nil, 0, errors.New("Message exceeds ParseMaxMessageLength") } - - // p.log.Debugf("%s reads body with length = %d bytes", p, contentLength) - body := make([]byte, contentLength) - total, err := reader.Read(body) + msg, contentLength, total, err := p.parseHeaders(data, stream) if err != nil { - return nil, fmt.Errorf("read message body failed: %w", err) + return msg, total, err + } + data = data[total:] + bodySize := -1 + if contentLength != nil { + bodySize = int(*contentLength) + } else if !stream { + bodySize = len(data) + } + if bodySize < 0 { + // RFC 3261 - 7.5. + // The Content-Length header field value is used to locate the end of + // each SIP message in a stream. It will always be present when SIP + // messages are sent over stream-oriented transports. + return msg, total, ErrParseReadBodyIncomplete } + if bodySize == 0 { + return msg, total, nil + } + body := make([]byte, bodySize) + n := copy(body, data) + total += n + msg.SetBody(body) // RFC 3261 - 18.3. - if total != contentLength { - return nil, fmt.Errorf( - "incomplete message body: read %d bytes, expected %d bytes", - len(body), - contentLength, - ) + if n != bodySize { + return msg, total, ErrParseReadBodyIncomplete } + return msg, total, nil +} - // Should we trim this? - // if len(bytes.TrimSpace(body)) > 0 { - if len(body) > 0 { - msg.SetBody(body) +// ParseSIP converts data to sip message. Buffer must contain full sip message +func (p *Parser) ParseSIP(data []byte) (msg Message, err error) { + msg, _, err = p.Parse(data, false) + if err == io.ErrUnexpectedEOF { + err = ErrParseEOF } - return msg, nil + return msg, err } // NewSIPStream implements SIP parsing contructor for IO that stream SIP message // It should be created per each stream func (p *Parser) NewSIPStream() *ParserStream { + if p == nil { + p = NewParser() + } return &ParserStream{ - headersParsers: p.headersParsers, // safe as it read only + p: p, // safe as it read only } } func parseLine(startLine string) (msg Message, err error) { - if isRequest(startLine) { - recipient := Uri{} - method, sipVersion, err := parseRequestLine(startLine, &recipient) - if err != nil { - return nil, err - } - - m := NewRequest(method, recipient) - m.SipVersion = sipVersion - return m, nil - } + if parts, ok := split3(startLine); ok { + if isRequest(parts) { + recipient := Uri{} + method, sipVersion, err := parseRequestLine(parts, &recipient) + if err != nil { + return nil, err + } - if isResponse(startLine) { - sipVersion, statusCode, reason, err := parseStatusLine(startLine) - if err != nil { - return nil, err + m := NewRequest(method, recipient) + m.SipVersion = sipVersion + return m, nil } + if isResponse(parts) { + sipVersion, statusCode, reason, err := parseStatusLine(parts) + if err != nil { + return nil, err + } - m := NewResponse(statusCode, reason) - m.SipVersion = sipVersion - return m, nil + m := NewResponse(statusCode, reason) + m.SipVersion = sipVersion + return m, nil + } } return nil, fmt.Errorf("transmission beginning '%s' is not a SIP message", startLine) } -// nextLine should read until it hits CRLF -// ErrParseLineNoCRLF -> could not find CRLF in line +// NextLine reads the next line of a SIP message and the number of bytes read. // -// https://datatracker.ietf.org/doc/html/rfc3261#section-7 -// empty line MUST be -// terminated by a carriage-return line-feed sequence (CRLF). Note that -// the empty line MUST be present even if the message-body is not. -func nextLine(reader *bytes.Buffer) (line string, err error) { +// It returns io.ErrUnexpectedEOF is there's no CRLF (\r\n) in the data. +// If there's a CR (\r) which is not followed by LF (\n), a ErrParseLineNoCRLF is returned. +// As a special case, it returns io.EOF if data is empty. +func NextLine(data []byte) ([]byte, int, error) { + if len(data) == 0 { + return nil, 0, io.EOF + } // https://www.rfc-editor.org/rfc/rfc3261.html#section-7 // The start-line, each message-header line, and the empty line MUST be // terminated by a carriage-return line-feed sequence (CRLF). Note that // the empty line MUST be present even if the message-body is not. // Lines could be multiline as well so this is also acceptable - // TO : + // TO :\n // sip:vivekg@chair-dnrc.example.com ; tag = 1918181833n - line, err = reader.ReadString('\r') - if err != nil { - // We may get io.EOF and line till it was read - return line, err - } - br, err := reader.ReadByte() - if err != nil { - return line, err + i := bytes.IndexByte(data, '\r') + if i < 0 { + return data, len(data), io.ErrUnexpectedEOF } - - if br != '\n' { - return line, ErrParseLineNoCRLF - } - lenline := len(line) - if lenline < 1 { - return line, ErrParseLineNoCRLF + line := data[:i] + if i+1 >= len(data) { + return line, i + 1, io.ErrUnexpectedEOF } - - line = line[:lenline-1] - return line, nil -} - -// Calculate the size of a SIP message's body, given the entire contents of the message as a byte array. -func getBodyLength(data []byte) int { - // Body starts with first character following a double-CRLF. - idx := bytes.Index(data, []byte("\r\n\r\n")) - if idx == -1 { - return -1 + if data[i+1] != '\n' { + return line, i + 1, ErrParseLineNoCRLF } - - bodyStart := idx + 4 - - return len(data) - bodyStart + return line, i + 2, nil } -// detet is request by spaces -func isRequest(startLine string) bool { +// detect is request by spaces +func isRequest(parts [3]string) bool { // SIP request lines contain precisely two spaces. - ind := strings.IndexRune(startLine, ' ') - if ind <= 0 { - return false - } - - // part0 := startLine[:ind] - ind1 := strings.IndexRune(startLine[ind+1:], ' ') - if ind1 <= 0 { - return false - } - - part2 := startLine[ind+1+ind1+1:] - ind2 := strings.IndexRune(part2, ' ') - if ind2 >= 0 { + part2 := parts[2] + if len(part2) < 3 { return false } - - if len(part2) < 3 { + i := strings.IndexByte(part2, ' ') + if i >= 0 { return false } - return UriIsSIP(part2[:3]) } // Detect is response by spaces -func isResponse(startLine string) bool { - // SIP status lines contain at least two spaces. - ind := strings.IndexRune(startLine, ' ') - if ind <= 0 { +func isResponse(parts [3]string) bool { + part0 := parts[0] + if len(part0) < 3 { return false } + return UriIsSIP(part0[:3]) +} - // part0 := startLine[:ind] - ind1 := strings.IndexRune(startLine[ind+1:], ' ') - if ind1 <= 0 { - return false +func split3(s string) (parts [3]string, ok bool) { + i := strings.IndexByte(s, ' ') + if i < 0 { + return } + parts[0] = s[:i] + s = s[i+1:] - return UriIsSIP(startLine[:3]) + i = strings.IndexByte(s, ' ') + if i < 0 { + return + } + parts[1] = s[:i] + s = s[i+1:] + parts[2] = s + return parts, true } // Parse the first line of a SIP request, e.g: // // INVITE bob@example.com SIP/2.0 // REGISTER jane@telco.com SIP/1.0 -func parseRequestLine(requestLine string, recipient *Uri) ( - method RequestMethod, sipVersion string, err error) { - parts := strings.Split(requestLine, " ") - if len(parts) != 3 { - err = fmt.Errorf("request line should have 2 spaces: '%s'", requestLine) - return - } - +func parseRequestLine(parts [3]string, recipient *Uri) (method RequestMethod, sipVersion string, err error) { method = RequestMethod(strings.ToUpper(parts[0])) err = ParseUri(parts[1], recipient) sipVersion = parts[2] if recipient.Wildcard { - err = fmt.Errorf("wildcard URI '*' not permitted in request line: '%s'", requestLine) + err = fmt.Errorf("wildcard URI '*' not permitted in request line") return } @@ -295,18 +362,10 @@ func parseRequestLine(requestLine string, recipient *Uri) ( // // SIP/2.0 200 OK // SIP/1.0 403 Forbidden -func parseStatusLine(statusLine string) ( - sipVersion string, statusCode int, reasonPhrase string, err error) { - parts := strings.Split(statusLine, " ") - if len(parts) < 3 { - err = fmt.Errorf("status line has too few spaces: '%s'", statusLine) - return - } - +func parseStatusLine(parts [3]string) (sipVersion string, statusCode int, reasonPhrase string, err error) { sipVersion = parts[0] statusCodeRaw, err := strconv.ParseUint(parts[1], 10, 16) statusCode = int(statusCodeRaw) - reasonPhrase = strings.Join(parts[2:], " ") - + reasonPhrase = parts[2] return } diff --git a/sip/parser_stream.go b/sip/parser_stream.go index 6b67f86..6a509bb 100644 --- a/sip/parser_stream.go +++ b/sip/parser_stream.go @@ -2,20 +2,20 @@ package sip import ( "bytes" + "errors" "fmt" "io" "sync" ) +type parserState int + const ( - stateStartLine = 0 - stateHeader = 1 - stateContent = 2 - // stateParsed = 1 + stateStartLine = parserState(iota) + stateHeader + stateContent ) -var () - var streamBufReader = sync.Pool{ New: func() interface{} { // The Pool's New function should generally only return pointer @@ -26,159 +26,187 @@ var streamBufReader = sync.Pool{ } type ParserStream struct { - // HeadersParsers uses default list of headers to be parsed. Smaller list parser will be faster - headersParsers mapHeadersParser + p *Parser // runtime values - reader *bytes.Buffer - msg Message - readContentLength int - state int + buf *bytes.Buffer + state parserState + totalRead int + msg Message + headerBuf []Header + contentLength *ContentLengthHeader + contentOff int } func (p *ParserStream) reset() { p.state = stateStartLine - p.reader = nil + p.totalRead = 0 p.msg = nil - p.readContentLength = 0 + for i := range p.headerBuf { + p.headerBuf[i] = nil + } + p.headerBuf = p.headerBuf[:0] + p.contentLength = nil + p.contentOff = 0 +} + +// Reset the parser and the internal buffer. +func (p *ParserStream) Reset() { + p.reset() + if p.buf != nil { + p.buf.Reset() + } +} + +// Close the parser and free the associated resources. +func (p *ParserStream) Close() { + p.reset() + buf := p.buf + p.buf = nil + if buf != nil { + streamBufReader.Put(buf) + } } // ParseSIPStream parsing messages comming in stream // It has slight overhead vs parsing full message func (p *ParserStream) ParseSIPStream(data []byte) (msgs []Message, err error) { - return msgs, p.ParseSIPStreamEach(data, func(msg Message) { + err = p.ParseSIPStreamEach(data, func(msg Message) { msgs = append(msgs, msg) }) + return msgs, err } // ParseSIPStreamEach parses SIP stream and calls callback as soon first SIP message is parsed -func (p *ParserStream) ParseSIPStreamEach(data []byte, cb func(msg Message)) (err error) { - if p.reader == nil { - p.reader = streamBufReader.Get().(*bytes.Buffer) - p.reader.Reset() +func (p *ParserStream) ParseSIPStreamEach(data []byte, cb func(msg Message)) error { + if _, err := p.Write(data); err != nil { + return err } - - reader := p.reader - if reader.Len()+len(data) > ParseMaxMessageLength { - return fmt.Errorf("Message exceeds ParseMaxMessageLength") - } - - reader.Write(data) // This should append to our already buffer - - unparsed := reader.Bytes() // TODO find a better way as we only want to move our offset - for { - err := p.parseSingle(reader, &unparsed) - switch err { - case ErrParseLineNoCRLF, ErrParseReadBodyIncomplete: - reader.Reset() - reader.Write(unparsed) + for p.buf.Len() > 0 { + msg, _, err := p.ParseNext() + if errors.Is(err, io.ErrUnexpectedEOF) { return ErrParseSipPartial - } - - if err != nil { + } else if err != nil { return err } + cb(msg) + } + return nil +} - cb(p.msg) - if len(unparsed) == 0 { - // Maybe we need to check did empty spaces left - break - } - - p.reset() - reader.Reset() - reader.Write(unparsed) - p.reader = reader +// Buffer returns an internal buffer used by the parser. +// This allows to inspect the current parser state and possibly recover the stream with Discard. +func (p *ParserStream) Buffer() *bytes.Buffer { + if p.buf == nil { + p.buf = streamBufReader.Get().(*bytes.Buffer) + p.buf.Reset() } + return p.buf +} - // IN all other cases do reset - streamBufReader.Put(reader) +// Discard specified amount of data and reset the parser. +// Can be used to skip malformed messages and recover the stream. +func (p *ParserStream) Discard(n int) { p.reset() + if p.buf != nil { + _ = p.buf.Next(n) + } +} - return +// Write data to the internal buffer. Must be called before ParseNext. +func (p *ParserStream) Write(data []byte) (int, error) { + buf := p.Buffer() + if buf.Len()+len(data) > p.p.MaxMessageLength { + return 0, errors.New("Message exceeds ParseMaxMessageLength") + } + buf.Write(data) // This should append to our already buffer + return len(data), nil } -func (p *ParserStream) parseSingle(reader *bytes.Buffer, unparsed *[]byte) (err error) { - // TODO change this with functions and store last function state - switch p.state { - case stateStartLine: - startLine, err := nextLine(reader) +// ParseNext parses the next SIP message from an internal buffer. +// It may return io.ErrUnexpectedEOF, indicating that more data needs to be written with Write. +func (p *ParserStream) ParseNext() (Message, int, error) { + if p.buf == nil { + return nil, 0, io.ErrUnexpectedEOF + } + err := p.parseSingle() + msg, n := p.msg, p.totalRead + if err == nil { + p.reset() + } + return msg, n, err +} - if err != nil { - if err == io.EOF { - return ErrParseLineNoCRLF - } - return err - } +func (p *ParserStream) advance(n int) { + p.totalRead += n + _ = p.buf.Next(n) +} - msg, err := parseLine(startLine) +func (p *ParserStream) parseSingle() error { + if p.buf == nil { + return io.ErrUnexpectedEOF + } + var ( + n int + err error + ) + switch p.state { + case stateStartLine: + var msg Message + msg, n, err = p.p.parseStartLine(p.buf.Bytes(), true) + p.advance(n) if err != nil { return err } - - *unparsed = reader.Bytes() p.state = stateHeader p.msg = msg fallthrough case stateHeader: - msg := p.msg for { - line, err := nextLine(reader) - - if err != nil { - if err == io.EOF { - // No more to read - return ErrParseLineNoCRLF + p.headerBuf, n, err = p.p.parseNextHeader(p.headerBuf[:0], p.buf.Bytes()) + p.advance(n) + for _, h := range p.headerBuf { + switch h := h.(type) { + case *ContentLengthHeader: + p.contentLength = h } - return err + p.msg.AppendHeader(h) } - - if len(line) == 0 { - // We've hit second CRLF + if err == errParseNoMoreHeaders { break } - - err = p.headersParsers.parseMsgHeader(msg, line) if err != nil { - return fmt.Errorf("%s: %w", err.Error(), ErrParseEOF) - // log.Info().Err(err).Str("line", line).Msg("skip header due to error") + return err } - *unparsed = reader.Bytes() } - *unparsed = reader.Bytes() - - h := msg.ContentLength() - if h == nil { - return nil + if p.contentLength == nil || *p.contentLength < 0 { + // RFC 3261 - 7.5. + // The Content-Length header field value is used to locate the end of + // each SIP message in a stream. It will always be present when SIP + // messages are sent over stream-oriented transports. + return ErrParseReadBodyIncomplete } - - contentLength := int(*h) - if contentLength <= 0 { + contentLength := int(*p.contentLength) + if contentLength == 0 { + p.state = -1 return nil } - body := make([]byte, contentLength) - msg.SetBody(body) - + p.msg.SetBody(body) p.state = stateContent fallthrough case stateContent: - msg := p.msg - body := msg.Body() + body := p.msg.Body() contentLength := len(body) - n, err := reader.Read(body[p.readContentLength:]) - *unparsed = reader.Bytes() - if err != nil { - return fmt.Errorf("read message body failed: %w", err) - } - p.readContentLength += n + n = copy(body[p.contentOff:], p.buf.Bytes()) + p.advance(n) + p.contentOff += n - if p.readContentLength < contentLength { - return ErrParseReadBodyIncomplete + if p.contentOff < contentLength { + return io.ErrUnexpectedEOF } - - p.state = -1 // Clear state + p.state = -1 return nil default: return fmt.Errorf("Parser is in unknown state") diff --git a/sip/parser_stream_test.go b/sip/parser_stream_test.go index 5ad5eff..ec2c457 100644 --- a/sip/parser_stream_test.go +++ b/sip/parser_stream_test.go @@ -2,6 +2,7 @@ package sip import ( "fmt" + "math/rand/v2" "runtime" "strings" "testing" @@ -11,7 +12,7 @@ import ( ) func TestParserStreamBadMessage(t *testing.T) { - parser := ParserStream{} + parser := NewParser().NewSIPStream() // The start-line, each message-header line, and the empty line MUST be // terminated by a carriage-return line-feed sequence (CRLF). Note that @@ -27,7 +28,7 @@ func TestParserStreamBadMessage(t *testing.T) { } msgstr := strings.Join(rawMsg, "\r\n") _, err := parser.ParseSIPStream([]byte(msgstr)) - require.ErrorIs(t, err, ErrParseEOF) + require.Error(t, err) }) t.Run("finish empty line", func(t *testing.T) { rawMsg := []string{ @@ -47,6 +48,7 @@ func TestParserStreamMessage(t *testing.T) { parser := p.NewSIPStream() lines := []string{ + "", "", // Check that stream ignores CRLF in the beginning "INVITE sip:192.168.1.254:5060 SIP/2.0", "Via: SIP/2.0/TCP 192.168.1.155:44861;branch=z9hG4bK954690f3012120bc5d064d3f7b5d8a24;rport", "Call-ID: 25be1c3be64adb89fa2e86772dd99db1", @@ -153,46 +155,91 @@ func TestParserStreamMessage(t *testing.T) { "", // Content length includes last CRLF } data := []byte((strings.Join(lines, "\r\n"))) + const bodySize = 3119 + + for _, c := range []struct { + Name string + Skip int + Split []int + }{ + // arbitrary split points + {Split: []int{500, 1000}, Skip: 4}, + {Split: []int{300, 2000}, Skip: 4}, + {Split: []int{500, 1000}}, + {Split: []int{300, 2000}}, + // split at specific places + { + Name: "few bytes", + Split: []int{1, 2, 3, 4, 5, 6}, + }, + { + Name: "CRLF pings", + Split: []int{2, 4}, + }, + { + Name: "after start line", + Split: []int{37 + 2}, Skip: 4, + }, + { + Name: "after header", + Split: []int{37 + 2 + 89 + 2}, Skip: 4, + }, + { + Name: "after all headers", + Split: []int{702}, Skip: 4, + }, + { + Name: "before body", + Split: []int{704}, Skip: 4, + }, + // completely random split (try a few times) + {Split: []int{rand.IntN(len(data))}}, + {Split: []int{rand.IntN(len(data))}}, + {Split: []int{rand.IntN(len(data))}}, + } { + name := c.Name + if name == "" { + name = fmt.Sprintf("split_%v_skip_%d", c.Split, c.Skip) + name = strings.ReplaceAll(name, "[", "") + name = strings.ReplaceAll(name, "]", "") + } + t.Run(name, func(t *testing.T) { + data := data + if c.Skip != 0 { + data = data[c.Skip:] + } + var parts [][]byte + for i := range c.Split { + start := 0 + if i > 0 { + start = c.Split[i-1] + } + end := c.Split[i] + parts = append(parts, data[start:end]) + } + lastPart := data[c.Split[len(c.Split)-1]:] - // make partials - part1 := data[:500] - part2 := data[500:1000] - part3 := data[1000:] - - t.Run("first run", func(t *testing.T) { - t.Logf("Parsing part 1:\n%s", string(part1)) - _, err := parser.ParseSIPStream(part1) - require.Error(t, err) - require.ErrorIs(t, err, ErrParseSipPartial) - - t.Logf("Parsing part 2:\n%s", string(part2)) - _, err = parser.ParseSIPStream(part2) - require.Error(t, err) - require.ErrorIs(t, err, ErrParseSipPartial) - - t.Logf("Parsing part 3:\n%s", string(part3)) - msgs, err := parser.ParseSIPStream(part3) - msg := msgs[0] - require.NoError(t, err) - require.NotNil(t, msg) - require.Len(t, msg.Body(), 3119) - // Check is parser reset it self - require.Nil(t, parser.reader) - }) - - t.Run("second run", func(t *testing.T) { - part1 := data[:300] - part2 := data[300:2000] - part3 := data[2000:] + for i, part := range parts { + t.Logf("Parsing part %d:\n%s", i+1, string(part)) + _, err := parser.ParseSIPStream(part) + require.Error(t, err) + require.ErrorIs(t, err, ErrParseSipPartial) + } + t.Logf("Parsing final part:\n%s", string(lastPart)) + msgs, err := parser.ParseSIPStream(lastPart) + require.NoError(t, err) + msg := msgs[0] + require.NotNil(t, msg) + require.Len(t, msg.Body(), bodySize) + }) + } - parser.ParseSIPStream(part1) - parser.ParseSIPStream(part2) - msg, err := parser.ParseSIPStream(part3) - require.NoError(t, err) - require.NotNil(t, msg) - require.Nil(t, parser.reader) + t.Run("reset", func(t *testing.T) { + // Check is parser resets itself + require.True(t, parser.state == stateStartLine) + parser.Close() + require.Nil(t, parser.buf) }) - } func TestParserStreamChunky(t *testing.T) { diff --git a/sip/parser_test.go b/sip/parser_test.go index 4603d74..f7744a2 100644 --- a/sip/parser_test.go +++ b/sip/parser_test.go @@ -33,8 +33,11 @@ func testParseHeaderOnRequest(t *testing.T, parser *Parser, header string) (*Req // This is fake way to get parsing done. We use fake message and read first header msg := NewRequest(INVITE, Uri{}) name := strings.Split(header, ":")[0] - err := parser.headersParsers.parseMsgHeader(msg, header) + out, err := parser.headersParsers.ParseHeader(nil, []byte(header)) require.Nil(t, err) + for _, h := range out { + msg.AppendHeader(h) + } return msg, msg.GetHeader(name) } @@ -172,9 +175,10 @@ func BenchmarkParserHeaders(b *testing.B) { branch := GenerateBranch() header := "Via: SIP/2.0/UDP 127.0.0.2:5060;branch=" + branch colonIdx := strings.Index(header, ":") + name := []byte(header[:colonIdx]) b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := headerParserVia(header[:colonIdx], header[colonIdx+2:]) + _, err := headerParserVia(name, header[colonIdx+2:]) if err != nil { b.Fatal(err) } @@ -184,9 +188,10 @@ func BenchmarkParserHeaders(b *testing.B) { b.Run("ToHeader", func(b *testing.B) { header := "To: \"Bob\" " colonIdx := strings.Index(header, ":") + name := []byte(header[:colonIdx]) b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := headerParserTo(header[:colonIdx], header[colonIdx+2:]) + _, err := headerParserTo(name, header[colonIdx+2:]) if err != nil { b.Fatal(err) } @@ -196,9 +201,10 @@ func BenchmarkParserHeaders(b *testing.B) { b.Run("FromHeader", func(b *testing.B) { header := "From: \"Bob\" " colonIdx := strings.Index(header, ":") + name := []byte(header[:colonIdx]) b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := headerParserFrom(header[:colonIdx], header[colonIdx+2:]) + _, err := headerParserFrom(name, header[colonIdx+2:]) if err != nil { b.Fatal(err) } @@ -208,9 +214,10 @@ func BenchmarkParserHeaders(b *testing.B) { b.Run("ContactHeader", func(b *testing.B) { header := "Contact: " colonIdx := strings.Index(header, ":") + name := []byte(header[:colonIdx]) b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := headerParserContact(header[:colonIdx], header[colonIdx+2:]) + _, err := headerParserContact(name, header[colonIdx+2:]) if err != nil { b.Fatal(err) } @@ -220,9 +227,10 @@ func BenchmarkParserHeaders(b *testing.B) { b.Run("CSEQ", func(b *testing.B) { header := "CSEQ: 1234 INVITE" colonIdx := strings.Index(header, ":") + name := []byte(header[:colonIdx]) b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := headerParserCSeq(header[:colonIdx], header[colonIdx+2:]) + _, err := headerParserCSeq(name, header[colonIdx+2:]) if err != nil { b.Fatal(err) } @@ -232,9 +240,10 @@ func BenchmarkParserHeaders(b *testing.B) { b.Run("Route", func(b *testing.B) { header := "Route: " colonIdx := strings.Index(header, ":") + name := []byte(header[:colonIdx]) b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := headerParserRoute(header[:colonIdx], header[colonIdx+2:]) + _, err := headerParserRoute(name, header[colonIdx+2:]) if err != nil { b.Fatal(err) } @@ -260,6 +269,8 @@ func TestParseBadMessages(t *testing.T) { msgstr := strings.Join(rawMsg, "\r\n") _, err := parser.ParseSIP([]byte(msgstr)) require.ErrorIs(t, err, ErrParseEOF) + _, _, err = parser.Parse([]byte(msgstr), false) + require.Error(t, err, io.ErrUnexpectedEOF) }) t.Run("finish empty line", func(t *testing.T) { rawMsg := []string{ @@ -271,6 +282,8 @@ func TestParseBadMessages(t *testing.T) { msgstr := strings.Join(rawMsg, "\r\n") _, err := parser.ParseSIP([]byte(msgstr)) require.Error(t, err, ErrParseEOF) + _, _, err = parser.Parse([]byte(msgstr), false) + require.Error(t, err, io.ErrUnexpectedEOF) }) } @@ -282,12 +295,8 @@ func TestParseRequest(t *testing.T) { t.Run("NoCRLF", func(t *testing.T) { // https://www.rfc-editor.org/rfc/rfc3261.html#section-7 // In case of missing CRLF - m := "INVITE sip:10.5.0.10:5060;transport=udp SIP/2.0\nContent-Length: 0" - _, err := parser.ParseSIP([]byte(m)) - assert.ErrorIs(t, err, io.EOF) - for _, msgstr := range []string{ - // "INVITE sip:10.5.0.10:5060;transport=udp SIP/2.0\nContent-Length: 0", + "INVITE sip:10.5.0.10:5060;transport=udp SIP/2.0\nContent-Length: 0", "INVITE sip:10.5.0.10:5060;transport=udp SIP/2.0\r\nContent-Length: 0\n", "INVITE sip:10.5.0.10:5060;transport=udp SIP/2.0\r\nContent-Length: 0\r\n\n", "INVITE sip:10.5.0.10:5060;transport=udp SIP/2.0\r\nContent-Length: 10\r\nabcd\nefgh", @@ -560,9 +569,10 @@ func BenchmarkParseStartLine(b *testing.B) { func BenchmarkParserAddressValue(b *testing.B) { header := "To: \"Bob\" ;tag=1928301774;xxx=xxx;yyyy=yyyy" + name := []byte("To") b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := headerParserTo("To", header[4:]) + _, err := headerParserTo(name, header[4:]) if err != nil { b.Fatal(err) } diff --git a/sip/utils.go b/sip/utils.go index ca9505f..a9eddbc 100644 --- a/sip/utils.go +++ b/sip/utils.go @@ -42,6 +42,32 @@ func isASCII(c rune) bool { return 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' } +func asciiToLower(s []byte) []byte { + // first check is ascii already low to avoid alloc + nonLowInd := -1 + for i, c := range s { + if 'a' <= c && c <= 'z' { + continue + } + nonLowInd = i + break + } + if nonLowInd < 0 { + return s + } + + b := make([]byte, len(s)) + copy(b, s[:nonLowInd]) + for i := nonLowInd; i < len(s); i++ { + c := s[i] + if 'A' <= c && c <= 'Z' { + c += 'a' - 'A' + } + b[i] = c + } + return b +} + // ASCIIToLower is faster than go version. It avoids one more loop func ASCIIToLower(s string) string { // first check is ascii already low to avoid alloc @@ -107,6 +133,76 @@ func ASCIIToUpper(s string) string { return b.String() } +var ( + hdrVia = []byte("via") + hdrFrom = []byte("from") + hdrTo = []byte("to") + hdrCallID = []byte("call-id") + hdrContact = []byte("contact") + hdrCSeq = []byte("cseq") + hdrContentType = []byte("content-type") + hdrContentLength = []byte("content-length") + hdrRoute = []byte("route") + hdrRecordRoute = []byte("record-route") + hdrMaxForwards = []byte("max-forwards") + hdrTimestamp = []byte("timestamp") +) + +func headerToLower(s []byte) []byte { + if len(s) == 1 { + c := s[0] + if 'A' <= c && c <= 'Z' { + c += 'a' - 'A' + } + switch c { + case 't': + return hdrTo + case 'f': + return hdrFrom + case 'v': + return hdrVia + case 'i': + return hdrCallID + case 'l': + return hdrContentLength + case 'c': + return hdrContentType + case 'm': + return hdrContact + } + } + // Avoid allocations + switch string(s) { + case "Via", "via": + return hdrVia + case "From", "from": + return hdrFrom + case "To", "to": + return hdrTo + case "Call-ID", "call-id": + return hdrCallID + case "Contact", "contact": + return hdrContact + case "CSeq", "CSEQ", "cseq": + return hdrCSeq + case "Content-Type", "content-type": + return hdrContentType + case "Content-Length", "content-length": + return hdrContentLength + case "Route", "route": + return hdrRoute + case "Record-Route", "record-route": + return hdrRecordRoute + case "Max-Forwards": + return hdrMaxForwards + case "Timestamp", "timestamp": + return hdrTimestamp + } + + // This creates one allocation if we really need to lower + return asciiToLower(s) +} + // HeaderToLower is fast ASCII lower string func HeaderToLower(s string) string { // Avoid allocations @@ -125,6 +221,8 @@ func HeaderToLower(s string) string { return "cseq" case "Content-Type", "content-type": return "content-type" + case "Content-Length", "content-length": + return "content-length" case "Route", "route": return "route" case "Record-Route", "record-route": diff --git a/sip/utils_test.go b/sip/utils_test.go index a66275a..f70df5e 100644 --- a/sip/utils_test.go +++ b/sip/utils_test.go @@ -14,9 +14,7 @@ import ( func testCreateMessage(t testing.TB, rawMsg []string) Message { msg, err := ParseMessage([]byte(strings.Join(rawMsg, "\r\n"))) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) return msg }