mirror of
https://github.com/documize/community.git
synced 2025-07-19 13:19:43 +02:00
1427 lines
37 KiB
Go
1427 lines
37 KiB
Go
package mssql
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
"unicode/utf16"
|
|
"unicode/utf8"
|
|
|
|
"github.com/microsoft/go-mssqldb/aecmk"
|
|
"github.com/microsoft/go-mssqldb/integratedauth"
|
|
"github.com/microsoft/go-mssqldb/msdsn"
|
|
)
|
|
|
|
func parseDAC(msg []byte, instance string) msdsn.BrowserData {
|
|
results := msdsn.BrowserData{}
|
|
if len(msg) == 6 && msg[0] == 5 {
|
|
results[strings.ToUpper(instance)]["tcp"] = fmt.Sprint(binary.LittleEndian.Uint16(msg[5:]))
|
|
}
|
|
return results
|
|
}
|
|
|
|
func parseInstances(msg []byte) msdsn.BrowserData {
|
|
results := msdsn.BrowserData{}
|
|
if len(msg) > 3 && msg[0] == 5 {
|
|
out_s := string(msg[3:])
|
|
tokens := strings.Split(out_s, ";")
|
|
instdict := map[string]string{}
|
|
got_name := false
|
|
var name string
|
|
for _, token := range tokens {
|
|
if got_name {
|
|
instdict[name] = token
|
|
got_name = false
|
|
} else {
|
|
name = token
|
|
if len(name) == 0 {
|
|
if len(instdict) == 0 {
|
|
break
|
|
}
|
|
results[strings.ToUpper(instdict["InstanceName"])] = instdict
|
|
instdict = map[string]string{}
|
|
continue
|
|
}
|
|
got_name = true
|
|
}
|
|
}
|
|
}
|
|
return results
|
|
}
|
|
|
|
func getInstances(ctx context.Context, d Dialer, address string, browserMsg msdsn.BrowserMsg, instance string) (msdsn.BrowserData, error) {
|
|
emptyInstances := msdsn.BrowserData{}
|
|
var bmsg []byte
|
|
var resp []byte
|
|
if browserMsg == msdsn.BrowserDAC {
|
|
bmsg = make([]byte, 3+len(instance))
|
|
bmsg[0] = byte(msdsn.BrowserDAC)
|
|
bmsg[1] = 1
|
|
_ = copy(bmsg[3:], instance)
|
|
resp = make([]byte, 6)
|
|
} else { // default to AllInstances
|
|
bmsg = []byte{byte(msdsn.BrowserAllInstances)}
|
|
resp = make([]byte, 16*1024-1)
|
|
}
|
|
conn, err := d.DialContext(ctx, "udp", net.JoinHostPort(address, "1434"))
|
|
if err != nil {
|
|
return emptyInstances, err
|
|
}
|
|
defer conn.Close()
|
|
deadline, _ := ctx.Deadline()
|
|
conn.SetDeadline(deadline)
|
|
_, err = conn.Write(bmsg)
|
|
if err != nil {
|
|
return emptyInstances, err
|
|
}
|
|
|
|
read, err := conn.Read(resp)
|
|
if err != nil {
|
|
return emptyInstances, err
|
|
}
|
|
if browserMsg == msdsn.BrowserDAC {
|
|
return parseDAC(resp[:read], instance), nil
|
|
}
|
|
return parseInstances(resp[:read]), nil
|
|
}
|
|
|
|
// tds versions
|
|
const (
|
|
verTDS70 = 0x70000000
|
|
verTDS71 = 0x71000000
|
|
verTDS71rev1 = 0x71000001
|
|
verTDS72 = 0x72090002
|
|
verTDS73A = 0x730A0003
|
|
verTDS73 = verTDS73A
|
|
verTDS73B = 0x730B0003
|
|
verTDS74 = 0x74000004
|
|
verTDS80 = 0x08000000
|
|
)
|
|
|
|
// packet types
|
|
// https://msdn.microsoft.com/en-us/library/dd304214.aspx
|
|
const (
|
|
packSQLBatch packetType = 1
|
|
packRPCRequest packetType = 3
|
|
packReply packetType = 4
|
|
|
|
// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
|
|
// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
|
|
packAttention packetType = 6
|
|
|
|
packBulkLoadBCP packetType = 7
|
|
packFedAuthToken packetType = 8
|
|
packTransMgrReq packetType = 14
|
|
packNormal packetType = 15
|
|
packLogin7 packetType = 16
|
|
packSSPIMessage packetType = 17
|
|
packPrelogin packetType = 18
|
|
)
|
|
|
|
// prelogin fields
|
|
// http://msdn.microsoft.com/en-us/library/dd357559.aspx
|
|
const (
|
|
preloginVERSION = 0
|
|
preloginENCRYPTION = 1
|
|
preloginINSTOPT = 2
|
|
preloginTHREADID = 3
|
|
preloginMARS = 4
|
|
preloginTRACEID = 5
|
|
preloginFEDAUTHREQUIRED = 6
|
|
preloginNONCEOPT = 7
|
|
preloginTERMINATOR = 0xff
|
|
)
|
|
|
|
const (
|
|
encryptOff = 0 // Encryption is available but off.
|
|
encryptOn = 1 // Encryption is available and on.
|
|
encryptNotSup = 2 // Encryption is not available.
|
|
encryptReq = 3 // Encryption is required.
|
|
encryptStrict = 4
|
|
)
|
|
|
|
const (
|
|
featExtSESSIONRECOVERY byte = 0x01
|
|
featExtFEDAUTH byte = 0x02
|
|
featExtCOLUMNENCRYPTION byte = 0x04
|
|
featExtGLOBALTRANSACTIONS byte = 0x05
|
|
featExtAZURESQLSUPPORT byte = 0x08
|
|
featExtDATACLASSIFICATION byte = 0x09
|
|
featExtUTF8SUPPORT byte = 0x0A
|
|
featExtTERMINATOR byte = 0xFF
|
|
)
|
|
|
|
type tdsSession struct {
|
|
buf *tdsBuffer
|
|
loginAck loginAckStruct
|
|
database string
|
|
partner string
|
|
columns []columnStruct
|
|
tranid uint64
|
|
logFlags uint64
|
|
logger ContextLogger
|
|
routedServer string
|
|
routedPort uint16
|
|
alwaysEncrypted bool
|
|
aeSettings *alwaysEncryptedSettings
|
|
}
|
|
|
|
type alwaysEncryptedSettings struct {
|
|
enclaveType string
|
|
keyProviders aecmk.ColumnEncryptionKeyProviderMap
|
|
}
|
|
|
|
const (
|
|
// Default packet size for a TDS buffer.
|
|
defaultPacketSize = 4096
|
|
|
|
// Default port if no port given.
|
|
defaultServerPort = 1433
|
|
)
|
|
|
|
type columnStruct struct {
|
|
UserType uint32
|
|
Flags uint16
|
|
ColName string
|
|
ti typeInfo
|
|
cryptoMeta *cryptoMetadata
|
|
}
|
|
|
|
func (c columnStruct) isEncrypted() bool {
|
|
return isEncryptedFlag(c.Flags)
|
|
}
|
|
|
|
func isEncryptedFlag(flags uint16) bool {
|
|
return colFlagEncrypted == (flags & colFlagEncrypted)
|
|
}
|
|
|
|
func (c columnStruct) originalTypeInfo() typeInfo {
|
|
if c.isEncrypted() {
|
|
return c.cryptoMeta.typeInfo
|
|
}
|
|
return c.ti
|
|
}
|
|
|
|
type keySlice []uint8
|
|
|
|
func (p keySlice) Len() int { return len(p) }
|
|
func (p keySlice) Less(i, j int) bool { return p[i] < p[j] }
|
|
func (p keySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
|
|
|
type preloginOption struct {
|
|
token byte
|
|
offset uint16
|
|
length uint16
|
|
}
|
|
|
|
var preloginOptionSize = binary.Size(preloginOption{})
|
|
|
|
// http://msdn.microsoft.com/en-us/library/dd357559.aspx
|
|
func writePrelogin(packetType packetType, w *tdsBuffer, fields map[uint8][]byte) error {
|
|
var err error
|
|
|
|
w.BeginPacket(packetType, false)
|
|
offset := uint16(5*len(fields) + 1)
|
|
keys := make(keySlice, 0, len(fields))
|
|
for k := range fields {
|
|
keys = append(keys, k)
|
|
}
|
|
sort.Sort(keys)
|
|
// writing header
|
|
for _, k := range keys {
|
|
err = w.WriteByte(k)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = binary.Write(w, binary.BigEndian, offset)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
v := fields[k]
|
|
size := uint16(len(v))
|
|
err = binary.Write(w, binary.BigEndian, size)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
offset += size
|
|
}
|
|
err = w.WriteByte(preloginTERMINATOR)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// writing values
|
|
for _, k := range keys {
|
|
v := fields[k]
|
|
written, err := w.Write(v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if written != len(v) {
|
|
return errors.New("Write method didn't write the whole value")
|
|
}
|
|
}
|
|
return w.FinishPacket()
|
|
}
|
|
|
|
func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) {
|
|
packet_type, err := r.BeginRead()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
struct_buf, err := ioutil.ReadAll(r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if packet_type != packReply {
|
|
return nil, errors.New("invalid respones, expected packet type 4, PRELOGIN RESPONSE")
|
|
}
|
|
if len(struct_buf) == 0 {
|
|
return nil, errors.New("invalid empty PRELOGIN response, it must contain at least one byte")
|
|
}
|
|
offset := 0
|
|
results := map[uint8][]byte{}
|
|
for {
|
|
// read prelogin option
|
|
plOption, err := readPreloginOption(struct_buf, offset)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if plOption.token == preloginTERMINATOR {
|
|
break
|
|
}
|
|
|
|
// read prelogin option data
|
|
value, err := readPreloginOptionData(plOption, struct_buf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
results[plOption.token] = value
|
|
|
|
offset += preloginOptionSize
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
func readPreloginOption(buffer []byte, offset int) (*preloginOption, error) {
|
|
buffer_length := len(buffer)
|
|
|
|
// check if prelogin option record exists in buffer
|
|
if offset >= buffer_length {
|
|
return nil, fmt.Errorf("invalid buffer, invalid prelogin option")
|
|
}
|
|
|
|
rec_type := buffer[offset]
|
|
if rec_type == preloginTERMINATOR {
|
|
return &preloginOption{token: rec_type}, nil
|
|
}
|
|
|
|
// check if prelogin option exists in buffer
|
|
if offset+preloginOptionSize >= buffer_length {
|
|
return nil, fmt.Errorf("invalid buffer, invalid prelogin option")
|
|
}
|
|
|
|
plOption := &preloginOption{
|
|
token: rec_type,
|
|
offset: binary.BigEndian.Uint16(buffer[offset+1:]),
|
|
length: binary.BigEndian.Uint16(buffer[offset+3:]),
|
|
}
|
|
|
|
return plOption, nil
|
|
}
|
|
|
|
func readPreloginOptionData(plOption *preloginOption, buffer []byte) ([]byte, error) {
|
|
buffer_length := len(buffer)
|
|
// check if prelogin option data exists in buffer
|
|
if plOption == nil || int(plOption.length+plOption.offset) > buffer_length ||
|
|
int(plOption.offset) >= buffer_length {
|
|
return nil, fmt.Errorf("invalid buffer, invalid prelogin option")
|
|
}
|
|
|
|
if plOption.token == preloginTERMINATOR {
|
|
return nil, fmt.Errorf("cannot read data for prelogin terminator record")
|
|
}
|
|
|
|
value := buffer[plOption.offset : plOption.length+plOption.offset]
|
|
return value, nil
|
|
}
|
|
|
|
// OptionFlags1
|
|
// http://msdn.microsoft.com/en-us/library/dd304019.aspx
|
|
const (
|
|
fUseDB = 0x20
|
|
fSetLang = 0x80
|
|
)
|
|
|
|
// OptionFlags2
|
|
// http://msdn.microsoft.com/en-us/library/dd304019.aspx
|
|
const (
|
|
fLanguageFatal = 1
|
|
fODBC = 2
|
|
fTransBoundary = 4
|
|
fCacheConnect = 8
|
|
fIntSecurity = 0x80
|
|
)
|
|
|
|
// OptionFlags3
|
|
// http://msdn.microsoft.com/en-us/library/dd304019.aspx
|
|
const (
|
|
fChangePassword = 1
|
|
fSendYukonBinaryXML = 2
|
|
fUserInstance = 4
|
|
fUnknownCollationHandling = 8
|
|
fExtension = 0x10
|
|
)
|
|
|
|
// TypeFlags
|
|
const (
|
|
// 4 bits for fSQLType
|
|
// 1 bit for fOLEDB
|
|
fReadOnlyIntent = 32
|
|
)
|
|
|
|
type login struct {
|
|
TDSVersion uint32
|
|
PacketSize uint32
|
|
ClientProgVer uint32
|
|
ClientPID uint32
|
|
ConnectionID uint32
|
|
OptionFlags1 uint8
|
|
OptionFlags2 uint8
|
|
TypeFlags uint8
|
|
OptionFlags3 uint8
|
|
ClientTimeZone int32
|
|
ClientLCID uint32
|
|
HostName string
|
|
UserName string
|
|
Password string
|
|
AppName string
|
|
ServerName string
|
|
CtlIntName string
|
|
Language string
|
|
Database string
|
|
ClientID [6]byte
|
|
SSPI []byte
|
|
AtchDBFile string
|
|
ChangePassword string
|
|
FeatureExt featureExts
|
|
}
|
|
|
|
type featureExts struct {
|
|
features map[byte]featureExt
|
|
}
|
|
|
|
type featureExt interface {
|
|
featureID() byte
|
|
toBytes() []byte
|
|
}
|
|
|
|
func (e *featureExts) Add(f featureExt) error {
|
|
if f == nil {
|
|
return nil
|
|
}
|
|
id := f.featureID()
|
|
if _, exists := e.features[id]; exists {
|
|
f := "login error: Feature with ID '%v' is already present in FeatureExt block"
|
|
return fmt.Errorf(f, id)
|
|
}
|
|
if e.features == nil {
|
|
e.features = make(map[byte]featureExt)
|
|
}
|
|
e.features[id] = f
|
|
return nil
|
|
}
|
|
|
|
func (e featureExts) toBytes() []byte {
|
|
if len(e.features) == 0 {
|
|
return nil
|
|
}
|
|
var d []byte
|
|
for featureID, f := range e.features {
|
|
featureData := f.toBytes()
|
|
|
|
hdr := make([]byte, 5)
|
|
hdr[0] = featureID // FedAuth feature extension BYTE
|
|
binary.LittleEndian.PutUint32(hdr[1:], uint32(len(featureData))) // FeatureDataLen DWORD
|
|
d = append(d, hdr...)
|
|
|
|
d = append(d, featureData...) // FeatureData *BYTE
|
|
}
|
|
if d != nil {
|
|
d = append(d, 0xff) // Terminator
|
|
}
|
|
return d
|
|
}
|
|
|
|
// featureExtFedAuth tracks federated authentication state before and during login
|
|
type featureExtFedAuth struct {
|
|
// FedAuthLibrary is populated by the federated authentication provider.
|
|
FedAuthLibrary int
|
|
|
|
// ADALWorkflow is populated by the federated authentication provider.
|
|
ADALWorkflow byte
|
|
|
|
// FedAuthEcho is populated from the prelogin response
|
|
FedAuthEcho bool
|
|
|
|
// FedAuthToken is populated during login with the value from the provider.
|
|
FedAuthToken string
|
|
|
|
// Nonce is populated during login with the value from the provider.
|
|
Nonce []byte
|
|
|
|
// Signature is populated during login with the value from the server.
|
|
Signature []byte
|
|
}
|
|
|
|
func (e *featureExtFedAuth) featureID() byte {
|
|
return featExtFEDAUTH
|
|
}
|
|
|
|
func (e *featureExtFedAuth) toBytes() []byte {
|
|
if e == nil {
|
|
return nil
|
|
}
|
|
|
|
options := byte(e.FedAuthLibrary) << 1
|
|
if e.FedAuthEcho {
|
|
options |= 1 // fFedAuthEcho
|
|
}
|
|
|
|
// Feature extension format depends on the federated auth library.
|
|
// Options are described at
|
|
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac
|
|
var d []byte
|
|
|
|
switch e.FedAuthLibrary {
|
|
case FedAuthLibrarySecurityToken:
|
|
d = make([]byte, 5)
|
|
d[0] = options
|
|
|
|
// looks like string in
|
|
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/f88b63bb-b479-49e1-a87b-deda521da508
|
|
tokenBytes := str2ucs2(e.FedAuthToken)
|
|
binary.LittleEndian.PutUint32(d[1:], uint32(len(tokenBytes))) // Should be a signed int32, but since the length is relatively small, this should work
|
|
d = append(d, tokenBytes...)
|
|
|
|
if len(e.Nonce) == 32 {
|
|
d = append(d, e.Nonce...)
|
|
}
|
|
|
|
case FedAuthLibraryADAL:
|
|
d = []byte{options, e.ADALWorkflow}
|
|
}
|
|
|
|
return d
|
|
}
|
|
|
|
type loginHeader struct {
|
|
Length uint32
|
|
TDSVersion uint32
|
|
PacketSize uint32
|
|
ClientProgVer uint32
|
|
ClientPID uint32
|
|
ConnectionID uint32
|
|
OptionFlags1 uint8
|
|
OptionFlags2 uint8
|
|
TypeFlags uint8
|
|
OptionFlags3 uint8
|
|
ClientTimeZone int32
|
|
ClientLCID uint32
|
|
HostNameOffset uint16
|
|
HostNameLength uint16
|
|
UserNameOffset uint16
|
|
UserNameLength uint16
|
|
PasswordOffset uint16
|
|
PasswordLength uint16
|
|
AppNameOffset uint16
|
|
AppNameLength uint16
|
|
ServerNameOffset uint16
|
|
ServerNameLength uint16
|
|
ExtensionOffset uint16
|
|
ExtensionLength uint16
|
|
CtlIntNameOffset uint16
|
|
CtlIntNameLength uint16
|
|
LanguageOffset uint16
|
|
LanguageLength uint16
|
|
DatabaseOffset uint16
|
|
DatabaseLength uint16
|
|
ClientID [6]byte
|
|
SSPIOffset uint16
|
|
SSPILength uint16
|
|
AtchDBFileOffset uint16
|
|
AtchDBFileLength uint16
|
|
ChangePasswordOffset uint16
|
|
ChangePasswordLength uint16
|
|
SSPILongLength uint32
|
|
}
|
|
|
|
// convert Go string to UTF-16 encoded []byte (littleEndian)
|
|
// done manually rather than using bytes and binary packages
|
|
// for performance reasons
|
|
func str2ucs2(s string) []byte {
|
|
res := utf16.Encode([]rune(s))
|
|
ucs2 := make([]byte, 2*len(res))
|
|
for i := 0; i < len(res); i++ {
|
|
ucs2[2*i] = byte(res[i])
|
|
ucs2[2*i+1] = byte(res[i] >> 8)
|
|
}
|
|
return ucs2
|
|
}
|
|
|
|
const (
|
|
mask64 uint64 = 0xFF80FF80FF80FF80
|
|
mask32 uint32 = 0xFF80FF80
|
|
mask16 uint16 = 0xFF80
|
|
)
|
|
|
|
func manglePassword(password string) []byte {
|
|
var ucs2password []byte = str2ucs2(password)
|
|
for i, ch := range ucs2password {
|
|
ucs2password[i] = ((ch<<4)&0xff | (ch >> 4)) ^ 0xA5
|
|
}
|
|
return ucs2password
|
|
}
|
|
|
|
// http://msdn.microsoft.com/en-us/library/dd304019.aspx
|
|
func sendLogin(w *tdsBuffer, login *login) error {
|
|
w.BeginPacket(packLogin7, false)
|
|
hostname := str2ucs2(login.HostName)
|
|
username := str2ucs2(login.UserName)
|
|
password := manglePassword(login.Password)
|
|
appname := str2ucs2(login.AppName)
|
|
servername := str2ucs2(login.ServerName)
|
|
ctlintname := str2ucs2(login.CtlIntName)
|
|
language := str2ucs2(login.Language)
|
|
database := str2ucs2(login.Database)
|
|
atchdbfile := str2ucs2(login.AtchDBFile)
|
|
changepassword := manglePassword(login.ChangePassword)
|
|
featureExt := login.FeatureExt.toBytes()
|
|
|
|
hdr := loginHeader{
|
|
TDSVersion: login.TDSVersion,
|
|
PacketSize: login.PacketSize,
|
|
ClientProgVer: login.ClientProgVer,
|
|
ClientPID: login.ClientPID,
|
|
ConnectionID: login.ConnectionID,
|
|
OptionFlags1: login.OptionFlags1,
|
|
OptionFlags2: login.OptionFlags2,
|
|
TypeFlags: login.TypeFlags,
|
|
OptionFlags3: login.OptionFlags3,
|
|
ClientTimeZone: login.ClientTimeZone,
|
|
ClientLCID: login.ClientLCID,
|
|
HostNameLength: uint16(utf8.RuneCountInString(login.HostName)),
|
|
UserNameLength: uint16(utf8.RuneCountInString(login.UserName)),
|
|
PasswordLength: uint16(utf8.RuneCountInString(login.Password)),
|
|
AppNameLength: uint16(utf8.RuneCountInString(login.AppName)),
|
|
ServerNameLength: uint16(utf8.RuneCountInString(login.ServerName)),
|
|
CtlIntNameLength: uint16(utf8.RuneCountInString(login.CtlIntName)),
|
|
LanguageLength: uint16(utf8.RuneCountInString(login.Language)),
|
|
DatabaseLength: uint16(utf8.RuneCountInString(login.Database)),
|
|
ClientID: login.ClientID,
|
|
SSPILength: uint16(len(login.SSPI)),
|
|
AtchDBFileLength: uint16(utf8.RuneCountInString(login.AtchDBFile)),
|
|
ChangePasswordLength: uint16(utf8.RuneCountInString(login.ChangePassword)),
|
|
}
|
|
offset := uint16(binary.Size(hdr))
|
|
hdr.HostNameOffset = offset
|
|
offset += uint16(len(hostname))
|
|
hdr.UserNameOffset = offset
|
|
offset += uint16(len(username))
|
|
hdr.PasswordOffset = offset
|
|
offset += uint16(len(password))
|
|
hdr.AppNameOffset = offset
|
|
offset += uint16(len(appname))
|
|
hdr.ServerNameOffset = offset
|
|
offset += uint16(len(servername))
|
|
hdr.CtlIntNameOffset = offset
|
|
offset += uint16(len(ctlintname))
|
|
hdr.LanguageOffset = offset
|
|
offset += uint16(len(language))
|
|
hdr.DatabaseOffset = offset
|
|
offset += uint16(len(database))
|
|
hdr.SSPIOffset = offset
|
|
offset += uint16(len(login.SSPI))
|
|
hdr.AtchDBFileOffset = offset
|
|
offset += uint16(len(atchdbfile))
|
|
hdr.ChangePasswordOffset = offset
|
|
offset += uint16(len(changepassword))
|
|
|
|
featureExtOffset := uint32(0)
|
|
featureExtLen := len(featureExt)
|
|
if featureExtLen > 0 {
|
|
hdr.OptionFlags3 |= fExtension
|
|
hdr.ExtensionOffset = offset
|
|
hdr.ExtensionLength = 4
|
|
offset += hdr.ExtensionLength // DWORD
|
|
featureExtOffset = uint32(offset)
|
|
}
|
|
if len(changepassword) > 0 {
|
|
hdr.OptionFlags3 |= fChangePassword
|
|
}
|
|
hdr.Length = uint32(offset) + uint32(featureExtLen)
|
|
|
|
var err error
|
|
err = binary.Write(w, binary.LittleEndian, &hdr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = w.Write(hostname)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = w.Write(username)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = w.Write(password)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = w.Write(appname)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = w.Write(servername)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = w.Write(ctlintname)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = w.Write(language)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = w.Write(database)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = w.Write(login.SSPI)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = w.Write(atchdbfile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = w.Write(changepassword)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if featureExtOffset > 0 {
|
|
err = binary.Write(w, binary.LittleEndian, featureExtOffset)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = w.Write(featureExt)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return w.FinishPacket()
|
|
}
|
|
|
|
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/827d9632-2957-4d54-b9ea-384530ae79d0
|
|
func sendFedAuthInfo(w *tdsBuffer, fedAuth *featureExtFedAuth) (err error) {
|
|
fedauthtoken := str2ucs2(fedAuth.FedAuthToken)
|
|
tokenlen := len(fedauthtoken)
|
|
datalen := 4 + tokenlen + len(fedAuth.Nonce)
|
|
|
|
w.BeginPacket(packFedAuthToken, false)
|
|
err = binary.Write(w, binary.LittleEndian, uint32(datalen))
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
err = binary.Write(w, binary.LittleEndian, uint32(tokenlen))
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
_, err = w.Write(fedauthtoken)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
_, err = w.Write(fedAuth.Nonce)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
return w.FinishPacket()
|
|
}
|
|
|
|
func readUcs2(r io.Reader, numchars int) (res string, err error) {
|
|
buf := make([]byte, numchars*2)
|
|
_, err = io.ReadFull(r, buf)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return ucs22str(buf)
|
|
}
|
|
|
|
func readUsVarChar(r io.Reader) (res string, err error) {
|
|
numchars, err := readUshort(r)
|
|
if err != nil {
|
|
return
|
|
}
|
|
return readUcs2(r, int(numchars))
|
|
}
|
|
|
|
func writeUsVarChar(w io.Writer, s string) (err error) {
|
|
buf := str2ucs2(s)
|
|
var numchars int = len(buf) / 2
|
|
if numchars > 0xffff {
|
|
panic("invalid size for US_VARCHAR")
|
|
}
|
|
err = binary.Write(w, binary.LittleEndian, uint16(numchars))
|
|
if err != nil {
|
|
return
|
|
}
|
|
_, err = w.Write(buf)
|
|
return
|
|
}
|
|
|
|
func readBVarChar(r io.Reader) (string, error) {
|
|
numchars, err := readByte(r)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// A zero length could be returned, return an empty string
|
|
if numchars == 0 {
|
|
return "", nil
|
|
}
|
|
return readUcs2(r, int(numchars))
|
|
}
|
|
|
|
func writeBVarChar(w io.Writer, s string) (err error) {
|
|
buf := str2ucs2(s)
|
|
var numchars int = len(buf) / 2
|
|
if numchars > 0xff {
|
|
panic("invalid size for B_VARCHAR")
|
|
}
|
|
err = binary.Write(w, binary.LittleEndian, uint8(numchars))
|
|
if err != nil {
|
|
return
|
|
}
|
|
_, err = w.Write(buf)
|
|
return
|
|
}
|
|
|
|
func readBVarByte(r io.Reader) (res []byte, err error) {
|
|
length, err := readByte(r)
|
|
if err != nil {
|
|
return
|
|
}
|
|
res = make([]byte, length)
|
|
_, err = io.ReadFull(r, res)
|
|
return
|
|
}
|
|
|
|
func readUshort(r io.Reader) (res uint16, err error) {
|
|
err = binary.Read(r, binary.LittleEndian, &res)
|
|
return
|
|
}
|
|
|
|
func readByte(r io.Reader) (res byte, err error) {
|
|
var b [1]byte
|
|
_, err = r.Read(b[:])
|
|
res = b[0]
|
|
return
|
|
}
|
|
|
|
// Packet Data Stream Headers
|
|
// http://msdn.microsoft.com/en-us/library/dd304953.aspx
|
|
type headerStruct struct {
|
|
hdrtype uint16
|
|
data []byte
|
|
}
|
|
|
|
const (
|
|
dataStmHdrQueryNotif = 1 // query notifications
|
|
dataStmHdrTransDescr = 2 // MARS transaction descriptor (required)
|
|
dataStmHdrTraceActivity = 3
|
|
)
|
|
|
|
// Query Notifications Header
|
|
// http://msdn.microsoft.com/en-us/library/dd304949.aspx
|
|
type queryNotifHdr struct {
|
|
notifyId string
|
|
ssbDeployment string
|
|
notifyTimeout uint32
|
|
}
|
|
|
|
func (hdr queryNotifHdr) pack() (res []byte) {
|
|
notifyId := str2ucs2(hdr.notifyId)
|
|
ssbDeployment := str2ucs2(hdr.ssbDeployment)
|
|
|
|
res = make([]byte, 2+len(notifyId)+2+len(ssbDeployment)+4)
|
|
b := res
|
|
|
|
binary.LittleEndian.PutUint16(b, uint16(len(notifyId)))
|
|
b = b[2:]
|
|
copy(b, notifyId)
|
|
b = b[len(notifyId):]
|
|
|
|
binary.LittleEndian.PutUint16(b, uint16(len(ssbDeployment)))
|
|
b = b[2:]
|
|
copy(b, ssbDeployment)
|
|
b = b[len(ssbDeployment):]
|
|
|
|
binary.LittleEndian.PutUint32(b, hdr.notifyTimeout)
|
|
|
|
return res
|
|
}
|
|
|
|
// MARS Transaction Descriptor Header
|
|
// http://msdn.microsoft.com/en-us/library/dd340515.aspx
|
|
type transDescrHdr struct {
|
|
transDescr uint64 // transaction descriptor returned from ENVCHANGE
|
|
outstandingReqCnt uint32 // outstanding request count
|
|
}
|
|
|
|
func (hdr transDescrHdr) pack() (res []byte) {
|
|
res = make([]byte, 8+4)
|
|
binary.LittleEndian.PutUint64(res, hdr.transDescr)
|
|
binary.LittleEndian.PutUint32(res[8:], hdr.outstandingReqCnt)
|
|
return res
|
|
}
|
|
|
|
func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
|
|
// Calculating total length.
|
|
var totallen uint32 = 4
|
|
for _, hdr := range headers {
|
|
totallen += 4 + 2 + uint32(len(hdr.data))
|
|
}
|
|
// writing
|
|
err = binary.Write(w, binary.LittleEndian, totallen)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, hdr := range headers {
|
|
var headerlen uint32 = 4 + 2 + uint32(len(hdr.data))
|
|
err = binary.Write(w, binary.LittleEndian, headerlen)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = binary.Write(w, binary.LittleEndian, hdr.hdrtype)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = w.Write(hdr.data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct, resetSession bool) (err error) {
|
|
buf.BeginPacket(packSQLBatch, resetSession)
|
|
|
|
if err = writeAllHeaders(buf, headers); err != nil {
|
|
return
|
|
}
|
|
|
|
_, err = buf.Write(str2ucs2(sqltext))
|
|
if err != nil {
|
|
return
|
|
}
|
|
return buf.FinishPacket()
|
|
}
|
|
|
|
// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
|
|
// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
|
|
func sendAttention(buf *tdsBuffer) error {
|
|
buf.BeginPacket(packAttention, false)
|
|
return buf.FinishPacket()
|
|
}
|
|
|
|
// Makes an attempt to connect with each available protocol, in order, until one succeeds or the timeout elapses
|
|
func dialConnection(ctx context.Context, c *Connector, p *msdsn.Config, logger ContextLogger) (conn net.Conn, err error) {
|
|
var instances msdsn.BrowserData
|
|
for _, protocol := range p.Protocols {
|
|
dialer := msdsn.ProtocolDialers[protocol]
|
|
if dialer.CallBrowser(p) {
|
|
if instances == nil {
|
|
d := c.getDialer(p)
|
|
instances, err = getInstances(ctx, d, p.Host, p.BrowserMessage, p.Instance)
|
|
if err != nil && logger != nil && uint64(p.LogFlags)&logErrors != 0 {
|
|
e := fmt.Sprintf("unable to get instances from Sql Server Browser on host %v: %v", p.Host, err.Error())
|
|
logger.Log(ctx, msdsn.Log(logErrors), e)
|
|
}
|
|
}
|
|
err = dialer.ParseBrowserData(instances, p)
|
|
if err != nil {
|
|
if logger != nil && uint64(p.LogFlags)&logErrors != 0 {
|
|
logger.Log(ctx, msdsn.Log(logErrors), "Skipping protocol "+protocol+". Error:"+err.Error())
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
sqlDialer, ok := dialer.(MssqlProtocolDialer)
|
|
if logger != nil && uint64(p.LogFlags)&logDebug != 0 {
|
|
logger.Log(ctx, msdsn.LogDebug, "Dialing with protocol "+protocol)
|
|
}
|
|
if !ok {
|
|
conn, err = dialer.DialConnection(ctx, p)
|
|
} else {
|
|
conn, err = sqlDialer.DialSqlConnection(ctx, c, p)
|
|
}
|
|
if err != nil && logger != nil && uint64(p.LogFlags)&logErrors != 0 {
|
|
logger.Log(ctx, msdsn.LogErrors, "Unable to connect with protocol "+protocol+":"+err.Error())
|
|
}
|
|
if conn != nil {
|
|
if logger != nil && uint64(p.LogFlags)&logDebug != 0 {
|
|
logger.Log(ctx, msdsn.LogDebug, "Returning connection from protocol "+protocol)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func preparePreloginFields(p msdsn.Config, fe *featureExtFedAuth) map[uint8][]byte {
|
|
instance_buf := []byte(p.Instance)
|
|
instance_buf = append(instance_buf, 0) // zero terminate instance name
|
|
|
|
var encrypt byte
|
|
switch p.Encryption {
|
|
default:
|
|
panic(fmt.Errorf("Unsupported Encryption Config %v", p.Encryption))
|
|
case msdsn.EncryptionDisabled:
|
|
encrypt = encryptNotSup
|
|
case msdsn.EncryptionRequired:
|
|
encrypt = encryptOn
|
|
case msdsn.EncryptionOff:
|
|
encrypt = encryptOff
|
|
case msdsn.EncryptionStrict:
|
|
encrypt = encryptStrict
|
|
}
|
|
v := getDriverVersion(driverVersion)
|
|
fields := map[uint8][]byte{
|
|
// 4 bytes for version and 2 bytes for minor version
|
|
preloginVERSION: {byte(v), byte(v >> 8), byte(v >> 16), byte(v >> 24), 0, 0},
|
|
preloginENCRYPTION: {encrypt},
|
|
preloginINSTOPT: instance_buf,
|
|
preloginTHREADID: {0, 0, 0, 0},
|
|
preloginMARS: {0}, // MARS disabled
|
|
}
|
|
|
|
if fe.FedAuthLibrary != FedAuthLibraryReserved {
|
|
fields[preloginFEDAUTHREQUIRED] = []byte{1}
|
|
}
|
|
|
|
return fields
|
|
}
|
|
|
|
func interpretPreloginResponse(p msdsn.Config, fe *featureExtFedAuth, fields map[uint8][]byte) (encrypt byte, err error) {
|
|
// If the server returns the preloginFEDAUTHREQUIRED field, then federated authentication
|
|
// is supported. The actual value may be 0 or 1, where 0 means either SSPI or federated
|
|
// authentication is allowed, while 1 means only federated authentication is allowed.
|
|
if fedAuthSupport, ok := fields[preloginFEDAUTHREQUIRED]; ok {
|
|
if len(fedAuthSupport) != 1 {
|
|
return 0, fmt.Errorf("federated authentication flag length should be 1: is %d", len(fedAuthSupport))
|
|
}
|
|
|
|
// We need to be able to echo the value back to the server
|
|
fe.FedAuthEcho = fedAuthSupport[0] != 0
|
|
} else if fe.FedAuthLibrary != FedAuthLibraryReserved && fe.ADALWorkflow > 0 {
|
|
return 0, fmt.Errorf("federated authentication is not supported by the server")
|
|
}
|
|
|
|
encryptBytes, ok := fields[preloginENCRYPTION]
|
|
if !ok {
|
|
return 0, fmt.Errorf("encrypt negotiation failed")
|
|
}
|
|
encrypt = encryptBytes[0]
|
|
if p.Encryption == msdsn.EncryptionRequired && (encrypt == encryptNotSup || encrypt == encryptOff) {
|
|
return 0, fmt.Errorf("server does not support encryption")
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger ContextLogger, auth integratedauth.IntegratedAuthenticator, fe *featureExtFedAuth, packetSize uint32) (l *login, err error) {
|
|
var TDSVersion uint32
|
|
if p.Encryption == msdsn.EncryptionStrict {
|
|
TDSVersion = verTDS80
|
|
} else {
|
|
TDSVersion = verTDS74
|
|
}
|
|
var typeFlags uint8
|
|
if p.ReadOnlyIntent {
|
|
typeFlags |= fReadOnlyIntent
|
|
}
|
|
// We need to include Instance in ServerName field of LOGIN7 record
|
|
var serverName string
|
|
if len(p.Instance) > 0 {
|
|
serverName = p.Host + "\\" + p.Instance
|
|
} else {
|
|
serverName = p.Host
|
|
}
|
|
l = &login{
|
|
TDSVersion: TDSVersion,
|
|
PacketSize: packetSize,
|
|
Database: p.Database,
|
|
OptionFlags2: fODBC, // to get unlimited TEXTSIZE
|
|
OptionFlags1: fUseDB | fSetLang,
|
|
HostName: p.Workstation,
|
|
ServerName: serverName,
|
|
AppName: p.AppName,
|
|
TypeFlags: typeFlags,
|
|
CtlIntName: "go-mssqldb",
|
|
ClientProgVer: getDriverVersion(driverVersion),
|
|
ChangePassword: p.ChangePassword,
|
|
}
|
|
if p.ColumnEncryption {
|
|
_ = l.FeatureExt.Add(&featureExtColumnEncryption{})
|
|
}
|
|
switch {
|
|
case fe.FedAuthLibrary == FedAuthLibrarySecurityToken:
|
|
if uint64(p.LogFlags)&logDebug != 0 {
|
|
logger.Log(ctx, msdsn.LogDebug, "Starting federated authentication using security token")
|
|
}
|
|
|
|
fe.FedAuthToken, err = c.securityTokenProvider(ctx)
|
|
if err != nil {
|
|
if uint64(p.LogFlags)&logDebug != 0 {
|
|
logger.Log(ctx, msdsn.LogDebug, fmt.Sprintf("Failed to retrieve service principal token for federated authentication security token library: %v", err))
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
_ = l.FeatureExt.Add(fe)
|
|
|
|
case fe.FedAuthLibrary == FedAuthLibraryADAL:
|
|
if uint64(p.LogFlags)&logDebug != 0 {
|
|
logger.Log(ctx, msdsn.LogDebug, "Starting federated authentication using ADAL")
|
|
}
|
|
|
|
_ = l.FeatureExt.Add(fe)
|
|
|
|
case auth != nil:
|
|
if uint64(p.LogFlags)&logDebug != 0 {
|
|
logger.Log(ctx, msdsn.LogDebug, "Starting SSPI login")
|
|
}
|
|
|
|
l.SSPI, err = auth.InitialBytes()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
l.OptionFlags2 |= fIntSecurity
|
|
return l, nil
|
|
|
|
default:
|
|
// Default to SQL server authentication with user and password
|
|
l.UserName = p.User
|
|
l.Password = p.Password
|
|
}
|
|
|
|
return l, nil
|
|
}
|
|
|
|
func getTLSConn(conn *timeoutConn, p msdsn.Config, alpnSeq string) (tlsConn *tls.Conn, err error) {
|
|
var config *tls.Config
|
|
if pc := p.TLSConfig; pc != nil {
|
|
config = pc
|
|
}
|
|
if config == nil {
|
|
config, err = msdsn.SetupTLS("", false, p.Host, "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
//Set ALPN Sequence
|
|
config.NextProtos = []string{alpnSeq}
|
|
tlsConn = tls.Client(conn.c, config)
|
|
err = tlsConn.Handshake()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("TLS Handshake failed: %w", err)
|
|
}
|
|
return tlsConn, nil
|
|
}
|
|
|
|
func connect(ctx context.Context, c *Connector, logger ContextLogger, p msdsn.Config) (res *tdsSession, err error) {
|
|
isTransportEncrypted := false
|
|
// if instance is specified use instance resolution service
|
|
if len(p.Instance) > 0 && p.Port != 0 && uint64(p.LogFlags)&logDebug != 0 {
|
|
// both instance name and port specified
|
|
// when port is specified instance name is not used
|
|
// you should not provide instance name when you provide port
|
|
logger.Log(ctx, msdsn.LogDebug, "WARN: You specified both instance name and port in the connection string, port will be used and instance name will be ignored")
|
|
}
|
|
|
|
packetSize := p.PacketSize
|
|
if packetSize == 0 {
|
|
packetSize = defaultPacketSize
|
|
}
|
|
// Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes
|
|
// NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request
|
|
// a higher packet size, the server will respond with an ENVCHANGE request to
|
|
// alter the packet size to 16383 bytes.
|
|
if packetSize < 512 {
|
|
packetSize = 512
|
|
} else if packetSize > 32767 {
|
|
packetSize = 32767
|
|
}
|
|
|
|
initiate_connection:
|
|
dialCtx := ctx
|
|
if p.DialTimeout >= 0 {
|
|
dt := p.DialTimeout
|
|
if dt == 0 {
|
|
dt = time.Duration(15*len(p.Protocols)) * time.Second
|
|
}
|
|
var cancel func()
|
|
dialCtx, cancel = context.WithTimeout(ctx, dt)
|
|
defer cancel()
|
|
}
|
|
conn, err := dialConnection(dialCtx, c, &p, logger)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
toconn := newTimeoutConn(conn, p.ConnTimeout)
|
|
outbuf := newTdsBuffer(packetSize, toconn)
|
|
|
|
if p.Encryption == msdsn.EncryptionStrict {
|
|
outbuf.transport, err = getTLSConn(toconn, p, "tds/8.0")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
isTransportEncrypted = true
|
|
}
|
|
sess := tdsSession{
|
|
buf: outbuf,
|
|
logger: logger,
|
|
logFlags: uint64(p.LogFlags),
|
|
aeSettings: &alwaysEncryptedSettings{keyProviders: aecmk.GetGlobalCekProviders()},
|
|
}
|
|
|
|
for i, p := range c.keyProviders {
|
|
sess.aeSettings.keyProviders[i] = p
|
|
}
|
|
fedAuth := &featureExtFedAuth{
|
|
FedAuthLibrary: FedAuthLibraryReserved,
|
|
}
|
|
if c.fedAuthRequired {
|
|
fedAuth.FedAuthLibrary = c.fedAuthLibrary
|
|
fedAuth.ADALWorkflow = c.fedAuthADALWorkflow
|
|
}
|
|
|
|
fields := preparePreloginFields(p, fedAuth)
|
|
|
|
err = writePrelogin(packPrelogin, outbuf, fields)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
fields, err = readPrelogin(outbuf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
encrypt, err := interpretPreloginResponse(p, fedAuth, fields)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
//We need not perform TLS handshake if the communication channel is already encrypted (encrypt=strict)
|
|
if !isTransportEncrypted {
|
|
if encrypt != encryptNotSup {
|
|
var config *tls.Config
|
|
if pc := p.TLSConfig; pc != nil {
|
|
config = pc
|
|
if !config.DynamicRecordSizingDisabled {
|
|
config = config.Clone()
|
|
|
|
// fix for https://github.com/microsoft/go-mssqldb/issues/166
|
|
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
|
|
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
|
|
// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
|
|
config.DynamicRecordSizingDisabled = true
|
|
}
|
|
}
|
|
if config == nil {
|
|
config, err = msdsn.SetupTLS("", false, p.Host, "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
}
|
|
|
|
// setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream
|
|
handshakeConn := tlsHandshakeConn{buf: outbuf}
|
|
passthrough := passthroughConn{c: &handshakeConn}
|
|
tlsConn := tls.Client(&passthrough, config)
|
|
err = tlsConn.Handshake()
|
|
passthrough.c = toconn
|
|
outbuf.transport = tlsConn
|
|
if err != nil {
|
|
return nil, fmt.Errorf("TLS Handshake failed: %v", err)
|
|
}
|
|
if encrypt == encryptOff {
|
|
outbuf.afterFirst = func() {
|
|
outbuf.transport = toconn
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
auth, err := integratedauth.GetIntegratedAuthenticator(p)
|
|
if err != nil {
|
|
if uint64(p.LogFlags)&logDebug != 0 {
|
|
logger.Log(ctx, msdsn.LogDebug, fmt.Sprintf("Error while creating integrated authenticator: %v", err))
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
if auth != nil {
|
|
defer auth.Free()
|
|
}
|
|
|
|
login, err := prepareLogin(ctx, c, p, logger, auth, fedAuth, uint32(outbuf.PackageSize()))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = sendLogin(outbuf, login)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Loop until a packet containing a login acknowledgement is received.
|
|
// SSPI and federated authentication scenarios may require multiple
|
|
// packet exchanges to complete the login sequence.
|
|
for loginAck := false; !loginAck; {
|
|
reader := startReading(&sess, ctx, outputs{})
|
|
// don't send attention or wait for cancel confirmation during login
|
|
reader.noAttn = true
|
|
|
|
for {
|
|
tok, err := reader.nextToken()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if tok == nil {
|
|
break
|
|
}
|
|
|
|
switch token := tok.(type) {
|
|
case sspiMsg:
|
|
sspi_msg, err := auth.NextBytes(token)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(sspi_msg) > 0 {
|
|
outbuf.BeginPacket(packSSPIMessage, false)
|
|
_, err = outbuf.Write(sspi_msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = outbuf.FinishPacket()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
sspi_msg = nil
|
|
}
|
|
// TODO: for Live ID authentication it may be necessary to
|
|
// compare fedAuth.Nonce == token.Nonce and keep track of signature
|
|
//case fedAuthAckStruct:
|
|
//fedAuth.Signature = token.Signature
|
|
case fedAuthInfoStruct:
|
|
// For ADAL workflows this contains the STS URL and server SPN.
|
|
// If received outside of an ADAL workflow, ignore.
|
|
if c == nil || c.adalTokenProvider == nil {
|
|
continue
|
|
}
|
|
|
|
// Request the AD token given the server SPN and STS URL
|
|
fedAuth.FedAuthToken, err = c.adalTokenProvider(ctx, token.ServerSPN, token.STSURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Now need to send the token as a FEDINFO packet
|
|
err = sendFedAuthInfo(outbuf, fedAuth)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
case loginAckStruct:
|
|
sess.loginAck = token
|
|
loginAck = true
|
|
case featureExtAck:
|
|
for _, v := range token {
|
|
switch v := v.(type) {
|
|
case colAckStruct:
|
|
if v.Version <= 2 && v.Version > 0 {
|
|
sess.alwaysEncrypted = true
|
|
if len(v.EnclaveType) > 0 {
|
|
sess.aeSettings.enclaveType = string(v.EnclaveType)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
case doneStruct:
|
|
if token.isError() {
|
|
tokenErr := token.getError()
|
|
tokenErr.Message = "login error: " + tokenErr.Message
|
|
return nil, tokenErr
|
|
}
|
|
case error:
|
|
return nil, fmt.Errorf("login error: %s", token.Error())
|
|
}
|
|
}
|
|
}
|
|
|
|
if sess.routedServer != "" {
|
|
toconn.Close()
|
|
// Need to handle case when routedServer is in "host\instance" format.
|
|
routedParts := strings.SplitN(sess.routedServer, "\\", 2)
|
|
p.Host = routedParts[0]
|
|
if len(routedParts) == 2 {
|
|
p.Instance = routedParts[1]
|
|
}
|
|
p.Port = uint64(sess.routedPort)
|
|
if !p.HostInCertificateProvided && p.TLSConfig != nil {
|
|
p.TLSConfig = p.TLSConfig.Clone()
|
|
p.TLSConfig.ServerName = p.Host
|
|
}
|
|
goto initiate_connection
|
|
}
|
|
return &sess, nil
|
|
}
|
|
|
|
type featureExtColumnEncryption struct {
|
|
}
|
|
|
|
func (f *featureExtColumnEncryption) featureID() byte {
|
|
return featExtCOLUMNENCRYPTION
|
|
}
|
|
|
|
func (f *featureExtColumnEncryption) toBytes() []byte {
|
|
/*
|
|
1 = The client supports column encryption without enclave computations.
|
|
2 = The client SHOULD<25> support column encryption when encrypted data require enclave computations.
|
|
3 = The client SHOULD<26> support column encryption when encrypted data require enclave computations
|
|
with the additional ability to cache column encryption keys that are to be sent to the enclave
|
|
and the ability to retry queries when the keys sent by the client do not match what is needed for the query to run.
|
|
*/
|
|
return []byte{0x01}
|
|
}
|