1
0
Fork 0
mirror of https://github.com/documize/community.git synced 2025-07-18 20:59:43 +02:00

Sync with Community

This commit is contained in:
HarveyKandola 2021-08-19 13:02:56 -04:00
parent df8f650319
commit 989b7cd62c
123 changed files with 5054 additions and 2015 deletions

View file

@ -20,8 +20,8 @@ import (
"github.com/documize/community/core/stringutil"
lm "github.com/documize/community/model/auth"
"github.com/documize/community/model/user"
ld "github.com/go-ldap/ldap/v3"
"github.com/pkg/errors"
ld "gopkg.in/ldap.v3"
)
// Connect establishes connection to LDAP server.

13
go.mod
View file

@ -6,13 +6,14 @@ require (
github.com/BurntSushi/toml v0.3.1
github.com/andygrunwald/go-jira v1.12.0
github.com/codegangsta/negroni v1.0.0
github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc
github.com/denisenkom/go-mssqldb v0.10.1-0.20210728001037-ee2fbc25fd8f
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/documize/blackfriday v2.0.0+incompatible
github.com/documize/glick v0.0.0-20160503134043-a8ccbef88237
github.com/documize/html-diff v0.0.0-20160503140253-f61c192c7796
github.com/documize/slug v1.1.1
github.com/go-sql-driver/mysql v1.5.0
github.com/go-ldap/ldap/v3 v3.4.1
github.com/go-sql-driver/mysql v1.6.0
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b // indirect
github.com/golang/protobuf v1.4.0 // indirect
github.com/google/go-github v17.0.0+incompatible
@ -21,29 +22,29 @@ require (
github.com/gorilla/mux v1.7.4
github.com/jmoiron/sqlx v1.2.0
github.com/kr/pretty v0.2.0 // indirect
github.com/lib/pq v1.5.2
github.com/lib/pq v1.10.2
github.com/mb0/diff v0.0.0-20131118162322-d8d9a906c24d // indirect
github.com/microcosm-cc/bluemonday v1.0.2
github.com/nu7hatch/gouuid v0.0.0-20131221200532-179d4d0c4d8d
github.com/pkg/errors v0.9.1
github.com/rainycape/unidecode v0.0.0-20150907023854-cb7f23ec59be // indirect
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37
golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9
golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
google.golang.org/appengine v1.6.6 // indirect
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc
gopkg.in/cas.v2 v2.1.0
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
gopkg.in/ldap.v3 v3.1.0
gopkg.in/yaml.v2 v2.2.2 // indirect
)
require (
github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c // indirect
github.com/fatih/structs v1.0.0 // indirect
github.com/go-asn1-ber/asn1-ber v1.5.3 // indirect
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe // indirect
github.com/trivago/tgo v1.0.1 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
google.golang.org/protobuf v1.21.0 // indirect
gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d // indirect
)

27
go.sum
View file

@ -1,12 +1,14 @@
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c h1:/IBSNwUN8+eKzUzbJPqhK839ygXJ82sde8x3ogr6R28=
github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU=
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/andygrunwald/go-jira v1.12.0 h1:JJi2cEDmDxVtTXxC8ruLDbtOU6pA4OLeL0niyfNcoWw=
github.com/andygrunwald/go-jira v1.12.0/go.mod h1:jYi4kFDbRPZTJdJOVJO4mpMMIwdB+rcZwSO58DzPd2I=
github.com/codegangsta/negroni v1.0.0 h1:+aYywywx4bnKXWvoWtRfJ91vC59NbEhEY03sZjQhbVY=
github.com/codegangsta/negroni v1.0.0/go.mod h1:v0y3T5G7Y1UlFfyxFn/QLRU4a2EuNau2iZY63YTKWo0=
github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc h1:VRRKCwnzqk8QCaRC4os14xoKDdbHqqlJtJA0oc1ZAjg=
github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
github.com/denisenkom/go-mssqldb v0.10.1-0.20210728001037-ee2fbc25fd8f h1:3UtVZFKTqZwLZi65UbfSIqYR75aUTP8FYUAEQnMXSJs=
github.com/denisenkom/go-mssqldb v0.10.1-0.20210728001037-ee2fbc25fd8f/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/documize/blackfriday v2.0.0+incompatible h1:qjRGAIVwZlHBtA/b9u0LtseYM3v3WpIXofPCwNjcUsE=
@ -19,9 +21,14 @@ github.com/documize/slug v1.1.1 h1:OCJRbWxbOgrgiBYSbVzuFwxb9wVu4oy1LxvLJOC2s8Y=
github.com/documize/slug v1.1.1/go.mod h1:Vi7fQ5PzeOpXAiIrk1WCEDRihjTfU/bf4eWUPSD7tkU=
github.com/fatih/structs v1.0.0 h1:BrX964Rv5uQ3wwS+KRUAJCBBw5PQmgJfJ6v4yly5QwU=
github.com/fatih/structs v1.0.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M=
github.com/go-asn1-ber/asn1-ber v1.5.1/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
github.com/go-asn1-ber/asn1-ber v1.5.3 h1:u7utq56RUFiynqUzgVMFDymapcOtQ/MZkh3H4QYkxag=
github.com/go-asn1-ber/asn1-ber v1.5.3/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
github.com/go-ldap/ldap/v3 v3.4.1 h1:fU/0xli6HY02ocbMuozHAYsaHLcnkLjvho2r5a34BUU=
github.com/go-ldap/ldap/v3 v3.4.1/go.mod h1:iYS1MdmrmceOJ1QOTnRXrIs7i3kloqtmGQjRvjKpyMg=
github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58=
@ -55,8 +62,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.5.2 h1:yTSXVswvWUOQ3k1sd7vJfDrbSl8lKuscqFJRqjC0ifw=
github.com/lib/pq v1.5.2/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8=
github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-sqlite3 v1.9.0 h1:pDRiWfl+++eC2FEFRy6jXmQlvp4Yh3z1MJKg4UeYM/4=
github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mb0/diff v0.0.0-20131118162322-d8d9a906c24d h1:eAS2t2Vy+6psf9LZ4T5WXWsbkBt3Tu5PWekJy5AGyEU=
@ -77,8 +84,8 @@ github.com/trivago/tgo v1.0.1/go.mod h1:w4dpD+3tzNIIiIfkWWa85w5/B77tlvdZckQ+6PkF
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 h1:cg5LA/zNPRzIXIWSCxQW10Rvpy94aQh3LT/ShoCpkHw=
golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 h1:vEg9joUBmeBcK9iSJftGNf3coIG4HqZElCPehJsfAYM=
golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@ -109,14 +116,10 @@ google.golang.org/protobuf v1.21.0 h1:qdOKuR/EIArgaWNjetjgTzgVTAZ+S/WXVrq9HW9zim
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk=
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk=
gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d h1:TxyelI5cVkbREznMhfzycHdkp5cLA7DpE+GKjSslYhM=
gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d/go.mod h1:cuepJuh7vyXfUyUwEgHQXw849cJrilpS5NeIjOWESAw=
gopkg.in/cas.v2 v2.1.0 h1:sbYBMWtpanwLH75GAWjIp5JnON9wa3NodLZhouu0G9I=
gopkg.in/cas.v2 v2.1.0/go.mod h1:M291I/o/u3eeMl9SkXMPYpWasHp7weFY9G/pM5DbB+g=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/ldap.v3 v3.1.0 h1:DIDWEjI7vQWREh0S8X5/NFPCZ3MCVd55LmXKPW4XLGE=
gopkg.in/ldap.v3 v3.1.0/go.mod h1:dQjCc0R0kfyFjIlWNMH1DORwUASZyDxo2Ry1B51dXaQ=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View file

@ -1,3 +1,5 @@
/* eslint-disable ember/no-actions-hash */
/* eslint-disable ember/no-classic-classes */
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under

View file

@ -1,3 +1,4 @@
/* eslint-disable ember/no-classic-classes */
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under

View file

@ -186,7 +186,8 @@ export default Router.map(function () {
path: 'updates'
});
this.route('not-found', {
this.route('auth/login', {
path: '/*wildcard'
// path: '/*wildcard'
});
});

View file

@ -1,3 +1,4 @@
/* eslint-disable ember/no-classic-classes */
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under

View file

@ -43,12 +43,11 @@ func (m *middleware) cors(w http.ResponseWriter, r *http.Request, next http.Hand
w.Header().Set("Access-Control-Allow-Headers", "host, content-type, accept, authorization, origin, referer, user-agent, cache-control, x-requested-with, range")
w.Header().Set("Access-Control-Expose-Headers", "x-documize-version, x-documize-status, x-documize-filename, x-documize-subscription, Content-Disposition, Content-Length")
w.Header().Add("X-Documize-Version", m.Runtime.Product.Version)
w.Header().Add("Cache-Control", "no-cache")
if r.Method == "OPTIONS" {
w.Header().Add("X-Documize-Version", m.Runtime.Product.Version)
w.Header().Add("Cache-Control", "no-cache")
w.Write([]byte(""))
return
}

17
vendor/github.com/Azure/go-ntlmssp/.travis.yml generated vendored Normal file
View file

@ -0,0 +1,17 @@
sudo: false
language: go
before_script:
- go get -u golang.org/x/lint/golint
go:
- 1.10.x
- master
script:
- test -z "$(gofmt -s -l . | tee /dev/stderr)"
- test -z "$(golint ./... | tee /dev/stderr)"
- go vet ./...
- go build -v ./...
- go test -v ./...

21
vendor/github.com/Azure/go-ntlmssp/LICENSE generated vendored Normal file
View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2016 Microsoft
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

29
vendor/github.com/Azure/go-ntlmssp/README.md generated vendored Normal file
View file

@ -0,0 +1,29 @@
# go-ntlmssp
Golang package that provides NTLM/Negotiate authentication over HTTP
[![GoDoc](https://godoc.org/github.com/Azure/go-ntlmssp?status.svg)](https://godoc.org/github.com/Azure/go-ntlmssp) [![Build Status](https://travis-ci.org/Azure/go-ntlmssp.svg?branch=dev)](https://travis-ci.org/Azure/go-ntlmssp)
Protocol details from https://msdn.microsoft.com/en-us/library/cc236621.aspx
Implementation hints from http://davenport.sourceforge.net/ntlm.html
This package only implements authentication, no key exchange or encryption. It
only supports Unicode (UTF16LE) encoding of protocol strings, no OEM encoding.
This package implements NTLMv2.
# Usage
```
url, user, password := "http://www.example.com/secrets", "robpike", "pw123"
client := &http.Client{
Transport: ntlmssp.Negotiator{
RoundTripper:&http.Transport{},
},
}
req, _ := http.NewRequest("GET", url, nil)
req.SetBasicAuth(user, password)
res, _ := client.Do(req)
```
-----
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.

View file

@ -0,0 +1,183 @@
package ntlmssp
import (
"bytes"
"crypto/rand"
"encoding/binary"
"encoding/hex"
"errors"
"strings"
"time"
)
type authenicateMessage struct {
LmChallengeResponse []byte
NtChallengeResponse []byte
TargetName string
UserName string
// only set if negotiateFlag_NTLMSSP_NEGOTIATE_KEY_EXCH
EncryptedRandomSessionKey []byte
NegotiateFlags negotiateFlags
MIC []byte
}
type authenticateMessageFields struct {
messageHeader
LmChallengeResponse varField
NtChallengeResponse varField
TargetName varField
UserName varField
Workstation varField
_ [8]byte
NegotiateFlags negotiateFlags
}
func (m authenicateMessage) MarshalBinary() ([]byte, error) {
if !m.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEUNICODE) {
return nil, errors.New("Only unicode is supported")
}
target, user := toUnicode(m.TargetName), toUnicode(m.UserName)
workstation := toUnicode("go-ntlmssp")
ptr := binary.Size(&authenticateMessageFields{})
f := authenticateMessageFields{
messageHeader: newMessageHeader(3),
NegotiateFlags: m.NegotiateFlags,
LmChallengeResponse: newVarField(&ptr, len(m.LmChallengeResponse)),
NtChallengeResponse: newVarField(&ptr, len(m.NtChallengeResponse)),
TargetName: newVarField(&ptr, len(target)),
UserName: newVarField(&ptr, len(user)),
Workstation: newVarField(&ptr, len(workstation)),
}
f.NegotiateFlags.Unset(negotiateFlagNTLMSSPNEGOTIATEVERSION)
b := bytes.Buffer{}
if err := binary.Write(&b, binary.LittleEndian, &f); err != nil {
return nil, err
}
if err := binary.Write(&b, binary.LittleEndian, &m.LmChallengeResponse); err != nil {
return nil, err
}
if err := binary.Write(&b, binary.LittleEndian, &m.NtChallengeResponse); err != nil {
return nil, err
}
if err := binary.Write(&b, binary.LittleEndian, &target); err != nil {
return nil, err
}
if err := binary.Write(&b, binary.LittleEndian, &user); err != nil {
return nil, err
}
if err := binary.Write(&b, binary.LittleEndian, &workstation); err != nil {
return nil, err
}
return b.Bytes(), nil
}
//ProcessChallenge crafts an AUTHENTICATE message in response to the CHALLENGE message
//that was received from the server
func ProcessChallenge(challengeMessageData []byte, user, password string) ([]byte, error) {
if user == "" && password == "" {
return nil, errors.New("Anonymous authentication not supported")
}
var cm challengeMessage
if err := cm.UnmarshalBinary(challengeMessageData); err != nil {
return nil, err
}
if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATELMKEY) {
return nil, errors.New("Only NTLM v2 is supported, but server requested v1 (NTLMSSP_NEGOTIATE_LM_KEY)")
}
if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEKEYEXCH) {
return nil, errors.New("Key exchange requested but not supported (NTLMSSP_NEGOTIATE_KEY_EXCH)")
}
am := authenicateMessage{
UserName: user,
TargetName: cm.TargetName,
NegotiateFlags: cm.NegotiateFlags,
}
timestamp := cm.TargetInfo[avIDMsvAvTimestamp]
if timestamp == nil { // no time sent, take current time
ft := uint64(time.Now().UnixNano()) / 100
ft += 116444736000000000 // add time between unix & windows offset
timestamp = make([]byte, 8)
binary.LittleEndian.PutUint64(timestamp, ft)
}
clientChallenge := make([]byte, 8)
rand.Reader.Read(clientChallenge)
ntlmV2Hash := getNtlmV2Hash(password, user, cm.TargetName)
am.NtChallengeResponse = computeNtlmV2Response(ntlmV2Hash,
cm.ServerChallenge[:], clientChallenge, timestamp, cm.TargetInfoRaw)
if cm.TargetInfoRaw == nil {
am.LmChallengeResponse = computeLmV2Response(ntlmV2Hash,
cm.ServerChallenge[:], clientChallenge)
}
return am.MarshalBinary()
}
func ProcessChallengeWithHash(challengeMessageData []byte, user, hash string) ([]byte, error) {
if user == "" && hash == "" {
return nil, errors.New("Anonymous authentication not supported")
}
var cm challengeMessage
if err := cm.UnmarshalBinary(challengeMessageData); err != nil {
return nil, err
}
if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATELMKEY) {
return nil, errors.New("Only NTLM v2 is supported, but server requested v1 (NTLMSSP_NEGOTIATE_LM_KEY)")
}
if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEKEYEXCH) {
return nil, errors.New("Key exchange requested but not supported (NTLMSSP_NEGOTIATE_KEY_EXCH)")
}
am := authenicateMessage{
UserName: user,
TargetName: cm.TargetName,
NegotiateFlags: cm.NegotiateFlags,
}
timestamp := cm.TargetInfo[avIDMsvAvTimestamp]
if timestamp == nil { // no time sent, take current time
ft := uint64(time.Now().UnixNano()) / 100
ft += 116444736000000000 // add time between unix & windows offset
timestamp = make([]byte, 8)
binary.LittleEndian.PutUint64(timestamp, ft)
}
clientChallenge := make([]byte, 8)
rand.Reader.Read(clientChallenge)
hashParts := strings.Split(hash, ":")
if len(hashParts) > 1 {
hash = hashParts[1]
}
hashBytes, err := hex.DecodeString(hash)
if err != nil {
return nil, err
}
ntlmV2Hash := hmacMd5(hashBytes, toUnicode(strings.ToUpper(user)+cm.TargetName))
am.NtChallengeResponse = computeNtlmV2Response(ntlmV2Hash,
cm.ServerChallenge[:], clientChallenge, timestamp, cm.TargetInfoRaw)
if cm.TargetInfoRaw == nil {
am.LmChallengeResponse = computeLmV2Response(ntlmV2Hash,
cm.ServerChallenge[:], clientChallenge)
}
return am.MarshalBinary()
}

37
vendor/github.com/Azure/go-ntlmssp/authheader.go generated vendored Normal file
View file

@ -0,0 +1,37 @@
package ntlmssp
import (
"encoding/base64"
"strings"
)
type authheader string
func (h authheader) IsBasic() bool {
return strings.HasPrefix(string(h), "Basic ")
}
func (h authheader) IsNegotiate() bool {
return strings.HasPrefix(string(h), "Negotiate")
}
func (h authheader) IsNTLM() bool {
return strings.HasPrefix(string(h), "NTLM")
}
func (h authheader) GetData() ([]byte, error) {
p := strings.Split(string(h), " ")
if len(p) < 2 {
return nil, nil
}
return base64.StdEncoding.DecodeString(string(p[1]))
}
func (h authheader) GetBasicCreds() (username, password string, err error) {
d, err := h.GetData()
if err != nil {
return "", "", err
}
parts := strings.SplitN(string(d), ":", 2)
return parts[0], parts[1], nil
}

17
vendor/github.com/Azure/go-ntlmssp/avids.go generated vendored Normal file
View file

@ -0,0 +1,17 @@
package ntlmssp
type avID uint16
const (
avIDMsvAvEOL avID = iota
avIDMsvAvNbComputerName
avIDMsvAvNbDomainName
avIDMsvAvDNSComputerName
avIDMsvAvDNSDomainName
avIDMsvAvDNSTreeName
avIDMsvAvFlags
avIDMsvAvTimestamp
avIDMsvAvSingleHost
avIDMsvAvTargetName
avIDMsvChannelBindings
)

View file

@ -0,0 +1,82 @@
package ntlmssp
import (
"bytes"
"encoding/binary"
"fmt"
)
type challengeMessageFields struct {
messageHeader
TargetName varField
NegotiateFlags negotiateFlags
ServerChallenge [8]byte
_ [8]byte
TargetInfo varField
}
func (m challengeMessageFields) IsValid() bool {
return m.messageHeader.IsValid() && m.MessageType == 2
}
type challengeMessage struct {
challengeMessageFields
TargetName string
TargetInfo map[avID][]byte
TargetInfoRaw []byte
}
func (m *challengeMessage) UnmarshalBinary(data []byte) error {
r := bytes.NewReader(data)
err := binary.Read(r, binary.LittleEndian, &m.challengeMessageFields)
if err != nil {
return err
}
if !m.challengeMessageFields.IsValid() {
return fmt.Errorf("Message is not a valid challenge message: %+v", m.challengeMessageFields.messageHeader)
}
if m.challengeMessageFields.TargetName.Len > 0 {
m.TargetName, err = m.challengeMessageFields.TargetName.ReadStringFrom(data, m.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEUNICODE))
if err != nil {
return err
}
}
if m.challengeMessageFields.TargetInfo.Len > 0 {
d, err := m.challengeMessageFields.TargetInfo.ReadFrom(data)
m.TargetInfoRaw = d
if err != nil {
return err
}
m.TargetInfo = make(map[avID][]byte)
r := bytes.NewReader(d)
for {
var id avID
var l uint16
err = binary.Read(r, binary.LittleEndian, &id)
if err != nil {
return err
}
if id == avIDMsvAvEOL {
break
}
err = binary.Read(r, binary.LittleEndian, &l)
if err != nil {
return err
}
value := make([]byte, l)
n, err := r.Read(value)
if err != nil {
return err
}
if n != int(l) {
return fmt.Errorf("Expected to read %d bytes, got only %d", l, n)
}
m.TargetInfo[id] = value
}
}
return nil
}

21
vendor/github.com/Azure/go-ntlmssp/messageheader.go generated vendored Normal file
View file

@ -0,0 +1,21 @@
package ntlmssp
import (
"bytes"
)
var signature = [8]byte{'N', 'T', 'L', 'M', 'S', 'S', 'P', 0}
type messageHeader struct {
Signature [8]byte
MessageType uint32
}
func (h messageHeader) IsValid() bool {
return bytes.Equal(h.Signature[:], signature[:]) &&
h.MessageType > 0 && h.MessageType < 4
}
func newMessageHeader(messageType uint32) messageHeader {
return messageHeader{signature, messageType}
}

52
vendor/github.com/Azure/go-ntlmssp/negotiate_flags.go generated vendored Normal file
View file

@ -0,0 +1,52 @@
package ntlmssp
type negotiateFlags uint32
const (
/*A*/ negotiateFlagNTLMSSPNEGOTIATEUNICODE negotiateFlags = 1 << 0
/*B*/ negotiateFlagNTLMNEGOTIATEOEM = 1 << 1
/*C*/ negotiateFlagNTLMSSPREQUESTTARGET = 1 << 2
/*D*/
negotiateFlagNTLMSSPNEGOTIATESIGN = 1 << 4
/*E*/ negotiateFlagNTLMSSPNEGOTIATESEAL = 1 << 5
/*F*/ negotiateFlagNTLMSSPNEGOTIATEDATAGRAM = 1 << 6
/*G*/ negotiateFlagNTLMSSPNEGOTIATELMKEY = 1 << 7
/*H*/
negotiateFlagNTLMSSPNEGOTIATENTLM = 1 << 9
/*J*/
negotiateFlagANONYMOUS = 1 << 11
/*K*/ negotiateFlagNTLMSSPNEGOTIATEOEMDOMAINSUPPLIED = 1 << 12
/*L*/ negotiateFlagNTLMSSPNEGOTIATEOEMWORKSTATIONSUPPLIED = 1 << 13
/*M*/
negotiateFlagNTLMSSPNEGOTIATEALWAYSSIGN = 1 << 15
/*N*/ negotiateFlagNTLMSSPTARGETTYPEDOMAIN = 1 << 16
/*O*/ negotiateFlagNTLMSSPTARGETTYPESERVER = 1 << 17
/*P*/
negotiateFlagNTLMSSPNEGOTIATEEXTENDEDSESSIONSECURITY = 1 << 19
/*Q*/ negotiateFlagNTLMSSPNEGOTIATEIDENTIFY = 1 << 20
/*R*/
negotiateFlagNTLMSSPREQUESTNONNTSESSIONKEY = 1 << 22
/*S*/ negotiateFlagNTLMSSPNEGOTIATETARGETINFO = 1 << 23
/*T*/
negotiateFlagNTLMSSPNEGOTIATEVERSION = 1 << 25
/*U*/
negotiateFlagNTLMSSPNEGOTIATE128 = 1 << 29
/*V*/ negotiateFlagNTLMSSPNEGOTIATEKEYEXCH = 1 << 30
/*W*/ negotiateFlagNTLMSSPNEGOTIATE56 = 1 << 31
)
func (field negotiateFlags) Has(flags negotiateFlags) bool {
return field&flags == flags
}
func (field *negotiateFlags) Unset(flags negotiateFlags) {
*field = *field ^ (*field & flags)
}

View file

@ -0,0 +1,64 @@
package ntlmssp
import (
"bytes"
"encoding/binary"
"errors"
"strings"
)
const expMsgBodyLen = 40
type negotiateMessageFields struct {
messageHeader
NegotiateFlags negotiateFlags
Domain varField
Workstation varField
Version
}
var defaultFlags = negotiateFlagNTLMSSPNEGOTIATETARGETINFO |
negotiateFlagNTLMSSPNEGOTIATE56 |
negotiateFlagNTLMSSPNEGOTIATE128 |
negotiateFlagNTLMSSPNEGOTIATEUNICODE |
negotiateFlagNTLMSSPNEGOTIATEEXTENDEDSESSIONSECURITY
//NewNegotiateMessage creates a new NEGOTIATE message with the
//flags that this package supports.
func NewNegotiateMessage(domainName, workstationName string) ([]byte, error) {
payloadOffset := expMsgBodyLen
flags := defaultFlags
if domainName != "" {
flags |= negotiateFlagNTLMSSPNEGOTIATEOEMDOMAINSUPPLIED
}
if workstationName != "" {
flags |= negotiateFlagNTLMSSPNEGOTIATEOEMWORKSTATIONSUPPLIED
}
msg := negotiateMessageFields{
messageHeader: newMessageHeader(1),
NegotiateFlags: flags,
Domain: newVarField(&payloadOffset, len(domainName)),
Workstation: newVarField(&payloadOffset, len(workstationName)),
Version: DefaultVersion(),
}
b := bytes.Buffer{}
if err := binary.Write(&b, binary.LittleEndian, &msg); err != nil {
return nil, err
}
if b.Len() != expMsgBodyLen {
return nil, errors.New("incorrect body length")
}
payload := strings.ToUpper(domainName + workstationName)
if _, err := b.WriteString(payload); err != nil {
return nil, err
}
return b.Bytes(), nil
}

144
vendor/github.com/Azure/go-ntlmssp/negotiator.go generated vendored Normal file
View file

@ -0,0 +1,144 @@
package ntlmssp
import (
"bytes"
"encoding/base64"
"io"
"io/ioutil"
"net/http"
"strings"
)
// GetDomain : parse domain name from based on slashes in the input
func GetDomain(user string) (string, string) {
domain := ""
if strings.Contains(user, "\\") {
ucomponents := strings.SplitN(user, "\\", 2)
domain = ucomponents[0]
user = ucomponents[1]
}
return user, domain
}
//Negotiator is a http.Roundtripper decorator that automatically
//converts basic authentication to NTLM/Negotiate authentication when appropriate.
type Negotiator struct{ http.RoundTripper }
//RoundTrip sends the request to the server, handling any authentication
//re-sends as needed.
func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error) {
// Use default round tripper if not provided
rt := l.RoundTripper
if rt == nil {
rt = http.DefaultTransport
}
// If it is not basic auth, just round trip the request as usual
reqauth := authheader(req.Header.Get("Authorization"))
if !reqauth.IsBasic() {
return rt.RoundTrip(req)
}
// Save request body
body := bytes.Buffer{}
if req.Body != nil {
_, err = body.ReadFrom(req.Body)
if err != nil {
return nil, err
}
req.Body.Close()
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
}
// first try anonymous, in case the server still finds us
// authenticated from previous traffic
req.Header.Del("Authorization")
res, err = rt.RoundTrip(req)
if err != nil {
return nil, err
}
if res.StatusCode != http.StatusUnauthorized {
return res, err
}
resauth := authheader(res.Header.Get("Www-Authenticate"))
if !resauth.IsNegotiate() && !resauth.IsNTLM() {
// Unauthorized, Negotiate not requested, let's try with basic auth
req.Header.Set("Authorization", string(reqauth))
io.Copy(ioutil.Discard, res.Body)
res.Body.Close()
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
res, err = rt.RoundTrip(req)
if err != nil {
return nil, err
}
if res.StatusCode != http.StatusUnauthorized {
return res, err
}
resauth = authheader(res.Header.Get("Www-Authenticate"))
}
if resauth.IsNegotiate() || resauth.IsNTLM() {
// 401 with request:Basic and response:Negotiate
io.Copy(ioutil.Discard, res.Body)
res.Body.Close()
// recycle credentials
u, p, err := reqauth.GetBasicCreds()
if err != nil {
return nil, err
}
// get domain from username
domain := ""
u, domain = GetDomain(u)
// send negotiate
negotiateMessage, err := NewNegotiateMessage(domain, "")
if err != nil {
return nil, err
}
if resauth.IsNTLM() {
req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(negotiateMessage))
} else {
req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(negotiateMessage))
}
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
res, err = rt.RoundTrip(req)
if err != nil {
return nil, err
}
// receive challenge?
resauth = authheader(res.Header.Get("Www-Authenticate"))
challengeMessage, err := resauth.GetData()
if err != nil {
return nil, err
}
if !(resauth.IsNegotiate() || resauth.IsNTLM()) || len(challengeMessage) == 0 {
// Negotiation failed, let client deal with response
return res, nil
}
io.Copy(ioutil.Discard, res.Body)
res.Body.Close()
// send authenticate
authenticateMessage, err := ProcessChallenge(challengeMessage, u, p)
if err != nil {
return nil, err
}
if resauth.IsNTLM() {
req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(authenticateMessage))
} else {
req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(authenticateMessage))
}
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
return rt.RoundTrip(req)
}
return res, err
}

51
vendor/github.com/Azure/go-ntlmssp/nlmp.go generated vendored Normal file
View file

@ -0,0 +1,51 @@
// Package ntlmssp provides NTLM/Negotiate authentication over HTTP
//
// Protocol details from https://msdn.microsoft.com/en-us/library/cc236621.aspx,
// implementation hints from http://davenport.sourceforge.net/ntlm.html .
// This package only implements authentication, no key exchange or encryption. It
// only supports Unicode (UTF16LE) encoding of protocol strings, no OEM encoding.
// This package implements NTLMv2.
package ntlmssp
import (
"crypto/hmac"
"crypto/md5"
"golang.org/x/crypto/md4"
"strings"
)
func getNtlmV2Hash(password, username, target string) []byte {
return hmacMd5(getNtlmHash(password), toUnicode(strings.ToUpper(username)+target))
}
func getNtlmHash(password string) []byte {
hash := md4.New()
hash.Write(toUnicode(password))
return hash.Sum(nil)
}
func computeNtlmV2Response(ntlmV2Hash, serverChallenge, clientChallenge,
timestamp, targetInfo []byte) []byte {
temp := []byte{1, 1, 0, 0, 0, 0, 0, 0}
temp = append(temp, timestamp...)
temp = append(temp, clientChallenge...)
temp = append(temp, 0, 0, 0, 0)
temp = append(temp, targetInfo...)
temp = append(temp, 0, 0, 0, 0)
NTProofStr := hmacMd5(ntlmV2Hash, serverChallenge, temp)
return append(NTProofStr, temp...)
}
func computeLmV2Response(ntlmV2Hash, serverChallenge, clientChallenge []byte) []byte {
return append(hmacMd5(ntlmV2Hash, serverChallenge, clientChallenge), clientChallenge...)
}
func hmacMd5(key []byte, data ...[]byte) []byte {
mac := hmac.New(md5.New, key)
for _, d := range data {
mac.Write(d)
}
return mac.Sum(nil)
}

29
vendor/github.com/Azure/go-ntlmssp/unicode.go generated vendored Normal file
View file

@ -0,0 +1,29 @@
package ntlmssp
import (
"bytes"
"encoding/binary"
"errors"
"unicode/utf16"
)
// helper func's for dealing with Windows Unicode (UTF16LE)
func fromUnicode(d []byte) (string, error) {
if len(d)%2 > 0 {
return "", errors.New("Unicode (UTF 16 LE) specified, but uneven data length")
}
s := make([]uint16, len(d)/2)
err := binary.Read(bytes.NewReader(d), binary.LittleEndian, &s)
if err != nil {
return "", err
}
return string(utf16.Decode(s)), nil
}
func toUnicode(s string) []byte {
uints := utf16.Encode([]rune(s))
b := bytes.Buffer{}
binary.Write(&b, binary.LittleEndian, &uints)
return b.Bytes()
}

40
vendor/github.com/Azure/go-ntlmssp/varfield.go generated vendored Normal file
View file

@ -0,0 +1,40 @@
package ntlmssp
import (
"errors"
)
type varField struct {
Len uint16
MaxLen uint16
BufferOffset uint32
}
func (f varField) ReadFrom(buffer []byte) ([]byte, error) {
if len(buffer) < int(f.BufferOffset+uint32(f.Len)) {
return nil, errors.New("Error reading data, varField extends beyond buffer")
}
return buffer[f.BufferOffset : f.BufferOffset+uint32(f.Len)], nil
}
func (f varField) ReadStringFrom(buffer []byte, unicode bool) (string, error) {
d, err := f.ReadFrom(buffer)
if err != nil {
return "", err
}
if unicode { // UTF-16LE encoding scheme
return fromUnicode(d)
}
// OEM encoding, close enough to ASCII, since no code page is specified
return string(d), err
}
func newVarField(ptr *int, fieldsize int) varField {
f := varField{
Len: uint16(fieldsize),
MaxLen: uint16(fieldsize),
BufferOffset: uint32(*ptr),
}
*ptr += fieldsize
return f
}

20
vendor/github.com/Azure/go-ntlmssp/version.go generated vendored Normal file
View file

@ -0,0 +1,20 @@
package ntlmssp
// Version is a struct representing https://msdn.microsoft.com/en-us/library/cc236654.aspx
type Version struct {
ProductMajorVersion uint8
ProductMinorVersion uint8
ProductBuild uint16
_ [3]byte
NTLMRevisionCurrent uint8
}
// DefaultVersion returns a Version with "sensible" defaults (Windows 7)
func DefaultVersion() Version {
return Version{
ProductMajorVersion: 6,
ProductMinorVersion: 1,
ProductBuild: 7601,
NTLMRevisionCurrent: 15,
}
}

13
vendor/github.com/denisenkom/go-mssqldb/.gitignore generated vendored Normal file
View file

@ -0,0 +1,13 @@
/.idea
/.connstr
.vscode
.terraform
*.tfstate*
*.log
*.swp
*~
coverage.json
coverage.txt
coverage.xml
testresults.xml

10
vendor/github.com/denisenkom/go-mssqldb/.golangci.yml generated vendored Normal file
View file

@ -0,0 +1,10 @@
linters:
enable:
# basic go linters
- gofmt
- golint
- govet
# sql related linters
- rowserrcheck
- sqlclosecheck

View file

@ -1,6 +1,6 @@
# A pure Go MSSQL driver for Go's database/sql package
[![GoDoc](https://godoc.org/github.com/denisenkom/go-mssqldb?status.svg)](http://godoc.org/github.com/denisenkom/go-mssqldb)
[![Go Reference](https://pkg.go.dev/badge/github.com/denisenkom/go-mssqldb.svg)](https://pkg.go.dev/github.com/denisenkom/go-mssqldb)
[![Build status](https://ci.appveyor.com/api/projects/status/jrln8cs62wj9i0a2?svg=true)](https://ci.appveyor.com/project/denisenkom/go-mssqldb)
[![codecov](https://codecov.io/gh/denisenkom/go-mssqldb/branch/master/graph/badge.svg)](https://codecov.io/gh/denisenkom/go-mssqldb)
@ -16,7 +16,7 @@ The recommended connection string uses a URL format:
`sqlserver://username:password@host/instance?param1=value&param2=value`
Other supported formats are listed below.
### Common parameters:
### Common parameters
* `user id` - enter the SQL Server Authentication user id or the Windows Authentication user id in the DOMAIN\User format. On Windows, if user id is empty or missing Single-Sign-On is used. The user domain sensitive to the case which is defined in the connection string.
* `password`
@ -29,24 +29,24 @@ Other supported formats are listed below.
* `true` - Data sent between client and server is encrypted.
* `app name` - The application name (default is go-mssqldb)
### Connection parameters for ODBC and ADO style connection strings:
### Connection parameters for ODBC and ADO style connection strings
* `server` - host or host\instance (default localhost)
* `port` - used only when there is no instance in server (default 1433)
### Less common parameters:
### Less common parameters
* `keepAlive` - in seconds; 0 to disable (default is 30)
* `failoverpartner` - host or host\instance (default is no partner).
* `failoverpartner` - host or host\instance (default is no partner).
* `failoverport` - used only when there is no instance in failoverpartner (default 1433)
* `packet size` - in bytes; 512 to 32767 (default is 4096)
* Encrypted connections have a maximum packet size of 16383 bytes
* Further information on usage: https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option
* Further information on usage: <https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option>
* `log` - logging flags (default 0/no logging, 63 for full logging)
* 1 log errors
* 2 log messages
* 4 log rows affected
* 8 trace sql statements
* 1 log errors
* 2 log messages
* 4 log rows affected
* 8 trace sql statements
* 16 log statement parameters
* 32 log transaction begin/end
* `TrustServerCertificate`
@ -56,72 +56,82 @@ Other supported formats are listed below.
* `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host.
* `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port.
* `Workstation ID` - The workstation name (default is the host name)
* `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`.
### The connection string can be specified in one of three formats:
* `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`.
### The connection string can be specified in one of three formats
1. URL: with `sqlserver` scheme. username and password appears before the host. Any instance appears as
the first segment in the path. All other options are query parameters. Examples:
* `sqlserver://username:password@host/instance?param1=value&param2=value`
* `sqlserver://username:password@host:port?param1=value&param2=value`
* `sqlserver://sa@localhost/SQLExpress?database=master&connection+timeout=30` // `SQLExpress instance.
* `sqlserver://sa:mypass@localhost?database=master&connection+timeout=30` // username=sa, password=mypass.
* `sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30` // port 1234 on localhost.
* `sqlserver://sa:my%7Bpass@somehost?connection+timeout=30` // password is "my{pass"
* `sqlserver://username:password@host/instance?param1=value&param2=value`
* `sqlserver://username:password@host:port?param1=value&param2=value`
* `sqlserver://sa@localhost/SQLExpress?database=master&connection+timeout=30` // `SQLExpress instance.
* `sqlserver://sa:mypass@localhost?database=master&connection+timeout=30` // username=sa, password=mypass.
* `sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30` // port 1234 on localhost.
* `sqlserver://sa:my%7Bpass@somehost?connection+timeout=30` // password is "my{pass"
A string of this format can be constructed using the `URL` type in the `net/url` package.
A string of this format can be constructed using the `URL` type in the `net/url` package.
```go
```go
query := url.Values{}
query.Add("app name", "MyAppName")
query := url.Values{}
query.Add("app name", "MyAppName")
u := &url.URL{
Scheme: "sqlserver",
User: url.UserPassword(username, password),
Host: fmt.Sprintf("%s:%d", hostname, port),
// Path: instance, // if connecting to an instance instead of a port
RawQuery: query.Encode(),
}
db, err := sql.Open("sqlserver", u.String())
u := &url.URL{
Scheme: "sqlserver",
User: url.UserPassword(username, password),
Host: fmt.Sprintf("%s:%d", hostname, port),
// Path: instance, // if connecting to an instance instead of a port
RawQuery: query.Encode(),
}
db, err := sql.Open("sqlserver", u.String())
```
```
2. ADO: `key=value` pairs separated by `;`. Values may not contain `;`, leading and trailing whitespace is ignored.
Examples:
* `server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName`
* `server=localhost;user id=sa;database=master;app name=MyAppName`
* `server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName`
* `server=localhost;user id=sa;database=master;app name=MyAppName`
ADO strings support synonyms for database, app name, user id, and server
* server <= addr, address, network address, data source
* user id <= user, uid
* database <= initial catalog
* app name <= application name
3. ODBC: Prefix with `odbc`, `key=value` pairs separated by `;`. Allow `;` by wrapping
values in `{}`. Examples:
* `odbc:server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName`
* `odbc:server=localhost;user id=sa;database=master;app name=MyAppName`
* `odbc:server=localhost;user id=sa;password={foo;bar}` // Value marked with `{}`, password is "foo;bar"
* `odbc:server=localhost;user id=sa;password={foo{bar}` // Value marked with `{}`, password is "foo{bar"
* `odbc:server=localhost;user id=sa;password={foobar }` // Value marked with `{}`, password is "foobar "
* `odbc:server=localhost;user id=sa;password=foo{bar` // Literal `{`, password is "foo{bar"
* `odbc:server=localhost;user id=sa;password=foo}bar` // Literal `}`, password is "foo}bar"
* `odbc:server=localhost;user id=sa;password={foo{bar}` // Literal `{`, password is "foo{bar"
* `odbc:server=localhost;user id=sa;password={foo}}bar}` // Escaped `} with `}}`, password is "foo}bar"
* `odbc:server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName`
* `odbc:server=localhost;user id=sa;database=master;app name=MyAppName`
* `odbc:server=localhost;user id=sa;password={foo;bar}` // Value marked with `{}`, password is "foo;bar"
* `odbc:server=localhost;user id=sa;password={foo{bar}` // Value marked with `{}`, password is "foo{bar"
* `odbc:server=localhost;user id=sa;password={foobar }` // Value marked with `{}`, password is "foobar "
* `odbc:server=localhost;user id=sa;password=foo{bar` // Literal `{`, password is "foo{bar"
* `odbc:server=localhost;user id=sa;password=foo}bar` // Literal `}`, password is "foo}bar"
* `odbc:server=localhost;user id=sa;password={foo{bar}` // Literal `{`, password is "foo{bar"
* `odbc:server=localhost;user id=sa;password={foo}}bar}` // Escaped `} with`}}`, password is "foo}bar"
### Azure Active Directory authentication - preview
The configuration of functionality might change in the future.
Azure Active Directory (AAD) access tokens are relatively short lived and need to be
Azure Active Directory (AAD) access tokens are relatively short lived and need to be
valid when a new connection is made. Authentication is supported using a callback func that
provides a fresh and valid token using a connector:
``` golang
``` go
conn, err := mssql.NewAccessTokenConnector(
"Server=test.database.windows.net;Database=testdb",
tokenProvider)
"Server=test.database.windows.net;Database=testdb",
tokenProvider)
if err != nil {
// handle errors in DSN
}
db := sql.OpenDB(conn)
```
Where `tokenProvider` is a function that returns a fresh access token or an error. None of these statements
actually trigger the retrieval of a token, this happens when the first statment is issued and a connection
is created.
@ -129,18 +139,23 @@ is created.
## Executing Stored Procedures
To run a stored procedure, set the query text to the procedure name:
```go
var account = "abc"
_, err := db.ExecContext(ctx, "sp_RunMe",
sql.Named("ID", 123),
sql.Named("Account", sql.Out{Dest: &account}),
)
```
## Reading Output Parameters from a Stored Procedure with Resultset
To read output parameters from a stored procedure with resultset, make sure you read all the rows before reading the output parameters:
```go
sqltextcreate := `
CREATE PROCEDURE spwithoutputandrows
@bitparam BIT OUTPUT
@ -156,6 +171,7 @@ for rows.Next() {
err = rows.Scan(&strrow)
}
fmt.Printf("bitparam is %d", bitout)
```
## Caveat for local temporary tables
@ -201,21 +217,25 @@ _, err := conn.ExecContext(ctx, "insert into #mytemp (x) values (@p1)", 1)
To get the procedure return status, pass into the parameters a
`*mssql.ReturnStatus`. For example:
```
```go
var rs mssql.ReturnStatus
_, err := db.ExecContext(ctx, "theproc", &rs)
log.Printf("status=%d", rs)
```
or
```
```go
var rs mssql.ReturnStatus
_, err := db.QueryContext(ctx, "theproc", &rs)
for rows.Next() {
err = rows.Scan(&val)
}
log.Printf("status=%d", rs)
```
Limitation: ReturnStatus cannot be retrieved using `QueryRow`.
@ -226,7 +246,9 @@ The `sqlserver` driver uses normal MS SQL Server syntax and expects parameters i
the sql query to be in the form of either `@Name` or `@p1` to `@pN` (ordinal position).
```go
db.QueryContext(ctx, `select * from t where ID = @ID and Name = @p2;`, sql.Named("ID", 6), "Bob")
```
### Parameter Types
@ -235,30 +257,30 @@ To pass specific types to the query parameters, say `varchar` or `date` types,
you must convert the types to the type before passing in. The following types
are supported:
* string -> nvarchar
* mssql.VarChar -> varchar
* time.Time -> datetimeoffset or datetime (TDS version dependent)
* mssql.DateTime1 -> datetime
* mssql.DateTimeOffset -> datetimeoffset
* "github.com/golang-sql/civil".Date -> date
* "github.com/golang-sql/civil".DateTime -> datetime2
* "github.com/golang-sql/civil".Time -> time
* mssql.TVP -> Table Value Parameter (TDS version dependent)
* string -> nvarchar
* mssql.VarChar -> varchar
* time.Time -> datetimeoffset or datetime (TDS version dependent)
* mssql.DateTime1 -> datetime
* mssql.DateTimeOffset -> datetimeoffset
* "github.com/golang-sql/civil".Date -> date
* "github.com/golang-sql/civil".DateTime -> datetime2
* "github.com/golang-sql/civil".Time -> time
* mssql.TVP -> Table Value Parameter (TDS version dependent)
## Important Notes
* [LastInsertId](https://golang.org/pkg/database/sql/#Result.LastInsertId) should
* [LastInsertId](https://golang.org/pkg/database/sql/#Result.LastInsertId) should
not be used with this driver (or SQL Server) due to how the TDS protocol
works. Please use the [OUTPUT Clause](https://docs.microsoft.com/en-us/sql/t-sql/queries/output-clause-transact-sql)
or add a `select ID = convert(bigint, SCOPE_IDENTITY());` to the end of your
query (ref [SCOPE_IDENTITY](https://docs.microsoft.com/en-us/sql/t-sql/functions/scope-identity-transact-sql)).
This will ensure you are getting the correct ID and will prevent a network round trip.
* [NewConnector](https://godoc.org/github.com/denisenkom/go-mssqldb#NewConnector)
works. Please use the [OUTPUT Clause](https://docs.microsoft.com/en-us/sql/t-sql/queries/output-clause-transact-sql)
or add a `select ID = convert(bigint, SCOPE_IDENTITY());` to the end of your
query (ref [SCOPE_IDENTITY](https://docs.microsoft.com/en-us/sql/t-sql/functions/scope-identity-transact-sql)).
This will ensure you are getting the correct ID and will prevent a network round trip.
* [NewConnector](https://godoc.org/github.com/denisenkom/go-mssqldb#NewConnector)
may be used with [OpenDB](https://golang.org/pkg/database/sql/#OpenDB).
* [Connector.SessionInitSQL](https://godoc.org/github.com/denisenkom/go-mssqldb#Connector.SessionInitSQL)
may be set to set any driver specific session settings after the session
has been reset. If empty the session will still be reset but use the database
defaults in Go1.10+.
* [Connector.SessionInitSQL](https://godoc.org/github.com/denisenkom/go-mssqldb#Connector.SessionInitSQL)
may be set to set any driver specific session settings after the session
has been reset. If empty the session will still be reset but use the database
defaults in Go1.10+.
## Features
@ -280,7 +302,9 @@ Environment variables are used to pass login information.
Example:
```bash
env SQLSERVER_DSN=sqlserver://user:pass@hostname/instance?database=test1 go test
```
## Deprecated
@ -296,7 +320,7 @@ will be loosly parsed and an attempt to extract identifiers using one of
* :nnn
* $nnn
will be used. This is not recommended with SQL Server.
will be used. This is not recommended with SQL Server.
There is at least one existing `won't fix` issue with the query parsing.
Use the native "@Name" parameters instead with the "sqlserver" driver name.
@ -306,4 +330,4 @@ Use the native "@Name" parameters instead with the "sqlserver" driver name.
* SQL Server 2008 and 2008 R2 engine cannot handle login records when SSL encryption is not disabled.
To fix SQL Server 2008 R2 issue, install SQL Server 2008 R2 Service Pack 2.
To fix SQL Server 2008 issue, install Microsoft SQL Server 2008 Service Pack 3 and Cumulative update package 3 for SQL Server 2008 SP3.
More information: http://support.microsoft.com/kb/2653857
More information: <http://support.microsoft.com/kb/2653857>

View file

@ -6,19 +6,8 @@ import (
"context"
"database/sql/driver"
"errors"
"fmt"
)
var _ driver.Connector = &accessTokenConnector{}
// accessTokenConnector wraps Connector and injects a
// fresh access token when connecting to the database
type accessTokenConnector struct {
Connector
accessTokenProvider func() (string, error)
}
// NewAccessTokenConnector creates a new connector from a DSN and a token provider.
// The token provider func will be called when a new connection is requested and should return a valid access token.
// The returned connector may be used with sql.OpenDB.
@ -32,20 +21,11 @@ func NewAccessTokenConnector(dsn string, tokenProvider func() (string, error)) (
return nil, err
}
c := &accessTokenConnector{
Connector: *conn,
accessTokenProvider: tokenProvider,
}
return c, nil
}
// Connect returns a new database connection
func (c *accessTokenConnector) Connect(ctx context.Context) (driver.Conn, error) {
var err error
c.Connector.params.fedAuthAccessToken, err = c.accessTokenProvider()
if err != nil {
return nil, fmt.Errorf("mssql: error retrieving access token: %+v", err)
conn.fedAuthRequired = true
conn.fedAuthLibrary = fedAuthLibrarySecurityToken
conn.securityTokenProvider = func(ctx context.Context) (string, error) {
return tokenProvider()
}
return c.Connector.Connect(ctx)
return conn, nil
}

View file

@ -1,6 +1,7 @@
version: 1.0.{build}
os: Windows Server 2012 R2
image:
- Visual Studio 2015
clone_folder: c:\gopath\src\github.com\denisenkom\go-mssqldb
@ -9,21 +10,39 @@ environment:
HOST: localhost
SQLUSER: sa
SQLPASSWORD: Password12!
DATABASE: test
GOVERSION: 111
DATABASE: test
GOVERSION: 113
matrix:
- GOVERSION: 18
SQLINSTANCE: SQL2016
SQLINSTANCE: SQL2017
- GOVERSION: 19
SQLINSTANCE: SQL2016
SQLINSTANCE: SQL2017
- GOVERSION: 110
SQLINSTANCE: SQL2016
SQLINSTANCE: SQL2017
- GOVERSION: 111
SQLINSTANCE: SQL2016
SQLINSTANCE: SQL2017
- GOVERSION: 112
SQLINSTANCE: SQL2017
- SQLINSTANCE: SQL2017
- SQLINSTANCE: SQL2016
- SQLINSTANCE: SQL2014
- SQLINSTANCE: SQL2012SP1
- SQLINSTANCE: SQL2008R2SP2
# Go 1.14+ and SQL2019 are available on the Visual Studio 2019 image only
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 114
SQLINSTANCE: SQL2019
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 115
SQLINSTANCE: SQL2019
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 115
SQLINSTANCE: SQL2017
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 116
SQLINSTANCE: SQL2017
install:
- set GOROOT=c:\go%GOVERSION%
- set PATH=%GOPATH%\bin;%GOROOT%\bin;%PATH%
@ -35,15 +54,14 @@ build_script:
- go build
before_test:
# setup SQL Server
- ps: |
# setup SQL Server
- ps: |
$instanceName = $env:SQLINSTANCE
Start-Service "MSSQL`$$instanceName"
Start-Service "SQLBrowser"
- sqlcmd -S "(local)\%SQLINSTANCE%" -Q "Use [master]; CREATE DATABASE test;"
- sqlcmd -S "(local)\%SQLINSTANCE%" -h -1 -Q "set nocount on; Select @@version"
- pip install codecov
test_script:
- go test -race -cpu 4 -coverprofile=coverage.txt -covermode=atomic

View file

@ -48,8 +48,8 @@ type tdsBuffer struct {
func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer {
return &tdsBuffer{
packetSize: int(bufsize),
wbuf: make([]byte, 1<<16),
rbuf: make([]byte, 1<<16),
wbuf: make([]byte, bufsize),
rbuf: make([]byte, bufsize),
rpos: 8,
transport: transport,
}
@ -137,19 +137,28 @@ func (w *tdsBuffer) FinishPacket() error {
var headerSize = binary.Size(header{})
func (r *tdsBuffer) readNextPacket() error {
h := header{}
var err error
err = binary.Read(r.transport, binary.BigEndian, &h)
buf := r.rbuf[:headerSize]
_, err := io.ReadFull(r.transport, buf)
if err != nil {
return err
}
h := header{
PacketType: packetType(buf[0]),
Status: buf[1],
Size: binary.BigEndian.Uint16(buf[2:4]),
Spid: binary.BigEndian.Uint16(buf[4:6]),
PacketNo: buf[6],
Pad: buf[7],
}
if int(h.Size) > r.packetSize {
return errors.New("Invalid packet size, it is longer than buffer size")
return errors.New("invalid packet size, it is longer than buffer size")
}
if headerSize > int(h.Size) {
return errors.New("Invalid packet size, it is shorter than header size")
return errors.New("invalid packet size, it is shorter than header size")
}
_, err = io.ReadFull(r.transport, r.rbuf[headerSize:h.Size])
//s := base64.StdEncoding.EncodeToString(r.rbuf[headerSize:h.Size])
//fmt.Print(s)
if err != nil {
return err
}

View file

@ -7,6 +7,7 @@ import (
"fmt"
"math"
"reflect"
"strconv"
"strings"
"time"
@ -44,8 +45,9 @@ type BulkOptions struct {
type DataValue interface{}
const (
sqlDateFormat = "2006-01-02"
sqlTimeFormat = "2006-01-02 15:04:05.999999999Z07:00"
sqlDateFormat = "2006-01-02"
sqlDateTimeFormat = "2006-01-02 15:04:05.999999999Z07:00"
sqlTimeFormat = "15:04:05.9999999"
)
func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) {
@ -86,7 +88,7 @@ func (b *Bulk) sendBulkCommand(ctx context.Context) (err error) {
b.bulkColumns = append(b.bulkColumns, *bulkCol)
b.dlogf("Adding column %s %s %#x", colname, bulkCol.ColName, bulkCol.ti.TypeId)
} else {
return fmt.Errorf("Column %s does not exist in destination table %s", colname, b.tablename)
return fmt.Errorf("column %s does not exist in destination table %s", colname, b.tablename)
}
}
@ -166,7 +168,7 @@ func (b *Bulk) AddRow(row []interface{}) (err error) {
}
if len(row) != len(b.bulkColumns) {
return fmt.Errorf("Row does not have the same number of columns than the destination table %d %d",
return fmt.Errorf("row does not have the same number of columns than the destination table %d %d",
len(row), len(b.bulkColumns))
}
@ -215,7 +217,7 @@ func (b *Bulk) makeRowData(row []interface{}) ([]byte, error) {
}
func (b *Bulk) Done() (rowcount int64, err error) {
if b.headerSent == false {
if !b.headerSent {
//no rows had been sent
return 0, nil
}
@ -233,24 +235,13 @@ func (b *Bulk) Done() (rowcount int64, err error) {
buf.FinishPacket()
tokchan := make(chan tokenStruct, 5)
go processResponse(b.ctx, b.cn.sess, tokchan, nil)
var rowCount int64
for token := range tokchan {
switch token := token.(type) {
case doneStruct:
if token.Status&doneCount != 0 {
rowCount = int64(token.RowCount)
}
if token.isError() {
return 0, token.getError()
}
case error:
return 0, b.cn.checkBadConn(token)
}
reader := startReading(b.cn.sess, b.ctx, outputs{})
err = reader.iterateResponse()
if err != nil {
return 0, b.cn.checkBadConn(err, false)
}
return rowCount, nil
return reader.rowCount, nil
}
func (b *Bulk) createColMetadata() []byte {
@ -339,6 +330,10 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
intvalue = int64(val)
case int64:
intvalue = val
case float32:
intvalue = int64(val)
case float64:
intvalue = int64(val)
default:
err = fmt.Errorf("mssql: invalid type for int column: %T", val)
return
@ -383,6 +378,8 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
switch val := val.(type) {
case string:
res.buffer = str2ucs2(val)
case int64:
res.buffer = []byte(strconv.FormatInt(val, 10))
case []byte:
res.buffer = val
default:
@ -397,6 +394,8 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
res.buffer = []byte(val)
case []byte:
res.buffer = val
case int64:
res.buffer = []byte(strconv.FormatInt(val, 10))
default:
err = fmt.Errorf("mssql: invalid type for varchar column: %T %s", val, val)
return
@ -421,7 +420,7 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
res.ti.Size = len(res.buffer)
case string:
var t time.Time
if t, err = time.Parse(sqlTimeFormat, val); err != nil {
if t, err = time.Parse(sqlDateTimeFormat, val); err != nil {
return res, fmt.Errorf("bulk: unable to convert string to date: %v", err)
}
res.buffer = encodeDateTime2(t, int(col.ti.Scale))
@ -437,7 +436,7 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
res.ti.Size = len(res.buffer)
case string:
var t time.Time
if t, err = time.Parse(sqlTimeFormat, val); err != nil {
if t, err = time.Parse(sqlDateTimeFormat, val); err != nil {
return res, fmt.Errorf("bulk: unable to convert string to date: %v", err)
}
res.buffer = encodeDateTimeOffset(t, int(col.ti.Scale))
@ -468,7 +467,7 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
case time.Time:
t = val
case string:
if t, err = time.Parse(sqlTimeFormat, val); err != nil {
if t, err = time.Parse(sqlDateTimeFormat, val); err != nil {
return res, fmt.Errorf("bulk: unable to convert string to date: %v", err)
}
default:
@ -485,7 +484,22 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
} else {
err = fmt.Errorf("mssql: invalid size of column %d", col.ti.Size)
}
case typeTimeN:
var t time.Time
switch val := val.(type) {
case time.Time:
res.buffer = encodeTime(val.Hour(), val.Minute(), val.Second(), val.Nanosecond(), int(col.ti.Scale))
res.ti.Size = len(res.buffer)
case string:
if t, err = time.Parse(sqlTimeFormat, val); err != nil {
return res, fmt.Errorf("bulk: unable to convert string to time: %v", err)
}
res.buffer = encodeTime(t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), int(col.ti.Scale))
res.ti.Size = len(res.buffer)
default:
err = fmt.Errorf("mssql: invalid type for time column: %T %s", val, val)
return
}
// case typeMoney, typeMoney4, typeMoneyN:
case typeDecimal, typeDecimalN, typeNumeric, typeNumericN:
prec := col.ti.Prec

View file

@ -1,6 +1,7 @@
package mssql
import (
"database/sql/driver"
"fmt"
)
@ -17,6 +18,9 @@ type Error struct {
ServerName string
ProcName string
LineNo int32
// All lists all errors that were received from first to last.
// This includes the last one, which is described in the other members.
All []Error
}
func (e Error) Error() string {
@ -65,9 +69,53 @@ func streamErrorf(format string, v ...interface{}) StreamError {
}
func badStreamPanic(err error) {
panic(err)
panic(streamErrorf("%v", err))
}
func badStreamPanicf(format string, v ...interface{}) {
panic(streamErrorf(format, v...))
}
// ServerError is returned when the server got a fatal error
// that aborts the process and severs the connection.
//
// To get the errors returned before the process was aborted,
// unwrap this error or call errors.As with a pointer to an
// mssql.Error variable.
type ServerError struct {
sqlError Error
}
func (e ServerError) Error() string {
return "SQL Server had internal error"
}
func (e ServerError) Unwrap() error {
return e.sqlError
}
// RetryableError is returned when an error was caused by a bad
// connection at the start of a query and can be safely retried
// using database/sql's automatic retry logic.
//
// In many cases database/sql's retry logic will transparently
// handle this error, the retried call will return successfully,
// and you won't even see this error. However, you may see this
// error if the retry logic cannot successfully handle the error.
// In that case you can get the underlying error by calling this
// error's UnWrap function.
type RetryableError struct {
err error
}
func (r RetryableError) Error() string {
return r.err.Error()
}
func (r RetryableError) Unwrap() error {
return r.err
}
func (r RetryableError) Is(err error) bool {
return err == driver.ErrBadConn
}

78
vendor/github.com/denisenkom/go-mssqldb/fedauth.go generated vendored Normal file
View file

@ -0,0 +1,78 @@
package mssql
import (
"context"
"errors"
"github.com/denisenkom/go-mssqldb/msdsn"
)
// Federated authentication library affects the login data structure and message sequence.
const (
// fedAuthLibraryLiveIDCompactToken specifies the Microsoft Live ID Compact Token authentication scheme
fedAuthLibraryLiveIDCompactToken = 0x00
// fedAuthLibrarySecurityToken specifies a token-based authentication where the token is available
// without additional information provided during the login sequence.
fedAuthLibrarySecurityToken = 0x01
// fedAuthLibraryADAL specifies a token-based authentication where a token is obtained during the
// login sequence using the server SPN and STS URL provided by the server during login.
fedAuthLibraryADAL = 0x02
// fedAuthLibraryReserved is used to indicate that no federated authentication scheme applies.
fedAuthLibraryReserved = 0x7F
)
// Federated authentication ADAL workflow affects the mechanism used to authenticate.
const (
// fedAuthADALWorkflowPassword uses a username/password to obtain a token from Active Directory
fedAuthADALWorkflowPassword = 0x01
// fedAuthADALWorkflowPassword uses the Windows identity to obtain a token from Active Directory
fedAuthADALWorkflowIntegrated = 0x02
// fedAuthADALWorkflowMSI uses the managed identity service to obtain a token
fedAuthADALWorkflowMSI = 0x03
)
// newSecurityTokenConnector creates a new connector from a Config and a token provider.
// When invoked, token provider implementations should contact the security token
// service specified and obtain the appropriate token, or return an error
// to indicate why a token is not available.
// The returned connector may be used with sql.OpenDB.
func newSecurityTokenConnector(config msdsn.Config, tokenProvider func(ctx context.Context) (string, error)) (*Connector, error) {
if tokenProvider == nil {
return nil, errors.New("mssql: tokenProvider cannot be nil")
}
conn := NewConnectorConfig(config)
conn.fedAuthRequired = true
conn.fedAuthLibrary = fedAuthLibrarySecurityToken
conn.securityTokenProvider = tokenProvider
return conn, nil
}
// newADALTokenConnector creates a new connector from a Config and a Active Directory token provider.
// Token provider implementations are called during federated
// authentication login sequences where the server provides a service
// principal name and security token service endpoint that should be used
// to obtain the token. Implementations should contact the security token
// service specified and obtain the appropriate token, or return an error
// to indicate why a token is not available.
//
// The returned connector may be used with sql.OpenDB.
func newActiveDirectoryTokenConnector(config msdsn.Config, adalWorkflow byte, tokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error)) (*Connector, error) {
if tokenProvider == nil {
return nil, errors.New("mssql: tokenProvider cannot be nil")
}
conn := NewConnectorConfig(config)
conn.fedAuthRequired = true
conn.fedAuthLibrary = fedAuthLibraryADAL
conn.fedAuthADALWorkflow = adalWorkflow
conn.adalTokenProvider = tokenProvider
return conn, nil
}

View file

@ -1,7 +1,10 @@
package mssql
package msdsn
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"net/url"
"os"
@ -11,49 +14,104 @@ import (
"unicode"
)
const defaultServerPort = 1433
type (
Encryption int
Log uint64
)
type connectParams struct {
logFlags uint64
port uint64
host string
instance string
database string
user string
password string
dial_timeout time.Duration
conn_timeout time.Duration
keepAlive time.Duration
encrypt bool
disableEncryption bool
trustServerCertificate bool
certificate string
hostInCertificate string
hostInCertificateProvided bool
serverSPN string
workstation string
appname string
typeFlags uint8
failOverPartner string
failOverPort uint64
packetSize uint16
fedAuthAccessToken string
const (
EncryptionOff = 0
EncryptionRequired = 1
EncryptionDisabled = 3
)
const (
LogErrors Log = 1
LogMessages Log = 2
LogRows Log = 4
LogSQL Log = 8
LogParams Log = 16
LogTransaction Log = 32
LogDebug Log = 64
)
type Config struct {
Port uint64
Host string
Instance string
Database string
User string
Password string
Encryption Encryption
TLSConfig *tls.Config
FailOverPartner string
FailOverPort uint64
// If true the TLSConfig servername should use the routed server.
HostInCertificateProvided bool
// Read Only intent for application database.
// NOTE: This does not make queries to most databases read-only.
ReadOnlyIntent bool
LogFlags Log
ServerSPN string
Workstation string
AppName string
// If true disables database/sql's automatic retry of queries
// that start on bad connections.
DisableRetry bool
// Do not use the following.
DialTimeout time.Duration // DialTimeout defaults to 15s. Set negative to disable.
ConnTimeout time.Duration // Use context for timeouts.
KeepAlive time.Duration // Leave at default.
PacketSize uint16
}
func parseConnectParams(dsn string) (connectParams, error) {
var p connectParams
func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate string) (*tls.Config, error) {
var config tls.Config
if certificate != "" {
pem, err := ioutil.ReadFile(certificate)
if err != nil {
return nil, fmt.Errorf("cannot read certificate %q: %v", certificate, err)
}
certs := x509.NewCertPool()
certs.AppendCertsFromPEM(pem)
config.RootCAs = certs
}
if insecureSkipVerify {
config.InsecureSkipVerify = true
}
config.ServerName = hostInCertificate
// fix for https://github.com/denisenkom/go-mssqldb/issues/166
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
config.DynamicRecordSizingDisabled = true
return &config, nil
}
func Parse(dsn string) (Config, map[string]string, error) {
p := Config{}
var params map[string]string
if strings.HasPrefix(dsn, "odbc:") {
parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):])
if err != nil {
return p, err
return p, params, err
}
params = parameters
} else if strings.HasPrefix(dsn, "sqlserver://") {
parameters, err := splitConnectionStringURL(dsn)
if err != nil {
return p, err
return p, params, err
}
params = parameters
} else {
@ -62,57 +120,55 @@ func parseConnectParams(dsn string) (connectParams, error) {
strlog, ok := params["log"]
if ok {
var err error
p.logFlags, err = strconv.ParseUint(strlog, 10, 64)
flags, err := strconv.ParseUint(strlog, 10, 64)
if err != nil {
return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error())
return p, params, fmt.Errorf("invalid log parameter '%s': %s", strlog, err.Error())
}
p.LogFlags = Log(flags)
}
server := params["server"]
parts := strings.SplitN(server, `\`, 2)
p.host = parts[0]
if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" {
p.host = "localhost"
p.Host = parts[0]
if p.Host == "." || strings.ToUpper(p.Host) == "(LOCAL)" || p.Host == "" {
p.Host = "localhost"
}
if len(parts) > 1 {
p.instance = parts[1]
p.Instance = parts[1]
}
p.database = params["database"]
p.user = params["user id"]
p.password = params["password"]
p.Database = params["database"]
p.User = params["user id"]
p.Password = params["password"]
p.port = 0
p.Port = 0
strport, ok := params["port"]
if ok {
var err error
p.port, err = strconv.ParseUint(strport, 10, 16)
p.Port, err = strconv.ParseUint(strport, 10, 16)
if err != nil {
f := "Invalid tcp port '%v': %v"
return p, fmt.Errorf(f, strport, err.Error())
f := "invalid tcp port '%v': %v"
return p, params, fmt.Errorf(f, strport, err.Error())
}
}
// https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option
// Default packet size remains at 4096 bytes
p.packetSize = 4096
// https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option\
strpsize, ok := params["packet size"]
if ok {
var err error
psize, err := strconv.ParseUint(strpsize, 0, 16)
if err != nil {
f := "Invalid packet size '%v': %v"
return p, fmt.Errorf(f, strpsize, err.Error())
f := "invalid packet size '%v': %v"
return p, params, fmt.Errorf(f, strpsize, err.Error())
}
// Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes
// NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request
// a higher packet size, the server will respond with an ENVCHANGE request to
// alter the packet size to 16383 bytes.
p.packetSize = uint16(psize)
if p.packetSize < 512 {
p.packetSize = 512
} else if p.packetSize > 32767 {
p.packetSize = 32767
p.PacketSize = uint16(psize)
if p.PacketSize < 512 {
p.PacketSize = 512
} else if p.PacketSize > 32767 {
p.PacketSize = 32767
}
}
@ -123,79 +179,95 @@ func parseConnectParams(dsn string) (connectParams, error) {
if strconntimeout, ok := params["connection timeout"]; ok {
timeout, err := strconv.ParseUint(strconntimeout, 10, 64)
if err != nil {
f := "Invalid connection timeout '%v': %v"
return p, fmt.Errorf(f, strconntimeout, err.Error())
f := "invalid connection timeout '%v': %v"
return p, params, fmt.Errorf(f, strconntimeout, err.Error())
}
p.conn_timeout = time.Duration(timeout) * time.Second
p.ConnTimeout = time.Duration(timeout) * time.Second
}
p.dial_timeout = 15 * time.Second
p.DialTimeout = 15 * time.Second
if strdialtimeout, ok := params["dial timeout"]; ok {
timeout, err := strconv.ParseUint(strdialtimeout, 10, 64)
if err != nil {
f := "Invalid dial timeout '%v': %v"
return p, fmt.Errorf(f, strdialtimeout, err.Error())
f := "invalid dial timeout '%v': %v"
return p, params, fmt.Errorf(f, strdialtimeout, err.Error())
}
p.dial_timeout = time.Duration(timeout) * time.Second
p.DialTimeout = time.Duration(timeout) * time.Second
}
// default keep alive should be 30 seconds according to spec:
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
p.keepAlive = 30 * time.Second
p.KeepAlive = 30 * time.Second
if keepAlive, ok := params["keepalive"]; ok {
timeout, err := strconv.ParseUint(keepAlive, 10, 64)
if err != nil {
f := "Invalid keepAlive value '%s': %s"
return p, fmt.Errorf(f, keepAlive, err.Error())
f := "invalid keepAlive value '%s': %s"
return p, params, fmt.Errorf(f, keepAlive, err.Error())
}
p.keepAlive = time.Duration(timeout) * time.Second
p.KeepAlive = time.Duration(timeout) * time.Second
}
var (
trustServerCert = false
certificate = ""
hostInCertificate = ""
)
encrypt, ok := params["encrypt"]
if ok {
if strings.EqualFold(encrypt, "DISABLE") {
p.disableEncryption = true
p.Encryption = EncryptionDisabled
} else {
var err error
p.encrypt, err = strconv.ParseBool(encrypt)
e, err := strconv.ParseBool(encrypt)
if err != nil {
f := "Invalid encrypt '%s': %s"
return p, fmt.Errorf(f, encrypt, err.Error())
f := "invalid encrypt '%s': %s"
return p, params, fmt.Errorf(f, encrypt, err.Error())
}
if e {
p.Encryption = EncryptionRequired
}
}
} else {
p.trustServerCertificate = true
trustServerCert = true
}
trust, ok := params["trustservercertificate"]
if ok {
var err error
p.trustServerCertificate, err = strconv.ParseBool(trust)
trustServerCert, err = strconv.ParseBool(trust)
if err != nil {
f := "Invalid trust server certificate '%s': %s"
return p, fmt.Errorf(f, trust, err.Error())
f := "invalid trust server certificate '%s': %s"
return p, params, fmt.Errorf(f, trust, err.Error())
}
}
p.certificate = params["certificate"]
p.hostInCertificate, ok = params["hostnameincertificate"]
certificate = params["certificate"]
hostInCertificate, ok = params["hostnameincertificate"]
if ok {
p.hostInCertificateProvided = true
p.HostInCertificateProvided = true
} else {
p.hostInCertificate = p.host
p.hostInCertificateProvided = false
hostInCertificate = p.Host
p.HostInCertificateProvided = false
}
if p.Encryption != EncryptionDisabled {
var err error
p.TLSConfig, err = SetupTLS(certificate, trustServerCert, hostInCertificate)
if err != nil {
return p, params, fmt.Errorf("failed to setup TLS: %w", err)
}
}
serverSPN, ok := params["serverspn"]
if ok {
p.serverSPN = serverSPN
p.ServerSPN = serverSPN
} else {
p.serverSPN = generateSpn(p.host, resolveServerPort(p.port))
p.ServerSPN = generateSpn(p.Host, p.Port)
}
workstation, ok := params["workstation id"]
if ok {
p.workstation = workstation
p.Workstation = workstation
} else {
workstation, err := os.Hostname()
if err == nil {
p.workstation = workstation
p.Workstation = workstation
}
}
@ -203,34 +275,86 @@ func parseConnectParams(dsn string) (connectParams, error) {
if !ok {
appname = "go-mssqldb"
}
p.appname = appname
p.AppName = appname
appintent, ok := params["applicationintent"]
if ok {
if appintent == "ReadOnly" {
if p.database == "" {
return p, fmt.Errorf("Database must be specified when ApplicationIntent is ReadOnly")
if p.Database == "" {
return p, params, fmt.Errorf("database must be specified when ApplicationIntent is ReadOnly")
}
p.typeFlags |= fReadOnlyIntent
p.ReadOnlyIntent = true
}
}
failOverPartner, ok := params["failoverpartner"]
if ok {
p.failOverPartner = failOverPartner
p.FailOverPartner = failOverPartner
}
failOverPort, ok := params["failoverport"]
if ok {
var err error
p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
p.FailOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
if err != nil {
f := "Invalid tcp port '%v': %v"
return p, fmt.Errorf(f, failOverPort, err.Error())
f := "invalid failover port '%v': %v"
return p, params, fmt.Errorf(f, failOverPort, err.Error())
}
}
return p, nil
disableRetry, ok := params["disableretry"]
if ok {
var err error
p.DisableRetry, err = strconv.ParseBool(disableRetry)
if err != nil {
f := "invalid disableRetry '%s': %s"
return p, params, fmt.Errorf(f, disableRetry, err.Error())
}
} else {
p.DisableRetry = disableRetryDefault
}
return p, params, nil
}
// convert connectionParams to url style connection string
// used mostly for testing
func (p Config) URL() *url.URL {
q := url.Values{}
if p.Database != "" {
q.Add("database", p.Database)
}
if p.LogFlags != 0 {
q.Add("log", strconv.FormatUint(uint64(p.LogFlags), 10))
}
host := p.Host
if p.Port > 0 {
host = fmt.Sprintf("%s:%d", p.Host, p.Port)
}
q.Add("disableRetry", fmt.Sprintf("%t", p.DisableRetry))
res := url.URL{
Scheme: "sqlserver",
Host: host,
User: url.UserPassword(p.User, p.Password),
}
if p.Instance != "" {
res.Path = p.Instance
}
if len(q) > 0 {
res.RawQuery = q.Encode()
}
return &res
}
var adoSynonyms = map[string]string{
"application name": "app name",
"data source": "server",
"address": "server",
"network address": "server",
"addr": "server",
"user": "user id",
"uid": "user id",
"initial catalog": "database",
}
func splitConnectionString(dsn string) (res map[string]string) {
@ -249,6 +373,20 @@ func splitConnectionString(dsn string) (res map[string]string) {
if len(lst) > 1 {
value = strings.TrimSpace(lst[1])
}
synonym, hasSynonym := adoSynonyms[name]
if hasSynonym {
name = synonym
}
// "server" in ADO can include a protocol and a port.
// We only support tcp protocol
if name == "server" {
value = strings.TrimPrefix(value, "tcp:")
serverParts := strings.Split(value, ",")
if len(serverParts) == 2 && len(serverParts[1]) > 0 {
value = serverParts[0]
res["port"] = serverParts[1]
}
}
res[name] = value
}
return res
@ -340,7 +478,7 @@ func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
case parserStateBeforeKey:
switch {
case c == '=':
return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i)
return res, fmt.Errorf("unexpected character = at index %d. Expected start of key or semi-colon or whitespace", i)
case !unicode.IsSpace(c) && c != ';':
state = parserStateKey
key += string(c)
@ -419,7 +557,7 @@ func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
case unicode.IsSpace(c):
// Ignore whitespace
default:
return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
return res, fmt.Errorf("unexpected character %c at index %d. Expected semi-colon or whitespace", c, i)
}
case parserStateEndValue:
@ -429,7 +567,7 @@ func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
case unicode.IsSpace(c):
// Ignore whitespace
default:
return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
return res, fmt.Errorf("unexpected character %c at index %d. Expected semi-colon or whitespace", c, i)
}
}
}
@ -444,7 +582,7 @@ func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
case parserStateBareValue:
res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
case parserStateBracedValue:
return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn))
return res, fmt.Errorf("unexpected end of braced value at index %d", len(dsn))
case parserStateBracedValueClosingBrace: // End of braced value
res[key] = value
case parserStateEndValue: // Okay
@ -458,14 +596,6 @@ func normalizeOdbcKey(s string) string {
return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace))
}
func resolveServerPort(port uint64) uint64 {
if port == 0 {
return defaultServerPort
}
return port
}
func generateSpn(host string, port uint64) string {
return fmt.Sprintf("MSSQLSvc/%s:%d", host, port)
}

View file

@ -0,0 +1,9 @@
// +build go1.18
package msdsn
// disableRetryDefault is false for Go versions 1.18 and higher. This matches
// the behavior requested in issue #586. A query that fails at the start due to
// a bad connection is automatically retried. An error is returned only if the
// query fails all of its retries.
const disableRetryDefault bool = false

View file

@ -0,0 +1,9 @@
// +build !go1.18
package msdsn
// disableRetryDefault is true for versions of Go less than 1.18. This matches
// the behavior requested in issue #275. A query that fails at the start due to
// a bad connection is not retried. Instead, the detailed error is immediately
// returned to the caller.
const disableRetryDefault bool = true

View file

@ -16,6 +16,7 @@ import (
"unicode"
"github.com/denisenkom/go-mssqldb/internal/querytext"
"github.com/denisenkom/go-mssqldb/msdsn"
)
// ReturnStatus may be used to return the return value from a proc.
@ -31,12 +32,16 @@ var driverInstanceNoProcess = &Driver{processQueryText: false}
func init() {
sql.Register("mssql", driverInstance)
sql.Register("sqlserver", driverInstanceNoProcess)
createDialer = func(p *connectParams) Dialer {
return netDialer{&net.Dialer{KeepAlive: p.keepAlive}}
createDialer = func(p *msdsn.Config) Dialer {
ka := p.KeepAlive
if ka == 0 {
ka = 30 * time.Second
}
return netDialer{&net.Dialer{KeepAlive: ka}}
}
}
var createDialer func(p *connectParams) Dialer
var createDialer func(p *msdsn.Config) Dialer
type netDialer struct {
nd *net.Dialer
@ -54,10 +59,11 @@ type Driver struct {
// OpenConnector opens a new connector. Useful to dial with a context.
func (d *Driver) OpenConnector(dsn string) (*Connector, error) {
params, err := parseConnectParams(dsn)
params, _, err := msdsn.Parse(dsn)
if err != nil {
return nil, err
}
return &Connector{
params: params,
driver: d,
@ -80,7 +86,7 @@ func (d *Driver) SetLogger(logger Logger) {
// NewConnector creates a new connector from a DSN.
// The returned connector may be used with sql.OpenDB.
func NewConnector(dsn string) (*Connector, error) {
params, err := parseConnectParams(dsn)
params, _, err := msdsn.Parse(dsn)
if err != nil {
return nil, err
}
@ -91,15 +97,34 @@ func NewConnector(dsn string) (*Connector, error) {
return c, nil
}
// NewConnectorConfig creates a new Connector for a DSN Config struct.
// The returned connector may be used with sql.OpenDB.
func NewConnectorConfig(config msdsn.Config) *Connector {
return &Connector{
params: config,
driver: driverInstanceNoProcess,
}
}
// Connector holds the parsed DSN and is ready to make a new connection
// at any time.
//
// In the future, settings that cannot be passed through a string DSN
// may be set directly on the connector.
type Connector struct {
params connectParams
params msdsn.Config
driver *Driver
fedAuthRequired bool
fedAuthLibrary int
fedAuthADALWorkflow byte
// callback that can provide a security token during login
securityTokenProvider func(ctx context.Context) (string, error)
// callback that can provide a security token during ADAL login
adalTokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error)
// SessionInitSQL is executed after marking a given session to be reset.
// When not present, the next query will still reset the session to the
// database defaults.
@ -132,7 +157,7 @@ type Dialer interface {
DialContext(ctx context.Context, network string, addr string) (net.Conn, error)
}
func (c *Connector) getDialer(p *connectParams) Dialer {
func (c *Connector) getDialer(p *msdsn.Config) Dialer {
if c != nil && c.Dialer != nil {
return c.Dialer
}
@ -148,34 +173,32 @@ type Conn struct {
processQueryText bool
connectionGood bool
outs map[string]interface{}
outs outputs
}
type outputs struct {
params map[string]interface{}
returnStatus *ReturnStatus
}
func (c *Conn) setReturnStatus(s ReturnStatus) {
if c.returnStatus == nil {
return
}
*c.returnStatus = s
// IsValid satisfies the driver.Validator interface.
func (c *Conn) IsValid() bool {
return c.connectionGood
}
func (c *Conn) checkBadConn(err error) error {
// this is a hack to address Issue #275
// we set connectionGood flag to false if
// error indicates that connection is not usable
// but we return actual error instead of ErrBadConn
// this will cause connection to stay in a pool
// but next request to this connection will return ErrBadConn
// it might be possible to revise this hack after
// https://github.com/golang/go/issues/20807
// is implemented
// checkBadConn marks the connection as bad based on the characteristics
// of the supplied error. Bad connections will be dropped from the connection
// pool rather than reused.
//
// If bad connection retry is enabled and the error + connection state permits
// retrying, checkBadConn will return a RetryableError that allows database/sql
// to automatically retry the query with another connection.
func (c *Conn) checkBadConn(err error, mayRetry bool) error {
switch err {
case nil:
return nil
case io.EOF:
c.connectionGood = false
return driver.ErrBadConn
case driver.ErrBadConn:
// It is an internal programming error if driver.ErrBadConn
// is ever passed to this function. driver.ErrBadConn should
@ -187,34 +210,33 @@ func (c *Conn) checkBadConn(err error) error {
switch err.(type) {
case net.Error:
c.connectionGood = false
return err
case StreamError:
c.connectionGood = false
return err
default:
return err
case ServerError:
c.connectionGood = false
}
if !c.connectionGood && mayRetry && !c.connector.params.DisableRetry {
return newRetryableError(err)
}
return err
}
func (c *Conn) clearOuts() {
c.outs = nil
c.outs = outputs{}
}
func (c *Conn) simpleProcessResp(ctx context.Context) error {
tokchan := make(chan tokenStruct, 5)
go processResponse(ctx, c.sess, tokchan, c.outs)
reader := startReading(c.sess, ctx, c.outs)
c.clearOuts()
for tok := range tokchan {
switch token := tok.(type) {
case doneStruct:
if token.isError() {
return c.checkBadConn(token.getError())
}
case error:
return c.checkBadConn(token)
}
var resultError error
err := reader.iterateResponse()
if err != nil {
return c.checkBadConn(err, false)
}
return nil
return resultError
}
func (c *Conn) Commit() error {
@ -222,7 +244,7 @@ func (c *Conn) Commit() error {
return driver.ErrBadConn
}
if err := c.sendCommitRequest(); err != nil {
return c.checkBadConn(err)
return c.checkBadConn(err, true)
}
return c.simpleProcessResp(c.transactionCtx)
}
@ -239,7 +261,7 @@ func (c *Conn) sendCommitRequest() error {
c.sess.log.Printf("Failed to send CommitXact with %v", err)
}
c.connectionGood = false
return fmt.Errorf("Faild to send CommitXact: %v", err)
return fmt.Errorf("faild to send CommitXact: %v", err)
}
return nil
}
@ -249,7 +271,7 @@ func (c *Conn) Rollback() error {
return driver.ErrBadConn
}
if err := c.sendRollbackRequest(); err != nil {
return c.checkBadConn(err)
return c.checkBadConn(err, true)
}
return c.simpleProcessResp(c.transactionCtx)
}
@ -266,7 +288,7 @@ func (c *Conn) sendRollbackRequest() error {
c.sess.log.Printf("Failed to send RollbackXact with %v", err)
}
c.connectionGood = false
return fmt.Errorf("Failed to send RollbackXact: %v", err)
return fmt.Errorf("failed to send RollbackXact: %v", err)
}
return nil
}
@ -281,11 +303,11 @@ func (c *Conn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx,
}
err = c.sendBeginRequest(ctx, tdsIsolation)
if err != nil {
return nil, c.checkBadConn(err)
return nil, c.checkBadConn(err, true)
}
tx, err = c.processBeginResponse(ctx)
if err != nil {
return nil, c.checkBadConn(err)
return nil, err
}
return
}
@ -303,7 +325,7 @@ func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) erro
c.sess.log.Printf("Failed to send BeginXact with %v", err)
}
c.connectionGood = false
return fmt.Errorf("Failed to send BeginXact: %v", err)
return fmt.Errorf("failed to send BeginXact: %v", err)
}
return nil
}
@ -318,25 +340,26 @@ func (c *Conn) processBeginResponse(ctx context.Context) (driver.Tx, error) {
}
func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) {
params, err := parseConnectParams(dsn)
params, _, err := msdsn.Parse(dsn)
if err != nil {
return nil, err
}
return d.connect(ctx, nil, params)
c := &Connector{params: params}
return d.connect(ctx, c, params)
}
// connect to the server, using the provided context for dialing only.
func (d *Driver) connect(ctx context.Context, c *Connector, params connectParams) (*Conn, error) {
func (d *Driver) connect(ctx context.Context, c *Connector, params msdsn.Config) (*Conn, error) {
sess, err := connect(ctx, c, d.log, params)
if err != nil {
// main server failed, try fail-over partner
if params.failOverPartner == "" {
if params.FailOverPartner == "" {
return nil, err
}
params.host = params.failOverPartner
if params.failOverPort != 0 {
params.port = params.failOverPort
params.Host = params.FailOverPartner
if params.FailOverPort != 0 {
params.Port = params.FailOverPort
}
sess, err = connect(ctx, c, d.log, params)
@ -447,7 +470,8 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) {
reset := conn.resetSession
conn.resetSession = false
if len(args) == 0 {
isProc := isProc(s.query)
if len(args) == 0 && !isProc {
if err = sendSqlBatch72(conn.sess.buf, s.query, headers, reset); err != nil {
if conn.sess.logFlags&logErrors != 0 {
conn.sess.log.Printf("Failed to send SqlBatch with %v", err)
@ -458,7 +482,7 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) {
} else {
proc := sp_ExecuteSql
var params []param
if isProc(s.query) {
if isProc {
proc.name = s.query
params, _, err = s.makeRPCParams(args, true)
if err != nil {
@ -478,7 +502,7 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) {
conn.sess.log.Printf("Failed to send Rpc with %v", err)
}
conn.connectionGood = false
return fmt.Errorf("Failed to send RPC: %v", err)
return fmt.Errorf("failed to send RPC: %v", err)
}
}
return
@ -500,30 +524,38 @@ func isProc(s string) bool {
for _, r := range s {
rPrev = rn1
rn1 = r
switch r {
// No newlines or string sequences.
case '\n', '\r', '\'', ';':
return false
if st != escaped {
switch r {
// No newlines or string sequences.
case '\n', '\r', '\'', ';':
return false
}
}
switch st {
case outside:
switch {
case unicode.IsSpace(r):
return false
case r == '[':
st = escaped
continue
case r == ']' && rPrev == ']':
st = escaped
continue
case unicode.IsLetter(r):
st = text
case r == '_':
st = text
case r == '#':
st = text
case r == '.':
default:
return false
}
case text:
switch {
case r == '.':
st = outside
continue
case r == '[':
return false
case r == '(':
return false
case unicode.IsSpace(r):
return false
}
@ -531,7 +563,6 @@ func isProc(s string) bool {
switch {
case r == ']':
st = outside
continue
}
}
}
@ -558,7 +589,13 @@ func (s *Stmt) makeRPCParams(args []namedValue, isProc bool) ([]param, []string,
name = fmt.Sprintf("@p%d", val.Ordinal)
}
params[i+offset].Name = name
decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+offset].ti))
const outputSuffix = " output"
var output string
if isOutputValue(val.Value) {
output = outputSuffix
}
decls[i] = fmt.Sprintf("%s %s%s", name, makeDecl(params[i+offset].ti), output)
}
return params, decls, nil
}
@ -581,6 +618,8 @@ func convertOldArgs(args []driver.Value) []namedValue {
}
func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
defer s.c.clearOuts()
return s.queryContext(context.Background(), convertOldArgs(args))
}
@ -589,48 +628,60 @@ func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver
return nil, driver.ErrBadConn
}
if err = s.sendQuery(args); err != nil {
return nil, s.c.checkBadConn(err)
return nil, s.c.checkBadConn(err, true)
}
return s.processQueryResponse(ctx)
}
func (s *Stmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {
tokchan := make(chan tokenStruct, 5)
ctx, cancel := context.WithCancel(ctx)
go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
reader := startReading(s.c.sess, ctx, s.c.outs)
s.c.clearOuts()
// process metadata
var cols []columnStruct
loop:
for tok := range tokchan {
switch token := tok.(type) {
// By ignoring DONE token we effectively
// skip empty result-sets.
// This improves results in queries like that:
// set nocount on; select 1
// see TestIgnoreEmptyResults test
//case doneStruct:
//break loop
case []columnStruct:
cols = token
break loop
case doneStruct:
if token.isError() {
cancel()
return nil, s.c.checkBadConn(token.getError())
for {
tok, err := reader.nextToken()
if err == nil {
if tok == nil {
break
} else {
switch token := tok.(type) {
// By ignoring DONE token we effectively
// skip empty result-sets.
// This improves results in queries like that:
// set nocount on; select 1
// see TestIgnoreEmptyResults test
//case doneStruct:
//break loop
case []columnStruct:
cols = token
break loop
case doneStruct:
if token.isError() {
// need to cleanup cancellable context
cancel()
return nil, s.c.checkBadConn(token.getError(), false)
}
case ReturnStatus:
if reader.outs.returnStatus != nil {
*reader.outs.returnStatus = token
}
}
}
case ReturnStatus:
s.c.setReturnStatus(token)
case error:
} else {
// need to cleanup cancellable context
cancel()
return nil, s.c.checkBadConn(token)
return nil, s.c.checkBadConn(err, false)
}
}
res = &Rows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel}
res = &Rows{stmt: s, reader: reader, cols: cols, cancel: cancel}
return
}
func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
defer s.c.clearOuts()
return s.exec(context.Background(), convertOldArgs(args))
}
@ -639,57 +690,55 @@ func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result,
return nil, driver.ErrBadConn
}
if err = s.sendQuery(args); err != nil {
return nil, s.c.checkBadConn(err)
return nil, s.c.checkBadConn(err, true)
}
if res, err = s.processExec(ctx); err != nil {
return nil, s.c.checkBadConn(err)
return nil, err
}
return
}
func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) {
tokchan := make(chan tokenStruct, 5)
go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
reader := startReading(s.c.sess, ctx, s.c.outs)
s.c.clearOuts()
var rowCount int64
for token := range tokchan {
switch token := token.(type) {
case doneInProcStruct:
if token.Status&doneCount != 0 {
rowCount += int64(token.RowCount)
}
case doneStruct:
if token.Status&doneCount != 0 {
rowCount += int64(token.RowCount)
}
if token.isError() {
return nil, token.getError()
}
case ReturnStatus:
s.c.setReturnStatus(token)
case error:
return nil, token
}
err = reader.iterateResponse()
if err != nil {
return nil, s.c.checkBadConn(err, false)
}
return &Result{s.c, rowCount}, nil
return &Result{s.c, reader.rowCount}, nil
}
type Rows struct {
stmt *Stmt
cols []columnStruct
tokchan chan tokenStruct
stmt *Stmt
cols []columnStruct
reader *tokenProcessor
nextCols []columnStruct
cancel func()
}
func (rc *Rows) Close() error {
// need to add a test which returns lots of rows
// and check closing after reading only few rows
rc.cancel()
for _ = range rc.tokchan {
for {
tok, err := rc.reader.nextToken()
if err == nil {
if tok == nil {
return nil
} else {
// continue consuming tokens
continue
}
} else {
if err == rc.reader.ctx.Err() {
return nil
} else {
return err
}
}
}
rc.tokchan = nil
return nil
}
func (rc *Rows) Columns() (res []string) {
@ -707,27 +756,36 @@ func (rc *Rows) Next(dest []driver.Value) error {
if rc.nextCols != nil {
return io.EOF
}
for tok := range rc.tokchan {
switch tokdata := tok.(type) {
case []columnStruct:
rc.nextCols = tokdata
return io.EOF
case []interface{}:
for i := range dest {
dest[i] = tokdata[i]
for {
tok, err := rc.reader.nextToken()
if err == nil {
if tok == nil {
return io.EOF
} else {
switch tokdata := tok.(type) {
case []columnStruct:
rc.nextCols = tokdata
return io.EOF
case []interface{}:
for i := range dest {
dest[i] = tokdata[i]
}
return nil
case doneStruct:
if tokdata.isError() {
return rc.stmt.c.checkBadConn(tokdata.getError(), false)
}
case ReturnStatus:
if rc.reader.outs.returnStatus != nil {
*rc.reader.outs.returnStatus = tokdata
}
}
}
return nil
case doneStruct:
if tokdata.isError() {
return rc.stmt.c.checkBadConn(tokdata.getError())
}
case ReturnStatus:
rc.stmt.c.setReturnStatus(tokdata)
case error:
return rc.stmt.c.checkBadConn(tokdata)
} else {
return rc.stmt.c.checkBadConn(err, false)
}
}
return io.EOF
}
func (rc *Rows) HasNextResultSet() bool {
@ -895,35 +953,41 @@ func (c *Conn) Ping(ctx context.Context) error {
var _ driver.ConnBeginTx = &Conn{}
func convertIsolationLevel(level sql.IsolationLevel) (isoLevel, error) {
switch level {
case sql.LevelDefault:
return isolationUseCurrent, nil
case sql.LevelReadUncommitted:
return isolationReadUncommited, nil
case sql.LevelReadCommitted:
return isolationReadCommited, nil
case sql.LevelWriteCommitted:
return isolationUseCurrent, errors.New("LevelWriteCommitted isolation level is not supported")
case sql.LevelRepeatableRead:
return isolationRepeatableRead, nil
case sql.LevelSnapshot:
return isolationSnapshot, nil
case sql.LevelSerializable:
return isolationSerializable, nil
case sql.LevelLinearizable:
return isolationUseCurrent, errors.New("LevelLinearizable isolation level is not supported")
default:
return isolationUseCurrent, errors.New("isolation level is not supported or unknown")
}
}
// BeginTx satisfies ConnBeginTx.
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if !c.connectionGood {
return nil, driver.ErrBadConn
}
if opts.ReadOnly {
return nil, errors.New("Read-only transactions are not supported")
return nil, errors.New("read-only transactions are not supported")
}
var tdsIsolation isoLevel
switch sql.IsolationLevel(opts.Isolation) {
case sql.LevelDefault:
tdsIsolation = isolationUseCurrent
case sql.LevelReadUncommitted:
tdsIsolation = isolationReadUncommited
case sql.LevelReadCommitted:
tdsIsolation = isolationReadCommited
case sql.LevelWriteCommitted:
return nil, errors.New("LevelWriteCommitted isolation level is not supported")
case sql.LevelRepeatableRead:
tdsIsolation = isolationRepeatableRead
case sql.LevelSnapshot:
tdsIsolation = isolationSnapshot
case sql.LevelSerializable:
tdsIsolation = isolationSerializable
case sql.LevelLinearizable:
return nil, errors.New("LevelLinearizable isolation level is not supported")
default:
return nil, errors.New("Isolation level is not supported or unknown")
tdsIsolation, err := convertIsolationLevel(sql.IsolationLevel(opts.Isolation))
if err != nil {
return nil, err
}
return c.begin(ctx, tdsIsolation)
}
@ -940,6 +1004,8 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
}
func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
defer s.c.clearOuts()
if !s.c.connectionGood {
return nil, driver.ErrBadConn
}
@ -951,6 +1017,8 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
}
func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
defer s.c.clearOuts()
if !s.c.connectionGood {
return nil, driver.ErrBadConn
}

View file

@ -48,5 +48,5 @@ func (c *Connector) Driver() driver.Driver {
}
func (r *Result) LastInsertId() (int64, error) {
return -1, errors.New("LastInsertId is not supported. Please use the OUTPUT clause or add `select ID = convert(bigint, SCOPE_IDENTITY())` to the end of your query.")
return -1, errors.New("LastInsertId is not supported. Please use the OUTPUT clause or add `select ID = convert(bigint, SCOPE_IDENTITY())` to the end of your query")
}

14
vendor/github.com/denisenkom/go-mssqldb/mssql_go118.go generated vendored Normal file
View file

@ -0,0 +1,14 @@
// +build go1.18
package mssql
// newRetryableError returns an error that allows the database/sql package
// to automatically retry the failed query. Versions of Go 1.18 and higher
// use errors.Is to determine whether or not a failed query can be retried.
// Therefore, we wrap the underlying error in a RetryableError that both
// implements errors.Is for automatic retry and maintains the error details.
func newRetryableError(err error) error {
return RetryableError{
err: err,
}
}

View file

@ -0,0 +1,17 @@
// +build !go1.18
package mssql
import (
"database/sql/driver"
)
// newRetryableError returns an error that allows the database/sql package
// to automatically retry the failed query. Versions of Go lower than 1.18
// compare directly to the sentinel error driver.ErrBadConn to determine
// whether or not a failed query can be retried. Therefore, we replace the
// actual error with driver.ErrBadConn, enabling retry but losing the error
// details.
func newRetryableError(err error) error {
return driver.ErrBadConn
}

View file

@ -66,10 +66,10 @@ func convertInputParameter(val interface{}) (interface{}, error) {
func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error {
switch v := nv.Value.(type) {
case sql.Out:
if c.outs == nil {
c.outs = make(map[string]interface{})
if c.outs.params == nil {
c.outs.params = make(map[string]interface{})
}
c.outs[nv.Name] = v.Dest
c.outs.params[nv.Name] = v.Dest
if v.Dest == nil {
return errors.New("destination is a nil pointer")
@ -110,7 +110,7 @@ func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error {
return nil
case *ReturnStatus:
*v = 0 // By default the return value should be zero.
c.returnStatus = v
c.outs.returnStatus = v
return driver.ErrRemoveArgument
case TVP:
return nil
@ -194,3 +194,8 @@ func (s *Stmt) makeParamExtra(val driver.Value) (res param, err error) {
func scanIntoOut(name string, fromServer, scanInto interface{}) error {
return convertAssign(scanInto, fromServer)
}
func isOutputValue(val driver.Value) bool {
_, out := val.(sql.Out)
return out
}

View file

@ -14,3 +14,7 @@ func (s *Stmt) makeParamExtra(val driver.Value) (param, error) {
func scanIntoOut(name string, fromServer, scanInto interface{}) error {
return fmt.Errorf("mssql: unsupported OUTPUT type, use a newer Go version")
}
func isOutputValue(val driver.Value) bool {
return false
}

View file

@ -7,8 +7,8 @@ import (
)
type timeoutConn struct {
c net.Conn
timeout time.Duration
c net.Conn
timeout time.Duration
}
func newTimeoutConn(conn net.Conn, timeout time.Duration) *timeoutConn {
@ -51,21 +51,21 @@ func (c timeoutConn) RemoteAddr() net.Addr {
}
func (c timeoutConn) SetDeadline(t time.Time) error {
panic("Not implemented")
return c.c.SetDeadline(t)
}
func (c timeoutConn) SetReadDeadline(t time.Time) error {
panic("Not implemented")
return c.c.SetReadDeadline(t)
}
func (c timeoutConn) SetWriteDeadline(t time.Time) error {
panic("Not implemented")
return c.c.SetWriteDeadline(t)
}
// this connection is used during TLS Handshake
// TDS protocol requires TLS handshake messages to be sent inside TDS packets
type tlsHandshakeConn struct {
buf *tdsBuffer
buf *tdsBuffer
packetPending bool
continueRead bool
}
@ -75,7 +75,7 @@ func (c *tlsHandshakeConn) Read(b []byte) (n int, err error) {
c.packetPending = false
err = c.buf.FinishPacket()
if err != nil {
err = fmt.Errorf("Cannot send handshake packet: %s", err.Error())
err = fmt.Errorf("cannot send handshake packet: %s", err.Error())
return
}
c.continueRead = false
@ -84,7 +84,7 @@ func (c *tlsHandshakeConn) Read(b []byte) (n int, err error) {
var packet packetType
packet, err = c.buf.BeginRead()
if err != nil {
err = fmt.Errorf("Cannot read handshake packet: %s", err.Error())
err = fmt.Errorf("cannot read handshake packet: %s", err.Error())
return
}
if packet != packPrelogin {
@ -105,27 +105,27 @@ func (c *tlsHandshakeConn) Write(b []byte) (n int, err error) {
}
func (c *tlsHandshakeConn) Close() error {
panic("Not implemented")
return c.buf.transport.Close()
}
func (c *tlsHandshakeConn) LocalAddr() net.Addr {
panic("Not implemented")
return nil
}
func (c *tlsHandshakeConn) RemoteAddr() net.Addr {
panic("Not implemented")
return nil
}
func (c *tlsHandshakeConn) SetDeadline(t time.Time) error {
panic("Not implemented")
func (c *tlsHandshakeConn) SetDeadline(_ time.Time) error {
return nil
}
func (c *tlsHandshakeConn) SetReadDeadline(t time.Time) error {
panic("Not implemented")
func (c *tlsHandshakeConn) SetReadDeadline(_ time.Time) error {
return nil
}
func (c *tlsHandshakeConn) SetWriteDeadline(t time.Time) error {
panic("Not implemented")
func (c *tlsHandshakeConn) SetWriteDeadline(_ time.Time) error {
return nil
}
// this connection just delegates all methods to it's wrapped connection
@ -148,21 +148,21 @@ func (c passthroughConn) Close() error {
}
func (c passthroughConn) LocalAddr() net.Addr {
panic("Not implemented")
return c.c.LocalAddr()
}
func (c passthroughConn) RemoteAddr() net.Addr {
panic("Not implemented")
return c.c.RemoteAddr()
}
func (c passthroughConn) SetDeadline(t time.Time) error {
panic("Not implemented")
return c.c.SetDeadline(t)
}
func (c passthroughConn) SetReadDeadline(t time.Time) error {
panic("Not implemented")
return c.c.SetReadDeadline(t)
}
func (c passthroughConn) SetWriteDeadline(t time.Time) error {
panic("Not implemented")
return c.c.SetWriteDeadline(t)
}

View file

@ -14,6 +14,7 @@ import (
"time"
"unicode/utf16"
//lint:ignore SA1019 MD4 is used by legacy NTLM
"golang.org/x/crypto/md4"
)
@ -126,18 +127,6 @@ func createDesKey(bytes, material []byte) {
material[7] = (byte)(bytes[6] << 1)
}
func oddParity(bytes []byte) {
for i := 0; i < len(bytes); i++ {
b := bytes[i]
needsParity := (((b >> 7) ^ (b >> 6) ^ (b >> 5) ^ (b >> 4) ^ (b >> 3) ^ (b >> 2) ^ (b >> 1)) & 0x01) == 0
if needsParity {
bytes[i] = bytes[i] | byte(0x01)
} else {
bytes[i] = bytes[i] & byte(0xfe)
}
}
}
func encryptDes(key []byte, cleartext []byte, ciphertext []byte) {
var desKey [8]byte
createDesKey(key, desKey[:])

View file

@ -22,12 +22,6 @@ type param struct {
buffer []byte
}
const (
fWithRecomp = 1
fNoMetaData = 2
fReuseMetaData = 4
)
var (
sp_Cursor = procId{1, ""}
sp_CursorOpen = procId{2, ""}

View file

@ -3,7 +3,6 @@ package mssql
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/binary"
"errors"
"fmt"
@ -13,8 +12,11 @@ import (
"sort"
"strconv"
"strings"
"time"
"unicode/utf16"
"unicode/utf8"
"github.com/denisenkom/go-mssqldb/msdsn"
)
func parseInstances(msg []byte) map[string]map[string]string {
@ -82,19 +84,20 @@ const (
// https://msdn.microsoft.com/en-us/library/dd304214.aspx
const (
packSQLBatch packetType = 1
packRPCRequest = 3
packReply = 4
packRPCRequest packetType = 3
packReply packetType = 4
// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
packAttention = 6
packAttention packetType = 6
packBulkLoadBCP = 7
packTransMgrReq = 14
packNormal = 15
packLogin7 = 16
packSSPIMessage = 17
packPrelogin = 18
packBulkLoadBCP packetType = 7
packFedAuthToken packetType = 8
packTransMgrReq packetType = 14
packNormal packetType = 15
packLogin7 packetType = 16
packSSPIMessage packetType = 17
packPrelogin packetType = 18
)
// prelogin fields
@ -118,6 +121,17 @@ const (
encryptReq = 3 // Encryption is required.
)
const (
featExtSESSIONRECOVERY byte = 0x01
featExtFEDAUTH byte = 0x02
featExtCOLUMNENCRYPTION byte = 0x04
featExtGLOBALTRANSACTIONS byte = 0x05
featExtAZURESQLSUPPORT byte = 0x08
featExtDATACLASSIFICATION byte = 0x09
featExtUTF8SUPPORT byte = 0x0A
featExtTERMINATOR byte = 0xFF
)
type tdsSession struct {
buf *tdsBuffer
loginAck loginAckStruct
@ -132,13 +146,21 @@ type tdsSession struct {
}
const (
logErrors = 1
logMessages = 2
logRows = 4
logSQL = 8
logParams = 16
logTransaction = 32
logDebug = 64
// Default packet size for a TDS buffer.
defaultPacketSize = 4096
// Default port if no port given.
defaultServerPort = 1433
)
const (
logErrors = uint64(msdsn.LogErrors)
logMessages = uint64(msdsn.LogMessages)
logRows = uint64(msdsn.LogRows)
logSQL = uint64(msdsn.LogSQL)
logParams = uint64(msdsn.LogParams)
logTransaction = uint64(msdsn.LogTransaction)
logDebug = uint64(msdsn.LogDebug)
)
type columnStruct struct {
@ -155,13 +177,13 @@ func (p keySlice) Less(i, j int) bool { return p[i] < p[j] }
func (p keySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// http://msdn.microsoft.com/en-us/library/dd357559.aspx
func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error {
func writePrelogin(packetType packetType, w *tdsBuffer, fields map[uint8][]byte) error {
var err error
w.BeginPacket(packPrelogin, false)
w.BeginPacket(packetType, false)
offset := uint16(5*len(fields) + 1)
keys := make(keySlice, 0, len(fields))
for k, _ := range fields {
for k := range fields {
keys = append(keys, k)
}
sort.Sort(keys)
@ -210,12 +232,15 @@ func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) {
if err != nil {
return nil, err
}
if packet_type != 4 {
return nil, errors.New("Invalid respones, expected packet type 4, PRELOGIN RESPONSE")
if packet_type != packReply {
return nil, errors.New("invalid respones, expected packet type 4, PRELOGIN RESPONSE")
}
if len(struct_buf) == 0 {
return nil, errors.New("invalid empty PRELOGIN response, it must contain at least one byte")
}
offset := 0
results := map[uint8][]byte{}
for true {
for {
rec_type := struct_buf[offset]
if rec_type == preloginTERMINATOR {
break
@ -240,6 +265,16 @@ const (
fIntSecurity = 0x80
)
// OptionFlags3
// http://msdn.microsoft.com/en-us/library/dd304019.aspx
const (
fChangePassword = 1
fSendYukonBinaryXML = 2
fUserInstance = 4
fUnknownCollationHandling = 8
fExtension = 0x10
)
// TypeFlags
const (
// 4 bits for fSQLType
@ -247,12 +282,6 @@ const (
fReadOnlyIntent = 32
)
// OptionFlags3
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac
const (
fExtension = 0x10
)
type login struct {
TDSVersion uint32
PacketSize uint32
@ -295,7 +324,7 @@ func (e *featureExts) Add(f featureExt) error {
}
id := f.featureID()
if _, exists := e.features[id]; exists {
f := "Login error: Feature with ID '%v' is already present in FeatureExt block."
f := "login error: Feature with ID '%v' is already present in FeatureExt block"
return fmt.Errorf(f, id)
}
if e.features == nil {
@ -326,37 +355,63 @@ func (e featureExts) toBytes() []byte {
return d
}
type featureExtFedAuthSTS struct {
FedAuthEcho bool
// featureExtFedAuth tracks federated authentication state before and during login
type featureExtFedAuth struct {
// FedAuthLibrary is populated by the federated authentication provider.
FedAuthLibrary int
// ADALWorkflow is populated by the federated authentication provider.
ADALWorkflow byte
// FedAuthEcho is populated from the prelogin response
FedAuthEcho bool
// FedAuthToken is populated during login with the value from the provider.
FedAuthToken string
Nonce []byte
// Nonce is populated during login with the value from the provider.
Nonce []byte
// Signature is populated during login with the value from the server.
Signature []byte
}
func (e *featureExtFedAuthSTS) featureID() byte {
return 0x02
func (e *featureExtFedAuth) featureID() byte {
return featExtFEDAUTH
}
func (e *featureExtFedAuthSTS) toBytes() []byte {
func (e *featureExtFedAuth) toBytes() []byte {
if e == nil {
return nil
}
options := byte(0x01) << 1 // 0x01 => STS bFedAuthLibrary 7BIT
options := byte(e.FedAuthLibrary) << 1
if e.FedAuthEcho {
options |= 1 // fFedAuthEcho
}
d := make([]byte, 5)
d[0] = options
// Feature extension format depends on the federated auth library.
// Options are described at
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac
var d []byte
// looks like string in
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/f88b63bb-b479-49e1-a87b-deda521da508
tokenBytes := str2ucs2(e.FedAuthToken)
binary.LittleEndian.PutUint32(d[1:], uint32(len(tokenBytes))) // Should be a signed int32, but since the length is relatively small, this should work
d = append(d, tokenBytes...)
switch e.FedAuthLibrary {
case fedAuthLibrarySecurityToken:
d = make([]byte, 5)
d[0] = options
if len(e.Nonce) == 32 {
d = append(d, e.Nonce...)
// looks like string in
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/f88b63bb-b479-49e1-a87b-deda521da508
tokenBytes := str2ucs2(e.FedAuthToken)
binary.LittleEndian.PutUint32(d[1:], uint32(len(tokenBytes))) // Should be a signed int32, but since the length is relatively small, this should work
d = append(d, tokenBytes...)
if len(e.Nonce) == 32 {
d = append(d, e.Nonce...)
}
case fedAuthLibraryADAL:
d = []byte{options, e.ADALWorkflow}
}
return d
@ -418,7 +473,7 @@ func str2ucs2(s string) []byte {
func ucs22str(s []byte) (string, error) {
if len(s)%2 != 0 {
return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s))
return "", fmt.Errorf("illegal UCS2 string length: %d", len(s))
}
buf := make([]uint16, len(s)/2)
for i := 0; i < len(s); i += 2 {
@ -436,7 +491,7 @@ func manglePassword(password string) []byte {
}
// http://msdn.microsoft.com/en-us/library/dd304019.aspx
func sendLogin(w *tdsBuffer, login login) error {
func sendLogin(w *tdsBuffer, login *login) error {
w.BeginPacket(packLogin7, false)
hostname := str2ucs2(login.HostName)
username := str2ucs2(login.UserName)
@ -572,6 +627,36 @@ func sendLogin(w *tdsBuffer, login login) error {
return w.FinishPacket()
}
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/827d9632-2957-4d54-b9ea-384530ae79d0
func sendFedAuthInfo(w *tdsBuffer, fedAuth *featureExtFedAuth) (err error) {
fedauthtoken := str2ucs2(fedAuth.FedAuthToken)
tokenlen := len(fedauthtoken)
datalen := 4 + tokenlen + len(fedAuth.Nonce)
w.BeginPacket(packFedAuthToken, false)
err = binary.Write(w, binary.LittleEndian, uint32(datalen))
if err != nil {
return
}
err = binary.Write(w, binary.LittleEndian, uint32(tokenlen))
if err != nil {
return
}
_, err = w.Write(fedauthtoken)
if err != nil {
return
}
_, err = w.Write(fedAuth.Nonce)
if err != nil {
return
}
return w.FinishPacket()
}
func readUcs2(r io.Reader, numchars int) (res string, err error) {
buf := make([]byte, numchars*2)
_, err = io.ReadFull(r, buf)
@ -768,26 +853,27 @@ type auth interface {
// SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
// list of IP addresses. So if there is more than one, try them all and
// use the first one that allows a connection.
func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn net.Conn, err error) {
func dialConnection(ctx context.Context, c *Connector, p msdsn.Config) (conn net.Conn, err error) {
var ips []net.IP
ips, err = net.LookupIP(p.host)
if err != nil {
ip := net.ParseIP(p.host)
if ip == nil {
return nil, err
ip := net.ParseIP(p.Host)
if ip == nil {
ips, err = net.LookupIP(p.Host)
if err != nil {
return
}
} else {
ips = []net.IP{ip}
}
if len(ips) == 1 {
d := c.getDialer(&p)
addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(resolveServerPort(p.port))))
addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(resolveServerPort(p.Port))))
conn, err = d.DialContext(ctx, "tcp", addr)
} else {
//Try Dials in parallel to avoid waiting for timeouts.
connChan := make(chan net.Conn, len(ips))
errChan := make(chan error, len(ips))
portStr := strconv.Itoa(int(resolveServerPort(p.port)))
portStr := strconv.Itoa(int(resolveServerPort(p.Port)))
for _, ip := range ips {
go func(ip net.IP) {
d := c.getDialer(&p)
@ -802,7 +888,7 @@ func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn ne
}
// Wait for either the *first* successful connection, or all the errors
wait_loop:
for i, _ := range ips {
for i := range ips {
select {
case conn = <-connChan:
// Got a connection to use, close any others
@ -824,66 +910,28 @@ func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn ne
}
// Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection
if conn == nil {
f := "Unable to open tcp connection with host '%v:%v': %v"
return nil, fmt.Errorf(f, p.host, resolveServerPort(p.port), err.Error())
f := "unable to open tcp connection with host '%v:%v': %v"
return nil, fmt.Errorf(f, p.Host, resolveServerPort(p.Port), err.Error())
}
return conn, err
}
func connect(ctx context.Context, c *Connector, log optionalLogger, p connectParams) (res *tdsSession, err error) {
dialCtx := ctx
if p.dial_timeout > 0 {
var cancel func()
dialCtx, cancel = context.WithTimeout(ctx, p.dial_timeout)
defer cancel()
}
// if instance is specified use instance resolution service
if p.instance != "" && p.port == 0 {
p.instance = strings.ToUpper(p.instance)
d := c.getDialer(&p)
instances, err := getInstances(dialCtx, d, p.host)
if err != nil {
f := "Unable to get instances from Sql Server Browser on host %v: %v"
return nil, fmt.Errorf(f, p.host, err.Error())
}
strport, ok := instances[p.instance]["tcp"]
if !ok {
f := "No instance matching '%v' returned from host '%v'"
return nil, fmt.Errorf(f, p.instance, p.host)
}
port, err := strconv.ParseUint(strport, 0, 16)
if err != nil {
f := "Invalid tcp port returned from Sql Server Browser '%v': %v"
return nil, fmt.Errorf(f, strport, err.Error())
}
p.port = port
}
initiate_connection:
conn, err := dialConnection(dialCtx, c, p)
if err != nil {
return nil, err
}
toconn := newTimeoutConn(conn, p.conn_timeout)
outbuf := newTdsBuffer(p.packetSize, toconn)
sess := tdsSession{
buf: outbuf,
log: log,
logFlags: p.logFlags,
}
instance_buf := []byte(p.instance)
func preparePreloginFields(p msdsn.Config, fe *featureExtFedAuth) map[uint8][]byte {
instance_buf := []byte(p.Instance)
instance_buf = append(instance_buf, 0) // zero terminate instance name
var encrypt byte
if p.disableEncryption {
switch p.Encryption {
default:
panic(fmt.Errorf("Unsupported Encryption Config %v", p.Encryption))
case msdsn.EncryptionDisabled:
encrypt = encryptNotSup
} else if p.encrypt {
case msdsn.EncryptionRequired:
encrypt = encryptOn
} else {
case msdsn.EncryptionOff:
encrypt = encryptOff
}
fields := map[uint8][]byte{
preloginVERSION: {0, 0, 0, 0, 0, 0},
preloginENCRYPTION: {encrypt},
@ -892,7 +940,182 @@ initiate_connection:
preloginMARS: {0}, // MARS disabled
}
err = writePrelogin(outbuf, fields)
if fe.FedAuthLibrary != fedAuthLibraryReserved {
fields[preloginFEDAUTHREQUIRED] = []byte{1}
}
return fields
}
func interpretPreloginResponse(p msdsn.Config, fe *featureExtFedAuth, fields map[uint8][]byte) (encrypt byte, err error) {
// If the server returns the preloginFEDAUTHREQUIRED field, then federated authentication
// is supported. The actual value may be 0 or 1, where 0 means either SSPI or federated
// authentication is allowed, while 1 means only federated authentication is allowed.
if fedAuthSupport, ok := fields[preloginFEDAUTHREQUIRED]; ok {
if len(fedAuthSupport) != 1 {
return 0, fmt.Errorf("federated authentication flag length should be 1: is %d", len(fedAuthSupport))
}
// We need to be able to echo the value back to the server
fe.FedAuthEcho = fedAuthSupport[0] != 0
} else if fe.FedAuthLibrary != fedAuthLibraryReserved {
return 0, fmt.Errorf("federated authentication is not supported by the server")
}
encryptBytes, ok := fields[preloginENCRYPTION]
if !ok {
return 0, fmt.Errorf("encrypt negotiation failed")
}
encrypt = encryptBytes[0]
if p.Encryption == msdsn.EncryptionRequired && (encrypt == encryptNotSup || encrypt == encryptOff) {
return 0, fmt.Errorf("server does not support encryption")
}
return
}
func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, log optionalLogger, auth auth, fe *featureExtFedAuth, packetSize uint32) (l *login, err error) {
var typeFlags uint8
if p.ReadOnlyIntent {
typeFlags |= fReadOnlyIntent
}
l = &login{
TDSVersion: verTDS74,
PacketSize: packetSize,
Database: p.Database,
OptionFlags2: fODBC, // to get unlimited TEXTSIZE
HostName: p.Workstation,
ServerName: p.Host,
AppName: p.AppName,
TypeFlags: typeFlags,
}
switch {
case fe.FedAuthLibrary == fedAuthLibrarySecurityToken:
if uint64(p.LogFlags)&logDebug != 0 {
log.Println("Starting federated authentication using security token")
}
fe.FedAuthToken, err = c.securityTokenProvider(ctx)
if err != nil {
if uint64(p.LogFlags)&logDebug != 0 {
log.Printf("Failed to retrieve service principal token for federated authentication security token library: %v", err)
}
return nil, err
}
l.FeatureExt.Add(fe)
case fe.FedAuthLibrary == fedAuthLibraryADAL:
if uint64(p.LogFlags)&logDebug != 0 {
log.Println("Starting federated authentication using ADAL")
}
l.FeatureExt.Add(fe)
case auth != nil:
if uint64(p.LogFlags)&logDebug != 0 {
log.Println("Starting SSPI login")
}
l.SSPI, err = auth.InitialBytes()
if err != nil {
return nil, err
}
l.OptionFlags2 |= fIntSecurity
return l, nil
default:
// Default to SQL server authentication with user and password
l.UserName = p.User
l.Password = p.Password
}
return l, nil
}
func connect(ctx context.Context, c *Connector, log optionalLogger, p msdsn.Config) (res *tdsSession, err error) {
dialCtx := ctx
if p.DialTimeout >= 0 {
dt := p.DialTimeout
if dt == 0 {
dt = 15 * time.Second
}
var cancel func()
dialCtx, cancel = context.WithTimeout(ctx, dt)
defer cancel()
}
// if instance is specified use instance resolution service
if len(p.Instance) > 0 && p.Port != 0 {
// both instance name and port specified
// when port is specified instance name is not used
// you should not provide instance name when you provide port
log.Println("WARN: You specified both instance name and port in the connection string, port will be used and instance name will be ignored")
}
if len(p.Instance) > 0 {
p.Instance = strings.ToUpper(p.Instance)
d := c.getDialer(&p)
instances, err := getInstances(dialCtx, d, p.Host)
if err != nil {
f := "unable to get instances from Sql Server Browser on host %v: %v"
return nil, fmt.Errorf(f, p.Host, err.Error())
}
strport, ok := instances[p.Instance]["tcp"]
if !ok {
f := "no instance matching '%v' returned from host '%v'"
return nil, fmt.Errorf(f, p.Instance, p.Host)
}
port, err := strconv.ParseUint(strport, 0, 16)
if err != nil {
f := "invalid tcp port returned from Sql Server Browser '%v': %v"
return nil, fmt.Errorf(f, strport, err.Error())
}
p.Port = port
}
if p.Port == 0 {
p.Port = defaultServerPort
}
packetSize := p.PacketSize
if packetSize == 0 {
packetSize = defaultPacketSize
}
// Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes
// NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request
// a higher packet size, the server will respond with an ENVCHANGE request to
// alter the packet size to 16383 bytes.
if packetSize < 512 {
packetSize = 512
} else if packetSize > 32767 {
packetSize = 32767
}
initiate_connection:
conn, err := dialConnection(dialCtx, c, p)
if err != nil {
return nil, err
}
toconn := newTimeoutConn(conn, p.ConnTimeout)
outbuf := newTdsBuffer(packetSize, toconn)
sess := tdsSession{
buf: outbuf,
log: log,
logFlags: uint64(p.LogFlags),
}
fedAuth := &featureExtFedAuth{
FedAuthLibrary: fedAuthLibraryReserved,
}
if c.fedAuthRequired {
fedAuth.FedAuthLibrary = c.fedAuthLibrary
fedAuth.ADALWorkflow = c.fedAuthADALWorkflow
}
fields := preparePreloginFields(p, fedAuth)
err = writePrelogin(packPrelogin, outbuf, fields)
if err != nil {
return nil, err
}
@ -902,39 +1125,36 @@ initiate_connection:
return nil, err
}
encryptBytes, ok := fields[preloginENCRYPTION]
if !ok {
return nil, fmt.Errorf("Encrypt negotiation failed")
}
encrypt = encryptBytes[0]
if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) {
return nil, fmt.Errorf("Server does not support encryption")
encrypt, err := interpretPreloginResponse(p, fedAuth, fields)
if err != nil {
return nil, err
}
if encrypt != encryptNotSup {
var config tls.Config
if p.certificate != "" {
pem, err := ioutil.ReadFile(p.certificate)
if err != nil {
return nil, fmt.Errorf("Cannot read certificate %q: %v", p.certificate, err)
var config *tls.Config
if pc := p.TLSConfig; pc != nil {
config = pc
if config.DynamicRecordSizingDisabled == false {
config = config.Clone()
// fix for https://github.com/denisenkom/go-mssqldb/issues/166
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
config.DynamicRecordSizingDisabled = true
}
certs := x509.NewCertPool()
certs.AppendCertsFromPEM(pem)
config.RootCAs = certs
}
if p.trustServerCertificate {
config.InsecureSkipVerify = true
if config == nil {
config, err = msdsn.SetupTLS("", false, p.Host)
if err != nil {
return nil, err
}
}
config.ServerName = p.hostInCertificate
// fix for https://github.com/denisenkom/go-mssqldb/issues/166
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
config.DynamicRecordSizingDisabled = true
// setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream
handshakeConn := tlsHandshakeConn{buf: outbuf}
passthrough := passthroughConn{c: &handshakeConn}
tlsConn := tls.Client(&passthrough, &config)
tlsConn := tls.Client(&passthrough, config)
err = tlsConn.Handshake()
passthrough.c = toconn
outbuf.transport = tlsConn
@ -948,54 +1168,48 @@ initiate_connection:
}
}
login := login{
TDSVersion: verTDS74,
PacketSize: uint32(outbuf.PackageSize()),
Database: p.database,
OptionFlags2: fODBC, // to get unlimited TEXTSIZE
HostName: p.workstation,
ServerName: p.host,
AppName: p.appname,
TypeFlags: p.typeFlags,
}
auth, authOk := getAuth(p.user, p.password, p.serverSPN, p.workstation)
switch {
case p.fedAuthAccessToken != "": // accesstoken ignores user/password
featurext := &featureExtFedAuthSTS{
FedAuthEcho: len(fields[preloginFEDAUTHREQUIRED]) > 0 && fields[preloginFEDAUTHREQUIRED][0] == 1,
FedAuthToken: p.fedAuthAccessToken,
Nonce: fields[preloginNONCEOPT],
}
login.FeatureExt.Add(featurext)
case authOk:
login.SSPI, err = auth.InitialBytes()
if err != nil {
return nil, err
}
login.OptionFlags2 |= fIntSecurity
auth, authOk := getAuth(p.User, p.Password, p.ServerSPN, p.Workstation)
if authOk {
defer auth.Free()
default:
login.UserName = p.user
login.Password = p.password
} else {
auth = nil
}
login, err := prepareLogin(ctx, c, p, log, auth, fedAuth, uint32(outbuf.PackageSize()))
if err != nil {
return nil, err
}
err = sendLogin(outbuf, login)
if err != nil {
return nil, err
}
// processing login response
success := false
for {
tokchan := make(chan tokenStruct, 5)
go processResponse(context.Background(), &sess, tokchan, nil)
for tok := range tokchan {
// Loop until a packet containing a login acknowledgement is received.
// SSPI and federated authentication scenarios may require multiple
// packet exchanges to complete the login sequence.
for loginAck := false; !loginAck; {
reader := startReading(&sess, ctx, outputs{})
// don't send attention or wait for cancel confirmation during login
reader.noAttn = true
for {
tok, err := reader.nextToken()
if err != nil {
return nil, err
}
if tok == nil {
break
}
switch token := tok.(type) {
case sspiMsg:
sspi_msg, err := auth.NextBytes(token)
if err != nil {
return nil, err
}
if sspi_msg != nil && len(sspi_msg) > 0 {
if len(sspi_msg) > 0 {
outbuf.BeginPacket(packSSPIMessage, false)
_, err = outbuf.Write(sspi_msg)
if err != nil {
@ -1007,31 +1221,59 @@ initiate_connection:
}
sspi_msg = nil
}
// TODO: for Live ID authentication it may be necessary to
// compare fedAuth.Nonce == token.Nonce and keep track of signature
//case fedAuthAckStruct:
//fedAuth.Signature = token.Signature
case fedAuthInfoStruct:
// For ADAL workflows this contains the STS URL and server SPN.
// If received outside of an ADAL workflow, ignore.
if c == nil || c.adalTokenProvider == nil {
continue
}
// Request the AD token given the server SPN and STS URL
fedAuth.FedAuthToken, err = c.adalTokenProvider(ctx, token.ServerSPN, token.STSURL)
if err != nil {
return nil, err
}
// Now need to send the token as a FEDINFO packet
err = sendFedAuthInfo(outbuf, fedAuth)
if err != nil {
return nil, err
}
case loginAckStruct:
success = true
sess.loginAck = token
case error:
return nil, fmt.Errorf("Login error: %s", token.Error())
loginAck = true
case doneStruct:
if token.isError() {
return nil, fmt.Errorf("Login error: %s", token.getError())
tokenErr := token.getError()
tokenErr.Message = "login error: " + tokenErr.Message
return nil, tokenErr
}
goto loginEnd
case error:
return nil, fmt.Errorf("login error: %s", token.Error())
}
}
}
loginEnd:
if !success {
return nil, fmt.Errorf("Login failed")
}
if sess.routedServer != "" {
toconn.Close()
p.host = sess.routedServer
p.port = uint64(sess.routedPort)
if !p.hostInCertificateProvided {
p.hostInCertificate = sess.routedServer
p.Host = sess.routedServer
p.Port = uint64(sess.routedPort)
if !p.HostInCertificateProvided && p.TLSConfig != nil {
p.TLSConfig.ServerName = sess.routedServer
}
goto initiate_connection
}
return &sess, nil
}
func resolveServerPort(port uint64) uint64 {
if port == 0 {
return defaultServerPort
}
return port
}

View file

@ -6,12 +6,11 @@ import (
"errors"
"fmt"
"io"
"net"
"io/ioutil"
"strconv"
"strings"
)
//go:generate stringer -type token
//go:generate go run golang.org/x/tools/cmd/stringer -type token
type token byte
@ -29,6 +28,7 @@ const (
tokenNbcRow token = 210 // 0xd2
tokenEnvChange token = 227 // 0xE3
tokenSSPI token = 237 // 0xED
tokenFedAuthInfo token = 238 // 0xEE
tokenDone token = 253 // 0xFD
tokenDoneProc token = 254
tokenDoneInProc token = 255
@ -70,6 +70,11 @@ const (
envRouting = 20
)
const (
fedAuthInfoSTSURL = 0x01
fedAuthInfoSPN = 0x02
)
// COLMETADATA flags
// https://msdn.microsoft.com/en-us/library/dd357363.aspx
const (
@ -96,35 +101,18 @@ func (d doneStruct) isError() bool {
}
func (d doneStruct) getError() Error {
if len(d.errors) > 0 {
return d.errors[len(d.errors)-1]
} else {
n := len(d.errors)
if n == 0 {
return Error{Message: "Request failed but didn't provide reason"}
}
err := d.errors[n-1]
err.All = make([]Error, n)
copy(err.All, d.errors)
return err
}
type doneInProcStruct doneStruct
var doneFlags2str = map[uint16]string{
doneFinal: "final",
doneMore: "more",
doneError: "error",
doneInxact: "inxact",
doneCount: "count",
doneAttn: "attn",
doneSrvError: "srverror",
}
func doneFlags2Str(flags uint16) string {
strs := make([]string, 0, len(doneFlags2str))
for flag, tag := range doneFlags2str {
if flags&flag != 0 {
strs = append(strs, tag)
}
}
return strings.Join(strs, "|")
}
// ENVCHANGE stream
// http://msdn.microsoft.com/en-us/library/dd303449.aspx
func processEnvChg(sess *tdsSession) {
@ -380,9 +368,8 @@ func processEnvChg(sess *tdsSession) {
default:
// ignore rest of records because we don't know how to skip those
sess.log.Printf("WARN: Unknown ENVCHANGE record detected with type id = %d\n", envtype)
break
return
}
}
}
@ -425,6 +412,78 @@ func parseSSPIMsg(r *tdsBuffer) sspiMsg {
return sspiMsg(buf)
}
type fedAuthInfoStruct struct {
STSURL string
ServerSPN string
}
type fedAuthInfoOpt struct {
fedAuthInfoID byte
dataLength, dataOffset uint32
}
func parseFedAuthInfo(r *tdsBuffer) fedAuthInfoStruct {
size := r.uint32()
var STSURL, SPN string
var err error
// Each fedAuthInfoOpt is one byte to indicate the info ID,
// then a four byte offset and a four byte length.
count := r.uint32()
offset := uint32(4)
opts := make([]fedAuthInfoOpt, count)
for i := uint32(0); i < count; i++ {
fedAuthInfoID := r.byte()
dataLength := r.uint32()
dataOffset := r.uint32()
offset += 1 + 4 + 4
opts[i] = fedAuthInfoOpt{
fedAuthInfoID: fedAuthInfoID,
dataLength: dataLength,
dataOffset: dataOffset,
}
}
data := make([]byte, size-offset)
r.ReadFull(data)
for i := uint32(0); i < count; i++ {
if opts[i].dataOffset < offset {
badStreamPanicf("Fed auth info opt stated data offset %d is before data begins in packet at %d",
opts[i].dataOffset, offset)
// returns via panic
}
if opts[i].dataOffset+opts[i].dataLength > size {
badStreamPanicf("Fed auth info opt stated data length %d added to stated offset exceeds size of packet %d",
opts[i].dataOffset+opts[i].dataLength, size)
// returns via panic
}
optData := data[opts[i].dataOffset-offset : opts[i].dataOffset-offset+opts[i].dataLength]
switch opts[i].fedAuthInfoID {
case fedAuthInfoSTSURL:
STSURL, err = ucs22str(optData)
case fedAuthInfoSPN:
SPN, err = ucs22str(optData)
default:
err = fmt.Errorf("unexpected fed auth info opt ID %d", int(opts[i].fedAuthInfoID))
}
if err != nil {
badStreamPanic(err)
}
}
return fedAuthInfoStruct{
STSURL: STSURL,
ServerSPN: SPN,
}
}
type loginAckStruct struct {
Interface uint8
TDSVersion uint32
@ -449,19 +508,43 @@ func parseLoginAck(r *tdsBuffer) loginAckStruct {
}
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/2eb82f8e-11f0-46dc-b42d-27302fa4701a
func parseFeatureExtAck(r *tdsBuffer) {
// at most 1 featureAck per feature in featureExt
// go-mssqldb will add at most 1 feature, the spec defines 7 different features
for i := 0; i < 8; i++ {
featureID := r.byte() // FeatureID
if featureID == 0xff {
return
type fedAuthAckStruct struct {
Nonce []byte
Signature []byte
}
func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} {
ack := map[byte]interface{}{}
for feature := r.byte(); feature != featExtTERMINATOR; feature = r.byte() {
length := r.uint32()
switch feature {
case featExtFEDAUTH:
// In theory we need to know the federated authentication library to
// know how to parse, but the alternatives provide compatible structures.
fedAuthAck := fedAuthAckStruct{}
if length >= 32 {
fedAuthAck.Nonce = make([]byte, 32)
r.ReadFull(fedAuthAck.Nonce)
length -= 32
}
if length >= 32 {
fedAuthAck.Signature = make([]byte, 32)
r.ReadFull(fedAuthAck.Signature)
length -= 32
}
ack[feature] = fedAuthAck
}
// Skip unprocessed bytes
if length > 0 {
io.CopyN(ioutil.Discard, r, int64(length))
}
size := r.uint32() // FeatureAckDataLen
d := make([]byte, size)
r.ReadFull(d)
}
panic("parsed more than 7 featureAck's, protocol implementation error?")
return ack
}
// http://msdn.microsoft.com/en-us/library/dd357363.aspx
@ -555,7 +638,7 @@ func parseReturnValue(r *tdsBuffer) (nv namedValue) {
return
}
func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) {
func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs outputs) {
defer func() {
if err := recover(); err != nil {
if sess.logFlags&logErrors != 0 {
@ -579,7 +662,7 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin
}
var columns []columnStruct
errs := make([]Error, 0, 5)
for {
for tokens := 0; ; tokens += 1 {
token := token(sess.buf.byte())
if sess.logFlags&logDebug != 0 {
sess.log.Printf("got token %v", token)
@ -588,6 +671,9 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin
case tokenSSPI:
ch <- parseSSPIMsg(sess.buf)
return
case tokenFedAuthInfo:
ch <- parseFedAuthInfo(sess.buf)
return
case tokenReturnStatus:
returnStatus := parseReturnStatus(sess.buf)
ch <- returnStatus
@ -595,7 +681,8 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin
loginAck := parseLoginAck(sess.buf)
ch <- loginAck
case tokenFeatureExtAck:
parseFeatureExtAck(sess.buf)
featureExtAck := parseFeatureExtAck(sess.buf)
ch <- featureExtAck
case tokenOrder:
order := parseOrder(sess.buf)
ch <- order
@ -612,7 +699,7 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin
sess.log.Printf("got DONE or DONEPROC status=%d", done.Status)
}
if done.Status&doneSrvError != 0 {
ch <- errors.New("SQL Server had internal error")
ch <- ServerError{done.getError()}
return
}
if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 {
@ -656,7 +743,7 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin
nv := parseReturnValue(sess.buf)
if len(nv.Name) > 0 {
name := nv.Name[1:] // Remove the leading "@".
if ov, has := outs[name]; has {
if ov, has := outs.params[name]; has {
err = scanIntoOut(name, nv.Value, ov)
if err != nil {
fmt.Println("scan error", err)
@ -670,154 +757,144 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin
}
}
type parseRespIter byte
const (
parseRespIterContinue parseRespIter = iota // Continue parsing current token.
parseRespIterNext // Fetch the next token.
parseRespIterDone // Done with parsing the response.
)
type parseRespState byte
const (
parseRespStateNormal parseRespState = iota // Normal response state.
parseRespStateCancel // Query is canceled, wait for server to confirm.
parseRespStateClosing // Waiting for tokens to come through.
)
type parseResp struct {
sess *tdsSession
ctxDone <-chan struct{}
state parseRespState
cancelError error
type tokenProcessor struct {
tokChan chan tokenStruct
ctx context.Context
sess *tdsSession
outs outputs
lastRow []interface{}
rowCount int64
firstError error
// whether to skip sending attention when ctx is done
noAttn bool
}
func (ts *parseResp) sendAttention(ch chan tokenStruct) parseRespIter {
if err := sendAttention(ts.sess.buf); err != nil {
ts.dlogf("failed to send attention signal %v", err)
ch <- err
return parseRespIterDone
}
ts.state = parseRespStateCancel
return parseRespIterContinue
}
func (ts *parseResp) dlog(msg string) {
if ts.sess.logFlags&logDebug != 0 {
ts.sess.log.Println(msg)
}
}
func (ts *parseResp) dlogf(f string, v ...interface{}) {
if ts.sess.logFlags&logDebug != 0 {
ts.sess.log.Printf(f, v...)
}
}
func (ts *parseResp) iter(ctx context.Context, ch chan tokenStruct, tokChan chan tokenStruct) parseRespIter {
switch ts.state {
default:
panic("unknown state")
case parseRespStateNormal:
select {
case tok, ok := <-tokChan:
if !ok {
ts.dlog("response finished")
return parseRespIterDone
}
if err, ok := tok.(net.Error); ok && err.Timeout() {
ts.cancelError = err
ts.dlog("got timeout error, sending attention signal to server")
return ts.sendAttention(ch)
}
// Pass the token along.
ch <- tok
return parseRespIterContinue
case <-ts.ctxDone:
ts.ctxDone = nil
ts.dlog("got cancel message, sending attention signal to server")
return ts.sendAttention(ch)
}
case parseRespStateCancel: // Read all responses until a DONE or error is received.Auth
select {
case tok, ok := <-tokChan:
if !ok {
ts.dlog("response finished but waiting for attention ack")
return parseRespIterNext
}
switch tok := tok.(type) {
default:
// Ignore all other tokens while waiting.
// The TDS spec says other tokens may arrive after an attention
// signal is sent. Ignore these tokens and continue looking for
// a DONE with attention confirm mark.
case doneStruct:
if tok.Status&doneAttn != 0 {
ts.dlog("got cancellation confirmation from server")
if ts.cancelError != nil {
ch <- ts.cancelError
ts.cancelError = nil
} else {
ch <- ctx.Err()
}
return parseRespIterDone
}
// If an error happens during cancel, pass it along and just stop.
// We are uncertain to receive more tokens.
case error:
ch <- tok
ts.state = parseRespStateClosing
}
return parseRespIterContinue
case <-ts.ctxDone:
ts.ctxDone = nil
ts.state = parseRespStateClosing
return parseRespIterContinue
}
case parseRespStateClosing: // Wait for current token chan to close.
if _, ok := <-tokChan; !ok {
ts.dlog("response finished")
return parseRespIterDone
}
return parseRespIterContinue
}
}
func processResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) {
ts := &parseResp{
func startReading(sess *tdsSession, ctx context.Context, outs outputs) *tokenProcessor {
tokChan := make(chan tokenStruct, 5)
go processSingleResponse(sess, tokChan, outs)
return &tokenProcessor{
tokChan: tokChan,
ctx: ctx,
sess: sess,
ctxDone: ctx.Done(),
outs: outs,
}
defer func() {
// Ensure any remaining error is piped through
// or the query may look like it executed when it actually failed.
if ts.cancelError != nil {
ch <- ts.cancelError
ts.cancelError = nil
}
close(ch)
}()
}
// Loop over multiple responses.
func (t *tokenProcessor) iterateResponse() error {
for {
ts.dlog("initiating response reading")
tokChan := make(chan tokenStruct)
go processSingleResponse(sess, tokChan, outs)
// Loop over multiple tokens in response.
tokensLoop:
for {
switch ts.iter(ctx, ch, tokChan) {
case parseRespIterContinue:
// Nothing, continue to next token.
case parseRespIterNext:
break tokensLoop
case parseRespIterDone:
return
tok, err := t.nextToken()
if err == nil {
if tok == nil {
return t.firstError
} else {
switch token := tok.(type) {
case []columnStruct:
t.sess.columns = token
case []interface{}:
t.lastRow = token
case doneInProcStruct:
if token.Status&doneCount != 0 {
t.rowCount += int64(token.RowCount)
}
case doneStruct:
if token.Status&doneCount != 0 {
t.rowCount += int64(token.RowCount)
}
if token.isError() && t.firstError == nil {
t.firstError = token.getError()
}
case ReturnStatus:
if t.outs.returnStatus != nil {
*t.outs.returnStatus = token
}
/*case error:
if resultError == nil {
resultError = token
}*/
}
}
} else {
return err
}
}
}
func (t tokenProcessor) nextToken() (tokenStruct, error) {
// we do this separate non-blocking check on token channel to
// prioritize it over cancellation channel
select {
case tok, more := <-t.tokChan:
err, more := tok.(error)
if more {
// this is an error and not a token
return nil, err
} else {
return tok, nil
}
default:
// there are no tokens on the channel, will need to wait
}
select {
case tok, more := <-t.tokChan:
if more {
err, ok := tok.(error)
if ok {
// this is an error and not a token
return nil, err
} else {
return tok, nil
}
} else {
// completed reading response
return nil, nil
}
case <-t.ctx.Done():
if t.noAttn {
return nil, t.ctx.Err()
}
if err := sendAttention(t.sess.buf); err != nil {
// unable to send attention, current connection is bad
// notify caller and close channel
return nil, err
}
// now the server should send cancellation confirmation
// it is possible that we already received full response
// just before we sent cancellation request
// in this case current response would not contain confirmation
// and we would need to read one more response
// first lets finish reading current response and look
// for confirmation in it
if readCancelConfirmation(t.tokChan) {
// we got confirmation in current response
return nil, t.ctx.Err()
}
// we did not get cancellation confirmation in the current response
// read one more response, it must be there
t.tokChan = make(chan tokenStruct, 5)
go processSingleResponse(t.sess, t.tokChan, t.outs)
if readCancelConfirmation(t.tokChan) {
return nil, t.ctx.Err()
}
// we did not get cancellation confirmation, something is not
// right, this connection is not usable anymore
return nil, errors.New("did not get cancellation confirmation from the server")
}
}
func readCancelConfirmation(tokChan chan tokenStruct) bool {
for tok := range tokChan {
switch tok := tok.(type) {
default:
// just skip token
case doneStruct:
if tok.Status&doneAttn != 0 {
// got cancellation confirmation, exit
return true
}
}
}
return false
}

View file

@ -1,29 +1,24 @@
// Code generated by "stringer -type token"; DO NOT EDIT
// Code generated by "stringer -type token"; DO NOT EDIT.
package mssql
import "fmt"
import "strconv"
const (
_token_name_0 = "tokenReturnStatus"
_token_name_1 = "tokenColMetadata"
_token_name_2 = "tokenOrdertokenErrortokenInfo"
_token_name_3 = "tokenLoginAck"
_token_name_4 = "tokenRowtokenNbcRow"
_token_name_5 = "tokenEnvChange"
_token_name_6 = "tokenSSPI"
_token_name_7 = "tokenDonetokenDoneProctokenDoneInProc"
_token_name_2 = "tokenOrdertokenErrortokenInfotokenReturnValuetokenLoginAcktokenFeatureExtAck"
_token_name_3 = "tokenRowtokenNbcRow"
_token_name_4 = "tokenEnvChange"
_token_name_5 = "tokenSSPItokenFedAuthInfo"
_token_name_6 = "tokenDonetokenDoneProctokenDoneInProc"
)
var (
_token_index_0 = [...]uint8{0, 17}
_token_index_1 = [...]uint8{0, 16}
_token_index_2 = [...]uint8{0, 10, 20, 29}
_token_index_3 = [...]uint8{0, 13}
_token_index_4 = [...]uint8{0, 8, 19}
_token_index_5 = [...]uint8{0, 14}
_token_index_6 = [...]uint8{0, 9}
_token_index_7 = [...]uint8{0, 9, 22, 37}
_token_index_2 = [...]uint8{0, 10, 20, 29, 45, 58, 76}
_token_index_3 = [...]uint8{0, 8, 19}
_token_index_5 = [...]uint8{0, 9, 25}
_token_index_6 = [...]uint8{0, 9, 22, 37}
)
func (i token) String() string {
@ -32,22 +27,21 @@ func (i token) String() string {
return _token_name_0
case i == 129:
return _token_name_1
case 169 <= i && i <= 171:
case 169 <= i && i <= 174:
i -= 169
return _token_name_2[_token_index_2[i]:_token_index_2[i+1]]
case i == 173:
return _token_name_3
case 209 <= i && i <= 210:
i -= 209
return _token_name_4[_token_index_4[i]:_token_index_4[i+1]]
return _token_name_3[_token_index_3[i]:_token_index_3[i+1]]
case i == 227:
return _token_name_5
case i == 237:
return _token_name_6
return _token_name_4
case 237 <= i && i <= 238:
i -= 237
return _token_name_5[_token_index_5[i]:_token_index_5[i+1]]
case 253 <= i && i <= 255:
i -= 253
return _token_name_7[_token_index_7[i]:_token_index_7[i+1]]
return _token_name_6[_token_index_6[i]:_token_index_6[i+1]]
default:
return fmt.Sprintf("token(%d)", i)
return "token(" + strconv.FormatInt(int64(i), 10) + ")"
}
}

View file

@ -21,11 +21,11 @@ type isoLevel uint8
const (
isolationUseCurrent isoLevel = 0
isolationReadUncommited = 1
isolationReadCommited = 2
isolationRepeatableRead = 3
isolationSerializable = 4
isolationSnapshot = 5
isolationReadUncommited isoLevel = 1
isolationReadCommited isoLevel = 2
isolationRepeatableRead isoLevel = 3
isolationSerializable isoLevel = 4
isolationSnapshot isoLevel = 5
)
func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation isoLevel, name string, resetSession bool) (err error) {

View file

@ -4,6 +4,7 @@ package mssql
import (
"bytes"
"database/sql"
"encoding/binary"
"errors"
"fmt"
@ -97,6 +98,9 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
for columnStrIdx, fieldIdx := range tvpFieldIndexes {
field := refStr.Field(fieldIdx)
tvpVal := field.Interface()
if tvp.verifyStandardTypeOnNull(buf, tvpVal) {
continue
}
valOf := reflect.ValueOf(tvpVal)
elemKind := field.Kind()
if elemKind == reflect.Ptr && valOf.IsNil() {
@ -155,7 +159,7 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {
defaultValues = append(defaultValues, v.Interface())
continue
}
defaultValues = append(defaultValues, reflect.Zero(field.Type).Interface())
defaultValues = append(defaultValues, tvp.createZeroType(reflect.Zero(field.Type).Interface()))
}
if columnCount-len(tvpFieldIndexes) == columnCount {
@ -209,19 +213,23 @@ func getSchemeAndName(tvpName string) (string, string, error) {
}
splitVal := strings.Split(tvpName, ".")
if len(splitVal) > 2 {
return "", "", errors.New("wrong tvp name")
return "", "", ErrorObjectName
}
const (
openSquareBrackets = "["
closeSquareBrackets = "]"
)
if len(splitVal) == 2 {
res := make([]string, 2)
for key, value := range splitVal {
tmp := strings.Replace(value, "[", "", -1)
tmp = strings.Replace(tmp, "]", "", -1)
tmp := strings.Replace(value, openSquareBrackets, "", -1)
tmp = strings.Replace(tmp, closeSquareBrackets, "", -1)
res[key] = tmp
}
return res[0], res[1], nil
}
tmp := strings.Replace(splitVal[0], "[", "", -1)
tmp = strings.Replace(tmp, "]", "", -1)
tmp := strings.Replace(splitVal[0], openSquareBrackets, "", -1)
tmp = strings.Replace(tmp, closeSquareBrackets, "", -1)
return "", tmp, nil
}
@ -229,3 +237,56 @@ func getSchemeAndName(tvpName string) (string, string, error) {
func getCountSQLSeparators(str string) int {
return strings.Count(str, sqlSeparator)
}
// verify types https://golang.org/pkg/database/sql/
func (tvp TVP) createZeroType(fieldVal interface{}) interface{} {
const (
defaultBool = false
defaultFloat64 = float64(0)
defaultInt64 = int64(0)
defaultString = ""
)
switch fieldVal.(type) {
case sql.NullBool:
return defaultBool
case sql.NullFloat64:
return defaultFloat64
case sql.NullInt64:
return defaultInt64
case sql.NullString:
return defaultString
}
return fieldVal
}
// verify types https://golang.org/pkg/database/sql/
func (tvp TVP) verifyStandardTypeOnNull(buf *bytes.Buffer, tvpVal interface{}) bool {
const (
defaultNull = uint8(0)
)
switch val := tvpVal.(type) {
case sql.NullBool:
if !val.Valid {
binary.Write(buf, binary.LittleEndian, defaultNull)
return true
}
case sql.NullFloat64:
if !val.Valid {
binary.Write(buf, binary.LittleEndian, defaultNull)
return true
}
case sql.NullInt64:
if !val.Valid {
binary.Write(buf, binary.LittleEndian, defaultNull)
return true
}
case sql.NullString:
if !val.Valid {
binary.Write(buf, binary.LittleEndian, uint64(_PLP_NULL))
return true
}
}
return false
}

View file

@ -665,7 +665,7 @@ func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} {
default:
buf = bytes.NewBuffer(make([]byte, 0, size))
}
for true {
for {
chunksize := r.uint32()
if chunksize == 0 {
break
@ -690,6 +690,10 @@ func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} {
}
func writePLPType(w io.Writer, ti typeInfo, buf []byte) (err error) {
if buf == nil {
err = binary.Write(w, binary.LittleEndian, uint64(_PLP_NULL))
return
}
if err = binary.Write(w, binary.LittleEndian, uint64(_UNKNOWN_PLP_LEN)); err != nil {
return
}
@ -807,7 +811,6 @@ func readVarLen(ti *typeInfo, r *tdsBuffer) {
default:
badStreamPanicf("Invalid type %d", ti.TypeId)
}
return
}
func decodeMoney(buf []byte) []byte {
@ -834,8 +837,7 @@ func decodeGuid(buf []byte) []byte {
}
func decodeDecimal(prec uint8, scale uint8, buf []byte) []byte {
var sign uint8
sign = buf[0]
sign := buf[0]
var dec decimal.Decimal
dec.SetPositive(sign != 0)
dec.SetPrec(prec)
@ -1187,7 +1189,7 @@ func makeDecl(ti typeInfo) string {
return fmt.Sprintf("char(%d)", ti.Size)
case typeBigVarChar, typeVarChar:
if ti.Size > 8000 || ti.Size == 0 {
return fmt.Sprintf("varchar(max)")
return "varchar(max)"
} else {
return fmt.Sprintf("varchar(%d)", ti.Size)
}

43
vendor/github.com/go-asn1-ber/asn1-ber/.travis.yml generated vendored Normal file
View file

@ -0,0 +1,43 @@
language: go
go:
- 1.2.x
- 1.6.x
- 1.9.x
- 1.10.x
- 1.11.x
- 1.12.x
- 1.14.x
- tip
os:
- linux
arch:
- amd64
- ppc64le
dist: xenial
env:
- GOARCH=amd64
jobs:
include:
- os: windows
go: 1.14.x
- os: osx
go: 1.14.x
- os: linux
go: 1.14.x
arch: arm64
- os: linux
go: 1.14.x
env:
- GOARCH=386
script:
- go test -v -cover ./... || go test -v ./...
matrix:
allowfailures:
go: 1.2.x

View file

@ -8,6 +8,8 @@ import (
"math"
"os"
"reflect"
"time"
"unicode/utf8"
)
// MaxPacketLengthBytes specifies the maximum allowed packet size when calling ReadPacket or DecodePacket. Set to 0 for
@ -143,42 +145,46 @@ var TypeMap = map[Type]string{
TypeConstructed: "Constructed",
}
var Debug bool = false
var Debug = false
func PrintBytes(out io.Writer, buf []byte, indent string) {
data_lines := make([]string, (len(buf)/30)+1)
num_lines := make([]string, (len(buf)/30)+1)
dataLines := make([]string, (len(buf)/30)+1)
numLines := make([]string, (len(buf)/30)+1)
for i, b := range buf {
data_lines[i/30] += fmt.Sprintf("%02x ", b)
num_lines[i/30] += fmt.Sprintf("%02d ", (i+1)%100)
dataLines[i/30] += fmt.Sprintf("%02x ", b)
numLines[i/30] += fmt.Sprintf("%02d ", (i+1)%100)
}
for i := 0; i < len(data_lines); i++ {
out.Write([]byte(indent + data_lines[i] + "\n"))
out.Write([]byte(indent + num_lines[i] + "\n\n"))
for i := 0; i < len(dataLines); i++ {
_, _ = out.Write([]byte(indent + dataLines[i] + "\n"))
_, _ = out.Write([]byte(indent + numLines[i] + "\n\n"))
}
}
func WritePacket(out io.Writer, p *Packet) {
printPacket(out, p, 0, false)
}
func PrintPacket(p *Packet) {
printPacket(os.Stdout, p, 0, false)
}
func printPacket(out io.Writer, p *Packet, indent int, printBytes bool) {
indent_str := ""
indentStr := ""
for len(indent_str) != indent {
indent_str += " "
for len(indentStr) != indent {
indentStr += " "
}
class_str := ClassMap[p.ClassType]
classStr := ClassMap[p.ClassType]
tagtype_str := TypeMap[p.TagType]
tagTypeStr := TypeMap[p.TagType]
tag_str := fmt.Sprintf("0x%02X", p.Tag)
tagStr := fmt.Sprintf("0x%02X", p.Tag)
if p.ClassType == ClassUniversal {
tag_str = tagMap[p.Tag]
tagStr = tagMap[p.Tag]
}
value := fmt.Sprint(p.Value)
@ -188,10 +194,10 @@ func printPacket(out io.Writer, p *Packet, indent int, printBytes bool) {
description = p.Description + ": "
}
fmt.Fprintf(out, "%s%s(%s, %s, %s) Len=%d %q\n", indent_str, description, class_str, tagtype_str, tag_str, p.Data.Len(), value)
_, _ = fmt.Fprintf(out, "%s%s(%s, %s, %s) Len=%d %q\n", indentStr, description, classStr, tagTypeStr, tagStr, p.Data.Len(), value)
if printBytes {
PrintBytes(out, p.Bytes(), indent_str)
PrintBytes(out, p.Bytes(), indentStr)
}
for _, child := range p.Children {
@ -199,7 +205,7 @@ func printPacket(out io.Writer, p *Packet, indent int, printBytes bool) {
}
}
// ReadPacket reads a single Packet from the reader
// ReadPacket reads a single Packet from the reader.
func ReadPacket(reader io.Reader) (*Packet, error) {
p, _, err := readPacket(reader)
if err != nil {
@ -235,7 +241,7 @@ func encodeInteger(i int64) []byte {
var j int
for ; n > 0; n-- {
out[j] = (byte(i >> uint((n-1)*8)))
out[j] = byte(i >> uint((n-1)*8))
j++
}
@ -267,7 +273,7 @@ func DecodePacket(data []byte) *Packet {
}
// DecodePacketErr decodes the given bytes into a single Packet
// If a decode error is encountered, nil is returned
// If a decode error is encountered, nil is returned.
func DecodePacketErr(data []byte) (*Packet, error) {
p, _, err := readPacket(bytes.NewBuffer(data))
if err != nil {
@ -276,7 +282,7 @@ func DecodePacketErr(data []byte) (*Packet, error) {
return p, nil
}
// readPacket reads a single Packet from the reader, returning the number of bytes read
// readPacket reads a single Packet from the reader, returning the number of bytes read.
func readPacket(reader io.Reader) (*Packet, int, error) {
identifier, length, read, err := readHeader(reader)
if err != nil {
@ -338,7 +344,7 @@ func readPacket(reader io.Reader) (*Packet, int, error) {
if MaxPacketLengthBytes > 0 && int64(length) > MaxPacketLengthBytes {
return nil, read, fmt.Errorf("length %d greater than maximum %d", length, MaxPacketLengthBytes)
}
content := make([]byte, length, length)
content := make([]byte, length)
if length > 0 {
_, err := io.ReadFull(reader, content)
if err != nil {
@ -373,22 +379,42 @@ func readPacket(reader io.Reader) (*Packet, int, error) {
case TagObjectDescriptor:
case TagExternal:
case TagRealFloat:
p.Value, err = ParseReal(content)
case TagEnumerated:
p.Value, _ = ParseInt64(content)
case TagEmbeddedPDV:
case TagUTF8String:
p.Value = DecodeString(content)
val := DecodeString(content)
if !utf8.Valid([]byte(val)) {
err = errors.New("invalid UTF-8 string")
} else {
p.Value = val
}
case TagRelativeOID:
case TagSequence:
case TagSet:
case TagNumericString:
case TagPrintableString:
p.Value = DecodeString(content)
val := DecodeString(content)
if err = isPrintableString(val); err == nil {
p.Value = val
}
case TagT61String:
case TagVideotexString:
case TagIA5String:
val := DecodeString(content)
for i, c := range val {
if c >= 0x7F {
err = fmt.Errorf("invalid character for IA5String at pos %d: %c", i, c)
break
}
}
if err == nil {
p.Value = val
}
case TagUTCTime:
case TagGeneralizedTime:
p.Value, err = ParseGeneralizedTime(content)
case TagGraphicString:
case TagVisibleString:
case TagGeneralString:
@ -400,7 +426,24 @@ func readPacket(reader io.Reader) (*Packet, int, error) {
p.Data.Write(content)
}
return p, read, nil
return p, read, err
}
func isPrintableString(val string) error {
for i, c := range val {
switch {
case c >= 'a' && c <= 'z':
case c >= 'A' && c <= 'Z':
case c >= '0' && c <= '9':
default:
switch c {
case '\'', '(', ')', '+', ',', '-', '.', '=', '/', ':', '?', ' ':
default:
return fmt.Errorf("invalid character in position %d", i)
}
}
}
return nil
}
func (p *Packet) Bytes() []byte {
@ -418,61 +461,99 @@ func (p *Packet) AppendChild(child *Packet) {
p.Children = append(p.Children, child)
}
func Encode(ClassType Class, TagType Type, Tag Tag, Value interface{}, Description string) *Packet {
func Encode(classType Class, tagType Type, tag Tag, value interface{}, description string) *Packet {
p := new(Packet)
p.ClassType = ClassType
p.TagType = TagType
p.Tag = Tag
p.ClassType = classType
p.TagType = tagType
p.Tag = tag
p.Data = new(bytes.Buffer)
p.Children = make([]*Packet, 0, 2)
p.Value = Value
p.Description = Description
p.Value = value
p.Description = description
if Value != nil {
v := reflect.ValueOf(Value)
if value != nil {
v := reflect.ValueOf(value)
if ClassType == ClassUniversal {
switch Tag {
if classType == ClassUniversal {
switch tag {
case TagOctetString:
sv, ok := v.Interface().(string)
if ok {
p.Data.Write([]byte(sv))
}
case TagEnumerated:
bv, ok := v.Interface().([]byte)
if ok {
p.Data.Write(bv)
}
case TagEmbeddedPDV:
bv, ok := v.Interface().([]byte)
if ok {
p.Data.Write(bv)
}
}
} else if classType == ClassContext {
switch tag {
case TagEnumerated:
bv, ok := v.Interface().([]byte)
if ok {
p.Data.Write(bv)
}
case TagEmbeddedPDV:
bv, ok := v.Interface().([]byte)
if ok {
p.Data.Write(bv)
}
}
}
}
return p
}
func NewSequence(Description string) *Packet {
return Encode(ClassUniversal, TypeConstructed, TagSequence, nil, Description)
func NewSequence(description string) *Packet {
return Encode(ClassUniversal, TypeConstructed, TagSequence, nil, description)
}
func NewBoolean(ClassType Class, TagType Type, Tag Tag, Value bool, Description string) *Packet {
func NewBoolean(classType Class, tagType Type, tag Tag, value bool, description string) *Packet {
intValue := int64(0)
if Value {
if value {
intValue = 1
}
p := Encode(ClassType, TagType, Tag, nil, Description)
p := Encode(classType, tagType, tag, nil, description)
p.Value = Value
p.Value = value
p.Data.Write(encodeInteger(intValue))
return p
}
func NewInteger(ClassType Class, TagType Type, Tag Tag, Value interface{}, Description string) *Packet {
p := Encode(ClassType, TagType, Tag, nil, Description)
// NewLDAPBoolean returns a RFC 4511-compliant Boolean packet.
func NewLDAPBoolean(classType Class, tagType Type, tag Tag, value bool, description string) *Packet {
intValue := int64(0)
p.Value = Value
switch v := Value.(type) {
if value {
intValue = 255
}
p := Encode(classType, tagType, tag, nil, description)
p.Value = value
p.Data.Write(encodeInteger(intValue))
return p
}
func NewInteger(classType Class, tagType Type, tag Tag, value interface{}, description string) *Packet {
p := Encode(classType, tagType, tag, nil, description)
p.Value = value
switch v := value.(type) {
case int:
p.Data.Write(encodeInteger(int64(v)))
case uint:
@ -502,11 +583,38 @@ func NewInteger(ClassType Class, TagType Type, Tag Tag, Value interface{}, Descr
return p
}
func NewString(ClassType Class, TagType Type, Tag Tag, Value, Description string) *Packet {
p := Encode(ClassType, TagType, Tag, nil, Description)
func NewString(classType Class, tagType Type, tag Tag, value, description string) *Packet {
p := Encode(classType, tagType, tag, nil, description)
p.Value = Value
p.Data.Write([]byte(Value))
p.Value = value
p.Data.Write([]byte(value))
return p
}
func NewGeneralizedTime(classType Class, tagType Type, tag Tag, value time.Time, description string) *Packet {
p := Encode(classType, tagType, tag, nil, description)
var s string
if value.Nanosecond() != 0 {
s = value.Format(`20060102150405.000000000Z`)
} else {
s = value.Format(`20060102150405Z`)
}
p.Value = s
p.Data.Write([]byte(s))
return p
}
func NewReal(classType Class, tagType Type, tag Tag, value interface{}, description string) *Packet {
p := Encode(classType, tagType, tag, nil, description)
switch v := value.(type) {
case float64:
p.Data.Write(encodeFloat(v))
case float32:
p.Data.Write(encodeFloat(float64(v)))
default:
panic(fmt.Sprintf("Invalid type %T, expected float{64|32}", v))
}
return p
}

View file

@ -6,7 +6,7 @@ func encodeUnsignedInteger(i uint64) []byte {
var j int
for ; n > 0; n-- {
out[j] = (byte(i >> uint((n-1)*8)))
out[j] = byte(i >> uint((n-1)*8))
j++
}

View file

@ -0,0 +1,105 @@
package ber
import (
"bytes"
"errors"
"fmt"
"strconv"
"time"
)
// ErrInvalidTimeFormat is returned when the generalizedTime string was not correct.
var ErrInvalidTimeFormat = errors.New("invalid time format")
var zeroTime = time.Time{}
// ParseGeneralizedTime parses a string value and if it conforms to
// GeneralizedTime[^0] format, will return a time.Time for that value.
//
// [^0]: https://www.itu.int/rec/T-REC-X.690-201508-I/en Section 11.7
func ParseGeneralizedTime(v []byte) (time.Time, error) {
var format string
var fract time.Duration
str := []byte(DecodeString(v))
tzIndex := bytes.IndexAny(str, "Z+-")
if tzIndex < 0 {
return zeroTime, ErrInvalidTimeFormat
}
dot := bytes.IndexAny(str, ".,")
switch dot {
case -1:
switch tzIndex {
case 10:
format = `2006010215Z`
case 12:
format = `200601021504Z`
case 14:
format = `20060102150405Z`
default:
return zeroTime, ErrInvalidTimeFormat
}
case 10, 12:
if tzIndex < dot {
return zeroTime, ErrInvalidTimeFormat
}
// a "," is also allowed, but would not be parsed by time.Parse():
str[dot] = '.'
// If <minute> is omitted, then <fraction> represents a fraction of an
// hour; otherwise, if <second> and <leap-second> are omitted, then
// <fraction> represents a fraction of a minute; otherwise, <fraction>
// represents a fraction of a second.
// parse as float from dot to timezone
f, err := strconv.ParseFloat(string(str[dot:tzIndex]), 64)
if err != nil {
return zeroTime, fmt.Errorf("failed to parse float: %s", err)
}
// ...and strip that part
str = append(str[:dot], str[tzIndex:]...)
tzIndex = dot
if dot == 10 {
fract = time.Duration(int64(f * float64(time.Hour)))
format = `2006010215Z`
} else {
fract = time.Duration(int64(f * float64(time.Minute)))
format = `200601021504Z`
}
case 14:
if tzIndex < dot {
return zeroTime, ErrInvalidTimeFormat
}
str[dot] = '.'
// no need for fractional seconds, time.Parse() handles that
format = `20060102150405Z`
default:
return zeroTime, ErrInvalidTimeFormat
}
l := len(str)
switch l - tzIndex {
case 1:
if str[l-1] != 'Z' {
return zeroTime, ErrInvalidTimeFormat
}
case 3:
format += `0700`
str = append(str, []byte("00")...)
case 5:
format += `0700`
default:
return zeroTime, ErrInvalidTimeFormat
}
t, err := time.Parse(format, string(str))
if err != nil {
return zeroTime, fmt.Errorf("%s: %s", ErrInvalidTimeFormat, err)
}
return t.Add(fract), nil
}

View file

@ -7,19 +7,22 @@ import (
)
func readHeader(reader io.Reader) (identifier Identifier, length int, read int, err error) {
if i, c, err := readIdentifier(reader); err != nil {
return Identifier{}, 0, read, err
} else {
identifier = i
read += c
}
var (
c, l int
i Identifier
)
if l, c, err := readLength(reader); err != nil {
if i, c, err = readIdentifier(reader); err != nil {
return Identifier{}, 0, read, err
} else {
length = l
read += c
}
identifier = i
read += c
if l, c, err = readLength(reader); err != nil {
return Identifier{}, 0, read, err
}
length = l
read += c
// Validate length type with identifier (x.600, 8.1.3.2.a)
if length == LengthIndefinite && identifier.TagType == TypePrimitive {

View file

@ -71,11 +71,11 @@ func readLength(reader io.Reader) (length int, read int, err error) {
}
func encodeLength(length int) []byte {
length_bytes := encodeUnsignedInteger(uint64(length))
if length > 127 || len(length_bytes) > 1 {
longFormBytes := []byte{(LengthLongFormBitmask | byte(len(length_bytes)))}
longFormBytes = append(longFormBytes, length_bytes...)
length_bytes = longFormBytes
lengthBytes := encodeUnsignedInteger(uint64(length))
if length > 127 || len(lengthBytes) > 1 {
longFormBytes := []byte{LengthLongFormBitmask | byte(len(lengthBytes))}
longFormBytes = append(longFormBytes, lengthBytes...)
lengthBytes = longFormBytes
}
return length_bytes
return lengthBytes
}

157
vendor/github.com/go-asn1-ber/asn1-ber/real.go generated vendored Normal file
View file

@ -0,0 +1,157 @@
package ber
import (
"bytes"
"errors"
"fmt"
"math"
"strconv"
"strings"
)
func encodeFloat(v float64) []byte {
switch {
case math.IsInf(v, 1):
return []byte{0x40}
case math.IsInf(v, -1):
return []byte{0x41}
case math.IsNaN(v):
return []byte{0x42}
case v == 0.0:
if math.Signbit(v) {
return []byte{0x43}
}
return []byte{}
default:
// we take the easy part ;-)
value := []byte(strconv.FormatFloat(v, 'G', -1, 64))
var ret []byte
if bytes.Contains(value, []byte{'E'}) {
ret = []byte{0x03}
} else {
ret = []byte{0x02}
}
ret = append(ret, value...)
return ret
}
}
func ParseReal(v []byte) (val float64, err error) {
if len(v) == 0 {
return 0.0, nil
}
switch {
case v[0]&0x80 == 0x80:
val, err = parseBinaryFloat(v)
case v[0]&0xC0 == 0x40:
val, err = parseSpecialFloat(v)
case v[0]&0xC0 == 0x0:
val, err = parseDecimalFloat(v)
default:
return 0.0, fmt.Errorf("invalid info block")
}
if err != nil {
return 0.0, err
}
if val == 0.0 && !math.Signbit(val) {
return 0.0, errors.New("REAL value +0 must be encoded with zero-length value block")
}
return val, nil
}
func parseBinaryFloat(v []byte) (float64, error) {
var info byte
var buf []byte
info, v = v[0], v[1:]
var base int
switch info & 0x30 {
case 0x00:
base = 2
case 0x10:
base = 8
case 0x20:
base = 16
case 0x30:
return 0.0, errors.New("bits 6 and 5 of information octet for REAL are equal to 11")
}
scale := uint((info & 0x0c) >> 2)
var expLen int
switch info & 0x03 {
case 0x00:
expLen = 1
case 0x01:
expLen = 2
case 0x02:
expLen = 3
case 0x03:
expLen = int(v[0])
if expLen > 8 {
return 0.0, errors.New("too big value of exponent")
}
v = v[1:]
}
buf, v = v[:expLen], v[expLen:]
exponent, err := ParseInt64(buf)
if err != nil {
return 0.0, err
}
if len(v) > 8 {
return 0.0, errors.New("too big value of mantissa")
}
mant, err := ParseInt64(v)
if err != nil {
return 0.0, err
}
mantissa := mant << scale
if info&0x40 == 0x40 {
mantissa = -mantissa
}
return float64(mantissa) * math.Pow(float64(base), float64(exponent)), nil
}
func parseDecimalFloat(v []byte) (val float64, err error) {
switch v[0] & 0x3F {
case 0x01: // NR form 1
var iVal int64
iVal, err = strconv.ParseInt(strings.TrimLeft(string(v[1:]), " "), 10, 64)
val = float64(iVal)
case 0x02, 0x03: // NR form 2, 3
val, err = strconv.ParseFloat(strings.Replace(strings.TrimLeft(string(v[1:]), " "), ",", ".", -1), 64)
default:
err = errors.New("incorrect NR form")
}
if err != nil {
return 0.0, err
}
if val == 0.0 && math.Signbit(val) {
return 0.0, errors.New("REAL value -0 must be encoded as a special value")
}
return val, nil
}
func parseSpecialFloat(v []byte) (float64, error) {
if len(v) != 1 {
return 0.0, errors.New(`encoding of "special value" must not contain exponent and mantissa`)
}
switch v[0] {
case 0x40:
return math.Inf(1), nil
case 0x41:
return math.Inf(-1), nil
case 0x42:
return math.NaN(), nil
case 0x43:
return math.Copysign(0, -1), nil
}
return 0.0, errors.New(`encoding of "special value" not from ASN.1 standard`)
}

View file

@ -3,7 +3,7 @@ package ber
import "io"
func readByte(reader io.Reader) (byte, error) {
bytes := make([]byte, 1, 1)
bytes := make([]byte, 1)
_, err := io.ReadFull(reader, bytes)
if err != nil {
if err == io.EOF {

View file

@ -1,18 +1,9 @@
//
// https://tools.ietf.org/html/rfc4511
//
// AddRequest ::= [APPLICATION 8] SEQUENCE {
// entry LDAPDN,
// attributes AttributeList }
//
// AttributeList ::= SEQUENCE OF attribute Attribute
package ldap
import (
"log"
ber "gopkg.in/asn1-ber.v1"
ber "github.com/go-asn1-ber/asn1-ber"
)
// Attribute represents an LDAP attribute

540
vendor/github.com/go-ldap/ldap/v3/bind.go generated vendored Normal file
View file

@ -0,0 +1,540 @@
package ldap
import (
"bytes"
"crypto/md5"
enchex "encoding/hex"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"strings"
"github.com/Azure/go-ntlmssp"
ber "github.com/go-asn1-ber/asn1-ber"
)
// SimpleBindRequest represents a username/password bind operation
type SimpleBindRequest struct {
// Username is the name of the Directory object that the client wishes to bind as
Username string
// Password is the credentials to bind with
Password string
// Controls are optional controls to send with the bind request
Controls []Control
// AllowEmptyPassword sets whether the client allows binding with an empty password
// (normally used for unauthenticated bind).
AllowEmptyPassword bool
}
// SimpleBindResult contains the response from the server
type SimpleBindResult struct {
Controls []Control
}
// NewSimpleBindRequest returns a bind request
func NewSimpleBindRequest(username string, password string, controls []Control) *SimpleBindRequest {
return &SimpleBindRequest{
Username: username,
Password: password,
Controls: controls,
AllowEmptyPassword: false,
}
}
func (req *SimpleBindRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Username, "User Name"))
pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.Password, "Password"))
envelope.AppendChild(pkt)
if len(req.Controls) > 0 {
envelope.AppendChild(encodeControls(req.Controls))
}
return nil
}
// SimpleBind performs the simple bind operation defined in the given request
func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResult, error) {
if simpleBindRequest.Password == "" && !simpleBindRequest.AllowEmptyPassword {
return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client"))
}
msgCtx, err := l.doRequest(simpleBindRequest)
if err != nil {
return nil, err
}
defer l.finishMessage(msgCtx)
packet, err := l.readPacket(msgCtx)
if err != nil {
return nil, err
}
result := &SimpleBindResult{
Controls: make([]Control, 0),
}
if len(packet.Children) == 3 {
for _, child := range packet.Children[2].Children {
decodedChild, decodeErr := DecodeControl(child)
if decodeErr != nil {
return nil, fmt.Errorf("failed to decode child control: %s", decodeErr)
}
result.Controls = append(result.Controls, decodedChild)
}
}
err = GetLDAPError(packet)
return result, err
}
// Bind performs a bind with the given username and password.
//
// It does not allow unauthenticated bind (i.e. empty password). Use the UnauthenticatedBind method
// for that.
func (l *Conn) Bind(username, password string) error {
req := &SimpleBindRequest{
Username: username,
Password: password,
AllowEmptyPassword: false,
}
_, err := l.SimpleBind(req)
return err
}
// UnauthenticatedBind performs an unauthenticated bind.
//
// A username may be provided for trace (e.g. logging) purpose only, but it is normally not
// authenticated or otherwise validated by the LDAP server.
//
// See https://tools.ietf.org/html/rfc4513#section-5.1.2 .
// See https://tools.ietf.org/html/rfc4513#section-6.3.1 .
func (l *Conn) UnauthenticatedBind(username string) error {
req := &SimpleBindRequest{
Username: username,
Password: "",
AllowEmptyPassword: true,
}
_, err := l.SimpleBind(req)
return err
}
// DigestMD5BindRequest represents a digest-md5 bind operation
type DigestMD5BindRequest struct {
Host string
// Username is the name of the Directory object that the client wishes to bind as
Username string
// Password is the credentials to bind with
Password string
// Controls are optional controls to send with the bind request
Controls []Control
}
func (req *DigestMD5BindRequest) appendTo(envelope *ber.Packet) error {
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name"))
auth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication")
auth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "DIGEST-MD5", "SASL Mech"))
request.AppendChild(auth)
envelope.AppendChild(request)
if len(req.Controls) > 0 {
envelope.AppendChild(encodeControls(req.Controls))
}
return nil
}
// DigestMD5BindResult contains the response from the server
type DigestMD5BindResult struct {
Controls []Control
}
// MD5Bind performs a digest-md5 bind with the given host, username and password.
func (l *Conn) MD5Bind(host, username, password string) error {
req := &DigestMD5BindRequest{
Host: host,
Username: username,
Password: password,
}
_, err := l.DigestMD5Bind(req)
return err
}
// DigestMD5Bind performs the digest-md5 bind operation defined in the given request
func (l *Conn) DigestMD5Bind(digestMD5BindRequest *DigestMD5BindRequest) (*DigestMD5BindResult, error) {
if digestMD5BindRequest.Password == "" {
return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client"))
}
msgCtx, err := l.doRequest(digestMD5BindRequest)
if err != nil {
return nil, err
}
defer l.finishMessage(msgCtx)
packet, err := l.readPacket(msgCtx)
if err != nil {
return nil, err
}
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if l.Debug {
if err = addLDAPDescriptions(packet); err != nil {
return nil, err
}
ber.PrintPacket(packet)
}
result := &DigestMD5BindResult{
Controls: make([]Control, 0),
}
var params map[string]string
if len(packet.Children) == 2 {
if len(packet.Children[1].Children) == 4 {
child := packet.Children[1].Children[0]
if child.Tag != ber.TagEnumerated {
return result, GetLDAPError(packet)
}
if child.Value.(int64) != 14 {
return result, GetLDAPError(packet)
}
child = packet.Children[1].Children[3]
if child.Tag != ber.TagObjectDescriptor {
return result, GetLDAPError(packet)
}
if child.Data == nil {
return result, GetLDAPError(packet)
}
data, _ := ioutil.ReadAll(child.Data)
params, err = parseParams(string(data))
if err != nil {
return result, fmt.Errorf("parsing digest-challenge: %s", err)
}
}
}
if params != nil {
resp := computeResponse(
params,
"ldap/"+strings.ToLower(digestMD5BindRequest.Host),
digestMD5BindRequest.Username,
digestMD5BindRequest.Password,
)
packet = ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name"))
auth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication")
auth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "DIGEST-MD5", "SASL Mech"))
auth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, resp, "Credentials"))
request.AppendChild(auth)
packet.AppendChild(request)
msgCtx, err = l.sendMessage(packet)
if err != nil {
return nil, fmt.Errorf("send message: %s", err)
}
defer l.finishMessage(msgCtx)
packetResponse, ok := <-msgCtx.responses
if !ok {
return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
}
packet, err = packetResponse.ReadPacket()
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if err != nil {
return nil, fmt.Errorf("read packet: %s", err)
}
}
err = GetLDAPError(packet)
return result, err
}
func parseParams(str string) (map[string]string, error) {
m := make(map[string]string)
var key, value string
var state int
for i := 0; i <= len(str); i++ {
switch state {
case 0: //reading key
if i == len(str) {
return nil, fmt.Errorf("syntax error on %d", i)
}
if str[i] != '=' {
key += string(str[i])
continue
}
state = 1
case 1: //reading value
if i == len(str) {
m[key] = value
break
}
switch str[i] {
case ',':
m[key] = value
state = 0
key = ""
value = ""
case '"':
if value != "" {
return nil, fmt.Errorf("syntax error on %d", i)
}
state = 2
default:
value += string(str[i])
}
case 2: //inside quotes
if i == len(str) {
return nil, fmt.Errorf("syntax error on %d", i)
}
if str[i] != '"' {
value += string(str[i])
} else {
state = 1
}
}
}
return m, nil
}
func computeResponse(params map[string]string, uri, username, password string) string {
nc := "00000001"
qop := "auth"
cnonce := enchex.EncodeToString(randomBytes(16))
x := username + ":" + params["realm"] + ":" + password
y := md5Hash([]byte(x))
a1 := bytes.NewBuffer(y)
a1.WriteString(":" + params["nonce"] + ":" + cnonce)
if len(params["authzid"]) > 0 {
a1.WriteString(":" + params["authzid"])
}
a2 := bytes.NewBuffer([]byte("AUTHENTICATE"))
a2.WriteString(":" + uri)
ha1 := enchex.EncodeToString(md5Hash(a1.Bytes()))
ha2 := enchex.EncodeToString(md5Hash(a2.Bytes()))
kd := ha1
kd += ":" + params["nonce"]
kd += ":" + nc
kd += ":" + cnonce
kd += ":" + qop
kd += ":" + ha2
resp := enchex.EncodeToString(md5Hash([]byte(kd)))
return fmt.Sprintf(
`username="%s",realm="%s",nonce="%s",cnonce="%s",nc=00000001,qop=%s,digest-uri="%s",response=%s`,
username,
params["realm"],
params["nonce"],
cnonce,
qop,
uri,
resp,
)
}
func md5Hash(b []byte) []byte {
hasher := md5.New()
hasher.Write(b)
return hasher.Sum(nil)
}
func randomBytes(len int) []byte {
b := make([]byte, len)
for i := 0; i < len; i++ {
b[i] = byte(rand.Intn(256))
}
return b
}
var externalBindRequest = requestFunc(func(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name"))
saslAuth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication")
saslAuth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "EXTERNAL", "SASL Mech"))
saslAuth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "SASL Cred"))
pkt.AppendChild(saslAuth)
envelope.AppendChild(pkt)
return nil
})
// ExternalBind performs SASL/EXTERNAL authentication.
//
// Use ldap.DialURL("ldapi://") to connect to the Unix socket before ExternalBind.
//
// See https://tools.ietf.org/html/rfc4422#appendix-A
func (l *Conn) ExternalBind() error {
msgCtx, err := l.doRequest(externalBindRequest)
if err != nil {
return err
}
defer l.finishMessage(msgCtx)
packet, err := l.readPacket(msgCtx)
if err != nil {
return err
}
return GetLDAPError(packet)
}
// NTLMBind performs an NTLMSSP bind leveraging https://github.com/Azure/go-ntlmssp
// NTLMBindRequest represents an NTLMSSP bind operation
type NTLMBindRequest struct {
// Domain is the AD Domain to authenticate too. If not specified, it will be grabbed from the NTLMSSP Challenge
Domain string
// Username is the name of the Directory object that the client wishes to bind as
Username string
// Password is the credentials to bind with
Password string
// Hash is the hex NTLM hash to bind with. Password or hash must be provided
Hash string
// Controls are optional controls to send with the bind request
Controls []Control
}
func (req *NTLMBindRequest) appendTo(envelope *ber.Packet) error {
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name"))
// generate an NTLMSSP Negotiation message for the specified domain (it can be blank)
negMessage, err := ntlmssp.NewNegotiateMessage(req.Domain, "")
if err != nil {
return fmt.Errorf("err creating negmessage: %s", err)
}
// append the generated NTLMSSP message as a TagEnumerated BER value
auth := ber.Encode(ber.ClassContext, ber.TypePrimitive, ber.TagEnumerated, negMessage, "authentication")
request.AppendChild(auth)
envelope.AppendChild(request)
if len(req.Controls) > 0 {
envelope.AppendChild(encodeControls(req.Controls))
}
return nil
}
// NTLMBindResult contains the response from the server
type NTLMBindResult struct {
Controls []Control
}
// NTLMBind performs an NTLMSSP Bind with the given domain, username and password
func (l *Conn) NTLMBind(domain, username, password string) error {
req := &NTLMBindRequest{
Domain: domain,
Username: username,
Password: password,
}
_, err := l.NTLMChallengeBind(req)
return err
}
// NTLMBindWithHash performs an NTLM Bind with an NTLM hash instead of plaintext password (pass-the-hash)
func (l *Conn) NTLMBindWithHash(domain, username, hash string) error {
req := &NTLMBindRequest{
Domain: domain,
Username: username,
Hash: hash,
}
_, err := l.NTLMChallengeBind(req)
return err
}
// NTLMChallengeBind performs the NTLMSSP bind operation defined in the given request
func (l *Conn) NTLMChallengeBind(ntlmBindRequest *NTLMBindRequest) (*NTLMBindResult, error) {
if ntlmBindRequest.Password == "" && ntlmBindRequest.Hash == "" {
return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client"))
}
msgCtx, err := l.doRequest(ntlmBindRequest)
if err != nil {
return nil, err
}
defer l.finishMessage(msgCtx)
packet, err := l.readPacket(msgCtx)
if err != nil {
return nil, err
}
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if l.Debug {
if err = addLDAPDescriptions(packet); err != nil {
return nil, err
}
ber.PrintPacket(packet)
}
result := &NTLMBindResult{
Controls: make([]Control, 0),
}
var ntlmsspChallenge []byte
// now find the NTLM Response Message
if len(packet.Children) == 2 {
if len(packet.Children[1].Children) == 3 {
child := packet.Children[1].Children[1]
ntlmsspChallenge = child.ByteValue
// Check to make sure we got the right message. It will always start with NTLMSSP
if len(ntlmsspChallenge) < 7 || !bytes.Equal(ntlmsspChallenge[:7], []byte("NTLMSSP")) {
return result, GetLDAPError(packet)
}
l.Debug.Printf("%d: found ntlmssp challenge", msgCtx.id)
}
}
if ntlmsspChallenge != nil {
var err error
var responseMessage []byte
// generate a response message to the challenge with the given Username/Password if password is provided
if ntlmBindRequest.Password != "" {
responseMessage, err = ntlmssp.ProcessChallenge(ntlmsspChallenge, ntlmBindRequest.Username, ntlmBindRequest.Password)
} else if ntlmBindRequest.Hash != "" {
responseMessage, err = ntlmssp.ProcessChallengeWithHash(ntlmsspChallenge, ntlmBindRequest.Username, ntlmBindRequest.Hash)
} else {
err = fmt.Errorf("need a password or hash to generate reply")
}
if err != nil {
return result, fmt.Errorf("parsing ntlm-challenge: %s", err)
}
packet = ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name"))
// append the challenge response message as a TagEmbeddedPDV BER value
auth := ber.Encode(ber.ClassContext, ber.TypePrimitive, ber.TagEmbeddedPDV, responseMessage, "authentication")
request.AppendChild(auth)
packet.AppendChild(request)
msgCtx, err = l.sendMessage(packet)
if err != nil {
return nil, fmt.Errorf("send message: %s", err)
}
defer l.finishMessage(msgCtx)
packetResponse, ok := <-msgCtx.responses
if !ok {
return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
}
packet, err = packetResponse.ReadPacket()
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if err != nil {
return nil, fmt.Errorf("read packet: %s", err)
}
}
err = GetLDAPError(packet)
return result, err
}

View file

@ -10,6 +10,7 @@ type Client interface {
Start()
StartTLS(*tls.Config) error
Close()
IsClosing() bool
SetTimeout(time.Duration)
Bind(username, password string) error
@ -21,6 +22,7 @@ type Client interface {
Del(*DelRequest) error
Modify(*ModifyRequest) error
ModifyDN(*ModifyDNRequest) error
ModifyWithResult(*ModifyRequest) (*ModifyResult, error)
Compare(dn, attribute, value string) (bool, error)
PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error)

View file

@ -1,28 +1,9 @@
// File contains Compare functionality
//
// https://tools.ietf.org/html/rfc4511
//
// CompareRequest ::= [APPLICATION 14] SEQUENCE {
// entry LDAPDN,
// ava AttributeValueAssertion }
//
// AttributeValueAssertion ::= SEQUENCE {
// attributeDesc AttributeDescription,
// assertionValue AssertionValue }
//
// AttributeDescription ::= LDAPString
// -- Constrained to <attributedescription>
// -- [RFC4512]
//
// AttributeValue ::= OCTET STRING
//
package ldap
import (
"fmt"
ber "gopkg.in/asn1-ber.v1"
ber "github.com/go-asn1-ber/asn1-ber"
)
// CompareRequest represents an LDAP CompareRequest operation.

View file

@ -1,6 +1,7 @@
package ldap
import (
"bufio"
"crypto/tls"
"errors"
"fmt"
@ -11,7 +12,7 @@ import (
"sync/atomic"
"time"
ber "gopkg.in/asn1-ber.v1"
ber "github.com/go-asn1-ber/asn1-ber"
)
const (
@ -112,8 +113,72 @@ var _ Client = &Conn{}
// multiple places will probably result in undesired behaviour.
var DefaultTimeout = 60 * time.Second
// DialOpt configures DialContext.
type DialOpt func(*DialContext)
// DialWithDialer updates net.Dialer in DialContext.
func DialWithDialer(d *net.Dialer) DialOpt {
return func(dc *DialContext) {
dc.d = d
}
}
// DialWithTLSConfig updates tls.Config in DialContext.
func DialWithTLSConfig(tc *tls.Config) DialOpt {
return func(dc *DialContext) {
dc.tc = tc
}
}
// DialWithTLSDialer is a wrapper for DialWithTLSConfig with the option to
// specify a net.Dialer to for example define a timeout or a custom resolver.
func DialWithTLSDialer(tlsConfig *tls.Config, dialer *net.Dialer) DialOpt {
return func(dc *DialContext) {
dc.tc = tlsConfig
dc.d = dialer
}
}
// DialContext contains necessary parameters to dial the given ldap URL.
type DialContext struct {
d *net.Dialer
tc *tls.Config
}
func (dc *DialContext) dial(u *url.URL) (net.Conn, error) {
if u.Scheme == "ldapi" {
if u.Path == "" || u.Path == "/" {
u.Path = "/var/run/slapd/ldapi"
}
return dc.d.Dial("unix", u.Path)
}
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
// we assume that error is due to missing port
host = u.Host
port = ""
}
switch u.Scheme {
case "ldap":
if port == "" {
port = DefaultLdapPort
}
return dc.d.Dial("tcp", net.JoinHostPort(host, port))
case "ldaps":
if port == "" {
port = DefaultLdapsPort
}
return tls.DialWithDialer(dc.d, "tcp", net.JoinHostPort(host, port), dc.tc)
}
return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme)
}
// Dial connects to the given address on the given network using net.Dial
// and then returns a new Conn for the connection.
// @deprecated Use DialURL instead.
func Dial(network, addr string) (*Conn, error) {
c, err := net.DialTimeout(network, addr, DefaultTimeout)
if err != nil {
@ -126,6 +191,7 @@ func Dial(network, addr string) (*Conn, error) {
// DialTLS connects to the given address on the given network using tls.Dial
// and then returns a new Conn for the connection.
// @deprecated Use DialURL instead.
func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
c, err := tls.DialWithDialer(&net.Dialer{Timeout: DefaultTimeout}, network, addr, config)
if err != nil {
@ -136,44 +202,31 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
return conn, nil
}
// DialURL connects to the given ldap URL vie TCP using tls.Dial or net.Dial if ldaps://
// or ldap:// specified as protocol. On success a new Conn for the connection
// is returned.
func DialURL(addr string) (*Conn, error) {
lurl, err := url.Parse(addr)
// DialURL connects to the given ldap URL.
// The following schemas are supported: ldap://, ldaps://, ldapi://.
// On success a new Conn for the connection is returned.
func DialURL(addr string, opts ...DialOpt) (*Conn, error) {
u, err := url.Parse(addr)
if err != nil {
return nil, NewError(ErrorNetwork, err)
}
host, port, err := net.SplitHostPort(lurl.Host)
var dc DialContext
for _, opt := range opts {
opt(&dc)
}
if dc.d == nil {
dc.d = &net.Dialer{Timeout: DefaultTimeout}
}
c, err := dc.dial(u)
if err != nil {
// we asume that error is due to missing port
host = lurl.Host
port = ""
return nil, NewError(ErrorNetwork, err)
}
switch lurl.Scheme {
case "ldapi":
if lurl.Path == "" || lurl.Path == "/" {
lurl.Path = "/var/run/slapd/ldapi"
}
return Dial("unix", lurl.Path)
case "ldap":
if port == "" {
port = DefaultLdapPort
}
return Dial("tcp", net.JoinHostPort(host, port))
case "ldaps":
if port == "" {
port = DefaultLdapsPort
}
tlsConf := &tls.Config{
ServerName: host,
}
return DialTLS("tcp", net.JoinHostPort(host, port), tlsConf)
}
return nil, NewError(ErrorNetwork, fmt.Errorf("Unknown scheme '%s'", lurl.Scheme))
conn := NewConn(c, u.Scheme == "ldaps")
conn.Start()
return conn, nil
}
// NewConn returns a new Conn using conn for network I/O.
@ -191,9 +244,9 @@ func NewConn(conn net.Conn, isTLS bool) *Conn {
// Start initializes goroutines to read responses and process messages
func (l *Conn) Start() {
l.wgClose.Add(1)
go l.reader()
go l.processMessages()
l.wgClose.Add(1)
}
// IsClosing returns whether or not we're currently closing.
@ -278,7 +331,7 @@ func (l *Conn) StartTLS(config *tls.Config) error {
l.Close()
return err
}
ber.PrintPacket(packet)
l.Debug.PrintPacket(packet)
}
if err := GetLDAPError(packet); err == nil {
@ -347,7 +400,12 @@ func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags)
responses: responses,
},
}
l.sendProcessMessage(message)
if !l.sendProcessMessage(message) {
if l.IsClosing() {
return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
}
return nil, NewError(ErrorNetwork, errors.New("ldap: could not send message for unknown reason"))
}
return message.Context, nil
}
@ -451,14 +509,14 @@ func (l *Conn) processMessages() {
msgCtx.sendResponse(&PacketResponse{message.Packet, nil})
} else {
log.Printf("Received unexpected message %d, %v", message.MessageID, l.IsClosing())
ber.PrintPacket(message.Packet)
l.Debug.PrintPacket(message.Packet)
}
case MessageTimeout:
// Handle the timeout by closing the channel
// All reads will return immediately
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")})
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))})
delete(l.messageContexts, message.MessageID)
close(msgCtx.responses)
}
@ -484,12 +542,13 @@ func (l *Conn) reader() {
}
}()
bufConn := bufio.NewReader(l.conn)
for {
if cleanstop {
l.Debug.Printf("reader clean stopping (without closing the connection)")
return
}
packet, err := ber.ReadPacket(l.conn)
packet, err := ber.ReadPacket(bufConn)
if err != nil {
// A read error is expected here if we are closing the connection...
if !l.IsClosing() {

View file

@ -4,7 +4,7 @@ import (
"fmt"
"strconv"
"gopkg.in/asn1-ber.v1"
ber "github.com/go-asn1-ber/asn1-ber"
)
const (
@ -18,20 +18,25 @@ const (
ControlTypeVChuPasswordWarning = "2.16.840.1.113730.3.4.5"
// ControlTypeManageDsaIT - https://tools.ietf.org/html/rfc3296
ControlTypeManageDsaIT = "2.16.840.1.113730.3.4.2"
// ControlTypeWhoAmI - https://tools.ietf.org/html/rfc4532
ControlTypeWhoAmI = "1.3.6.1.4.1.4203.1.11.3"
// ControlTypeMicrosoftNotification - https://msdn.microsoft.com/en-us/library/aa366983(v=vs.85).aspx
ControlTypeMicrosoftNotification = "1.2.840.113556.1.4.528"
// ControlTypeMicrosoftShowDeleted - https://msdn.microsoft.com/en-us/library/aa366989(v=vs.85).aspx
ControlTypeMicrosoftShowDeleted = "1.2.840.113556.1.4.417"
// ControlTypeMicrosoftServerLinkTTL - https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-adts/f4f523a8-abc0-4b3a-a471-6b2fef135481?redirectedfrom=MSDN
ControlTypeMicrosoftServerLinkTTL = "1.2.840.113556.1.4.2309"
)
// ControlTypeMap maps controls to text descriptions
var ControlTypeMap = map[string]string{
ControlTypePaging: "Paging",
ControlTypeBeheraPasswordPolicy: "Password Policy - Behera Draft",
ControlTypeManageDsaIT: "Manage DSA IT",
ControlTypeMicrosoftNotification: "Change Notification - Microsoft",
ControlTypeMicrosoftShowDeleted: "Show Deleted Objects - Microsoft",
ControlTypePaging: "Paging",
ControlTypeBeheraPasswordPolicy: "Password Policy - Behera Draft",
ControlTypeManageDsaIT: "Manage DSA IT",
ControlTypeMicrosoftNotification: "Change Notification - Microsoft",
ControlTypeMicrosoftShowDeleted: "Show Deleted Objects - Microsoft",
ControlTypeMicrosoftServerLinkTTL: "Return TTL-DNs for link values with associated expiry times - Microsoft",
}
// Control defines an interface controls provide to encode and describe themselves
@ -305,6 +310,35 @@ func NewControlMicrosoftShowDeleted() *ControlMicrosoftShowDeleted {
return &ControlMicrosoftShowDeleted{}
}
// ControlMicrosoftServerLinkTTL implements the control described in https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-adts/f4f523a8-abc0-4b3a-a471-6b2fef135481?redirectedfrom=MSDN
type ControlMicrosoftServerLinkTTL struct{}
// GetControlType returns the OID
func (c *ControlMicrosoftServerLinkTTL) GetControlType() string {
return ControlTypeMicrosoftServerLinkTTL
}
// Encode returns the ber packet representation
func (c *ControlMicrosoftServerLinkTTL) Encode() *ber.Packet {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeMicrosoftServerLinkTTL, "Control Type ("+ControlTypeMap[ControlTypeMicrosoftServerLinkTTL]+")"))
return packet
}
// String returns a human-readable description
func (c *ControlMicrosoftServerLinkTTL) String() string {
return fmt.Sprintf(
"Control Type: %s (%q)",
ControlTypeMap[ControlTypeMicrosoftServerLinkTTL],
ControlTypeMicrosoftServerLinkTTL)
}
// NewControlMicrosoftServerLinkTTL returns a ControlMicrosoftServerLinkTTL control
func NewControlMicrosoftServerLinkTTL() *ControlMicrosoftServerLinkTTL {
return &ControlMicrosoftServerLinkTTL{}
}
// FindControl returns the first control of the given type in the list, or nil
func FindControl(controls []Control, controlType string) Control {
for _, c := range controls {
@ -404,33 +438,26 @@ func DecodeControl(packet *ber.Packet) (Control, error) {
if child.Tag == 0 {
//Warning
warningPacket := child.Children[0]
packet, err := ber.DecodePacketErr(warningPacket.Data.Bytes())
val, err := ber.ParseInt64(warningPacket.Data.Bytes())
if err != nil {
return nil, fmt.Errorf("failed to decode data bytes: %s", err)
}
val, ok := packet.Value.(int64)
if ok {
if warningPacket.Tag == 0 {
//timeBeforeExpiration
c.Expire = val
warningPacket.Value = c.Expire
} else if warningPacket.Tag == 1 {
//graceAuthNsRemaining
c.Grace = val
warningPacket.Value = c.Grace
}
if warningPacket.Tag == 0 {
//timeBeforeExpiration
c.Expire = val
warningPacket.Value = c.Expire
} else if warningPacket.Tag == 1 {
//graceAuthNsRemaining
c.Grace = val
warningPacket.Value = c.Grace
}
} else if child.Tag == 1 {
// Error
packet, err := ber.DecodePacketErr(child.Data.Bytes())
if err != nil {
return nil, fmt.Errorf("failed to decode data bytes: %s", err)
}
val, ok := packet.Value.(int8)
if !ok {
// what to do?
val = -1
bs := child.Data.Bytes()
if len(bs) != 1 || bs[0] > 8 {
return nil, fmt.Errorf("failed to decode data bytes: %s", "invalid PasswordPolicyResponse enum value")
}
val := int8(bs[0])
c.Error = val
child.Value = c.Error
c.ErrorString = BeheraPasswordPolicyErrorMap[c.Error]
@ -456,6 +483,8 @@ func DecodeControl(packet *ber.Packet) (Control, error) {
return NewControlMicrosoftNotification(), nil
case ControlTypeMicrosoftShowDeleted:
return NewControlMicrosoftShowDeleted(), nil
case ControlTypeMicrosoftServerLinkTTL:
return NewControlMicrosoftServerLinkTTL(), nil
default:
c := new(ControlString)
c.ControlType = ControlType

View file

@ -3,7 +3,7 @@ package ldap
import (
"log"
ber "gopkg.in/asn1-ber.v1"
ber "github.com/go-asn1-ber/asn1-ber"
)
// debugging type
@ -25,6 +25,6 @@ func (debug debugging) Printf(format string, args ...interface{}) {
// PrintPacket dumps a packet.
func (debug debugging) PrintPacket(packet *ber.Packet) {
if debug {
ber.PrintPacket(packet)
ber.WritePacket(log.Writer(), packet)
}
}

View file

@ -1,14 +1,9 @@
//
// https://tools.ietf.org/html/rfc4511
//
// DelRequest ::= [APPLICATION 10] LDAPDN
package ldap
import (
"log"
ber "gopkg.in/asn1-ber.v1"
ber "github.com/go-asn1-ber/asn1-ber"
)
// DelRequest implements an LDAP deletion request

View file

@ -1,44 +1,3 @@
// File contains DN parsing functionality
//
// https://tools.ietf.org/html/rfc4514
//
// distinguishedName = [ relativeDistinguishedName
// *( COMMA relativeDistinguishedName ) ]
// relativeDistinguishedName = attributeTypeAndValue
// *( PLUS attributeTypeAndValue )
// attributeTypeAndValue = attributeType EQUALS attributeValue
// attributeType = descr / numericoid
// attributeValue = string / hexstring
//
// ; The following characters are to be escaped when they appear
// ; in the value to be encoded: ESC, one of <escaped>, leading
// ; SHARP or SPACE, trailing SPACE, and NULL.
// string = [ ( leadchar / pair ) [ *( stringchar / pair )
// ( trailchar / pair ) ] ]
//
// leadchar = LUTF1 / UTFMB
// LUTF1 = %x01-1F / %x21 / %x24-2A / %x2D-3A /
// %x3D / %x3F-5B / %x5D-7F
//
// trailchar = TUTF1 / UTFMB
// TUTF1 = %x01-1F / %x21 / %x23-2A / %x2D-3A /
// %x3D / %x3F-5B / %x5D-7F
//
// stringchar = SUTF1 / UTFMB
// SUTF1 = %x01-21 / %x23-2A / %x2D-3A /
// %x3D / %x3F-5B / %x5D-7F
//
// pair = ESC ( ESC / special / hexpair )
// special = escaped / SPACE / SHARP / EQUALS
// escaped = DQUOTE / PLUS / COMMA / SEMI / LANGLE / RANGLE
// hexstring = SHARP 1*hexpair
// hexpair = HEX HEX
//
// where the productions <descr>, <numericoid>, <COMMA>, <DQUOTE>,
// <EQUALS>, <ESC>, <HEX>, <LANGLE>, <NULL>, <PLUS>, <RANGLE>, <SEMI>,
// <SPACE>, <SHARP>, and <UTFMB> are defined in [RFC4512].
//
package ldap
import (
@ -48,7 +7,7 @@ import (
"fmt"
"strings"
"gopkg.in/asn1-ber.v1"
ber "github.com/go-asn1-ber/asn1-ber"
)
// AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514
@ -69,7 +28,8 @@ type DN struct {
RDNs []*RelativeDN
}
// ParseDN returns a distinguishedName or an error
// ParseDN returns a distinguishedName or an error.
// The function respects https://tools.ietf.org/html/rfc4514
func ParseDN(str string) (*DN, error) {
dn := new(DN)
dn.RDNs = make([]*RelativeDN, 0)
@ -245,3 +205,66 @@ func (r *RelativeDN) hasAllAttributes(attrs []*AttributeTypeAndValue) bool {
func (a *AttributeTypeAndValue) Equal(other *AttributeTypeAndValue) bool {
return strings.EqualFold(a.Type, other.Type) && a.Value == other.Value
}
// Equal returns true if the DNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch).
// Returns true if they have the same number of relative distinguished names
// and corresponding relative distinguished names (by position) are the same.
// Case of the attribute type and value is not significant
func (d *DN) EqualFold(other *DN) bool {
if len(d.RDNs) != len(other.RDNs) {
return false
}
for i := range d.RDNs {
if !d.RDNs[i].EqualFold(other.RDNs[i]) {
return false
}
}
return true
}
// AncestorOfFold returns true if the other DN consists of at least one RDN followed by all the RDNs of the current DN.
// Case of the attribute type and value is not significant
func (d *DN) AncestorOfFold(other *DN) bool {
if len(d.RDNs) >= len(other.RDNs) {
return false
}
// Take the last `len(d.RDNs)` RDNs from the other DN to compare against
otherRDNs := other.RDNs[len(other.RDNs)-len(d.RDNs):]
for i := range d.RDNs {
if !d.RDNs[i].EqualFold(otherRDNs[i]) {
return false
}
}
return true
}
// Equal returns true if the RelativeDNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch).
// Case of the attribute type is not significant
func (r *RelativeDN) EqualFold(other *RelativeDN) bool {
if len(r.Attributes) != len(other.Attributes) {
return false
}
return r.hasAllAttributesFold(other.Attributes) && other.hasAllAttributesFold(r.Attributes)
}
func (r *RelativeDN) hasAllAttributesFold(attrs []*AttributeTypeAndValue) bool {
for _, attr := range attrs {
found := false
for _, myattr := range r.Attributes {
if myattr.EqualFold(attr) {
found = true
break
}
}
if !found {
return false
}
}
return true
}
// EqualFold returns true if the AttributeTypeAndValue is equivalent to the specified AttributeTypeAndValue
// Case of the attribute type and value is not significant
func (a *AttributeTypeAndValue) EqualFold(other *AttributeTypeAndValue) bool {
return strings.EqualFold(a.Type, other.Type) && strings.EqualFold(a.Value, other.Value)
}

View file

@ -3,7 +3,7 @@ package ldap
import (
"fmt"
ber "gopkg.in/asn1-ber.v1"
ber "github.com/go-asn1-ber/asn1-ber"
)
// LDAP Result Codes
@ -184,6 +184,8 @@ type Error struct {
ResultCode uint16
// MatchedDN is the matchedDN returned if any
MatchedDN string
// Packet is the returned packet if any
Packet *ber.Packet
}
func (e *Error) Error() string {
@ -201,19 +203,23 @@ func GetLDAPError(packet *ber.Packet) error {
if len(packet.Children) >= 2 {
response := packet.Children[1]
if response == nil {
return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty response in packet")}
return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty response in packet"), Packet: packet}
}
if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) >= 3 {
resultCode := uint16(response.Children[0].Value.(int64))
if resultCode == 0 { // No error
return nil
}
return &Error{ResultCode: resultCode, MatchedDN: response.Children[1].Value.(string),
Err: fmt.Errorf("%s", response.Children[2].Value.(string))}
return &Error{
ResultCode: resultCode,
MatchedDN: response.Children[1].Value.(string),
Err: fmt.Errorf("%s", response.Children[2].Value.(string)),
Packet: packet,
}
}
}
return &Error{ResultCode: ErrorNetwork, Err: fmt.Errorf("Invalid packet format")}
return &Error{ResultCode: ErrorNetwork, Err: fmt.Errorf("Invalid packet format"), Packet: packet}
}
// NewError creates an LDAP error with the given code and underlying error
@ -221,8 +227,8 @@ func NewError(resultCode uint16, err error) error {
return &Error{ResultCode: resultCode, Err: err}
}
// IsErrorWithCode returns true if the given error is an LDAP error with the given result code
func IsErrorWithCode(err error, desiredResultCode uint16) bool {
// IsErrorAnyOf returns true if the given error is an LDAP error with any one of the given result codes
func IsErrorAnyOf(err error, codes ...uint16) bool {
if err == nil {
return false
}
@ -232,5 +238,16 @@ func IsErrorWithCode(err error, desiredResultCode uint16) bool {
return false
}
return serverError.ResultCode == desiredResultCode
for _, code := range codes {
if serverError.ResultCode == code {
return true
}
}
return false
}
// IsErrorWithCode returns true if the given error is an LDAP error with the given result code
func IsErrorWithCode(err error, desiredResultCode uint16) bool {
return IsErrorAnyOf(err, desiredResultCode)
}

View file

@ -5,10 +5,12 @@ import (
hexpac "encoding/hex"
"errors"
"fmt"
"io"
"strings"
"unicode"
"unicode/utf8"
"gopkg.in/asn1-ber.v1"
ber "github.com/go-asn1-ber/asn1-ber"
)
// Filter choices
@ -69,6 +71,8 @@ var MatchingRuleAssertionMap = map[uint64]string{
MatchingRuleAssertionDNAttributes: "Matching Rule Assertion DN Attributes",
}
var _SymbolAny = []byte{'*'}
// CompileFilter converts a string representation of a filter into a BER-encoded packet
func CompileFilter(filter string) (*ber.Packet, error) {
if len(filter) == 0 || filter[0] != '(' {
@ -88,74 +92,75 @@ func CompileFilter(filter string) (*ber.Packet, error) {
}
// DecompileFilter converts a packet representation of a filter into a string representation
func DecompileFilter(packet *ber.Packet) (ret string, err error) {
func DecompileFilter(packet *ber.Packet) (_ string, err error) {
defer func() {
if r := recover(); r != nil {
err = NewError(ErrorFilterDecompile, errors.New("ldap: error decompiling filter"))
}
}()
ret = "("
err = nil
buf := bytes.NewBuffer(nil)
buf.WriteByte('(')
childStr := ""
switch packet.Tag {
case FilterAnd:
ret += "&"
buf.WriteByte('&')
for _, child := range packet.Children {
childStr, err = DecompileFilter(child)
if err != nil {
return
}
ret += childStr
buf.WriteString(childStr)
}
case FilterOr:
ret += "|"
buf.WriteByte('|')
for _, child := range packet.Children {
childStr, err = DecompileFilter(child)
if err != nil {
return
}
ret += childStr
buf.WriteString(childStr)
}
case FilterNot:
ret += "!"
buf.WriteByte('!')
childStr, err = DecompileFilter(packet.Children[0])
if err != nil {
return
}
ret += childStr
buf.WriteString(childStr)
case FilterSubstrings:
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
ret += "="
buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes()))
buf.WriteByte('=')
for i, child := range packet.Children[1].Children {
if i == 0 && child.Tag != FilterSubstringsInitial {
ret += "*"
buf.Write(_SymbolAny)
}
ret += EscapeFilter(ber.DecodeString(child.Data.Bytes()))
buf.WriteString(EscapeFilter(ber.DecodeString(child.Data.Bytes())))
if child.Tag != FilterSubstringsFinal {
ret += "*"
buf.Write(_SymbolAny)
}
}
case FilterEqualityMatch:
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
ret += "="
ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes()))
buf.WriteByte('=')
buf.WriteString(EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())))
case FilterGreaterOrEqual:
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
ret += ">="
ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes()))
buf.WriteString(">=")
buf.WriteString(EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())))
case FilterLessOrEqual:
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
ret += "<="
ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes()))
buf.WriteString("<=")
buf.WriteString(EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())))
case FilterPresent:
ret += ber.DecodeString(packet.Data.Bytes())
ret += "=*"
buf.WriteString(ber.DecodeString(packet.Data.Bytes()))
buf.WriteString("=*")
case FilterApproxMatch:
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
ret += "~="
ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes()))
buf.WriteString("~=")
buf.WriteString(EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())))
case FilterExtensibleMatch:
attr := ""
dnAttributes := false
@ -176,21 +181,22 @@ func DecompileFilter(packet *ber.Packet) (ret string, err error) {
}
if len(attr) > 0 {
ret += attr
buf.WriteString(attr)
}
if dnAttributes {
ret += ":dn"
buf.WriteString(":dn")
}
if len(matchingRule) > 0 {
ret += ":"
ret += matchingRule
buf.WriteString(":")
buf.WriteString(matchingRule)
}
ret += ":="
ret += EscapeFilter(value)
buf.WriteString(":=")
buf.WriteString(EscapeFilter(value))
}
ret += ")"
return
buf.WriteByte(')')
return buf.String(), nil
}
func compileFilterSet(filter string, pos int, parent *ber.Packet) (int, error) {
@ -253,11 +259,10 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
)
state := stateReadingAttr
attribute := ""
attribute := bytes.NewBuffer(nil)
extensibleDNAttributes := false
extensibleMatchingRule := ""
condition := ""
extensibleMatchingRule := bytes.NewBuffer(nil)
condition := bytes.NewBuffer(nil)
for newPos < len(filter) {
remainingFilter := filter[newPos:]
@ -324,7 +329,7 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
// Still reading the attribute name
default:
attribute += fmt.Sprintf("%c", currentRune)
attribute.WriteRune(currentRune)
newPos += currentWidth
}
@ -338,13 +343,13 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
// Still reading the matching rule oid
default:
extensibleMatchingRule += fmt.Sprintf("%c", currentRune)
extensibleMatchingRule.WriteRune(currentRune)
newPos += currentWidth
}
case stateReadingCondition:
// append to the condition
condition += fmt.Sprintf("%c", currentRune)
condition.WriteRune(currentRune)
newPos += currentWidth
}
}
@ -368,17 +373,17 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
// }
// Include the matching rule oid, if specified
if len(extensibleMatchingRule) > 0 {
packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionMatchingRule, extensibleMatchingRule, MatchingRuleAssertionMap[MatchingRuleAssertionMatchingRule]))
if extensibleMatchingRule.Len() > 0 {
packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionMatchingRule, extensibleMatchingRule.String(), MatchingRuleAssertionMap[MatchingRuleAssertionMatchingRule]))
}
// Include the attribute, if specified
if len(attribute) > 0 {
packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionType, attribute, MatchingRuleAssertionMap[MatchingRuleAssertionType]))
if attribute.Len() > 0 {
packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionType, attribute.String(), MatchingRuleAssertionMap[MatchingRuleAssertionType]))
}
// Add the value (only required child)
encodedString, encodeErr := escapedStringToEncodedBytes(condition)
encodedString, encodeErr := decodeEscapedSymbols(condition.Bytes())
if encodeErr != nil {
return packet, newPos, encodeErr
}
@ -389,16 +394,16 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
packet.AppendChild(ber.NewBoolean(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionDNAttributes, extensibleDNAttributes, MatchingRuleAssertionMap[MatchingRuleAssertionDNAttributes]))
}
case packet.Tag == FilterEqualityMatch && condition == "*":
packet = ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterPresent, attribute, FilterMap[FilterPresent])
case packet.Tag == FilterEqualityMatch && strings.Contains(condition, "*"):
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
case packet.Tag == FilterEqualityMatch && bytes.Equal(condition.Bytes(), _SymbolAny):
packet = ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterPresent, attribute.String(), FilterMap[FilterPresent])
case packet.Tag == FilterEqualityMatch && bytes.Index(condition.Bytes(), _SymbolAny) > -1:
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute.String(), "Attribute"))
packet.Tag = FilterSubstrings
packet.Description = FilterMap[uint64(packet.Tag)]
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
parts := strings.Split(condition, "*")
parts := bytes.Split(condition.Bytes(), _SymbolAny)
for i, part := range parts {
if part == "" {
if len(part) == 0 {
continue
}
var tag ber.Tag
@ -410,7 +415,7 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
default:
tag = FilterSubstringsAny
}
encodedString, encodeErr := escapedStringToEncodedBytes(part)
encodedString, encodeErr := decodeEscapedSymbols(part)
if encodeErr != nil {
return packet, newPos, encodeErr
}
@ -418,11 +423,11 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
}
packet.AppendChild(seq)
default:
encodedString, encodeErr := escapedStringToEncodedBytes(condition)
encodedString, encodeErr := decodeEscapedSymbols(condition.Bytes())
if encodeErr != nil {
return packet, newPos, encodeErr
}
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute.String(), "Attribute"))
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, encodedString, "Condition"))
}
@ -432,34 +437,51 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
}
// Convert from "ABC\xx\xx\xx" form to literal bytes for transport
func escapedStringToEncodedBytes(escapedString string) (string, error) {
var buffer bytes.Buffer
i := 0
for i < len(escapedString) {
currentRune, currentWidth := utf8.DecodeRuneInString(escapedString[i:])
if currentRune == utf8.RuneError {
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", i))
func decodeEscapedSymbols(src []byte) (string, error) {
var (
buffer bytes.Buffer
offset int
reader = bytes.NewReader(src)
byteHex []byte
byteVal []byte
)
for {
runeVal, runeSize, err := reader.ReadRune()
if err == io.EOF {
return buffer.String(), nil
} else if err != nil {
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: failed to read filter: %v", err))
} else if runeVal == unicode.ReplacementChar {
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", offset))
}
// Check for escaped hex characters and convert them to their literal value for transport.
if currentRune == '\\' {
if runeVal == '\\' {
// http://tools.ietf.org/search/rfc4515
// \ (%x5C) is not a valid character unless it is followed by two HEX characters due to not
// being a member of UTF1SUBSET.
if i+2 > len(escapedString) {
return "", NewError(ErrorFilterCompile, errors.New("ldap: missing characters for escape in filter"))
if byteHex == nil {
byteHex = make([]byte, 2)
byteVal = make([]byte, 1)
}
escByte, decodeErr := hexpac.DecodeString(escapedString[i+1 : i+3])
if decodeErr != nil {
return "", NewError(ErrorFilterCompile, errors.New("ldap: invalid characters for escape in filter"))
if _, err := io.ReadFull(reader, byteHex); err != nil {
if err == io.ErrUnexpectedEOF {
return "", NewError(ErrorFilterCompile, errors.New("ldap: missing characters for escape in filter"))
}
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: invalid characters for escape in filter: %v", err))
}
buffer.WriteByte(escByte[0])
i += 2 // +1 from end of loop, so 3 total for \xx.
if _, err := hexpac.Decode(byteVal, byteHex); err != nil {
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: invalid characters for escape in filter: %v", err))
}
buffer.Write(byteVal)
} else {
buffer.WriteRune(currentRune)
buffer.WriteRune(runeVal)
}
i += currentWidth
offset += runeSize
}
return buffer.String(), nil
}

View file

@ -5,7 +5,7 @@ import (
"io/ioutil"
"os"
ber "gopkg.in/asn1-ber.v1"
ber "github.com/go-asn1-ber/asn1-ber"
)
// LDAP Application Codes
@ -223,32 +223,26 @@ func addControlDescriptions(packet *ber.Packet) error {
if child.Tag == 0 {
//Warning
warningPacket := child.Children[0]
packet, err := ber.DecodePacketErr(warningPacket.Data.Bytes())
val, err := ber.ParseInt64(warningPacket.Data.Bytes())
if err != nil {
return fmt.Errorf("failed to decode data bytes: %s", err)
}
val, ok := packet.Value.(int64)
if ok {
if warningPacket.Tag == 0 {
//timeBeforeExpiration
value.Description += " (TimeBeforeExpiration)"
warningPacket.Value = val
} else if warningPacket.Tag == 1 {
//graceAuthNsRemaining
value.Description += " (GraceAuthNsRemaining)"
warningPacket.Value = val
}
if warningPacket.Tag == 0 {
//timeBeforeExpiration
value.Description += " (TimeBeforeExpiration)"
warningPacket.Value = val
} else if warningPacket.Tag == 1 {
//graceAuthNsRemaining
value.Description += " (GraceAuthNsRemaining)"
warningPacket.Value = val
}
} else if child.Tag == 1 {
// Error
packet, err := ber.DecodePacketErr(child.Data.Bytes())
if err != nil {
return fmt.Errorf("failed to decode data bytes: %s", err)
}
val, ok := packet.Value.(int8)
if !ok {
val = -1
bs := child.Data.Bytes()
if len(bs) != 1 || bs[0] > 8 {
return fmt.Errorf("failed to decode data bytes: %s", "invalid PasswordPolicyResponse enum value")
}
val := int8(bs[0])
child.Description = "Error"
child.Value = val
}
@ -269,13 +263,18 @@ func addRequestDescriptions(packet *ber.Packet) error {
}
func addDefaultLDAPResponseDescriptions(packet *ber.Packet) error {
err := GetLDAPError(packet)
if err == nil {
return nil
resultCode := uint16(LDAPResultSuccess)
matchedDN := ""
description := "Success"
if err := GetLDAPError(packet); err != nil {
resultCode = err.(*Error).ResultCode
matchedDN = err.(*Error).MatchedDN
description = "Error Message"
}
packet.Children[1].Children[0].Description = "Result Code (" + LDAPResultCodeMap[err.(*Error).ResultCode] + ")"
packet.Children[1].Children[1].Description = "Matched DN (" + err.(*Error).MatchedDN + ")"
packet.Children[1].Children[2].Description = "Error Message"
packet.Children[1].Children[0].Description = "Result Code (" + LDAPResultCodeMap[resultCode] + ")"
packet.Children[1].Children[1].Description = "Matched DN (" + matchedDN + ")"
packet.Children[1].Children[2].Description = description
if len(packet.Children[1].Children) > 3 {
packet.Children[1].Children[3].Description = "Referral"
}

View file

@ -1,19 +1,9 @@
// Package ldap - moddn.go contains ModifyDN functionality
//
// https://tools.ietf.org/html/rfc4511
// ModifyDNRequest ::= [APPLICATION 12] SEQUENCE {
// entry LDAPDN,
// newrdn RelativeLDAPDN,
// deleteoldrdn BOOLEAN,
// newSuperior [0] LDAPDN OPTIONAL }
//
//
package ldap
import (
"log"
ber "gopkg.in/asn1-ber.v1"
ber "github.com/go-asn1-ber/asn1-ber"
)
// ModifyDNRequest holds the request to modify a DN
@ -22,6 +12,8 @@ type ModifyDNRequest struct {
NewRDN string
DeleteOldRDN bool
NewSuperior string
// Controls hold optional controls to send with the request
Controls []Control
}
// NewModifyDNRequest creates a new request which can be passed to ModifyDN().
@ -45,16 +37,39 @@ func NewModifyDNRequest(dn string, rdn string, delOld bool, newSup string) *Modi
}
}
// NewModifyDNWithControlsRequest creates a new request which can be passed to ModifyDN()
// and also allows setting LDAP request controls.
//
// Refer NewModifyDNRequest for other parameters
func NewModifyDNWithControlsRequest(dn string, rdn string, delOld bool,
newSup string, controls []Control) *ModifyDNRequest {
return &ModifyDNRequest{
DN: dn,
NewRDN: rdn,
DeleteOldRDN: delOld,
NewSuperior: newSup,
Controls: controls,
}
}
func (req *ModifyDNRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyDNRequest, nil, "Modify DN Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN"))
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.NewRDN, "New RDN"))
pkt.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, req.DeleteOldRDN, "Delete old RDN"))
if req.DeleteOldRDN {
buf := []byte{0xff}
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, string(buf), "Delete old RDN"))
} else {
pkt.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, req.DeleteOldRDN, "Delete old RDN"))
}
if req.NewSuperior != "" {
pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.NewSuperior, "New Superior"))
}
envelope.AppendChild(pkt)
if len(req.Controls) > 0 {
envelope.AppendChild(encodeControls(req.Controls))
}
return nil
}

View file

@ -1,41 +1,18 @@
// File contains Modify functionality
//
// https://tools.ietf.org/html/rfc4511
//
// ModifyRequest ::= [APPLICATION 6] SEQUENCE {
// object LDAPDN,
// changes SEQUENCE OF change SEQUENCE {
// operation ENUMERATED {
// add (0),
// delete (1),
// replace (2),
// ... },
// modification PartialAttribute } }
//
// PartialAttribute ::= SEQUENCE {
// type AttributeDescription,
// vals SET OF value AttributeValue }
//
// AttributeDescription ::= LDAPString
// -- Constrained to <attributedescription>
// -- [RFC4512]
//
// AttributeValue ::= OCTET STRING
//
package ldap
import (
"errors"
"log"
ber "gopkg.in/asn1-ber.v1"
ber "github.com/go-asn1-ber/asn1-ber"
)
// Change operation choices
const (
AddAttribute = 0
DeleteAttribute = 1
ReplaceAttribute = 2
AddAttribute = 0
DeleteAttribute = 1
ReplaceAttribute = 2
IncrementAttribute = 3 // (https://tools.ietf.org/html/rfc4525)
)
// PartialAttribute for a ModifyRequest as defined in https://tools.ietf.org/html/rfc4511
@ -97,6 +74,11 @@ func (req *ModifyRequest) Replace(attrType string, attrVals []string) {
req.appendChange(ReplaceAttribute, attrType, attrVals)
}
// Increment appends the given attribute to the list of changes to be made
func (req *ModifyRequest) Increment(attrType string, attrVal string) {
req.appendChange(IncrementAttribute, attrType, []string{attrVal})
}
func (req *ModifyRequest) appendChange(operation uint, attrType string, attrVals []string) {
req.Changes = append(req.Changes, Change{operation, PartialAttribute{Type: attrType, Vals: attrVals}})
}
@ -149,3 +131,47 @@ func (l *Conn) Modify(modifyRequest *ModifyRequest) error {
}
return nil
}
// ModifyResult holds the server's response to a modify request
type ModifyResult struct {
// Controls are the returned controls
Controls []Control
}
// ModifyWithResult performs the ModifyRequest and returns the result
func (l *Conn) ModifyWithResult(modifyRequest *ModifyRequest) (*ModifyResult, error) {
msgCtx, err := l.doRequest(modifyRequest)
if err != nil {
return nil, err
}
defer l.finishMessage(msgCtx)
result := &ModifyResult{
Controls: make([]Control, 0),
}
l.Debug.Printf("%d: waiting for response", msgCtx.id)
packet, err := l.readPacket(msgCtx)
if err != nil {
return nil, err
}
switch packet.Children[1].Tag {
case ApplicationModifyResponse:
err := GetLDAPError(packet)
if err != nil {
return nil, err
}
if len(packet.Children) == 3 {
for _, child := range packet.Children[2].Children {
decodedChild, err := DecodeControl(child)
if err != nil {
return nil, errors.New("failed to decode child control: " + err.Error())
}
result.Controls = append(result.Controls, decodedChild)
}
}
}
l.Debug.Printf("%d: returning", msgCtx.id)
return result, nil
}

View file

@ -1,14 +1,9 @@
// This file contains the password modify extended operation as specified in rfc 3062
//
// https://tools.ietf.org/html/rfc3062
//
package ldap
import (
"fmt"
ber "gopkg.in/asn1-ber.v1"
ber "github.com/go-asn1-ber/asn1-ber"
)
const (
@ -61,7 +56,7 @@ func (req *PasswordModifyRequest) appendTo(envelope *ber.Packet) error {
// NewPasswordModifyRequest creates a new PasswordModifyRequest
//
// According to the RFC 3602:
// According to the RFC 3602 (https://tools.ietf.org/html/rfc3062):
// userIdentity is a string representing the user associated with the request.
// This string may or may not be an LDAPDN (RFC 2253).
// If userIdentity is empty then the operation will act on the user associated

View file

@ -3,12 +3,13 @@ package ldap
import (
"errors"
ber "gopkg.in/asn1-ber.v1"
ber "github.com/go-asn1-ber/asn1-ber"
)
var (
errRespChanClosed = errors.New("ldap: response channel closed")
errCouldNotRetMsg = errors.New("ldap: could not retrieve message")
ErrNilConnection = errors.New("ldap: conn is nil, expected net.Conn")
)
type request interface {
@ -22,6 +23,10 @@ func (f requestFunc) appendTo(p *ber.Packet) error {
}
func (l *Conn) doRequest(req request) (*messageContext, error) {
if l == nil || l.conn == nil {
return nil, ErrNilConnection
}
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
if err := req.appendTo(packet); err != nil {
@ -29,7 +34,7 @@ func (l *Conn) doRequest(req request) (*messageContext, error) {
}
if l.Debug {
ber.PrintPacket(packet)
l.Debug.PrintPacket(packet)
}
msgCtx, err := l.sendMessage(packet)
@ -60,7 +65,7 @@ func (l *Conn) readPacket(msgCtx *messageContext) (*ber.Packet, error) {
if err = addLDAPDescriptions(packet); err != nil {
return nil, err
}
ber.PrintPacket(packet)
l.Debug.PrintPacket(packet)
}
return packet, nil
}

View file

@ -1,58 +1,3 @@
// File contains Search functionality
//
// https://tools.ietf.org/html/rfc4511
//
// SearchRequest ::= [APPLICATION 3] SEQUENCE {
// baseObject LDAPDN,
// scope ENUMERATED {
// baseObject (0),
// singleLevel (1),
// wholeSubtree (2),
// ... },
// derefAliases ENUMERATED {
// neverDerefAliases (0),
// derefInSearching (1),
// derefFindingBaseObj (2),
// derefAlways (3) },
// sizeLimit INTEGER (0 .. maxInt),
// timeLimit INTEGER (0 .. maxInt),
// typesOnly BOOLEAN,
// filter Filter,
// attributes AttributeSelection }
//
// AttributeSelection ::= SEQUENCE OF selector LDAPString
// -- The LDAPString is constrained to
// -- <attributeSelector> in Section 4.5.1.8
//
// Filter ::= CHOICE {
// and [0] SET SIZE (1..MAX) OF filter Filter,
// or [1] SET SIZE (1..MAX) OF filter Filter,
// not [2] Filter,
// equalityMatch [3] AttributeValueAssertion,
// substrings [4] SubstringFilter,
// greaterOrEqual [5] AttributeValueAssertion,
// lessOrEqual [6] AttributeValueAssertion,
// present [7] AttributeDescription,
// approxMatch [8] AttributeValueAssertion,
// extensibleMatch [9] MatchingRuleAssertion,
// ... }
//
// SubstringFilter ::= SEQUENCE {
// type AttributeDescription,
// substrings SEQUENCE SIZE (1..MAX) OF substring CHOICE {
// initial [0] AssertionValue, -- can occur at most once
// any [1] AssertionValue,
// final [2] AssertionValue } -- can occur at most once
// }
//
// MatchingRuleAssertion ::= SEQUENCE {
// matchingRule [1] MatchingRuleId OPTIONAL,
// type [2] AttributeDescription OPTIONAL,
// matchValue [3] AssertionValue,
// dnAttributes [4] BOOLEAN DEFAULT FALSE }
//
//
package ldap
import (
@ -61,7 +6,7 @@ import (
"sort"
"strings"
ber "gopkg.in/asn1-ber.v1"
ber "github.com/go-asn1-ber/asn1-ber"
)
// scope choices
@ -132,6 +77,17 @@ func (e *Entry) GetAttributeValues(attribute string) []string {
return []string{}
}
// GetEqualFoldAttributeValues returns the values for the named attribute, or an
// empty list. Attribute matching is done with strings.EqualFold.
func (e *Entry) GetEqualFoldAttributeValues(attribute string) []string {
for _, attr := range e.Attributes {
if strings.EqualFold(attribute, attr.Name) {
return attr.Values
}
}
return []string{}
}
// GetRawAttributeValues returns the byte values for the named attribute, or an empty list
func (e *Entry) GetRawAttributeValues(attribute string) [][]byte {
for _, attr := range e.Attributes {
@ -142,6 +98,16 @@ func (e *Entry) GetRawAttributeValues(attribute string) [][]byte {
return [][]byte{}
}
// GetEqualFoldRawAttributeValues returns the byte values for the named attribute, or an empty list
func (e *Entry) GetEqualFoldRawAttributeValues(attribute string) [][]byte {
for _, attr := range e.Attributes {
if strings.EqualFold(attr.Name, attribute) {
return attr.ByteValues
}
}
return [][]byte{}
}
// GetAttributeValue returns the first value for the named attribute, or ""
func (e *Entry) GetAttributeValue(attribute string) string {
values := e.GetAttributeValues(attribute)
@ -151,6 +117,16 @@ func (e *Entry) GetAttributeValue(attribute string) string {
return values[0]
}
// GetEqualFoldAttributeValue returns the first value for the named attribute, or "".
// Attribute comparison is done with strings.EqualFold.
func (e *Entry) GetEqualFoldAttributeValue(attribute string) string {
values := e.GetEqualFoldAttributeValues(attribute)
if len(values) == 0 {
return ""
}
return values[0]
}
// GetRawAttributeValue returns the first value for the named attribute, or an empty slice
func (e *Entry) GetRawAttributeValue(attribute string) []byte {
values := e.GetRawAttributeValues(attribute)
@ -160,6 +136,15 @@ func (e *Entry) GetRawAttributeValue(attribute string) []byte {
return values[0]
}
// GetEqualFoldRawAttributeValue returns the first value for the named attribute, or an empty slice
func (e *Entry) GetEqualFoldRawAttributeValue(attribute string) []byte {
values := e.GetEqualFoldRawAttributeValues(attribute)
if len(values) == 0 {
return []byte{}
}
return values[0]
}
// Print outputs a human-readable description
func (e *Entry) Print() {
fmt.Printf("DN: %s\n", e.DN)
@ -386,33 +371,26 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
for {
packet, err := l.readPacket(msgCtx)
if err != nil {
return nil, err
return result, err
}
switch packet.Children[1].Tag {
case 4:
entry := new(Entry)
entry.DN = packet.Children[1].Children[0].Value.(string)
for _, child := range packet.Children[1].Children[1].Children {
attr := new(EntryAttribute)
attr.Name = child.Children[0].Value.(string)
for _, value := range child.Children[1].Children {
attr.Values = append(attr.Values, value.Value.(string))
attr.ByteValues = append(attr.ByteValues, value.ByteValue)
}
entry.Attributes = append(entry.Attributes, attr)
entry := &Entry{
DN: packet.Children[1].Children[0].Value.(string),
Attributes: unpackAttributes(packet.Children[1].Children[1].Children),
}
result.Entries = append(result.Entries, entry)
case 5:
err := GetLDAPError(packet)
if err != nil {
return nil, err
return result, err
}
if len(packet.Children) == 3 {
for _, child := range packet.Children[2].Children {
decodedChild, err := DecodeControl(child)
if err != nil {
return nil, fmt.Errorf("failed to decode child control: %s", err)
return result, fmt.Errorf("failed to decode child control: %s", err)
}
result.Controls = append(result.Controls, decodedChild)
}
@ -423,3 +401,27 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
}
}
}
// unpackAttributes will extract all given LDAP attributes and it's values
// from the ber.Packet
func unpackAttributes(children []*ber.Packet) []*EntryAttribute {
entries := make([]*EntryAttribute, len(children))
for i, child := range children {
length := len(child.Children[1].Children)
entry := &EntryAttribute{
Name: child.Children[0].Value.(string),
// pre-allocate the slice since we can determine
// the number of attributes at this point
Values: make([]string, length),
ByteValues: make([][]byte, length),
}
for i, value := range child.Children[1].Children {
entry.ByteValues[i] = value.ByteValue
entry.Values[i] = value.Value.(string)
}
entries[i] = entry
}
return entries
}

37
vendor/github.com/go-ldap/ldap/v3/unbind.go generated vendored Normal file
View file

@ -0,0 +1,37 @@
package ldap
import (
"errors"
ber "github.com/go-asn1-ber/asn1-ber"
)
var ErrConnUnbound = NewError(ErrorNetwork, errors.New("ldap: connection is closed"))
type unbindRequest struct{}
func (unbindRequest) appendTo(envelope *ber.Packet) error {
envelope.AppendChild(ber.Encode(ber.ClassApplication, ber.TypePrimitive, ApplicationUnbindRequest, nil, ApplicationMap[ApplicationUnbindRequest]))
return nil
}
// Unbind will perform an unbind request. The Unbind operation
// should be thought of as the "quit" operation.
// See https://datatracker.ietf.org/doc/html/rfc4511#section-4.3
func (l *Conn) Unbind() error {
if l.IsClosing() {
return ErrConnUnbound
}
_, err := l.doRequest(unbindRequest{})
if err != nil {
return err
}
// Sending an unbindRequest will make the connection unusable.
// Pending requests will fail with:
// LDAP Result Code 200 "Network Error": ldap: response channel closed
l.Close()
return nil
}

91
vendor/github.com/go-ldap/ldap/v3/whoami.go generated vendored Normal file
View file

@ -0,0 +1,91 @@
package ldap
// This file contains the "Who Am I?" extended operation as specified in rfc 4532
//
// https://tools.ietf.org/html/rfc4532
import (
"errors"
"fmt"
ber "github.com/go-asn1-ber/asn1-ber"
)
type whoAmIRequest bool
// WhoAmIResult is returned by the WhoAmI() call
type WhoAmIResult struct {
AuthzID string
}
func (r whoAmIRequest) encode() (*ber.Packet, error) {
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Who Am I? Extended Operation")
request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, ControlTypeWhoAmI, "Extended Request Name: Who Am I? OID"))
return request, nil
}
// WhoAmI returns the authzId the server thinks we are, you may pass controls
// like a Proxied Authorization control
func (l *Conn) WhoAmI(controls []Control) (*WhoAmIResult, error) {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
req := whoAmIRequest(true)
encodedWhoAmIRequest, err := req.encode()
if err != nil {
return nil, err
}
packet.AppendChild(encodedWhoAmIRequest)
if len(controls) != 0 {
packet.AppendChild(encodeControls(controls))
}
l.Debug.PrintPacket(packet)
msgCtx, err := l.sendMessage(packet)
if err != nil {
return nil, err
}
defer l.finishMessage(msgCtx)
result := &WhoAmIResult{}
l.Debug.Printf("%d: waiting for response", msgCtx.id)
packetResponse, ok := <-msgCtx.responses
if !ok {
return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
}
packet, err = packetResponse.ReadPacket()
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if err != nil {
return nil, err
}
if packet == nil {
return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve message"))
}
if l.Debug {
if err := addLDAPDescriptions(packet); err != nil {
return nil, err
}
ber.PrintPacket(packet)
}
if packet.Children[1].Tag == ApplicationExtendedResponse {
if err := GetLDAPError(packet); err != nil {
return nil, err
}
} else {
return nil, NewError(ErrorUnexpectedResponse, fmt.Errorf("Unexpected Response: %d", packet.Children[1].Tag))
}
extendedResponse := packet.Children[1]
for _, child := range extendedResponse.Children {
if child.Tag == 11 {
result.AuthzID = ber.DecodeString(child.Data.Bytes())
}
}
return result, nil
}

View file

@ -1,129 +0,0 @@
sudo: false
language: go
go:
- 1.10.x
- 1.11.x
- 1.12.x
- 1.13.x
- master
before_install:
- go get golang.org/x/tools/cmd/cover
- go get github.com/mattn/goveralls
before_script:
- echo -e "[server]\ninnodb_log_file_size=256MB\ninnodb_buffer_pool_size=512MB\nmax_allowed_packet=16MB" | sudo tee -a /etc/mysql/my.cnf
- sudo service mysql restart
- .travis/wait_mysql.sh
- mysql -e 'create database gotest;'
matrix:
include:
- env: DB=MYSQL8
sudo: required
dist: trusty
go: 1.10.x
services:
- docker
before_install:
- go get golang.org/x/tools/cmd/cover
- go get github.com/mattn/goveralls
- docker pull mysql:8.0
- docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret
mysql:8.0 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1
- cp .travis/docker.cnf ~/.my.cnf
- .travis/wait_mysql.sh
before_script:
- export MYSQL_TEST_USER=gotest
- export MYSQL_TEST_PASS=secret
- export MYSQL_TEST_ADDR=127.0.0.1:3307
- export MYSQL_TEST_CONCURRENT=1
- env: DB=MYSQL57
sudo: required
dist: trusty
go: 1.10.x
services:
- docker
before_install:
- go get golang.org/x/tools/cmd/cover
- go get github.com/mattn/goveralls
- docker pull mysql:5.7
- docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret
mysql:5.7 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1
- cp .travis/docker.cnf ~/.my.cnf
- .travis/wait_mysql.sh
before_script:
- export MYSQL_TEST_USER=gotest
- export MYSQL_TEST_PASS=secret
- export MYSQL_TEST_ADDR=127.0.0.1:3307
- export MYSQL_TEST_CONCURRENT=1
- env: DB=MARIA55
sudo: required
dist: trusty
go: 1.10.x
services:
- docker
before_install:
- go get golang.org/x/tools/cmd/cover
- go get github.com/mattn/goveralls
- docker pull mariadb:5.5
- docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret
mariadb:5.5 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1
- cp .travis/docker.cnf ~/.my.cnf
- .travis/wait_mysql.sh
before_script:
- export MYSQL_TEST_USER=gotest
- export MYSQL_TEST_PASS=secret
- export MYSQL_TEST_ADDR=127.0.0.1:3307
- export MYSQL_TEST_CONCURRENT=1
- env: DB=MARIA10_1
sudo: required
dist: trusty
go: 1.10.x
services:
- docker
before_install:
- go get golang.org/x/tools/cmd/cover
- go get github.com/mattn/goveralls
- docker pull mariadb:10.1
- docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret
mariadb:10.1 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1
- cp .travis/docker.cnf ~/.my.cnf
- .travis/wait_mysql.sh
before_script:
- export MYSQL_TEST_USER=gotest
- export MYSQL_TEST_PASS=secret
- export MYSQL_TEST_ADDR=127.0.0.1:3307
- export MYSQL_TEST_CONCURRENT=1
- os: osx
osx_image: xcode10.1
addons:
homebrew:
packages:
- mysql
update: true
go: 1.12.x
before_install:
- go get golang.org/x/tools/cmd/cover
- go get github.com/mattn/goveralls
before_script:
- echo -e "[server]\ninnodb_log_file_size=256MB\ninnodb_buffer_pool_size=512MB\nmax_allowed_packet=16MB\nlocal_infile=1" >> /usr/local/etc/my.cnf
- mysql.server start
- mysql -uroot -e 'CREATE USER gotest IDENTIFIED BY "secret"'
- mysql -uroot -e 'GRANT ALL ON *.* TO gotest'
- mysql -uroot -e 'create database gotest;'
- export MYSQL_TEST_USER=gotest
- export MYSQL_TEST_PASS=secret
- export MYSQL_TEST_ADDR=127.0.0.1:3306
- export MYSQL_TEST_CONCURRENT=1
script:
- go test -v -covermode=count -coverprofile=coverage.out
- go vet ./...
- .travis/gofmt.sh
after_script:
- $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci

View file

@ -13,11 +13,15 @@
Aaron Hopkins <go-sql-driver at die.net>
Achille Roussel <achille.roussel at gmail.com>
Alex Snast <alexsn at fb.com>
Alexey Palazhchenko <alexey.palazhchenko at gmail.com>
Andrew Reid <andrew.reid at tixtrack.com>
Animesh Ray <mail.rayanimesh at gmail.com>
Arne Hormann <arnehormann at gmail.com>
Ariel Mashraki <ariel at mashraki.co.il>
Asta Xie <xiemengjun at gmail.com>
Bulat Gaifullin <gaifullinbf at gmail.com>
Caine Jette <jette at alum.mit.edu>
Carlos Nieto <jose.carlos at menteslibres.net>
Chris Moos <chris at tech9computers.com>
Craig Wilson <craiggwilson at gmail.com>
@ -52,6 +56,7 @@ Julien Schmidt <go-sql-driver at julienschmidt.com>
Justin Li <jli at j-li.net>
Justin Nuß <nuss.justin at gmail.com>
Kamil Dziedzic <kamil at klecza.pl>
Kei Kamikawa <x00.x7f.x86 at gmail.com>
Kevin Malachowski <kevin at chowski.com>
Kieron Woodhouse <kieron.woodhouse at infosum.com>
Lennart Rudolph <lrudolph at hmc.edu>
@ -74,20 +79,26 @@ Reed Allman <rdallman10 at gmail.com>
Richard Wilkes <wilkes at me.com>
Robert Russell <robert at rrbrussell.com>
Runrioter Wung <runrioter at gmail.com>
Sho Iizuka <sho.i518 at gmail.com>
Sho Ikeda <suicaicoca at gmail.com>
Shuode Li <elemount at qq.com>
Simon J Mudd <sjmudd at pobox.com>
Soroush Pour <me at soroushjp.com>
Stan Putrya <root.vagner at gmail.com>
Stanley Gunawan <gunawan.stanley at gmail.com>
Steven Hartland <steven.hartland at multiplay.co.uk>
Tan Jinhua <312841925 at qq.com>
Thomas Wodarek <wodarekwebpage at gmail.com>
Tim Ruffles <timruffles at gmail.com>
Tom Jenkinson <tom at tjenkinson.me>
Vladimir Kovpak <cn007b at gmail.com>
Vladyslav Zhelezniak <zhvladi at gmail.com>
Xiangyu Hu <xiangyu.hu at outlook.com>
Xiaobing Jiang <s7v7nislands at gmail.com>
Xiuming Chen <cc at cxm.cc>
Xuehong Chan <chanxuehong at gmail.com>
Zhenye Xie <xiezhenye at gmail.com>
Zhixin Wen <john.wenzhixin at gmail.com>
# Organizations
@ -103,3 +114,4 @@ Multiplay Ltd.
Percona LLC
Pivotal Inc.
Stripe Inc.
Zendesk Inc.

View file

@ -1,3 +1,29 @@
## Version 1.6 (2021-04-01)
Changes:
- Migrate the CI service from travis-ci to GitHub Actions (#1176, #1183, #1190)
- `NullTime` is deprecated (#960, #1144)
- Reduce allocations when building SET command (#1111)
- Performance improvement for time formatting (#1118)
- Performance improvement for time parsing (#1098, #1113)
New Features:
- Implement `driver.Validator` interface (#1106, #1174)
- Support returning `uint64` from `Valuer` in `ConvertValue` (#1143)
- Add `json.RawMessage` for converter and prepared statement (#1059)
- Interpolate `json.RawMessage` as `string` (#1058)
- Implements `CheckNamedValue` (#1090)
Bugfixes:
- Stop rounding times (#1121, #1172)
- Put zero filler into the SSL handshake packet (#1066)
- Fix checking cancelled connections back into the connection pool (#1095)
- Fix remove last 0 byte for mysql_old_password when password is empty (#1133)
## Version 1.5 (2020-01-07)
Changes:

View file

@ -35,7 +35,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac
* Supports queries larger than 16MB
* Full [`sql.RawBytes`](https://golang.org/pkg/database/sql/#RawBytes) support.
* Intelligent `LONG DATA` handling in prepared statements
* Secure `LOAD DATA LOCAL INFILE` support with file Whitelisting and `io.Reader` support
* Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support
* Optional `time.Time` parsing
* Optional placeholder interpolation
@ -56,15 +56,37 @@ Make sure [Git is installed](https://git-scm.com/downloads) on your machine and
_Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface. You only need to import the driver and can use the full [`database/sql`](https://golang.org/pkg/database/sql/) API then.
Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`:
```go
import "database/sql"
import _ "github.com/go-sql-driver/mysql"
import (
"database/sql"
"time"
_ "github.com/go-sql-driver/mysql"
)
// ...
db, err := sql.Open("mysql", "user:password@/dbname")
if err != nil {
panic(err)
}
// See "Important settings" section.
db.SetConnMaxLifetime(time.Minute * 3)
db.SetMaxOpenConns(10)
db.SetMaxIdleConns(10)
```
[Examples are available in our Wiki](https://github.com/go-sql-driver/mysql/wiki/Examples "Go-MySQL-Driver Examples").
### Important settings
`db.SetConnMaxLifetime()` is required to ensure connections are closed by the driver safely before connection is closed by MySQL server, OS, or other middlewares. Since some middlewares close idle connections by 5 minutes, we recommend timeout shorter than 5 minutes. This setting helps load balancing and changing system variables too.
`db.SetMaxOpenConns()` is highly recommended to limit the number of connection used by the application. There is no recommended limit number because it depends on application and MySQL server.
`db.SetMaxIdleConns()` is recommended to be set same to (or greater than) `db.SetMaxOpenConns()`. When it is smaller than `SetMaxOpenConns()`, connections can be opened and closed very frequently than you expect. Idle connections can be closed by the `db.SetConnMaxLifetime()`. If you want to close idle connections more rapidly, you can use `db.SetConnMaxIdleTime()` since Go 1.15.
### DSN (Data Source Name)
@ -122,7 +144,7 @@ Valid Values: true, false
Default: false
```
`allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files.
`allowAllFiles=true` disables the file allowlist for `LOAD DATA LOCAL INFILE` and allows *all* files.
[*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)
##### `allowCleartextPasswords`
@ -133,7 +155,7 @@ Valid Values: true, false
Default: false
```
`allowCleartextPasswords=true` allows using the [cleartext client side plugin](http://dev.mysql.com/doc/en/cleartext-authentication-plugin.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network.
`allowCleartextPasswords=true` allows using the [cleartext client side plugin](https://dev.mysql.com/doc/en/cleartext-pluggable-authentication.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network.
##### `allowNativePasswords`
@ -230,7 +252,7 @@ Default: false
If `interpolateParams` is true, placeholders (`?`) in calls to `db.Query()` and `db.Exec()` are interpolated into a single query string with given parameters. This reduces the number of roundtrips, since the driver has to prepare a statement, execute it with given parameters and close the statement again with `interpolateParams=false`.
*This can not be used together with the multibyte encodings BIG5, CP932, GB2312, GBK or SJIS. These are blacklisted as they may [introduce a SQL injection vulnerability](http://stackoverflow.com/a/12118602/3430118)!*
*This can not be used together with the multibyte encodings BIG5, CP932, GB2312, GBK or SJIS. These are rejected as they may [introduce a SQL injection vulnerability](http://stackoverflow.com/a/12118602/3430118)!*
##### `loc`
@ -376,7 +398,7 @@ Rules:
Examples:
* `autocommit=1`: `SET autocommit=1`
* [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'`
* [`tx_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation): `SET tx_isolation='REPEATABLE-READ'`
* [`transaction_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_transaction_isolation): `SET transaction_isolation='REPEATABLE-READ'`
#### Examples
@ -445,7 +467,7 @@ For this feature you need direct access to the package. Therefore you must chang
import "github.com/go-sql-driver/mysql"
```
Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the Whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)).
Files must be explicitly allowed by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the allowlist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)).
To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::<name>` then. Choose different names for different handlers and `DeregisterReaderHandler` when you don't need it anymore.
@ -459,8 +481,6 @@ However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` v
**Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes).
Alternatively you can use the [`NullTime`](https://godoc.org/github.com/go-sql-driver/mysql#NullTime) type as the scan destination, which works with both `time.Time` and `string` / `[]byte`.
### Unicode support
Since version 1.5 Go-MySQL-Driver automatically uses the collation ` utf8mb4_general_ci` by default.
@ -477,7 +497,7 @@ To run the driver tests you may need to adjust the configuration. See the [Testi
Go-MySQL-Driver is not feature-complete yet. Your help is very appreciated.
If you want to contribute, you can work on an [open issue](https://github.com/go-sql-driver/mysql/issues?state=open) or review a [pull request](https://github.com/go-sql-driver/mysql/pulls).
See the [Contribution Guidelines](https://github.com/go-sql-driver/mysql/blob/master/CONTRIBUTING.md) for details.
See the [Contribution Guidelines](https://github.com/go-sql-driver/mysql/blob/master/.github/CONTRIBUTING.md) for details.
---------------------------------------
@ -498,4 +518,3 @@ Please read the [MPL 2.0 FAQ](https://www.mozilla.org/en-US/MPL/2.0/FAQ/) if you
You can read the full terms here: [LICENSE](https://raw.github.com/go-sql-driver/mysql/master/LICENSE).
![Go Gopher and MySQL Dolphin](https://raw.github.com/wiki/go-sql-driver/mysql/go-mysql-driver_m.jpg "Golang Gopher transporting the MySQL Dolphin in a wheelbarrow")

View file

@ -15,6 +15,7 @@ import (
"crypto/sha256"
"crypto/x509"
"encoding/pem"
"fmt"
"sync"
)
@ -136,10 +137,6 @@ func pwHash(password []byte) (result [2]uint32) {
// Hash password using insecure pre 4.1 method
func scrambleOldPassword(scramble []byte, password string) []byte {
if len(password) == 0 {
return nil
}
scramble = scramble[:8]
hashPw := pwHash([]byte(password))
@ -247,6 +244,9 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
if !mc.cfg.AllowOldPasswords {
return nil, ErrOldPassword
}
if len(mc.cfg.Passwd) == 0 {
return nil, nil
}
// Note: there are edge cases where this should work but doesn't;
// this is currently "wontfix":
// https://github.com/go-sql-driver/mysql/issues/184
@ -372,7 +372,10 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
return err
}
block, _ := pem.Decode(data[1:])
block, rest := pem.Decode(data[1:])
if block == nil {
return fmt.Errorf("No Pem data found, data: %s", rest)
}
pkix, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return err

View file

@ -247,7 +247,7 @@ var collations = map[string]byte{
"utf8mb4_0900_ai_ci": 255,
}
// A blacklist of collations which is unsafe to interpolate parameters.
// A denylist of collations which is unsafe to interpolate parameters.
// These multibyte encodings may contains 0x5c (`\`) in their trailing bytes.
var unsafeCollations = map[string]bool{
"big5_chinese_ci": true,

View file

@ -12,6 +12,7 @@ import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"io"
"net"
"strconv"
@ -46,9 +47,10 @@ type mysqlConn struct {
// Handles parameters set in DSN after the connection is established
func (mc *mysqlConn) handleParams() (err error) {
var cmdSet strings.Builder
for param, val := range mc.cfg.Params {
switch param {
// Charset
// Charset: character_set_connection, character_set_client, character_set_results
case "charset":
charsets := strings.Split(val, ",")
for i := range charsets {
@ -62,12 +64,25 @@ func (mc *mysqlConn) handleParams() (err error) {
return
}
// System Vars
// Other system vars accumulated in a single SET command
default:
err = mc.exec("SET " + param + "=" + val + "")
if err != nil {
return
if cmdSet.Len() == 0 {
// Heuristic: 29 chars for each other key=value to reduce reallocations
cmdSet.Grow(4 + len(param) + 1 + len(val) + 30*(len(mc.cfg.Params)-1))
cmdSet.WriteString("SET ")
} else {
cmdSet.WriteByte(',')
}
cmdSet.WriteString(param)
cmdSet.WriteByte('=')
cmdSet.WriteString(val)
}
}
if cmdSet.Len() > 0 {
err = mc.exec(cmdSet.String())
if err != nil {
return
}
}
@ -230,47 +245,21 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
} else {
v := v.In(mc.cfg.Loc)
v = v.Add(time.Nanosecond * 500) // To round under microsecond
year := v.Year()
year100 := year / 100
year1 := year % 100
month := v.Month()
day := v.Day()
hour := v.Hour()
minute := v.Minute()
second := v.Second()
micro := v.Nanosecond() / 1000
buf = append(buf, []byte{
'\'',
digits10[year100], digits01[year100],
digits10[year1], digits01[year1],
'-',
digits10[month], digits01[month],
'-',
digits10[day], digits01[day],
' ',
digits10[hour], digits01[hour],
':',
digits10[minute], digits01[minute],
':',
digits10[second], digits01[second],
}...)
if micro != 0 {
micro10000 := micro / 10000
micro100 := micro / 100 % 100
micro1 := micro % 100
buf = append(buf, []byte{
'.',
digits10[micro10000], digits01[micro10000],
digits10[micro100], digits01[micro100],
digits10[micro1], digits01[micro1],
}...)
buf = append(buf, '\'')
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc))
if err != nil {
return "", err
}
buf = append(buf, '\'')
}
case json.RawMessage:
buf = append(buf, '\'')
if mc.status&statusNoBackslashEscapes == 0 {
buf = escapeBytesBackslash(buf, v)
} else {
buf = escapeBytesQuotes(buf, v)
}
buf = append(buf, '\'')
case []byte:
if v == nil {
buf = append(buf, "NULL"...)
@ -480,6 +469,10 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
// BeginTx implements driver.ConnBeginTx interface
func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if mc.closed.IsSet() {
return nil, driver.ErrBadConn
}
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
@ -649,3 +642,9 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error {
mc.reset = true
return nil
}
// IsValid implements driver.Validator interface
// (From Go 1.15)
func (mc *mysqlConn) IsValid() bool {
return !mc.closed.IsSet()
}

View file

@ -375,7 +375,7 @@ func parseDSNParams(cfg *Config, params string) (err error) {
// cfg params
switch value := param[1]; param[0] {
// Disable INFILE whitelist / enable all files
// Disable INFILE allowlist / enable all files
case "allowAllFiles":
var isBool bool
cfg.AllowAllFiles, isBool = readBool(value)

View file

@ -106,7 +106,7 @@ var (
scanTypeInt64 = reflect.TypeOf(int64(0))
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
scanTypeNullTime = reflect.TypeOf(NullTime{})
scanTypeNullTime = reflect.TypeOf(nullTime{})
scanTypeUint8 = reflect.TypeOf(uint8(0))
scanTypeUint16 = reflect.TypeOf(uint16(0))
scanTypeUint32 = reflect.TypeOf(uint32(0))

24
vendor/github.com/go-sql-driver/mysql/fuzz.go generated vendored Normal file
View file

@ -0,0 +1,24 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package.
//
// Copyright 2020 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// +build gofuzz
package mysql
import (
"database/sql"
)
func Fuzz(data []byte) int {
db, err := sql.Open("mysql", string(data))
if err != nil {
return 0
}
db.Close()
return 1
}

View file

@ -23,7 +23,7 @@ var (
readerRegisterLock sync.RWMutex
)
// RegisterLocalFile adds the given file to the file whitelist,
// RegisterLocalFile adds the given file to the file allowlist,
// so that it can be used by "LOAD DATA LOCAL INFILE <filepath>".
// Alternatively you can allow the use of all local files with
// the DSN parameter 'allowAllFiles=true'
@ -45,7 +45,7 @@ func RegisterLocalFile(filePath string) {
fileRegisterLock.Unlock()
}
// DeregisterLocalFile removes the given filepath from the whitelist.
// DeregisterLocalFile removes the given filepath from the allowlist.
func DeregisterLocalFile(filePath string) {
fileRegisterLock.Lock()
delete(fileRegister, strings.Trim(filePath, `"`))

View file

@ -28,11 +28,11 @@ func (nt *NullTime) Scan(value interface{}) (err error) {
nt.Time, nt.Valid = v, true
return
case []byte:
nt.Time, err = parseDateTime(string(v), time.UTC)
nt.Time, err = parseDateTime(v, time.UTC)
nt.Valid = (err == nil)
return
case string:
nt.Time, err = parseDateTime(v, time.UTC)
nt.Time, err = parseDateTime([]byte(v), time.UTC)
nt.Valid = (err == nil)
return
}

View file

@ -28,4 +28,13 @@ import (
// }
//
// This NullTime implementation is not driver-specific
//
// Deprecated: NullTime doesn't honor the loc DSN parameter.
// NullTime.Scan interprets a time as UTC, not the loc DSN parameter.
// Use sql.NullTime instead.
type NullTime sql.NullTime
// for internal use.
// the mysql package uses sql.NullTime if it is available.
// if not, the package uses mysql.NullTime.
type nullTime = sql.NullTime // sql.NullTime is available

View file

@ -32,3 +32,8 @@ type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}
// for internal use.
// the mysql package uses sql.NullTime if it is available.
// if not, the package uses mysql.NullTime.
type nullTime = NullTime // sql.NullTime is not available

View file

@ -13,6 +13,7 @@ import (
"crypto/tls"
"database/sql/driver"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
@ -348,6 +349,12 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
return errors.New("unknown collation")
}
// Filler [23 bytes] (all 0x00)
pos := 13
for ; pos < 13+23; pos++ {
data[pos] = 0
}
// SSL Connection Request Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
if mc.cfg.tls != nil {
@ -366,12 +373,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
mc.buf.nc = tlsConn
}
// Filler [23 bytes] (all 0x00)
pos := 13
for ; pos < 13+23; pos++ {
data[pos] = 0
}
// User [null terminated string]
if len(mc.cfg.User) > 0 {
pos += copy(data[pos:], mc.cfg.User)
@ -777,7 +778,7 @@ func (rows *textRows) readRow(dest []driver.Value) error {
case fieldTypeTimestamp, fieldTypeDateTime,
fieldTypeDate, fieldTypeNewDate:
dest[i], err = parseDateTime(
string(dest[i].([]byte)),
dest[i].([]byte),
mc.cfg.Loc,
)
if err == nil {
@ -1003,6 +1004,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
continue
}
if v, ok := arg.(json.RawMessage); ok {
arg = []byte(v)
}
// cache types and values
switch v := arg.(type) {
case int64:
@ -1112,7 +1116,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if v.IsZero() {
b = append(b, "0000-00-00"...)
} else {
b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat)
b, err = appendDateTime(b, v.In(mc.cfg.Loc))
if err != nil {
return err
}
}
paramValues = appendLengthEncodedInteger(paramValues,

View file

@ -10,6 +10,7 @@ package mysql
import (
"database/sql/driver"
"encoding/json"
"fmt"
"io"
"reflect"
@ -43,6 +44,11 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
return converter{}
}
func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = converter{}.ConvertValue(nv.Value)
return
}
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
if stmt.mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
@ -129,6 +135,8 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
return rows, err
}
var jsonType = reflect.TypeOf(json.RawMessage{})
type converter struct{}
// ConvertValue mirrors the reference/default converter in database/sql/driver
@ -146,12 +154,17 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
if err != nil {
return nil, err
}
if !driver.IsValue(sv) {
return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
if driver.IsValue(sv) {
return sv, nil
}
return sv, nil
// A value returend from the Valuer interface can be "a type handled by
// a database driver's NamedValueChecker interface" so we should accept
// uint64 here as well.
if u, ok := sv.(uint64); ok {
return u, nil
}
return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Ptr:
@ -170,11 +183,14 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
case reflect.Bool:
return rv.Bool(), nil
case reflect.Slice:
ek := rv.Type().Elem().Kind()
if ek == reflect.Uint8 {
switch t := rv.Type(); {
case t == jsonType:
return v, nil
case t.Elem().Kind() == reflect.Uint8:
return rv.Bytes(), nil
default:
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, t.Elem().Kind())
}
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
case reflect.String:
return rv.String(), nil
}

View file

@ -106,27 +106,136 @@ func readBool(input string) (value bool, valid bool) {
* Time related utils *
******************************************************************************/
func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
base := "0000-00-00 00:00:00.0000000"
switch len(str) {
func parseDateTime(b []byte, loc *time.Location) (time.Time, error) {
const base = "0000-00-00 00:00:00.000000"
switch len(b) {
case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM"
if str == base[:len(str)] {
return
if string(b) == base[:len(b)] {
return time.Time{}, nil
}
t, err = time.Parse(timeFormat[:len(str)], str)
year, err := parseByteYear(b)
if err != nil {
return time.Time{}, err
}
if year <= 0 {
year = 1
}
if b[4] != '-' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[4])
}
m, err := parseByte2Digits(b[5], b[6])
if err != nil {
return time.Time{}, err
}
if m <= 0 {
m = 1
}
month := time.Month(m)
if b[7] != '-' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[7])
}
day, err := parseByte2Digits(b[8], b[9])
if err != nil {
return time.Time{}, err
}
if day <= 0 {
day = 1
}
if len(b) == 10 {
return time.Date(year, month, day, 0, 0, 0, 0, loc), nil
}
if b[10] != ' ' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[10])
}
hour, err := parseByte2Digits(b[11], b[12])
if err != nil {
return time.Time{}, err
}
if b[13] != ':' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[13])
}
min, err := parseByte2Digits(b[14], b[15])
if err != nil {
return time.Time{}, err
}
if b[16] != ':' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[16])
}
sec, err := parseByte2Digits(b[17], b[18])
if err != nil {
return time.Time{}, err
}
if len(b) == 19 {
return time.Date(year, month, day, hour, min, sec, 0, loc), nil
}
if b[19] != '.' {
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[19])
}
nsec, err := parseByteNanoSec(b[20:])
if err != nil {
return time.Time{}, err
}
return time.Date(year, month, day, hour, min, sec, nsec, loc), nil
default:
err = fmt.Errorf("invalid time string: %s", str)
return
return time.Time{}, fmt.Errorf("invalid time bytes: %s", b)
}
}
// Adjust location
if err == nil && loc != time.UTC {
y, mo, d := t.Date()
h, mi, s := t.Clock()
t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil
func parseByteYear(b []byte) (int, error) {
year, n := 0, 1000
for i := 0; i < 4; i++ {
v, err := bToi(b[i])
if err != nil {
return 0, err
}
year += v * n
n = n / 10
}
return year, nil
}
return
func parseByte2Digits(b1, b2 byte) (int, error) {
d1, err := bToi(b1)
if err != nil {
return 0, err
}
d2, err := bToi(b2)
if err != nil {
return 0, err
}
return d1*10 + d2, nil
}
func parseByteNanoSec(b []byte) (int, error) {
ns, digit := 0, 100000 // max is 6-digits
for i := 0; i < len(b); i++ {
v, err := bToi(b[i])
if err != nil {
return 0, err
}
ns += v * digit
digit /= 10
}
// nanoseconds has 10-digits. (needs to scale digits)
// 10 - 6 = 4, so we have to multiple 1000.
return ns * 1000, nil
}
func bToi(b byte) (int, error) {
if b < '0' || b > '9' {
return 0, errors.New("not [0-9]")
}
return int(b - '0'), nil
}
func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) {
@ -167,6 +276,64 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va
return nil, fmt.Errorf("invalid DATETIME packet length %d", num)
}
func appendDateTime(buf []byte, t time.Time) ([]byte, error) {
year, month, day := t.Date()
hour, min, sec := t.Clock()
nsec := t.Nanosecond()
if year < 1 || year > 9999 {
return buf, errors.New("year is not in the range [1, 9999]: " + strconv.Itoa(year)) // use errors.New instead of fmt.Errorf to avoid year escape to heap
}
year100 := year / 100
year1 := year % 100
var localBuf [len("2006-01-02T15:04:05.999999999")]byte // does not escape
localBuf[0], localBuf[1], localBuf[2], localBuf[3] = digits10[year100], digits01[year100], digits10[year1], digits01[year1]
localBuf[4] = '-'
localBuf[5], localBuf[6] = digits10[month], digits01[month]
localBuf[7] = '-'
localBuf[8], localBuf[9] = digits10[day], digits01[day]
if hour == 0 && min == 0 && sec == 0 && nsec == 0 {
return append(buf, localBuf[:10]...), nil
}
localBuf[10] = ' '
localBuf[11], localBuf[12] = digits10[hour], digits01[hour]
localBuf[13] = ':'
localBuf[14], localBuf[15] = digits10[min], digits01[min]
localBuf[16] = ':'
localBuf[17], localBuf[18] = digits10[sec], digits01[sec]
if nsec == 0 {
return append(buf, localBuf[:19]...), nil
}
nsec100000000 := nsec / 100000000
nsec1000000 := (nsec / 1000000) % 100
nsec10000 := (nsec / 10000) % 100
nsec100 := (nsec / 100) % 100
nsec1 := nsec % 100
localBuf[19] = '.'
// milli second
localBuf[20], localBuf[21], localBuf[22] =
digits01[nsec100000000], digits10[nsec1000000], digits01[nsec1000000]
// micro second
localBuf[23], localBuf[24], localBuf[25] =
digits10[nsec10000], digits01[nsec10000], digits10[nsec100]
// nano second
localBuf[26], localBuf[27], localBuf[28] =
digits01[nsec100], digits10[nsec1], digits01[nsec1]
// trim trailing zeros
n := len(localBuf)
for n > 0 && localBuf[n-1] == '0' {
n--
}
return append(buf, localBuf[:n]...), nil
}
// zeroDateTime is used in formatBinaryDateTime to avoid an allocation
// if the DATE or DATETIME has the zero value.
// It must never be changed.

View file

@ -2,3 +2,5 @@
*.test
*~
*.swp
.idea
.vscode

Some files were not shown because too many files have changed in this diff Show more