1
0
Fork 0
mirror of https://github.com/documize/community.git synced 2025-07-19 05:09:42 +02:00
documize/vendor/github.com/microsoft/go-mssqldb/protocol.go
2024-01-10 14:47:40 -05:00

169 lines
4.7 KiB
Go

package mssql
import (
"context"
"fmt"
"net"
"os"
"strconv"
"strings"
"github.com/microsoft/go-mssqldb/msdsn"
)
type MssqlProtocolDialer interface {
// DialSqlConnection creates a net.Conn from a Connector based on the Config
DialSqlConnection(ctx context.Context, c *Connector, p *msdsn.Config) (conn net.Conn, err error)
}
type tcpDialer struct{}
func (t tcpDialer) ParseBrowserData(data msdsn.BrowserData, p *msdsn.Config) error {
// If instance is specified, but no port, check SQL Server Browser
// for the instance and discover its port.
ok := len(data) > 0
strport := ""
inst := ""
if ok {
p.Instance = strings.ToUpper(p.Instance)
instanceName := stringForInstanceNameComparison(p.Instance)
for _, i := range data {
inst, ok = i["InstanceName"]
if ok && stringForInstanceNameComparison(inst) == instanceName {
strport, ok = i["tcp"]
break
}
ok = false
}
}
if !ok {
f := "no instance matching '%v' returned from host '%v'"
return fmt.Errorf(f, p.Instance, p.Host)
}
port, err := strconv.ParseUint(strport, 0, 16)
if err != nil {
f := "invalid tcp port returned from Sql Server Browser '%v': %v"
return fmt.Errorf(f, strport, err.Error())
}
p.Port = port
return nil
}
// SQL returns ASCII encoded instance names with \x## escaped UTF16 code points.
// We use QuoteToASCII to normalize strings like TJUTVÅ
// SQL returns 0xc5 as the byte value for Å while the UTF8 bytes in a Go string are [195 133]
// QuoteToASCII returns "TJUTV\u00c5" for both
func stringForInstanceNameComparison(inst string) (instanceName string) {
instanceName = strings.Replace(strconv.QuoteToASCII(inst), `\u00`, `\x`, -1)
instanceName = strings.Replace(instanceName, `\u`, `\x`, -1)
return
}
func (t tcpDialer) DialConnection(ctx context.Context, p *msdsn.Config) (conn net.Conn, err error) {
return nil, fmt.Errorf("tcp dialer requires a Connector instance")
}
// SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
// list of IP addresses. So if there is more than one, try them all and
// use the first one that allows a connection.
func (t tcpDialer) DialSqlConnection(ctx context.Context, c *Connector, p *msdsn.Config) (conn net.Conn, err error) {
var ips []net.IP
ip := net.ParseIP(p.Host)
if ip == nil {
// if the custom dialer is a host dialer, the DNS is resolved within the network
// the dialer is sending the request to, rather than the one the driver is running on
d := c.getDialer(p)
if _, ok := d.(HostDialer); ok {
addr := net.JoinHostPort(p.Host, strconv.Itoa(int(resolveServerPort(p.Port))))
return d.DialContext(ctx, "tcp", addr)
}
ips, err = net.LookupIP(p.Host)
if err != nil {
return
}
} else {
ips = []net.IP{ip}
}
if len(ips) == 1 {
d := c.getDialer(p)
addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(resolveServerPort(p.Port))))
conn, err = d.DialContext(ctx, "tcp", addr)
} else {
//Try Dials in parallel to avoid waiting for timeouts.
connChan := make(chan net.Conn, len(ips))
errChan := make(chan error, len(ips))
portStr := strconv.Itoa(int(resolveServerPort(p.Port)))
for _, ip := range ips {
go func(ip net.IP) {
d := c.getDialer(p)
addr := net.JoinHostPort(ip.String(), portStr)
conn, err := d.DialContext(ctx, "tcp", addr)
if err == nil {
connChan <- conn
} else {
errChan <- err
}
}(ip)
}
// Wait for either the *first* successful connection, or all the errors
wait_loop:
for i := range ips {
select {
case conn = <-connChan:
// Got a connection to use, close any others
go func(n int) {
for i := 0; i < n; i++ {
select {
case conn := <-connChan:
conn.Close()
case <-errChan:
}
}
}(len(ips) - i - 1)
// Remove any earlier errors we may have collected
err = nil
break wait_loop
case err = <-errChan:
}
}
}
// Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection
if conn == nil {
return nil, wrapConnErr(p, err)
}
if p.ServerSPN == "" {
p.ServerSPN = generateSpn(p.Host, instanceOrPort(p.Instance, p.Port))
}
p.Port = resolveServerPort(p.Port)
return conn, err
}
func (t tcpDialer) CallBrowser(p *msdsn.Config) bool {
return len(p.Instance) > 0 && p.Port == 0
}
func instanceOrPort(instance string, port uint64) string {
if len(instance) > 0 {
return instance
}
port = resolveServerPort(port)
return strconv.FormatInt(int64(port), 10)
}
func resolveServerPort(port uint64) uint64 {
if port == 0 {
return defaultServerPort
}
return port
}
func generateSpn(host string, port string) string {
ip := net.ParseIP(host)
if ip != nil && ip.IsLoopback() {
host, _ = os.Hostname()
}
return fmt.Sprintf("MSSQLSvc/%s:%s", host, port)
}