1
0
Fork 0
mirror of https://github.com/documize/community.git synced 2025-07-19 21:29:42 +02:00

Sync with Community

This commit is contained in:
HarveyKandola 2021-08-19 13:02:56 -04:00
parent df8f650319
commit 989b7cd62c
123 changed files with 5054 additions and 2015 deletions

View file

@ -16,6 +16,7 @@ import (
"unicode"
"github.com/denisenkom/go-mssqldb/internal/querytext"
"github.com/denisenkom/go-mssqldb/msdsn"
)
// ReturnStatus may be used to return the return value from a proc.
@ -31,12 +32,16 @@ var driverInstanceNoProcess = &Driver{processQueryText: false}
func init() {
sql.Register("mssql", driverInstance)
sql.Register("sqlserver", driverInstanceNoProcess)
createDialer = func(p *connectParams) Dialer {
return netDialer{&net.Dialer{KeepAlive: p.keepAlive}}
createDialer = func(p *msdsn.Config) Dialer {
ka := p.KeepAlive
if ka == 0 {
ka = 30 * time.Second
}
return netDialer{&net.Dialer{KeepAlive: ka}}
}
}
var createDialer func(p *connectParams) Dialer
var createDialer func(p *msdsn.Config) Dialer
type netDialer struct {
nd *net.Dialer
@ -54,10 +59,11 @@ type Driver struct {
// OpenConnector opens a new connector. Useful to dial with a context.
func (d *Driver) OpenConnector(dsn string) (*Connector, error) {
params, err := parseConnectParams(dsn)
params, _, err := msdsn.Parse(dsn)
if err != nil {
return nil, err
}
return &Connector{
params: params,
driver: d,
@ -80,7 +86,7 @@ func (d *Driver) SetLogger(logger Logger) {
// NewConnector creates a new connector from a DSN.
// The returned connector may be used with sql.OpenDB.
func NewConnector(dsn string) (*Connector, error) {
params, err := parseConnectParams(dsn)
params, _, err := msdsn.Parse(dsn)
if err != nil {
return nil, err
}
@ -91,15 +97,34 @@ func NewConnector(dsn string) (*Connector, error) {
return c, nil
}
// NewConnectorConfig creates a new Connector for a DSN Config struct.
// The returned connector may be used with sql.OpenDB.
func NewConnectorConfig(config msdsn.Config) *Connector {
return &Connector{
params: config,
driver: driverInstanceNoProcess,
}
}
// Connector holds the parsed DSN and is ready to make a new connection
// at any time.
//
// In the future, settings that cannot be passed through a string DSN
// may be set directly on the connector.
type Connector struct {
params connectParams
params msdsn.Config
driver *Driver
fedAuthRequired bool
fedAuthLibrary int
fedAuthADALWorkflow byte
// callback that can provide a security token during login
securityTokenProvider func(ctx context.Context) (string, error)
// callback that can provide a security token during ADAL login
adalTokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error)
// SessionInitSQL is executed after marking a given session to be reset.
// When not present, the next query will still reset the session to the
// database defaults.
@ -132,7 +157,7 @@ type Dialer interface {
DialContext(ctx context.Context, network string, addr string) (net.Conn, error)
}
func (c *Connector) getDialer(p *connectParams) Dialer {
func (c *Connector) getDialer(p *msdsn.Config) Dialer {
if c != nil && c.Dialer != nil {
return c.Dialer
}
@ -148,34 +173,32 @@ type Conn struct {
processQueryText bool
connectionGood bool
outs map[string]interface{}
outs outputs
}
type outputs struct {
params map[string]interface{}
returnStatus *ReturnStatus
}
func (c *Conn) setReturnStatus(s ReturnStatus) {
if c.returnStatus == nil {
return
}
*c.returnStatus = s
// IsValid satisfies the driver.Validator interface.
func (c *Conn) IsValid() bool {
return c.connectionGood
}
func (c *Conn) checkBadConn(err error) error {
// this is a hack to address Issue #275
// we set connectionGood flag to false if
// error indicates that connection is not usable
// but we return actual error instead of ErrBadConn
// this will cause connection to stay in a pool
// but next request to this connection will return ErrBadConn
// it might be possible to revise this hack after
// https://github.com/golang/go/issues/20807
// is implemented
// checkBadConn marks the connection as bad based on the characteristics
// of the supplied error. Bad connections will be dropped from the connection
// pool rather than reused.
//
// If bad connection retry is enabled and the error + connection state permits
// retrying, checkBadConn will return a RetryableError that allows database/sql
// to automatically retry the query with another connection.
func (c *Conn) checkBadConn(err error, mayRetry bool) error {
switch err {
case nil:
return nil
case io.EOF:
c.connectionGood = false
return driver.ErrBadConn
case driver.ErrBadConn:
// It is an internal programming error if driver.ErrBadConn
// is ever passed to this function. driver.ErrBadConn should
@ -187,34 +210,33 @@ func (c *Conn) checkBadConn(err error) error {
switch err.(type) {
case net.Error:
c.connectionGood = false
return err
case StreamError:
c.connectionGood = false
return err
default:
return err
case ServerError:
c.connectionGood = false
}
if !c.connectionGood && mayRetry && !c.connector.params.DisableRetry {
return newRetryableError(err)
}
return err
}
func (c *Conn) clearOuts() {
c.outs = nil
c.outs = outputs{}
}
func (c *Conn) simpleProcessResp(ctx context.Context) error {
tokchan := make(chan tokenStruct, 5)
go processResponse(ctx, c.sess, tokchan, c.outs)
reader := startReading(c.sess, ctx, c.outs)
c.clearOuts()
for tok := range tokchan {
switch token := tok.(type) {
case doneStruct:
if token.isError() {
return c.checkBadConn(token.getError())
}
case error:
return c.checkBadConn(token)
}
var resultError error
err := reader.iterateResponse()
if err != nil {
return c.checkBadConn(err, false)
}
return nil
return resultError
}
func (c *Conn) Commit() error {
@ -222,7 +244,7 @@ func (c *Conn) Commit() error {
return driver.ErrBadConn
}
if err := c.sendCommitRequest(); err != nil {
return c.checkBadConn(err)
return c.checkBadConn(err, true)
}
return c.simpleProcessResp(c.transactionCtx)
}
@ -239,7 +261,7 @@ func (c *Conn) sendCommitRequest() error {
c.sess.log.Printf("Failed to send CommitXact with %v", err)
}
c.connectionGood = false
return fmt.Errorf("Faild to send CommitXact: %v", err)
return fmt.Errorf("faild to send CommitXact: %v", err)
}
return nil
}
@ -249,7 +271,7 @@ func (c *Conn) Rollback() error {
return driver.ErrBadConn
}
if err := c.sendRollbackRequest(); err != nil {
return c.checkBadConn(err)
return c.checkBadConn(err, true)
}
return c.simpleProcessResp(c.transactionCtx)
}
@ -266,7 +288,7 @@ func (c *Conn) sendRollbackRequest() error {
c.sess.log.Printf("Failed to send RollbackXact with %v", err)
}
c.connectionGood = false
return fmt.Errorf("Failed to send RollbackXact: %v", err)
return fmt.Errorf("failed to send RollbackXact: %v", err)
}
return nil
}
@ -281,11 +303,11 @@ func (c *Conn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx,
}
err = c.sendBeginRequest(ctx, tdsIsolation)
if err != nil {
return nil, c.checkBadConn(err)
return nil, c.checkBadConn(err, true)
}
tx, err = c.processBeginResponse(ctx)
if err != nil {
return nil, c.checkBadConn(err)
return nil, err
}
return
}
@ -303,7 +325,7 @@ func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) erro
c.sess.log.Printf("Failed to send BeginXact with %v", err)
}
c.connectionGood = false
return fmt.Errorf("Failed to send BeginXact: %v", err)
return fmt.Errorf("failed to send BeginXact: %v", err)
}
return nil
}
@ -318,25 +340,26 @@ func (c *Conn) processBeginResponse(ctx context.Context) (driver.Tx, error) {
}
func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) {
params, err := parseConnectParams(dsn)
params, _, err := msdsn.Parse(dsn)
if err != nil {
return nil, err
}
return d.connect(ctx, nil, params)
c := &Connector{params: params}
return d.connect(ctx, c, params)
}
// connect to the server, using the provided context for dialing only.
func (d *Driver) connect(ctx context.Context, c *Connector, params connectParams) (*Conn, error) {
func (d *Driver) connect(ctx context.Context, c *Connector, params msdsn.Config) (*Conn, error) {
sess, err := connect(ctx, c, d.log, params)
if err != nil {
// main server failed, try fail-over partner
if params.failOverPartner == "" {
if params.FailOverPartner == "" {
return nil, err
}
params.host = params.failOverPartner
if params.failOverPort != 0 {
params.port = params.failOverPort
params.Host = params.FailOverPartner
if params.FailOverPort != 0 {
params.Port = params.FailOverPort
}
sess, err = connect(ctx, c, d.log, params)
@ -447,7 +470,8 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) {
reset := conn.resetSession
conn.resetSession = false
if len(args) == 0 {
isProc := isProc(s.query)
if len(args) == 0 && !isProc {
if err = sendSqlBatch72(conn.sess.buf, s.query, headers, reset); err != nil {
if conn.sess.logFlags&logErrors != 0 {
conn.sess.log.Printf("Failed to send SqlBatch with %v", err)
@ -458,7 +482,7 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) {
} else {
proc := sp_ExecuteSql
var params []param
if isProc(s.query) {
if isProc {
proc.name = s.query
params, _, err = s.makeRPCParams(args, true)
if err != nil {
@ -478,7 +502,7 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) {
conn.sess.log.Printf("Failed to send Rpc with %v", err)
}
conn.connectionGood = false
return fmt.Errorf("Failed to send RPC: %v", err)
return fmt.Errorf("failed to send RPC: %v", err)
}
}
return
@ -500,30 +524,38 @@ func isProc(s string) bool {
for _, r := range s {
rPrev = rn1
rn1 = r
switch r {
// No newlines or string sequences.
case '\n', '\r', '\'', ';':
return false
if st != escaped {
switch r {
// No newlines or string sequences.
case '\n', '\r', '\'', ';':
return false
}
}
switch st {
case outside:
switch {
case unicode.IsSpace(r):
return false
case r == '[':
st = escaped
continue
case r == ']' && rPrev == ']':
st = escaped
continue
case unicode.IsLetter(r):
st = text
case r == '_':
st = text
case r == '#':
st = text
case r == '.':
default:
return false
}
case text:
switch {
case r == '.':
st = outside
continue
case r == '[':
return false
case r == '(':
return false
case unicode.IsSpace(r):
return false
}
@ -531,7 +563,6 @@ func isProc(s string) bool {
switch {
case r == ']':
st = outside
continue
}
}
}
@ -558,7 +589,13 @@ func (s *Stmt) makeRPCParams(args []namedValue, isProc bool) ([]param, []string,
name = fmt.Sprintf("@p%d", val.Ordinal)
}
params[i+offset].Name = name
decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+offset].ti))
const outputSuffix = " output"
var output string
if isOutputValue(val.Value) {
output = outputSuffix
}
decls[i] = fmt.Sprintf("%s %s%s", name, makeDecl(params[i+offset].ti), output)
}
return params, decls, nil
}
@ -581,6 +618,8 @@ func convertOldArgs(args []driver.Value) []namedValue {
}
func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
defer s.c.clearOuts()
return s.queryContext(context.Background(), convertOldArgs(args))
}
@ -589,48 +628,60 @@ func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver
return nil, driver.ErrBadConn
}
if err = s.sendQuery(args); err != nil {
return nil, s.c.checkBadConn(err)
return nil, s.c.checkBadConn(err, true)
}
return s.processQueryResponse(ctx)
}
func (s *Stmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {
tokchan := make(chan tokenStruct, 5)
ctx, cancel := context.WithCancel(ctx)
go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
reader := startReading(s.c.sess, ctx, s.c.outs)
s.c.clearOuts()
// process metadata
var cols []columnStruct
loop:
for tok := range tokchan {
switch token := tok.(type) {
// By ignoring DONE token we effectively
// skip empty result-sets.
// This improves results in queries like that:
// set nocount on; select 1
// see TestIgnoreEmptyResults test
//case doneStruct:
//break loop
case []columnStruct:
cols = token
break loop
case doneStruct:
if token.isError() {
cancel()
return nil, s.c.checkBadConn(token.getError())
for {
tok, err := reader.nextToken()
if err == nil {
if tok == nil {
break
} else {
switch token := tok.(type) {
// By ignoring DONE token we effectively
// skip empty result-sets.
// This improves results in queries like that:
// set nocount on; select 1
// see TestIgnoreEmptyResults test
//case doneStruct:
//break loop
case []columnStruct:
cols = token
break loop
case doneStruct:
if token.isError() {
// need to cleanup cancellable context
cancel()
return nil, s.c.checkBadConn(token.getError(), false)
}
case ReturnStatus:
if reader.outs.returnStatus != nil {
*reader.outs.returnStatus = token
}
}
}
case ReturnStatus:
s.c.setReturnStatus(token)
case error:
} else {
// need to cleanup cancellable context
cancel()
return nil, s.c.checkBadConn(token)
return nil, s.c.checkBadConn(err, false)
}
}
res = &Rows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel}
res = &Rows{stmt: s, reader: reader, cols: cols, cancel: cancel}
return
}
func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
defer s.c.clearOuts()
return s.exec(context.Background(), convertOldArgs(args))
}
@ -639,57 +690,55 @@ func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result,
return nil, driver.ErrBadConn
}
if err = s.sendQuery(args); err != nil {
return nil, s.c.checkBadConn(err)
return nil, s.c.checkBadConn(err, true)
}
if res, err = s.processExec(ctx); err != nil {
return nil, s.c.checkBadConn(err)
return nil, err
}
return
}
func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) {
tokchan := make(chan tokenStruct, 5)
go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
reader := startReading(s.c.sess, ctx, s.c.outs)
s.c.clearOuts()
var rowCount int64
for token := range tokchan {
switch token := token.(type) {
case doneInProcStruct:
if token.Status&doneCount != 0 {
rowCount += int64(token.RowCount)
}
case doneStruct:
if token.Status&doneCount != 0 {
rowCount += int64(token.RowCount)
}
if token.isError() {
return nil, token.getError()
}
case ReturnStatus:
s.c.setReturnStatus(token)
case error:
return nil, token
}
err = reader.iterateResponse()
if err != nil {
return nil, s.c.checkBadConn(err, false)
}
return &Result{s.c, rowCount}, nil
return &Result{s.c, reader.rowCount}, nil
}
type Rows struct {
stmt *Stmt
cols []columnStruct
tokchan chan tokenStruct
stmt *Stmt
cols []columnStruct
reader *tokenProcessor
nextCols []columnStruct
cancel func()
}
func (rc *Rows) Close() error {
// need to add a test which returns lots of rows
// and check closing after reading only few rows
rc.cancel()
for _ = range rc.tokchan {
for {
tok, err := rc.reader.nextToken()
if err == nil {
if tok == nil {
return nil
} else {
// continue consuming tokens
continue
}
} else {
if err == rc.reader.ctx.Err() {
return nil
} else {
return err
}
}
}
rc.tokchan = nil
return nil
}
func (rc *Rows) Columns() (res []string) {
@ -707,27 +756,36 @@ func (rc *Rows) Next(dest []driver.Value) error {
if rc.nextCols != nil {
return io.EOF
}
for tok := range rc.tokchan {
switch tokdata := tok.(type) {
case []columnStruct:
rc.nextCols = tokdata
return io.EOF
case []interface{}:
for i := range dest {
dest[i] = tokdata[i]
for {
tok, err := rc.reader.nextToken()
if err == nil {
if tok == nil {
return io.EOF
} else {
switch tokdata := tok.(type) {
case []columnStruct:
rc.nextCols = tokdata
return io.EOF
case []interface{}:
for i := range dest {
dest[i] = tokdata[i]
}
return nil
case doneStruct:
if tokdata.isError() {
return rc.stmt.c.checkBadConn(tokdata.getError(), false)
}
case ReturnStatus:
if rc.reader.outs.returnStatus != nil {
*rc.reader.outs.returnStatus = tokdata
}
}
}
return nil
case doneStruct:
if tokdata.isError() {
return rc.stmt.c.checkBadConn(tokdata.getError())
}
case ReturnStatus:
rc.stmt.c.setReturnStatus(tokdata)
case error:
return rc.stmt.c.checkBadConn(tokdata)
} else {
return rc.stmt.c.checkBadConn(err, false)
}
}
return io.EOF
}
func (rc *Rows) HasNextResultSet() bool {
@ -895,35 +953,41 @@ func (c *Conn) Ping(ctx context.Context) error {
var _ driver.ConnBeginTx = &Conn{}
func convertIsolationLevel(level sql.IsolationLevel) (isoLevel, error) {
switch level {
case sql.LevelDefault:
return isolationUseCurrent, nil
case sql.LevelReadUncommitted:
return isolationReadUncommited, nil
case sql.LevelReadCommitted:
return isolationReadCommited, nil
case sql.LevelWriteCommitted:
return isolationUseCurrent, errors.New("LevelWriteCommitted isolation level is not supported")
case sql.LevelRepeatableRead:
return isolationRepeatableRead, nil
case sql.LevelSnapshot:
return isolationSnapshot, nil
case sql.LevelSerializable:
return isolationSerializable, nil
case sql.LevelLinearizable:
return isolationUseCurrent, errors.New("LevelLinearizable isolation level is not supported")
default:
return isolationUseCurrent, errors.New("isolation level is not supported or unknown")
}
}
// BeginTx satisfies ConnBeginTx.
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if !c.connectionGood {
return nil, driver.ErrBadConn
}
if opts.ReadOnly {
return nil, errors.New("Read-only transactions are not supported")
return nil, errors.New("read-only transactions are not supported")
}
var tdsIsolation isoLevel
switch sql.IsolationLevel(opts.Isolation) {
case sql.LevelDefault:
tdsIsolation = isolationUseCurrent
case sql.LevelReadUncommitted:
tdsIsolation = isolationReadUncommited
case sql.LevelReadCommitted:
tdsIsolation = isolationReadCommited
case sql.LevelWriteCommitted:
return nil, errors.New("LevelWriteCommitted isolation level is not supported")
case sql.LevelRepeatableRead:
tdsIsolation = isolationRepeatableRead
case sql.LevelSnapshot:
tdsIsolation = isolationSnapshot
case sql.LevelSerializable:
tdsIsolation = isolationSerializable
case sql.LevelLinearizable:
return nil, errors.New("LevelLinearizable isolation level is not supported")
default:
return nil, errors.New("Isolation level is not supported or unknown")
tdsIsolation, err := convertIsolationLevel(sql.IsolationLevel(opts.Isolation))
if err != nil {
return nil, err
}
return c.begin(ctx, tdsIsolation)
}
@ -940,6 +1004,8 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
}
func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
defer s.c.clearOuts()
if !s.c.connectionGood {
return nil, driver.ErrBadConn
}
@ -951,6 +1017,8 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
}
func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
defer s.c.clearOuts()
if !s.c.connectionGood {
return nil, driver.ErrBadConn
}