mql: Go Coverage Report (original) (raw)

// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0

package mql

import ( "fmt" "reflect" )

// isNil reports if a is nil func isNil(a any) bool { if a == nil { return true } switch reflect.TypeOf(a).Kind() { case reflect.Ptr, reflect.Map, reflect.Chan, reflect.Slice, reflect.Func: return reflect.ValueOf(a).IsNil() } return false }

// panicIfNil will panic if a is nil func panicIfNil(a any, caller, missing string) { if isNil(a) { panic(fmt.Sprintf("%s: missing %s", caller, missing)) } }

func pointer[T any](input T) *T { return &input }

// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0

package mql

import ( "fmt" )

type exprType int

const ( unknownExprType exprType = iota comparisonExprType logicalExprType )

type expr interface { Type() exprType String() string }

// ComparisonOp defines a set of comparison operators type ComparisonOp string

const ( GreaterThanOp ComparisonOp = ">" GreaterThanOrEqualOp ComparisonOp = ">=" LessThanOp ComparisonOp = "<" LessThanOrEqualOp ComparisonOp = "<=" EqualOp ComparisonOp = "=" NotEqualOp ComparisonOp = "!=" ContainsOp ComparisonOp = "%" )

func newComparisonOp(s string) (ComparisonOp, error) { const op = "newComparisonOp" switch ComparisonOp(s) { case GreaterThanOp, GreaterThanOrEqualOp, LessThanOp, LessThanOrEqualOp, EqualOp, NotEqualOp, ContainsOp: return ComparisonOp(s), nil default: return "", fmt.Errorf("%s: %w %q", op, ErrInvalidComparisonOp, s) } }

type comparisonExpr struct { column string comparisonOp ComparisonOp value *string }

// Type returns the expr type func (e *comparisonExpr) Type() exprType { return comparisonExprType }

// String returns a string rep of the expr func (e *comparisonExpr) String() string { switch e.value { case nil: return fmt.Sprintf("(comparisonExpr: %s %s nil)", e.column, e.comparisonOp) default: return fmt.Sprintf("(comparisonExpr: %s %s %s)", e.column, e.comparisonOp, *e.value) } }

func (e *comparisonExpr) isComplete() bool { return e.column != "" && e.comparisonOp != "" && e.value != nil }

// defaultValidateConvert will validate the comparison expr value, and then convert the // expr to its SQL equivalence. func defaultValidateConvert(columnName string, comparisonOp ComparisonOp, columnValue *string, validator validator, opt ...Option) (*WhereClause, error) { const op = "mql.(comparisonExpr).convertToSql" switch { case columnName == "": return nil, fmt.Errorf("%s: %w", op, ErrMissingColumn) case comparisonOp == "": return nil, fmt.Errorf("%s: %w", op, ErrMissingComparisonOp) case isNil(columnValue): return nil, fmt.Errorf("%s: %w", op, ErrMissingComparisonValue) case validator.fn == nil: return nil, fmt.Errorf("%s: missing validator function: %w", op, ErrInvalidParameter) case validator.typ == "": return nil, fmt.Errorf("%s: missing validator type: %w", op, ErrInvalidParameter) }

    // everything was validated at the start, so we know this is a valid/complete comparisonExpr
    e := &comparisonExpr{
            column:       columnName,
            comparisonOp: comparisonOp,
            value:        columnValue,
    }

    v, err := validator.fn(*e.value)
    if err != nil {
            return nil, fmt.Errorf("%s: %q in %s: %w", op, *e.value, e.String(), ErrInvalidParameter)
    }

    opts, err := getOpts(opt...)
    if err != nil {
            return nil, fmt.Errorf("%s: %w", op, err)
    }
    if n, ok := opts.withTableColumnMap[columnName]; ok {
            // override our column name with the mapped column name
            columnName = n
    }

    if validator.typ == "time" {
            columnName = fmt.Sprintf("%s::date", columnName)
    }
    switch e.comparisonOp {
    case ContainsOp:
            return &WhereClause{
                    Condition: fmt.Sprintf("%s like ?", columnName),
                    Args:      []any{fmt.Sprintf("%%%s%%", v)},
            }, nil
    default:
            return &WhereClause{
                    Condition: fmt.Sprintf("%s%s?", columnName, e.comparisonOp),
                    Args:      []any{v},
            }, nil
    }

}

type logicalOp string

const ( andOp logicalOp = "and" orOp logicalOp = "or" )

func newLogicalOp(s string) (logicalOp, error) { const op = "newLogicalOp" switch logicalOp(s) { case andOp, orOp: return logicalOp(s), nil default: return "", fmt.Errorf("%s: %w %q", op, ErrInvalidLogicalOp, s) } }

type logicalExpr struct { leftExpr expr logicalOp logicalOp rightExpr expr }

// Type returns the expr type func (l *logicalExpr) Type() exprType { return logicalExprType }

// String returns a string rep of the expr func (l *logicalExpr) String() string { return fmt.Sprintf("(logicalExpr: %s %s %s)", l.leftExpr, l.logicalOp, l.rightExpr) }

// root will return the root of the expr tree func root(lExpr *logicalExpr, raw string) (expr, error) { const op = "mql.root" switch { // intentionally not checking raw, since can be an empty string case lExpr == nil: return nil, fmt.Errorf("%s: %w (missing expression)", op, ErrInvalidParameter) } logicalOp := lExpr.logicalOp if logicalOp != "" && lExpr.rightExpr == nil { return nil, fmt.Errorf("%s: %w in: %q", op, ErrMissingRightSideExpr, raw) }

    for lExpr.logicalOp == "" {
            switch {
            case lExpr.leftExpr == nil:
                    return nil, fmt.Errorf("%s: %w nil in: %q", op, ErrMissingExpr, raw)
            case lExpr.leftExpr.Type() == comparisonExprType:
                    return lExpr.leftExpr, nil
            default:
                    lExpr = lExpr.leftExpr.(*logicalExpr)
            }
    }
    return lExpr, nil

}

// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0

package mql

import ( "bufio" "bytes" "fmt" "strings" "unicode" )

// Delimiter used to quote strings type Delimiter rune

const ( DoubleQuote Delimiter = '"' SingleQuote Delimiter = ''' Backtick Delimiter = '`'

    backslash = '\\'

)

type lexStateFunc func(*lexer) (lexStateFunc, error)

type lexer struct { source *bufio.Reader current stack[rune] tokens chan token state lexStateFunc }

func newLexer(s string) *lexer { l := &lexer{ source: bufio.NewReader(strings.NewReader(s)), state: lexStartState, tokens: make(chan token, 1), // define a ring buffer for emitted tokens } return l }

// nextToken is the external api for the lexer and it simply returns the next // token or an error. If EOF is encountered while scanning, nextToken will keep // returning an eofToken no matter how many times you call nextToken. func (l *lexer) nextToken() (token, error) { for { select { case tk := <-l.tokens: // return a token if one has been emitted return tk, nil default: // otherwise, keep scanning via the next state var err error if l.state, err = l.state(l); err != nil { return token{}, err }

            }
    }

}

// lexStartState is the start state. It doesn't emit tokens, but rather // transitions to other states. Other states typically transition back to // lexStartState after they emit a token. func lexStartState(l *lexer) (lexStateFunc, error) { panicIfNil(l, "lexStartState", "lexer") r := l.read() switch { // wait, if it's eof we're done case r == eof: l.emit(eofToken, "") return lexEofState, nil

    // start with finding all tokens that can have a trailing "="
    case r == '>':
            return lexGreaterState, nil
    case r == '<':
            return lexLesserState, nil

            // now, we can just look at the next rune...
    case r == '%':
            return lexContainsState, nil
    case r == '=':
            return lexEqualState, nil
    case r == '!':
            return lexNotEqualState, nil
    case r == ')':
            return lexRightParenState, nil
    case r == '(':
            return lexLeftParenState, nil
    case isSpace(r):
            return lexWhitespaceState, nil
    case unicode.IsDigit(r) || r == '.':
            l.unread()
            return lexNumberState, nil
    case isDelimiter(r):
            l.unread()
            return lexStringState, nil
    default:
            l.unread()
            return lexSymbolState, nil
    }

}

// lexStringState scans for strings and can emit a stringToken func lexStringState(l *lexer) (lexStateFunc, error) { const op = "mql.lexStringState" panicIfNil(l, "lexStringState", "lexer") defer l.current.clear()

    // we'll push the runes we read into this buffer and when appropriate will
    // emit tokens using the buffer's data.
    var tokenBuf bytes.Buffer

    // before we start looping, let's found out if we're scanning a quoted string
    r := l.read()
    delimiter := r
    if !isDelimiter(delimiter) {
            return nil, fmt.Errorf("%s: %w %q", op, ErrInvalidDelimiter, delimiter)
    }
    finalDelimiter := false

WriteToBuf: // keep reading runes into the buffer until we encounter eof or the final delimiter. for { r = l.read() switch { case r == eof: break WriteToBuf case r == backslash: nextR := l.read() switch { case nextR == eof: tokenBuf.WriteRune(r) return nil, fmt.Errorf("%s: %w in %q", op, ErrInvalidTrailingBackslash, tokenBuf.String()) case nextR == backslash: tokenBuf.WriteRune(nextR) case nextR == delimiter: tokenBuf.WriteRune(nextR) default: tokenBuf.WriteRune(r) tokenBuf.WriteRune(nextR) } case r == delimiter: // end of the quoted string we're scanning finalDelimiter = true break WriteToBuf default: // otherwise, write the rune into the keyword buffer tokenBuf.WriteRune(r) } } switch { case !finalDelimiter: return nil, fmt.Errorf("%s: %w for "%s", op, ErrMissingEndOfStringTokenDelimiter, tokenBuf.String()) default: l.emit(stringToken, tokenBuf.String()) return lexStartState, nil } }

// lexSymbolState scans for strings and can emit the following tokens: // orToken, andToken, containsToken func lexSymbolState(l *lexer) (lexStateFunc, error) { const op = "mql.lexSymbolState" panicIfNil(l, "lexSymbolState", "lexer") defer l.current.clear()

ReadRunes: // keep reading runes into the buffer until we encounter eof of non-text runes. for { r := l.read() switch { case r == eof: break ReadRunes case (isSpace(r) || isSpecial(r)): // whitespace or a special char l.unread() break ReadRunes default: continue ReadRunes } }

    switch strings.ToLower(runesToString(l.current)) {
    case "and":
            l.emit(andToken, "and")
            return lexStartState, nil
    case "or":
            l.emit(orToken, "or")
            return lexStartState, nil
    default:
            l.emit(symbolToken, runesToString(l.current))
            return lexStartState, nil
    }

}

func lexNumberState(l *lexer) (lexStateFunc, error) { const op = "mql.lexNumberState" defer l.current.clear()

    isFloat := false

    // we'll push the runes we read into this buffer and when appropriate will
    // emit tokens using the buffer's data.
    var buf []rune

WriteToBuf: // keep reading runes into the buffer until we encounter eof of non-number runes. for { r := l.read() switch { case r == eof: break WriteToBuf case r == '.' && isFloat: buf = append(buf, r) return nil, fmt.Errorf("%s: %w in %q", op, ErrInvalidNumber, string(buf)) case r == '.' && !isFloat: isFloat = true buf = append(buf, r) case unicode.IsDigit(r) || (r == '.' && len(buf) == 0): buf = append(buf, r) default: l.unread() break WriteToBuf } } l.emit(numberToken, string(buf)) return lexStartState, nil }

// lexContainsState emits an containsToken and returns to the lexStartState func lexContainsState(l *lexer) (lexStateFunc, error) { panicIfNil(l, "lexContainsState", "lexer") defer l.current.clear() l.emit(containsToken, "%") return lexStartState, nil }

// lexEqualState emits an equalToken and returns to the lexStartState func lexEqualState(l *lexer) (lexStateFunc, error) { panicIfNil(l, "lexEqualState", "lexer") defer l.current.clear() l.emit(equalToken, "=") return lexStartState, nil }

// lexNotEqualState scans for a notEqualToken and return either to the lexStartState or // lexErrorState func lexNotEqualState(l *lexer) (lexStateFunc, error) { const op = "mql.lexNotEqualState" panicIfNil(l, "lexNotEqualState", "lexer") defer l.current.clear() nextRune := l.read() switch nextRune { case '=': l.emit(notEqualToken, "!=") return lexStartState, nil default: return nil, fmt.Errorf("%s: %w, got %q", op, ErrInvalidNotEqual, fmt.Sprintf("%s%s", "!", string(nextRune))) } }

// lexLeftParenState emits a startLogicalExprToken and returns to the // lexStartState func lexLeftParenState(l *lexer) (lexStateFunc, error) { panicIfNil(l, "lexLeftParenState", "lexer") defer l.current.clear() l.emit(startLogicalExprToken, runesToString(l.current)) return lexStartState, nil }

// lexRightParenState emits an endLogicalExprToken and returns to the // lexStartState func lexRightParenState(l *lexer) (lexStateFunc, error) { panicIfNil(l, "lexRightParenState", "lexer") defer l.current.clear() l.emit(endLogicalExprToken, runesToString(l.current)) return lexStartState, nil }

// lexWhitespaceState emits a whitespaceToken and returns to the lexStartState func lexWhitespaceState(l *lexer) (lexStateFunc, error) { panicIfNil(l, "lexWhitespaceState", "lexer") defer l.current.clear() ReadWhitespace: for { ch := l.read() switch { case ch == eof: break ReadWhitespace case !isSpace(ch): l.unread() break ReadWhitespace } } l.emit(whitespaceToken, "") return lexStartState, nil }

// lexGreaterState will emit either a greaterThanToken or a // greaterThanOrEqualToken and return to the lexStartState func lexGreaterState(l *lexer) (lexStateFunc, error) { panicIfNil(l, "lexGreaterState", "lexer") defer l.current.clear() next := l.read() switch next { case '=': l.emit(greaterThanOrEqualToken, ">=") return lexStartState, nil default: l.unread() l.emit(greaterThanToken, ">") return lexStartState, nil } }

// lexLesserState will emit either a lessThanToken or a lessThanOrEqualToken and // return to the lexStartState func lexLesserState(l *lexer) (lexStateFunc, error) { panicIfNil(l, "lexLesserState", "lexer") defer l.current.clear() next := l.read() switch next { case '=': l.emit(lessThanOrEqualToken, "<=") return lexStartState, nil default: l.unread() l.emit(lessThanToken, "<") return lexStartState, nil } }

// lexEofState will emit an eofToken and returns right back to the lexEofState func lexEofState(l *lexer) (lexStateFunc, error) { panicIfNil(l, "lexEofState", "lexer") l.emit(eofToken, "") return lexEofState, nil }

// emit send a token to the lexer's token channel func (l *lexer) emit(t tokenType, v string) { l.tokens <- token{ Type: t, Value: v, } }

// isSpace reports if r is a space func isSpace(r rune) bool { return r == ' ' || r == '\t' || r == '\r' || r == '\n' }

// isSpecial reports r is special rune func isSpecial(r rune) bool { return r == '=' || r == '>' || r == '!' || r == '<' || r == '(' || r == ')' || r == '%' }

// read the next rune func (l *lexer) read() rune { ch, _, err := l.source.ReadRune() if err != nil { return eof } l.current.push(ch) return ch }

// unread the last rune read which means that rune will be returned the next // time lexer.read() is called. unread also removes the last rune from the // lexer's stack of current runes func (l *lexer) unread() { _ = l.source.UnreadRune() // error ignore which only occurs when nothing has been previously read _, _ = l.current.pop() }

func isDelimiter(r rune) bool { switch Delimiter(r) { case DoubleQuote, SingleQuote, Backtick: return true default: return false } }

// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0

package mql

import ( "fmt" "reflect" "strings" )

// WhereClause contains a SQL where clause condition and its arguments. type WhereClause struct { // Condition is the where clause condition Condition string // Args for the where clause condition Args []any }

// Parse will parse the query and use the provided database model to create a // where clause. Supported options: WithColumnMap, WithIgnoreFields, // WithConverter, WithPgPlaceholder func Parse(query string, model any, opt ...Option) (*WhereClause, error) { const op = "mql.Parse" switch { case query == "": return nil, fmt.Errorf("%s: missing query: %w", op, ErrInvalidParameter) case isNil(model): return nil, fmt.Errorf("%s: missing model: %w", op, ErrInvalidParameter) } p := newParser(query) expr, err := p.parse() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } fValidators, err := fieldValidators(reflect.ValueOf(model), opt...) if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } e, err := exprToWhereClause(expr, fValidators, opt...) if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } opts, err := getOpts(opt...) if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } if opts.withPgPlaceholder { for i := 0; i < len(e.Args); i++ { placeholder := fmt.Sprintf("$%d", i+1) e.Condition = strings.Replace(e.Condition, "?", placeholder, 1) } } return e, nil }

// exprToWhereClause generates the where clause condition along with its // required arguments. Supported options: WithColumnMap, WithConverter func exprToWhereClause(e expr, fValidators map[string]validator, opt ...Option) (*WhereClause, error) { const op = "mql.exprToWhereClause" switch { case isNil(e): return nil, fmt.Errorf("%s: missing expression: %w", op, ErrInvalidParameter) case isNil(fValidators): return nil, fmt.Errorf("%s: missing validators: %w", op, ErrInvalidParameter) }

    switch v := e.(type) {
    case *comparisonExpr:
            opts, err := getOpts(opt...)
            if err != nil {
                    return nil, fmt.Errorf("%s: %w", op, err)
            }
            switch validateConvertFn, ok := opts.withValidateConvertFns[v.column]; {
            case ok && !isNil(validateConvertFn):
                    return validateConvertFn(v.column, v.comparisonOp, v.value)
            default:
                    var ok bool
                    var validator validator
                    columnName := v.column
                    switch {
                    case opts.withColumnFieldTag != "":
                            validator, ok = fValidators[columnName]
                    default:
                            columnName = strings.ToLower(v.column)
                            if n, ok := opts.withColumnMap[columnName]; ok {
                                    columnName = n
                            }

                            validator, ok = fValidators[strings.ToLower(strings.ReplaceAll(columnName, "_", ""))]
                    }

                    if !ok {
                            cols := make([]string, len(fValidators))
                            for c := range fValidators {
                                    cols = append(cols, c)
                            }

                            return nil, fmt.Errorf("%s: %w %q %s", op, ErrInvalidColumn, columnName, cols)
                    }

                    w, err := defaultValidateConvert(columnName, v.comparisonOp, v.value, validator, opt...)
                    if err != nil {
                            return nil, fmt.Errorf("%s: %w", op, err)
                    }
                    return w, nil
            }
    case *logicalExpr:
            left, err := exprToWhereClause(v.leftExpr, fValidators, opt...)
            if err != nil {
                    return nil, fmt.Errorf("%s: invalid left expr: %w", op, err)
            }
            if v.logicalOp == "" {
                    return nil, fmt.Errorf("%s: %w that stated with left expr condition: %q args: %q", op, ErrMissingLogicalOp, left.Condition, left.Args)
            }
            right, err := exprToWhereClause(v.rightExpr, fValidators, opt...)
            if err != nil {
                    return nil, fmt.Errorf("%s: invalid right expr: %w", op, err)
            }
            return &WhereClause{
                    Condition: fmt.Sprintf("(%s %s %s)", left.Condition, v.logicalOp, right.Condition),
                    Args:      append(left.Args, right.Args...),
            }, nil
    default:
            return nil, fmt.Errorf("%s: unexpected expr type %T: %w", op, v, ErrInternal)
    }

}

// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0

package mql

import ( "fmt" )

type options struct { withSkipWhitespace bool withColumnMap map[string]string withColumnFieldTag string withValidateConvertFns map[string]ValidateConvertFunc withIgnoredFields []string withPgPlaceholder bool withTableColumnMap map[string]string // map of model field names to their table.column name }

// Option - how options are passed as args type Option func(*options) error

func getDefaultOptions() options { return options{ withColumnMap: make(map[string]string), withColumnFieldTag: "", withValidateConvertFns: make(map[string]ValidateConvertFunc), withTableColumnMap: make(map[string]string), } }

func getOpts(opt ...Option) (options, error) { opts := getDefaultOptions() for _, o := range opt { if err := o(&opts); err != nil { return opts, err } } return opts, nil }

// withSkipWhitespace provides an option to request that whitespace be skipped func withSkipWhitespace() Option { return func(o *options) error { o.withSkipWhitespace = true return nil } }

// WithColumnMap provides an optional map of columns from the user // provided query to a field in the given model func WithColumnMap(m map[string]string) Option { const op = "mql.WithColumnMap" return func(o *options) error { if !isNil(m) { if o.withColumnFieldTag != "" { return fmt.Errorf("%s: cannot be used with WithColumnFieldTag: %w", op, ErrInvalidParameter) } o.withColumnMap = m } return nil } }

// WithColumnFieldTag provides an optional struct tag to use for field mapping // If a field has this tag, the tag value will be used instead of the field name func WithColumnFieldTag(tagName string) Option { const op = "mql.WithColumnFieldTag" return func(o *options) error { if tagName == "" { return fmt.Errorf("%s: empty tag name: %w", op, ErrInvalidParameter) } if len(o.withColumnMap) > 0 { return fmt.Errorf("%s: cannot be used with WithColumnMap: %w", op, ErrInvalidParameter) } o.withColumnFieldTag = tagName return nil } }

// ValidateConvertFunc validates the value and then converts the columnName, // comparisonOp and value to a WhereClause type ValidateConvertFunc func(columnName string, comparisonOp ComparisonOp, value *string) (*WhereClause, error)

// WithConverter provides an optional ConvertFunc for a column identifier in the // query. This allows you to provide whatever custom validation+conversion you // need on a per column basis. See: DefaultValidateConvert(...) for inspiration. func WithConverter(fieldName string, fn ValidateConvertFunc) Option { const op = "mql.WithSqlConverter" return func(o *options) error { switch { case fieldName != "" && !isNil(fn): if _, exists := o.withValidateConvertFns[fieldName]; exists { return fmt.Errorf("%s: duplicated convert: %w", op, ErrInvalidParameter) } o.withValidateConvertFns[fieldName] = fn case fieldName == "" && !isNil(fn): return fmt.Errorf("%s: missing field name: %w", op, ErrInvalidParameter) case fieldName != "" && isNil(fn): return fmt.Errorf("%s: missing ConvertToSqlFunc: %w", op, ErrInvalidParameter) } return nil } }

// WithIgnoredFields provides an optional list of fields to ignore in the model // (your Go struct) when parsing. Note: Field names are case sensitive. func WithIgnoredFields(fieldName ...string) Option { return func(o *options) error { o.withIgnoredFields = fieldName return nil } }

// WithPgPlaceholders will use parameters placeholders that are compatible with // the postgres pg driver which requires a placeholder like $1 instead of ?. // See: // - https://pkg.go.dev/github.com/jackc/pgx/v5 // - https://pkg.go.dev/github.com/lib/pq func WithPgPlaceholders() Option { return func(o *options) error { o.withPgPlaceholder = true return nil } }

// WithTableColumnMap provides an optional map of columns from the // model to the table.column name in the generated where clause // // For example, if you need to map the language field name to something // more complex in your SQL statement then you can use this map: // // WithTableColumnMap(map[string]string{"language":"preferences->>'language'"}) // // In the example above we're mapping "language" field to a json field in // the "preferences" column. A user can say language="blah" and the // mql-created SQL where clause will contain preferences->>'language'="blah" // // The field names in the keys to the map should always be lower case. func WithTableColumnMap(m map[string]string) Option { return func(o *options) error { if !isNil(m) { o.withTableColumnMap = m } return nil } }

// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0

package mql

import ( "fmt" "strings" "unicode" )

type parser struct { l *lexer raw string currentToken token openLogicalExpr stack[struct{}] // something very simple to make sure every logical expr that's opened is closed. }

func newParser(s string) *parser { var fixedUp string { // remove any leading/trailing whitespace fixedUp = strings.TrimSpace(s) // remove any leading space before a right parenthesis (issue #42) fixedUp = removeSpacesBeforeParen(fixedUp) } return &parser{ l: newLexer(fixedUp), raw: s, } }

func (p *parser) parse() (expr, error) { const op = "mql.(parser).parse" lExpr, err := p.parseLogicalExpr() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } r, err := root(lExpr, p.raw) if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } return r, nil }

// parseLogicalExpr will parse a logicalExpr until an eofToken is reached, which // may require it to parse a comparisonExpr and/or recursively parse // logicalExprs func (p *parser) parseLogicalExpr() (*logicalExpr, error) { const op = "parseLogicalExpr" logicExpr := &logicalExpr{}

    if err := p.scan(withSkipWhitespace()); err != nil {
            return nil, fmt.Errorf("%s: %w", op, err)
    }

TkLoop: for p.currentToken.Type != eofToken { switch p.currentToken.Type { case startLogicalExprToken: // there's a opening paren: ( // so we've found a new logical expr to parse e, err := p.parseLogicalExpr() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } switch { // start by assigning the left expr case logicExpr.leftExpr == nil: logicExpr.leftExpr = e break TkLoop // we should have a logical operator before the right side expr is assigned case logicExpr.logicalOp == "": return nil, fmt.Errorf("%s: %w before right side expression in: %q", op, ErrMissingLogicalOp, p.raw) // finally, assign the right expr case logicExpr.rightExpr == nil: if e.rightExpr != nil { // if e.rightExpr isn't nil, then we've got a complete // expr (left + op + right) and we need to assign this to // our rightExpr logicExpr.rightExpr = e break TkLoop } // otherwise, we need to assign the left side of e logicExpr.rightExpr = e.leftExpr break TkLoop } case stringToken, numberToken, symbolToken: if (logicExpr.leftExpr != nil && logicExpr.logicalOp == "") || (logicExpr.leftExpr != nil && logicExpr.rightExpr != nil) { return nil, fmt.Errorf("%s: %w starting at %q in: %q", op, ErrUnexpectedExpr, p.currentToken.Value, p.raw) } cmpExpr, err := p.parseComparisonExpr() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } switch { case logicExpr.leftExpr == nil: logicExpr.leftExpr = cmpExpr case logicExpr.rightExpr == nil: logicExpr.rightExpr = cmpExpr tmpExpr := &logicalExpr{ leftExpr: logicExpr, logicalOp: "", rightExpr: nil, } logicExpr = tmpExpr default: return nil, fmt.Errorf("%s: %w at %q, but both left and right expressions already exist in: %q", op, ErrUnexpectedExpr, p.currentToken.Value, p.raw) } case endLogicalExprToken: if logicExpr.leftExpr == nil { return nil, fmt.Errorf("%s: %w %q but we haven't parsed a left side expression in: %q", op, ErrUnexpectedClosingParen, p.currentToken.Value, p.raw) } return logicExpr, nil case andToken, orToken: if logicExpr.logicalOp != "" { return nil, fmt.Errorf("%s: %w %q when we've already parsed one for expr in: %q", op, ErrUnexpectedLogicalOp, p.currentToken.Value, p.raw) } o, err := newLogicalOp(p.currentToken.Value) if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } logicExpr.logicalOp = o default: return nil, fmt.Errorf("%s: %w %q in: %q", op, ErrUnexpectedToken, p.currentToken.Value, p.raw) } if err := p.scan(withSkipWhitespace()); err != nil { return nil, fmt.Errorf("%s: %w", op, err) } } if p.openLogicalExpr.len() > 0 { return nil, fmt.Errorf("%s: %w in: %q", op, ErrMissingClosingParen, p.raw) } return logicExpr, nil }

// parseComparisonExpr will parse a comparisonExpr until an eofToken is reached, // which may require it to parse logicalExpr func (p *parser) parseComparisonExpr() (expr, error) { const op = "mql.(parser).parseComparisonExpr" cmpExpr := &comparisonExpr{}

    // our language (and this parser) def requires the tokens to be in the
    // correct order: column, comparisonOp, value. Swapping this order where the
    // value comes first (value, comparisonOp, column) is not supported
    for p.currentToken.Type != eofToken {
            switch {
            case p.currentToken.Type == startLogicalExprToken:
                    switch {
                    case cmpExpr.isComplete():
                            return nil, fmt.Errorf("%s: %w after %s in: %q", op, ErrUnexpectedOpeningParen, cmpExpr, p.raw)
                    default:
                            return nil, fmt.Errorf("%s: %w in: %q", op, ErrUnexpectedOpeningParen, p.raw)
                    }

                    // we already have a complete comparisonExpr
            case cmpExpr.isComplete() &&
                    (p.currentToken.Type != whitespaceToken && p.currentToken.Type != endLogicalExprToken):
                    return nil, fmt.Errorf("%s: %w %s:%q in: %s", op, ErrUnexpectedToken, p.currentToken.Type, p.currentToken.Value, p.raw)

            // we found whitespace, so check if there's a completed logical expr to return
            case p.currentToken.Type == whitespaceToken:
                    if cmpExpr.column != "" && cmpExpr.comparisonOp != "" && cmpExpr.value != nil {
                            return cmpExpr, nil
                    }

            // columns must come first, so handle those conditions
            case cmpExpr.column == "" && p.currentToken.Type != symbolToken:
                    // this should be unreachable because parseComparisonExpr(...) is
                    // called when a symbolToken is the current token, but I've kept
                    // this case here for completeness
                    return nil, fmt.Errorf("%s: %w: we expected a %s and got %s == %s in: %q", op, ErrUnexpectedToken, symbolToken, p.currentToken.Type, p.currentToken.Value, p.raw)
            case cmpExpr.column == "": // has to be stringToken representing the column
                    cmpExpr.column = p.currentToken.Value

            // after columns, comparison operators must come next
            case cmpExpr.comparisonOp == "":
                    c, err := newComparisonOp(p.currentToken.Value)
                    if err != nil {
                            return nil, fmt.Errorf("%s: %w %q in: %q", op, err, p.currentToken.Value, p.raw)
                    }
                    cmpExpr.comparisonOp = c

            // finally, values must come at the end
            case cmpExpr.value == nil && (p.currentToken.Type != stringToken && p.currentToken.Type != numberToken && p.currentToken.Type != symbolToken):
                    return nil, fmt.Errorf("%s: %w %q in: %q", op, ErrUnexpectedToken, p.currentToken.Value, p.raw)
            case cmpExpr.value == nil:
                    switch {
                    case p.currentToken.Type == symbolToken:
                            return nil, fmt.Errorf("%s: %w %s == %s (expected: %s or %s) in %q", op, ErrInvalidComparisonValueType, p.currentToken.Type, p.currentToken.Value, stringToken, numberToken, p.raw)
                    case p.currentToken.Type == stringToken, p.currentToken.Type == numberToken:
                            s := p.currentToken.Value
                            cmpExpr.value = &s
                    default:
                            return nil, fmt.Errorf("%s: %w of %s == %s", op, ErrUnexpectedToken, p.currentToken.Type, p.currentToken.Value)
                    }
            }
            if err := p.scan(); err != nil {
                    return nil, fmt.Errorf("%s: %w", op, err)
            }
    }

    switch {
    case cmpExpr.column != "" && cmpExpr.comparisonOp == "":
            return nil, fmt.Errorf("%s: %w in: %q", op, ErrMissingComparisonOp, p.raw)
    default:
            return cmpExpr, nil
    }

}

// scan will get the next token from the lexer. Supported options: // withSkipWhitespace func (p *parser) scan(opt ...Option) error { const op = "mql.(parser).scan"

    opts, err := getOpts(opt...)
    if err != nil {
            return fmt.Errorf("%s: %w", op, err)
    }

    if p.currentToken, err = p.l.nextToken(); err != nil {
            return fmt.Errorf("%s: %w", op, err)
    }

    if opts.withSkipWhitespace {
            for p.currentToken.Type == whitespaceToken {
                    if p.currentToken, err = p.l.nextToken(); err != nil {
                            return fmt.Errorf("%s: %w", op, err)
                    }
            }
    }

    switch p.currentToken.Type {
    case startLogicalExprToken:
            p.openLogicalExpr.push(struct{}{})
    case endLogicalExprToken:
            p.openLogicalExpr.pop()
    }

    return nil

}

func removeSpacesBeforeParen(s string) string { if len(s) == 0 { return s } var result strings.Builder runes := []rune(s) i := 0 for i < len(runes) { if unicode.IsSpace(runes[i]) { start := i for i < len(runes) && unicode.IsSpace(runes[i]) { i++ } if i < len(runes) && runes[i] == ')' { result.WriteRune(')') i++ // move past the ')' } else { // Otherwise, the whitespace is not followed by ')', so keep it result.WriteString(string(runes[start:i])) } } else { // Normal character, just append to result result.WriteRune(runes[i]) i++ } } return result.String() }

// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0

package mql

type stack[T any] struct { data []T }

func (s *stack[T]) push(v T) { s.data = append(s.data, v) }

func (s *stack[T]) pop() (T, bool) { var x T if len(s.data) > 0 { x, s.data = s.data[len(s.data)-1], s.data[:len(s.data)-1] return x, true } return x, false }

func (s *stack[T]) clear() { s.data = nil }

func (s *stack[T]) len() int { return len(s.data) }

func runesToString(s stack[rune]) string { return string(s.data) }

// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0

package mql

type token struct { Type tokenType Value string }

type tokenType int

const eof rune = -1

const ( unknownToken tokenType = iota eofToken whitespaceToken stringToken startLogicalExprToken endLogicalExprToken greaterThanToken greaterThanOrEqualToken lessThanToken lessThanOrEqualToken equalToken notEqualToken containsToken numberToken symbolToken

    // keywords
    andToken
    orToken

)

var tokenTypeToString = map[tokenType]string{ unknownToken: "unknown", eofToken: "eof", whitespaceToken: "ws", stringToken: "str", startLogicalExprToken: "lparen", endLogicalExprToken: "rparen", greaterThanToken: "gt", greaterThanOrEqualToken: "gte", lessThanToken: "lt", lessThanOrEqualToken: "lte", equalToken: "eq", notEqualToken: "neq", containsToken: "contains", andToken: "and", orToken: "or", numberToken: "num", symbolToken: "symbol", }

// String returns a string of the tokenType and will return "Unknown" for // invalid tokenTypes func (t tokenType) String() string { s, ok := tokenTypeToString[t] switch ok { case true: return s default: return tokenTypeToString[unknownToken] } }

// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0

package mql

import ( "fmt" "reflect" "strconv" "strings"

    "golang.org/x/exp/slices"

)

type validator struct { fn validateFunc typ string }

// validateFunc is used to validate a column value by converting it as needed, // validating the value, and returning the converted value type validateFunc func(columnValue string) (columnVal any, err error)

// fieldValidators takes a model and returns a map of field names to validate // functions. Supported options: WithIgnoreFields func fieldValidators(model reflect.Value, opt ...Option) (map[string]validator, error) { const op = "mql.fieldValidators" switch { case !model.IsValid(): return nil, fmt.Errorf("%s: missing model: %w", op, ErrInvalidParameter) case (model.Kind() != reflect.Struct && model.Kind() != reflect.Pointer), model.Kind() == reflect.Pointer && model.Elem().Kind() != reflect.Struct: return nil, fmt.Errorf("%s: model must be a struct or a pointer to a struct: %w", op, ErrInvalidParameter) } var m reflect.Value = model if m.Kind() != reflect.Struct { m = model.Elem() }

    opts, err := getOpts(opt...)
    if err != nil {
            return nil, fmt.Errorf("%s: %w", op, err)
    }

    fValidators := make(map[string]validator)
    for i := 0; i < m.NumField(); i++ {
            field := m.Type().Field(i)
            if slices.Contains(opts.withIgnoredFields, field.Name) {
                    continue
            }

            var fName string
            switch {
            case opts.withColumnFieldTag != "":
                    tagValue := field.Tag.Get(opts.withColumnFieldTag)
                    if tagValue != "" {
                            parts := strings.SplitN(tagValue, ",", 2)
                            fName = parts[0]
                    }
                    if fName == "" {
                            return nil, fmt.Errorf("%s: field %q has an invalid tag %q: %w", op, field.Name, opts.withColumnFieldTag, ErrInvalidParameter)
                    }
            default:
                    fName = strings.ToLower(field.Name)
            }

            // get a string val of the field type, then strip any leading '*' so we
            // can simplify the switch below when dealing with types like *int and int.
            fType := strings.TrimPrefix(m.Type().Field(i).Type.String(), "*")
            switch fType {
            case "float32", "float64":
                    fValidators[fName] = validator{fn: validateFloat, typ: "float"}
            case "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64":
                    fValidators[fName] = validator{fn: validateInt, typ: "int"}
            case "time.Time":
                    fValidators[fName] = validator{fn: validateDefault, typ: "time"}
            default:
                    fValidators[fName] = validator{fn: validateDefault, typ: "default"}
            }
    }
    return fValidators, nil

}

// by default, we'll use a no op validation func validateDefault(s string) (any, error) { return s, nil }

func validateInt(s string) (any, error) { const op = "mql.validateInt" i, err := strconv.Atoi(s) if err != nil { return 0, fmt.Errorf("%s: value %q is not an int: %w", op, s, ErrInvalidParameter) } return i, nil }

func validateFloat(s string) (any, error) { const op = "mql.validateFloat" f, err := strconv.ParseFloat(s, 64) if err != nil { return nil, fmt.Errorf("%s: value %q is not float: %w", op, s, ErrInvalidParameter) } return f, nil }