mirror of
https://github.com/documize/community.git
synced 2025-07-25 16:19:46 +02:00
Database and LDAP upgrades
Bumped underlying dependencies affecting database and LDAP connectivity. Bumped to Go v1.14.3 and released v3.8.0.
This commit is contained in:
parent
aaa8c3282d
commit
4fe022aa0c
310 changed files with 36835 additions and 16448 deletions
640
vendor/github.com/denisenkom/go-mssqldb/tds.go
generated
vendored
640
vendor/github.com/denisenkom/go-mssqldb/tds.go
generated
vendored
|
@ -10,13 +10,9 @@ import (
|
|||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
"unicode/utf16"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
@ -51,15 +47,13 @@ func parseInstances(msg []byte) map[string]map[string]string {
|
|||
}
|
||||
|
||||
func getInstances(ctx context.Context, d Dialer, address string) (map[string]map[string]string, error) {
|
||||
maxTime := 5 * time.Second
|
||||
ctx, cancel := context.WithTimeout(ctx, maxTime)
|
||||
defer cancel()
|
||||
conn, err := d.DialContext(ctx, "udp", address+":1434")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
conn.SetDeadline(time.Now().Add(maxTime))
|
||||
deadline, _ := ctx.Deadline()
|
||||
conn.SetDeadline(deadline)
|
||||
_, err = conn.Write([]byte{3})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -106,13 +100,15 @@ const (
|
|||
// prelogin fields
|
||||
// http://msdn.microsoft.com/en-us/library/dd357559.aspx
|
||||
const (
|
||||
preloginVERSION = 0
|
||||
preloginENCRYPTION = 1
|
||||
preloginINSTOPT = 2
|
||||
preloginTHREADID = 3
|
||||
preloginMARS = 4
|
||||
preloginTRACEID = 5
|
||||
preloginTERMINATOR = 0xff
|
||||
preloginVERSION = 0
|
||||
preloginENCRYPTION = 1
|
||||
preloginINSTOPT = 2
|
||||
preloginTHREADID = 3
|
||||
preloginMARS = 4
|
||||
preloginTRACEID = 5
|
||||
preloginFEDAUTHREQUIRED = 6
|
||||
preloginNONCEOPT = 7
|
||||
preloginTERMINATOR = 0xff
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -251,6 +247,12 @@ const (
|
|||
fReadOnlyIntent = 32
|
||||
)
|
||||
|
||||
// OptionFlags3
|
||||
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac
|
||||
const (
|
||||
fExtension = 0x10
|
||||
)
|
||||
|
||||
type login struct {
|
||||
TDSVersion uint32
|
||||
PacketSize uint32
|
||||
|
@ -275,6 +277,89 @@ type login struct {
|
|||
SSPI []byte
|
||||
AtchDBFile string
|
||||
ChangePassword string
|
||||
FeatureExt featureExts
|
||||
}
|
||||
|
||||
type featureExts struct {
|
||||
features map[byte]featureExt
|
||||
}
|
||||
|
||||
type featureExt interface {
|
||||
featureID() byte
|
||||
toBytes() []byte
|
||||
}
|
||||
|
||||
func (e *featureExts) Add(f featureExt) error {
|
||||
if f == nil {
|
||||
return nil
|
||||
}
|
||||
id := f.featureID()
|
||||
if _, exists := e.features[id]; exists {
|
||||
f := "Login error: Feature with ID '%v' is already present in FeatureExt block."
|
||||
return fmt.Errorf(f, id)
|
||||
}
|
||||
if e.features == nil {
|
||||
e.features = make(map[byte]featureExt)
|
||||
}
|
||||
e.features[id] = f
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e featureExts) toBytes() []byte {
|
||||
if len(e.features) == 0 {
|
||||
return nil
|
||||
}
|
||||
var d []byte
|
||||
for featureID, f := range e.features {
|
||||
featureData := f.toBytes()
|
||||
|
||||
hdr := make([]byte, 5)
|
||||
hdr[0] = featureID // FedAuth feature extension BYTE
|
||||
binary.LittleEndian.PutUint32(hdr[1:], uint32(len(featureData))) // FeatureDataLen DWORD
|
||||
d = append(d, hdr...)
|
||||
|
||||
d = append(d, featureData...) // FeatureData *BYTE
|
||||
}
|
||||
if d != nil {
|
||||
d = append(d, 0xff) // Terminator
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
type featureExtFedAuthSTS struct {
|
||||
FedAuthEcho bool
|
||||
FedAuthToken string
|
||||
Nonce []byte
|
||||
}
|
||||
|
||||
func (e *featureExtFedAuthSTS) featureID() byte {
|
||||
return 0x02
|
||||
}
|
||||
|
||||
func (e *featureExtFedAuthSTS) toBytes() []byte {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
options := byte(0x01) << 1 // 0x01 => STS bFedAuthLibrary 7BIT
|
||||
if e.FedAuthEcho {
|
||||
options |= 1 // fFedAuthEcho
|
||||
}
|
||||
|
||||
d := make([]byte, 5)
|
||||
d[0] = options
|
||||
|
||||
// looks like string in
|
||||
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/f88b63bb-b479-49e1-a87b-deda521da508
|
||||
tokenBytes := str2ucs2(e.FedAuthToken)
|
||||
binary.LittleEndian.PutUint32(d[1:], uint32(len(tokenBytes))) // Should be a signed int32, but since the length is relatively small, this should work
|
||||
d = append(d, tokenBytes...)
|
||||
|
||||
if len(e.Nonce) == 32 {
|
||||
d = append(d, e.Nonce...)
|
||||
}
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
type loginHeader struct {
|
||||
|
@ -301,7 +386,7 @@ type loginHeader struct {
|
|||
ServerNameOffset uint16
|
||||
ServerNameLength uint16
|
||||
ExtensionOffset uint16
|
||||
ExtensionLenght uint16
|
||||
ExtensionLength uint16
|
||||
CtlIntNameOffset uint16
|
||||
CtlIntNameLength uint16
|
||||
LanguageOffset uint16
|
||||
|
@ -363,6 +448,8 @@ func sendLogin(w *tdsBuffer, login login) error {
|
|||
database := str2ucs2(login.Database)
|
||||
atchdbfile := str2ucs2(login.AtchDBFile)
|
||||
changepassword := str2ucs2(login.ChangePassword)
|
||||
featureExt := login.FeatureExt.toBytes()
|
||||
|
||||
hdr := loginHeader{
|
||||
TDSVersion: login.TDSVersion,
|
||||
PacketSize: login.PacketSize,
|
||||
|
@ -411,7 +498,18 @@ func sendLogin(w *tdsBuffer, login login) error {
|
|||
offset += uint16(len(atchdbfile))
|
||||
hdr.ChangePasswordOffset = offset
|
||||
offset += uint16(len(changepassword))
|
||||
hdr.Length = uint32(offset)
|
||||
|
||||
featureExtOffset := uint32(0)
|
||||
featureExtLen := len(featureExt)
|
||||
if featureExtLen > 0 {
|
||||
hdr.OptionFlags3 |= fExtension
|
||||
hdr.ExtensionOffset = offset
|
||||
hdr.ExtensionLength = 4
|
||||
offset += hdr.ExtensionLength // DWORD
|
||||
featureExtOffset = uint32(offset)
|
||||
}
|
||||
hdr.Length = uint32(offset) + uint32(featureExtLen)
|
||||
|
||||
var err error
|
||||
err = binary.Write(w, binary.LittleEndian, &hdr)
|
||||
if err != nil {
|
||||
|
@ -461,6 +559,16 @@ func sendLogin(w *tdsBuffer, login login) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if featureExtOffset > 0 {
|
||||
err = binary.Write(w, binary.LittleEndian, featureExtOffset)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.Write(featureExt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return w.FinishPacket()
|
||||
}
|
||||
|
||||
|
@ -474,10 +582,9 @@ func readUcs2(r io.Reader, numchars int) (res string, err error) {
|
|||
}
|
||||
|
||||
func readUsVarChar(r io.Reader) (res string, err error) {
|
||||
var numchars uint16
|
||||
err = binary.Read(r, binary.LittleEndian, &numchars)
|
||||
numchars, err := readUshort(r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return
|
||||
}
|
||||
return readUcs2(r, int(numchars))
|
||||
}
|
||||
|
@ -497,8 +604,7 @@ func writeUsVarChar(w io.Writer, s string) (err error) {
|
|||
}
|
||||
|
||||
func readBVarChar(r io.Reader) (res string, err error) {
|
||||
var numchars uint8
|
||||
err = binary.Read(r, binary.LittleEndian, &numchars)
|
||||
numchars, err := readByte(r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -525,8 +631,7 @@ func writeBVarChar(w io.Writer, s string) (err error) {
|
|||
}
|
||||
|
||||
func readBVarByte(r io.Reader) (res []byte, err error) {
|
||||
var length uint8
|
||||
err = binary.Read(r, binary.LittleEndian, &length)
|
||||
length, err := readByte(r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -654,458 +759,6 @@ func sendAttention(buf *tdsBuffer) error {
|
|||
return buf.FinishPacket()
|
||||
}
|
||||
|
||||
type connectParams struct {
|
||||
logFlags uint64
|
||||
port uint64
|
||||
host string
|
||||
instance string
|
||||
database string
|
||||
user string
|
||||
password string
|
||||
dial_timeout time.Duration
|
||||
conn_timeout time.Duration
|
||||
keepAlive time.Duration
|
||||
encrypt bool
|
||||
disableEncryption bool
|
||||
trustServerCertificate bool
|
||||
certificate string
|
||||
hostInCertificate string
|
||||
hostInCertificateProvided bool
|
||||
serverSPN string
|
||||
workstation string
|
||||
appname string
|
||||
typeFlags uint8
|
||||
failOverPartner string
|
||||
failOverPort uint64
|
||||
packetSize uint16
|
||||
}
|
||||
|
||||
func splitConnectionString(dsn string) (res map[string]string) {
|
||||
res = map[string]string{}
|
||||
parts := strings.Split(dsn, ";")
|
||||
for _, part := range parts {
|
||||
if len(part) == 0 {
|
||||
continue
|
||||
}
|
||||
lst := strings.SplitN(part, "=", 2)
|
||||
name := strings.TrimSpace(strings.ToLower(lst[0]))
|
||||
if len(name) == 0 {
|
||||
continue
|
||||
}
|
||||
var value string = ""
|
||||
if len(lst) > 1 {
|
||||
value = strings.TrimSpace(lst[1])
|
||||
}
|
||||
res[name] = value
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// Splits a URL in the ODBC format
|
||||
func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
|
||||
res := map[string]string{}
|
||||
|
||||
type parserState int
|
||||
const (
|
||||
// Before the start of a key
|
||||
parserStateBeforeKey parserState = iota
|
||||
|
||||
// Inside a key
|
||||
parserStateKey
|
||||
|
||||
// Beginning of a value. May be bare or braced
|
||||
parserStateBeginValue
|
||||
|
||||
// Inside a bare value
|
||||
parserStateBareValue
|
||||
|
||||
// Inside a braced value
|
||||
parserStateBracedValue
|
||||
|
||||
// A closing brace inside a braced value.
|
||||
// May be the end of the value or an escaped closing brace, depending on the next character
|
||||
parserStateBracedValueClosingBrace
|
||||
|
||||
// After a value. Next character should be a semicolon or whitespace.
|
||||
parserStateEndValue
|
||||
)
|
||||
|
||||
var state = parserStateBeforeKey
|
||||
|
||||
var key string
|
||||
var value string
|
||||
|
||||
for i, c := range dsn {
|
||||
switch state {
|
||||
case parserStateBeforeKey:
|
||||
switch {
|
||||
case c == '=':
|
||||
return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i)
|
||||
case !unicode.IsSpace(c) && c != ';':
|
||||
state = parserStateKey
|
||||
key += string(c)
|
||||
}
|
||||
|
||||
case parserStateKey:
|
||||
switch c {
|
||||
case '=':
|
||||
key = normalizeOdbcKey(key)
|
||||
if len(key) == 0 {
|
||||
return res, fmt.Errorf("Unexpected end of key at index %d.", i)
|
||||
}
|
||||
|
||||
state = parserStateBeginValue
|
||||
|
||||
case ';':
|
||||
// Key without value
|
||||
key = normalizeOdbcKey(key)
|
||||
if len(key) == 0 {
|
||||
return res, fmt.Errorf("Unexpected end of key at index %d.", i)
|
||||
}
|
||||
|
||||
res[key] = value
|
||||
key = ""
|
||||
value = ""
|
||||
state = parserStateBeforeKey
|
||||
|
||||
default:
|
||||
key += string(c)
|
||||
}
|
||||
|
||||
case parserStateBeginValue:
|
||||
switch {
|
||||
case c == '{':
|
||||
state = parserStateBracedValue
|
||||
case c == ';':
|
||||
// Empty value
|
||||
res[key] = value
|
||||
key = ""
|
||||
state = parserStateBeforeKey
|
||||
case unicode.IsSpace(c):
|
||||
// Ignore whitespace
|
||||
default:
|
||||
state = parserStateBareValue
|
||||
value += string(c)
|
||||
}
|
||||
|
||||
case parserStateBareValue:
|
||||
if c == ';' {
|
||||
res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
|
||||
key = ""
|
||||
value = ""
|
||||
state = parserStateBeforeKey
|
||||
} else {
|
||||
value += string(c)
|
||||
}
|
||||
|
||||
case parserStateBracedValue:
|
||||
if c == '}' {
|
||||
state = parserStateBracedValueClosingBrace
|
||||
} else {
|
||||
value += string(c)
|
||||
}
|
||||
|
||||
case parserStateBracedValueClosingBrace:
|
||||
if c == '}' {
|
||||
// Escaped closing brace
|
||||
value += string(c)
|
||||
state = parserStateBracedValue
|
||||
continue
|
||||
}
|
||||
|
||||
// End of braced value
|
||||
res[key] = value
|
||||
key = ""
|
||||
value = ""
|
||||
|
||||
// This character is the first character past the end,
|
||||
// so it needs to be parsed like the parserStateEndValue state.
|
||||
state = parserStateEndValue
|
||||
switch {
|
||||
case c == ';':
|
||||
state = parserStateBeforeKey
|
||||
case unicode.IsSpace(c):
|
||||
// Ignore whitespace
|
||||
default:
|
||||
return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
|
||||
}
|
||||
|
||||
case parserStateEndValue:
|
||||
switch {
|
||||
case c == ';':
|
||||
state = parserStateBeforeKey
|
||||
case unicode.IsSpace(c):
|
||||
// Ignore whitespace
|
||||
default:
|
||||
return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch state {
|
||||
case parserStateBeforeKey: // Okay
|
||||
case parserStateKey: // Unfinished key. Treat as key without value.
|
||||
key = normalizeOdbcKey(key)
|
||||
if len(key) == 0 {
|
||||
return res, fmt.Errorf("Unexpected end of key at index %d.", len(dsn))
|
||||
}
|
||||
res[key] = value
|
||||
case parserStateBeginValue: // Empty value
|
||||
res[key] = value
|
||||
case parserStateBareValue:
|
||||
res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
|
||||
case parserStateBracedValue:
|
||||
return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn))
|
||||
case parserStateBracedValueClosingBrace: // End of braced value
|
||||
res[key] = value
|
||||
case parserStateEndValue: // Okay
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Normalizes the given string as an ODBC-format key
|
||||
func normalizeOdbcKey(s string) string {
|
||||
return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace))
|
||||
}
|
||||
|
||||
// Splits a URL of the form sqlserver://username:password@host/instance?param1=value¶m2=value
|
||||
func splitConnectionStringURL(dsn string) (map[string]string, error) {
|
||||
res := map[string]string{}
|
||||
|
||||
u, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
if u.Scheme != "sqlserver" {
|
||||
return res, fmt.Errorf("scheme %s is not recognized", u.Scheme)
|
||||
}
|
||||
|
||||
if u.User != nil {
|
||||
res["user id"] = u.User.Username()
|
||||
p, exists := u.User.Password()
|
||||
if exists {
|
||||
res["password"] = p
|
||||
}
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(u.Host)
|
||||
if err != nil {
|
||||
host = u.Host
|
||||
}
|
||||
|
||||
if len(u.Path) > 0 {
|
||||
res["server"] = host + "\\" + u.Path[1:]
|
||||
} else {
|
||||
res["server"] = host
|
||||
}
|
||||
|
||||
if len(port) > 0 {
|
||||
res["port"] = port
|
||||
}
|
||||
|
||||
query := u.Query()
|
||||
for k, v := range query {
|
||||
if len(v) > 1 {
|
||||
return res, fmt.Errorf("key %s provided more than once", k)
|
||||
}
|
||||
res[strings.ToLower(k)] = v[0]
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func parseConnectParams(dsn string) (connectParams, error) {
|
||||
var p connectParams
|
||||
|
||||
var params map[string]string
|
||||
if strings.HasPrefix(dsn, "odbc:") {
|
||||
parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):])
|
||||
if err != nil {
|
||||
return p, err
|
||||
}
|
||||
params = parameters
|
||||
} else if strings.HasPrefix(dsn, "sqlserver://") {
|
||||
parameters, err := splitConnectionStringURL(dsn)
|
||||
if err != nil {
|
||||
return p, err
|
||||
}
|
||||
params = parameters
|
||||
} else {
|
||||
params = splitConnectionString(dsn)
|
||||
}
|
||||
|
||||
strlog, ok := params["log"]
|
||||
if ok {
|
||||
var err error
|
||||
p.logFlags, err = strconv.ParseUint(strlog, 10, 64)
|
||||
if err != nil {
|
||||
return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error())
|
||||
}
|
||||
}
|
||||
server := params["server"]
|
||||
parts := strings.SplitN(server, `\`, 2)
|
||||
p.host = parts[0]
|
||||
if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" {
|
||||
p.host = "localhost"
|
||||
}
|
||||
if len(parts) > 1 {
|
||||
p.instance = parts[1]
|
||||
}
|
||||
p.database = params["database"]
|
||||
p.user = params["user id"]
|
||||
p.password = params["password"]
|
||||
|
||||
p.port = 1433
|
||||
strport, ok := params["port"]
|
||||
if ok {
|
||||
var err error
|
||||
p.port, err = strconv.ParseUint(strport, 10, 16)
|
||||
if err != nil {
|
||||
f := "Invalid tcp port '%v': %v"
|
||||
return p, fmt.Errorf(f, strport, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option
|
||||
// Default packet size remains at 4096 bytes
|
||||
p.packetSize = 4096
|
||||
strpsize, ok := params["packet size"]
|
||||
if ok {
|
||||
var err error
|
||||
psize, err := strconv.ParseUint(strpsize, 0, 16)
|
||||
if err != nil {
|
||||
f := "Invalid packet size '%v': %v"
|
||||
return p, fmt.Errorf(f, strpsize, err.Error())
|
||||
}
|
||||
|
||||
// Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes
|
||||
// NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request
|
||||
// a higher packet size, the server will respond with an ENVCHANGE request to
|
||||
// alter the packet size to 16383 bytes.
|
||||
p.packetSize = uint16(psize)
|
||||
if p.packetSize < 512 {
|
||||
p.packetSize = 512
|
||||
} else if p.packetSize > 32767 {
|
||||
p.packetSize = 32767
|
||||
}
|
||||
}
|
||||
|
||||
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
|
||||
//
|
||||
// Do not set a connection timeout. Use Context to manage such things.
|
||||
// Default to zero, but still allow it to be set.
|
||||
if strconntimeout, ok := params["connection timeout"]; ok {
|
||||
timeout, err := strconv.ParseUint(strconntimeout, 10, 64)
|
||||
if err != nil {
|
||||
f := "Invalid connection timeout '%v': %v"
|
||||
return p, fmt.Errorf(f, strconntimeout, err.Error())
|
||||
}
|
||||
p.conn_timeout = time.Duration(timeout) * time.Second
|
||||
}
|
||||
p.dial_timeout = 15 * time.Second
|
||||
if strdialtimeout, ok := params["dial timeout"]; ok {
|
||||
timeout, err := strconv.ParseUint(strdialtimeout, 10, 64)
|
||||
if err != nil {
|
||||
f := "Invalid dial timeout '%v': %v"
|
||||
return p, fmt.Errorf(f, strdialtimeout, err.Error())
|
||||
}
|
||||
p.dial_timeout = time.Duration(timeout) * time.Second
|
||||
}
|
||||
|
||||
// default keep alive should be 30 seconds according to spec:
|
||||
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
|
||||
p.keepAlive = 30 * time.Second
|
||||
if keepAlive, ok := params["keepalive"]; ok {
|
||||
timeout, err := strconv.ParseUint(keepAlive, 10, 64)
|
||||
if err != nil {
|
||||
f := "Invalid keepAlive value '%s': %s"
|
||||
return p, fmt.Errorf(f, keepAlive, err.Error())
|
||||
}
|
||||
p.keepAlive = time.Duration(timeout) * time.Second
|
||||
}
|
||||
encrypt, ok := params["encrypt"]
|
||||
if ok {
|
||||
if strings.EqualFold(encrypt, "DISABLE") {
|
||||
p.disableEncryption = true
|
||||
} else {
|
||||
var err error
|
||||
p.encrypt, err = strconv.ParseBool(encrypt)
|
||||
if err != nil {
|
||||
f := "Invalid encrypt '%s': %s"
|
||||
return p, fmt.Errorf(f, encrypt, err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
p.trustServerCertificate = true
|
||||
}
|
||||
trust, ok := params["trustservercertificate"]
|
||||
if ok {
|
||||
var err error
|
||||
p.trustServerCertificate, err = strconv.ParseBool(trust)
|
||||
if err != nil {
|
||||
f := "Invalid trust server certificate '%s': %s"
|
||||
return p, fmt.Errorf(f, trust, err.Error())
|
||||
}
|
||||
}
|
||||
p.certificate = params["certificate"]
|
||||
p.hostInCertificate, ok = params["hostnameincertificate"]
|
||||
if ok {
|
||||
p.hostInCertificateProvided = true
|
||||
} else {
|
||||
p.hostInCertificate = p.host
|
||||
p.hostInCertificateProvided = false
|
||||
}
|
||||
|
||||
serverSPN, ok := params["serverspn"]
|
||||
if ok {
|
||||
p.serverSPN = serverSPN
|
||||
} else {
|
||||
p.serverSPN = fmt.Sprintf("MSSQLSvc/%s:%d", p.host, p.port)
|
||||
}
|
||||
|
||||
workstation, ok := params["workstation id"]
|
||||
if ok {
|
||||
p.workstation = workstation
|
||||
} else {
|
||||
workstation, err := os.Hostname()
|
||||
if err == nil {
|
||||
p.workstation = workstation
|
||||
}
|
||||
}
|
||||
|
||||
appname, ok := params["app name"]
|
||||
if !ok {
|
||||
appname = "go-mssqldb"
|
||||
}
|
||||
p.appname = appname
|
||||
|
||||
appintent, ok := params["applicationintent"]
|
||||
if ok {
|
||||
if appintent == "ReadOnly" {
|
||||
p.typeFlags |= fReadOnlyIntent
|
||||
}
|
||||
}
|
||||
|
||||
failOverPartner, ok := params["failoverpartner"]
|
||||
if ok {
|
||||
p.failOverPartner = failOverPartner
|
||||
}
|
||||
|
||||
failOverPort, ok := params["failoverport"]
|
||||
if ok {
|
||||
var err error
|
||||
p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
|
||||
if err != nil {
|
||||
f := "Invalid tcp port '%v': %v"
|
||||
return p, fmt.Errorf(f, failOverPort, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
type auth interface {
|
||||
InitialBytes() ([]byte, error)
|
||||
NextBytes([]byte) ([]byte, error)
|
||||
|
@ -1127,14 +780,14 @@ func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn ne
|
|||
}
|
||||
if len(ips) == 1 {
|
||||
d := c.getDialer(&p)
|
||||
addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port)))
|
||||
addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(resolveServerPort(p.port))))
|
||||
conn, err = d.DialContext(ctx, "tcp", addr)
|
||||
|
||||
} else {
|
||||
//Try Dials in parallel to avoid waiting for timeouts.
|
||||
connChan := make(chan net.Conn, len(ips))
|
||||
errChan := make(chan error, len(ips))
|
||||
portStr := strconv.Itoa(int(p.port))
|
||||
portStr := strconv.Itoa(int(resolveServerPort(p.port)))
|
||||
for _, ip := range ips {
|
||||
go func(ip net.IP) {
|
||||
d := c.getDialer(&p)
|
||||
|
@ -1172,7 +825,7 @@ func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn ne
|
|||
// Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection
|
||||
if conn == nil {
|
||||
f := "Unable to open tcp connection with host '%v:%v': %v"
|
||||
return nil, fmt.Errorf(f, p.host, p.port, err.Error())
|
||||
return nil, fmt.Errorf(f, p.host, resolveServerPort(p.port), err.Error())
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
@ -1185,7 +838,7 @@ func connect(ctx context.Context, c *Connector, log optionalLogger, p connectPar
|
|||
defer cancel()
|
||||
}
|
||||
// if instance is specified use instance resolution service
|
||||
if p.instance != "" {
|
||||
if p.instance != "" && p.port == 0 {
|
||||
p.instance = strings.ToUpper(p.instance)
|
||||
d := c.getDialer(&p)
|
||||
instances, err := getInstances(dialCtx, d, p.host)
|
||||
|
@ -1198,11 +851,12 @@ func connect(ctx context.Context, c *Connector, log optionalLogger, p connectPar
|
|||
f := "No instance matching '%v' returned from host '%v'"
|
||||
return nil, fmt.Errorf(f, p.instance, p.host)
|
||||
}
|
||||
p.port, err = strconv.ParseUint(strport, 0, 16)
|
||||
port, err := strconv.ParseUint(strport, 0, 16)
|
||||
if err != nil {
|
||||
f := "Invalid tcp port returned from Sql Server Browser '%v': %v"
|
||||
return nil, fmt.Errorf(f, strport, err.Error())
|
||||
}
|
||||
p.port = port
|
||||
}
|
||||
|
||||
initiate_connection:
|
||||
|
@ -1277,12 +931,12 @@ initiate_connection:
|
|||
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
|
||||
// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
|
||||
config.DynamicRecordSizingDisabled = true
|
||||
outbuf.transport = conn
|
||||
toconn.buf = outbuf
|
||||
tlsConn := tls.Client(toconn, &config)
|
||||
// setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream
|
||||
handshakeConn := tlsHandshakeConn{buf: outbuf}
|
||||
passthrough := passthroughConn{c: &handshakeConn}
|
||||
tlsConn := tls.Client(&passthrough, &config)
|
||||
err = tlsConn.Handshake()
|
||||
|
||||
toconn.buf = nil
|
||||
passthrough.c = toconn
|
||||
outbuf.transport = tlsConn
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("TLS Handshake failed: %v", err)
|
||||
|
@ -1304,15 +958,23 @@ initiate_connection:
|
|||
AppName: p.appname,
|
||||
TypeFlags: p.typeFlags,
|
||||
}
|
||||
auth, auth_ok := getAuth(p.user, p.password, p.serverSPN, p.workstation)
|
||||
if auth_ok {
|
||||
auth, authOk := getAuth(p.user, p.password, p.serverSPN, p.workstation)
|
||||
switch {
|
||||
case p.fedAuthAccessToken != "": // accesstoken ignores user/password
|
||||
featurext := &featureExtFedAuthSTS{
|
||||
FedAuthEcho: len(fields[preloginFEDAUTHREQUIRED]) > 0 && fields[preloginFEDAUTHREQUIRED][0] == 1,
|
||||
FedAuthToken: p.fedAuthAccessToken,
|
||||
Nonce: fields[preloginNONCEOPT],
|
||||
}
|
||||
login.FeatureExt.Add(featurext)
|
||||
case authOk:
|
||||
login.SSPI, err = auth.InitialBytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
login.OptionFlags2 |= fIntSecurity
|
||||
defer auth.Free()
|
||||
} else {
|
||||
default:
|
||||
login.UserName = p.user
|
||||
login.Password = p.password
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue