1
0
Fork 0
mirror of https://github.com/documize/community.git synced 2025-07-19 05:09:42 +02:00
documize/vendor/github.com/denisenkom/go-mssqldb/batch/batch.go
2019-06-25 15:37:19 +01:00

287 lines
5.1 KiB
Go

// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// bach splits a single script containing multiple batches separated by
// a keyword into multiple scripts.
package batch
import (
"bytes"
"fmt"
"strconv"
"strings"
"unicode"
)
// Split the provided SQL into multiple sql scripts based on a given
// separator, often "GO". It also allows escaping newlines with a
// backslash.
func Split(sql, separator string) []string {
if len(separator) == 0 || len(sql) < len(separator) {
return []string{sql}
}
l := &lexer{
Sql: sql,
Sep: separator,
At: 0,
}
state := stateWhitespace
for state != nil {
state = state(l)
}
l.AddCurrent(1)
return l.Batch
}
const debugPrintStateName = false
func printStateName(name string, l *lexer) {
if debugPrintStateName {
fmt.Printf("state %s At=%d\n", name, l.At)
}
}
func hasPrefixFold(s, sep string) bool {
if len(s) < len(sep) {
return false
}
return strings.EqualFold(s[:len(sep)], sep)
}
type lexer struct {
Sql string
Sep string
At int
Start int
Skip []int
Batch []string
}
func (l *lexer) Add(b string) {
if len(b) == 0 {
return
}
l.Batch = append(l.Batch, b)
}
func (l *lexer) Next() bool {
l.At++
return l.At < len(l.Sql)
}
func (l *lexer) AddCurrent(count int64) bool {
if count < 0 {
count = 0
}
if l.At >= len(l.Sql) {
l.At = len(l.Sql)
}
text := l.Sql[l.Start:l.At]
if len(l.Skip) > 0 {
buf := &bytes.Buffer{}
nextSkipIndex := 0
nextSkip := l.Skip[nextSkipIndex]
for i, r := range text {
if i == nextSkip {
nextSkipIndex++
if nextSkipIndex < len(l.Skip) {
nextSkip = l.Skip[nextSkipIndex]
}
continue
}
buf.WriteRune(r)
}
text = buf.String()
l.Skip = nil
}
// Limit the number of counts for sanity.
if count > 1000 {
count = 1000
}
for i := int64(0); i < count; i++ {
l.Add(text)
}
l.At += len(l.Sep)
l.Start = l.At
return (l.At < len(l.Sql))
}
type stateFn func(*lexer) stateFn
const (
lineComment = "--"
leftComment = "/*"
rightComment = "*/"
)
func stateSep(l *lexer) stateFn {
printStateName("sep", l)
if l.At+len(l.Sep) >= len(l.Sql) {
return nil
}
s := l.Sql[l.At+len(l.Sep):]
parseNumberStart := -1
loop:
for i, r := range s {
switch {
case r == '\n', r == '\r':
l.AddCurrent(1)
return stateWhitespace
case unicode.IsSpace(r):
case unicode.IsNumber(r):
parseNumberStart = i
break loop
}
}
if parseNumberStart < 0 {
return nil
}
parseNumberCount := 0
numLoop:
for i, r := range s[parseNumberStart:] {
switch {
case unicode.IsNumber(r):
parseNumberCount = i
default:
break numLoop
}
}
parseNumberEnd := parseNumberStart + parseNumberCount + 1
count, err := strconv.ParseInt(s[parseNumberStart:parseNumberEnd], 10, 64)
if err != nil {
return stateText
}
for _, r := range s[parseNumberEnd:] {
switch {
case r == '\n', r == '\r':
l.AddCurrent(count)
l.At += parseNumberEnd
l.Start = l.At
return stateWhitespace
case unicode.IsSpace(r):
default:
return stateText
}
}
return nil
}
func stateText(l *lexer) stateFn {
printStateName("text", l)
for {
ch := l.Sql[l.At]
switch {
case strings.HasPrefix(l.Sql[l.At:], lineComment):
l.At += len(lineComment)
return stateLineComment
case strings.HasPrefix(l.Sql[l.At:], leftComment):
l.At += len(leftComment)
return stateMultiComment
case ch == '\'':
l.At += 1
return stateString
case ch == '\r', ch == '\n':
l.At += 1
return stateWhitespace
default:
if l.Next() == false {
return nil
}
}
}
}
func stateWhitespace(l *lexer) stateFn {
printStateName("whitespace", l)
if l.At >= len(l.Sql) {
return nil
}
ch := l.Sql[l.At]
switch {
case unicode.IsSpace(rune(ch)):
l.At += 1
return stateWhitespace
case hasPrefixFold(l.Sql[l.At:], l.Sep):
return stateSep
default:
return stateText
}
}
func stateLineComment(l *lexer) stateFn {
printStateName("line-comment", l)
for {
if l.At >= len(l.Sql) {
return nil
}
ch := l.Sql[l.At]
switch {
case ch == '\r', ch == '\n':
l.At += 1
return stateWhitespace
default:
if l.Next() == false {
return nil
}
}
}
}
func stateMultiComment(l *lexer) stateFn {
printStateName("multi-line-comment", l)
for {
switch {
case strings.HasPrefix(l.Sql[l.At:], rightComment):
l.At += len(leftComment)
return stateWhitespace
default:
if l.Next() == false {
return nil
}
}
}
}
func stateString(l *lexer) stateFn {
printStateName("string", l)
for {
if l.At >= len(l.Sql) {
return nil
}
ch := l.Sql[l.At]
chNext := rune(-1)
if l.At+1 < len(l.Sql) {
chNext = rune(l.Sql[l.At+1])
}
switch {
case ch == '\\' && (chNext == '\r' || chNext == '\n'):
next := 2
l.Skip = append(l.Skip, l.At, l.At+1)
if chNext == '\r' && l.At+2 < len(l.Sql) && l.Sql[l.At+2] == '\n' {
l.Skip = append(l.Skip, l.At+2)
next = 3
}
l.At += next
case ch == '\'' && chNext == '\'':
l.At += 2
case ch == '\'' && chNext != '\'':
l.At += 1
return stateWhitespace
default:
if l.Next() == false {
return nil
}
}
}
}