diff --git a/domain/auth/ldap/ldap.go b/domain/auth/ldap/ldap.go index b19bdbc4..65bbb03c 100644 --- a/domain/auth/ldap/ldap.go +++ b/domain/auth/ldap/ldap.go @@ -20,8 +20,8 @@ import ( "github.com/documize/community/core/stringutil" lm "github.com/documize/community/model/auth" "github.com/documize/community/model/user" + ld "github.com/go-ldap/ldap/v3" "github.com/pkg/errors" - ld "gopkg.in/ldap.v3" ) // Connect establishes connection to LDAP server. diff --git a/go.mod b/go.mod index 1631b6ff..8fb84987 100644 --- a/go.mod +++ b/go.mod @@ -6,13 +6,14 @@ require ( github.com/BurntSushi/toml v0.3.1 github.com/andygrunwald/go-jira v1.12.0 github.com/codegangsta/negroni v1.0.0 - github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc + github.com/denisenkom/go-mssqldb v0.10.1-0.20210728001037-ee2fbc25fd8f github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/documize/blackfriday v2.0.0+incompatible github.com/documize/glick v0.0.0-20160503134043-a8ccbef88237 github.com/documize/html-diff v0.0.0-20160503140253-f61c192c7796 github.com/documize/slug v1.1.1 - github.com/go-sql-driver/mysql v1.5.0 + github.com/go-ldap/ldap/v3 v3.4.1 + github.com/go-sql-driver/mysql v1.6.0 github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b // indirect github.com/golang/protobuf v1.4.0 // indirect github.com/google/go-github v17.0.0+incompatible @@ -21,29 +22,29 @@ require ( github.com/gorilla/mux v1.7.4 github.com/jmoiron/sqlx v1.2.0 github.com/kr/pretty v0.2.0 // indirect - github.com/lib/pq v1.5.2 + github.com/lib/pq v1.10.2 github.com/mb0/diff v0.0.0-20131118162322-d8d9a906c24d // indirect github.com/microcosm-cc/bluemonday v1.0.2 github.com/nu7hatch/gouuid v0.0.0-20131221200532-179d4d0c4d8d github.com/pkg/errors v0.9.1 github.com/rainycape/unidecode v0.0.0-20150907023854-cb7f23ec59be // indirect github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect - golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 + golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2 golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d google.golang.org/appengine v1.6.6 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc gopkg.in/cas.v2 v2.1.0 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect - gopkg.in/ldap.v3 v3.1.0 gopkg.in/yaml.v2 v2.2.2 // indirect ) require ( + github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c // indirect github.com/fatih/structs v1.0.0 // indirect + github.com/go-asn1-ber/asn1-ber v1.5.3 // indirect github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe // indirect github.com/trivago/tgo v1.0.1 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect google.golang.org/protobuf v1.21.0 // indirect - gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d // indirect ) diff --git a/go.sum b/go.sum index d0adcf4c..cc093d32 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,14 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c h1:/IBSNwUN8+eKzUzbJPqhK839ygXJ82sde8x3ogr6R28= +github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/andygrunwald/go-jira v1.12.0 h1:JJi2cEDmDxVtTXxC8ruLDbtOU6pA4OLeL0niyfNcoWw= github.com/andygrunwald/go-jira v1.12.0/go.mod h1:jYi4kFDbRPZTJdJOVJO4mpMMIwdB+rcZwSO58DzPd2I= github.com/codegangsta/negroni v1.0.0 h1:+aYywywx4bnKXWvoWtRfJ91vC59NbEhEY03sZjQhbVY= github.com/codegangsta/negroni v1.0.0/go.mod h1:v0y3T5G7Y1UlFfyxFn/QLRU4a2EuNau2iZY63YTKWo0= -github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc h1:VRRKCwnzqk8QCaRC4os14xoKDdbHqqlJtJA0oc1ZAjg= -github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= +github.com/denisenkom/go-mssqldb v0.10.1-0.20210728001037-ee2fbc25fd8f h1:3UtVZFKTqZwLZi65UbfSIqYR75aUTP8FYUAEQnMXSJs= +github.com/denisenkom/go-mssqldb v0.10.1-0.20210728001037-ee2fbc25fd8f/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/documize/blackfriday v2.0.0+incompatible h1:qjRGAIVwZlHBtA/b9u0LtseYM3v3WpIXofPCwNjcUsE= @@ -19,9 +21,14 @@ github.com/documize/slug v1.1.1 h1:OCJRbWxbOgrgiBYSbVzuFwxb9wVu4oy1LxvLJOC2s8Y= github.com/documize/slug v1.1.1/go.mod h1:Vi7fQ5PzeOpXAiIrk1WCEDRihjTfU/bf4eWUPSD7tkU= github.com/fatih/structs v1.0.0 h1:BrX964Rv5uQ3wwS+KRUAJCBBw5PQmgJfJ6v4yly5QwU= github.com/fatih/structs v1.0.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= +github.com/go-asn1-ber/asn1-ber v1.5.1/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= +github.com/go-asn1-ber/asn1-ber v1.5.3 h1:u7utq56RUFiynqUzgVMFDymapcOtQ/MZkh3H4QYkxag= +github.com/go-asn1-ber/asn1-ber v1.5.3/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= +github.com/go-ldap/ldap/v3 v3.4.1 h1:fU/0xli6HY02ocbMuozHAYsaHLcnkLjvho2r5a34BUU= +github.com/go-ldap/ldap/v3 v3.4.1/go.mod h1:iYS1MdmrmceOJ1QOTnRXrIs7i3kloqtmGQjRvjKpyMg= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= -github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= -github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= @@ -55,8 +62,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.5.2 h1:yTSXVswvWUOQ3k1sd7vJfDrbSl8lKuscqFJRqjC0ifw= -github.com/lib/pq v1.5.2/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= +github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-sqlite3 v1.9.0 h1:pDRiWfl+++eC2FEFRy6jXmQlvp4Yh3z1MJKg4UeYM/4= github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/mb0/diff v0.0.0-20131118162322-d8d9a906c24d h1:eAS2t2Vy+6psf9LZ4T5WXWsbkBt3Tu5PWekJy5AGyEU= @@ -77,8 +84,8 @@ github.com/trivago/tgo v1.0.1/go.mod h1:w4dpD+3tzNIIiIfkWWa85w5/B77tlvdZckQ+6PkF golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 h1:cg5LA/zNPRzIXIWSCxQW10Rvpy94aQh3LT/ShoCpkHw= -golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 h1:vEg9joUBmeBcK9iSJftGNf3coIG4HqZElCPehJsfAYM= +golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -109,14 +116,10 @@ google.golang.org/protobuf v1.21.0 h1:qdOKuR/EIArgaWNjetjgTzgVTAZ+S/WXVrq9HW9zim google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk= -gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d h1:TxyelI5cVkbREznMhfzycHdkp5cLA7DpE+GKjSslYhM= -gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d/go.mod h1:cuepJuh7vyXfUyUwEgHQXw849cJrilpS5NeIjOWESAw= gopkg.in/cas.v2 v2.1.0 h1:sbYBMWtpanwLH75GAWjIp5JnON9wa3NodLZhouu0G9I= gopkg.in/cas.v2 v2.1.0/go.mod h1:M291I/o/u3eeMl9SkXMPYpWasHp7weFY9G/pM5DbB+g= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/ldap.v3 v3.1.0 h1:DIDWEjI7vQWREh0S8X5/NFPCZ3MCVd55LmXKPW4XLGE= -gopkg.in/ldap.v3 v3.1.0/go.mod h1:dQjCc0R0kfyFjIlWNMH1DORwUASZyDxo2Ry1B51dXaQ= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/gui/app/pods/auth/login/controller.js b/gui/app/pods/auth/login/controller.js index 770fc3cd..50483d53 100644 --- a/gui/app/pods/auth/login/controller.js +++ b/gui/app/pods/auth/login/controller.js @@ -1,3 +1,5 @@ +/* eslint-disable ember/no-actions-hash */ +/* eslint-disable ember/no-classic-classes */ // Copyright 2016 Documize Inc. . All rights reserved. // // This software (Documize Community Edition) is licensed under diff --git a/gui/app/pods/auth/login/route.js b/gui/app/pods/auth/login/route.js index 18babbfd..b315ad5c 100644 --- a/gui/app/pods/auth/login/route.js +++ b/gui/app/pods/auth/login/route.js @@ -1,3 +1,4 @@ +/* eslint-disable ember/no-classic-classes */ // Copyright 2016 Documize Inc. . All rights reserved. // // This software (Documize Community Edition) is licensed under diff --git a/gui/app/router.js b/gui/app/router.js index 06831cbe..e0b8c142 100644 --- a/gui/app/router.js +++ b/gui/app/router.js @@ -186,7 +186,8 @@ export default Router.map(function () { path: 'updates' }); - this.route('not-found', { + this.route('auth/login', { path: '/*wildcard' + // path: '/*wildcard' }); }); diff --git a/gui/app/routes/application.js b/gui/app/routes/application.js index d84598f2..875d04df 100644 --- a/gui/app/routes/application.js +++ b/gui/app/routes/application.js @@ -1,3 +1,4 @@ +/* eslint-disable ember/no-classic-classes */ // Copyright 2016 Documize Inc. . All rights reserved. // // This software (Documize Community Edition) is licensed under diff --git a/server/middleware.go b/server/middleware.go index a655c100..58b6f948 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -43,12 +43,11 @@ func (m *middleware) cors(w http.ResponseWriter, r *http.Request, next http.Hand w.Header().Set("Access-Control-Allow-Headers", "host, content-type, accept, authorization, origin, referer, user-agent, cache-control, x-requested-with, range") w.Header().Set("Access-Control-Expose-Headers", "x-documize-version, x-documize-status, x-documize-filename, x-documize-subscription, Content-Disposition, Content-Length") + w.Header().Add("X-Documize-Version", m.Runtime.Product.Version) + w.Header().Add("Cache-Control", "no-cache") + if r.Method == "OPTIONS" { - w.Header().Add("X-Documize-Version", m.Runtime.Product.Version) - w.Header().Add("Cache-Control", "no-cache") - w.Write([]byte("")) - return } diff --git a/vendor/github.com/Azure/go-ntlmssp/.travis.yml b/vendor/github.com/Azure/go-ntlmssp/.travis.yml new file mode 100644 index 00000000..23c95fe9 --- /dev/null +++ b/vendor/github.com/Azure/go-ntlmssp/.travis.yml @@ -0,0 +1,17 @@ +sudo: false + +language: go + +before_script: + - go get -u golang.org/x/lint/golint + +go: + - 1.10.x + - master + +script: + - test -z "$(gofmt -s -l . | tee /dev/stderr)" + - test -z "$(golint ./... | tee /dev/stderr)" + - go vet ./... + - go build -v ./... + - go test -v ./... diff --git a/vendor/github.com/Azure/go-ntlmssp/LICENSE b/vendor/github.com/Azure/go-ntlmssp/LICENSE new file mode 100644 index 00000000..dc1cf39d --- /dev/null +++ b/vendor/github.com/Azure/go-ntlmssp/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2016 Microsoft + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/Azure/go-ntlmssp/README.md b/vendor/github.com/Azure/go-ntlmssp/README.md new file mode 100644 index 00000000..55cdcefa --- /dev/null +++ b/vendor/github.com/Azure/go-ntlmssp/README.md @@ -0,0 +1,29 @@ +# go-ntlmssp +Golang package that provides NTLM/Negotiate authentication over HTTP + +[![GoDoc](https://godoc.org/github.com/Azure/go-ntlmssp?status.svg)](https://godoc.org/github.com/Azure/go-ntlmssp) [![Build Status](https://travis-ci.org/Azure/go-ntlmssp.svg?branch=dev)](https://travis-ci.org/Azure/go-ntlmssp) + +Protocol details from https://msdn.microsoft.com/en-us/library/cc236621.aspx +Implementation hints from http://davenport.sourceforge.net/ntlm.html + +This package only implements authentication, no key exchange or encryption. It +only supports Unicode (UTF16LE) encoding of protocol strings, no OEM encoding. +This package implements NTLMv2. + +# Usage + +``` +url, user, password := "http://www.example.com/secrets", "robpike", "pw123" +client := &http.Client{ + Transport: ntlmssp.Negotiator{ + RoundTripper:&http.Transport{}, + }, +} + +req, _ := http.NewRequest("GET", url, nil) +req.SetBasicAuth(user, password) +res, _ := client.Do(req) +``` + +----- +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. diff --git a/vendor/github.com/Azure/go-ntlmssp/authenticate_message.go b/vendor/github.com/Azure/go-ntlmssp/authenticate_message.go new file mode 100644 index 00000000..c8930680 --- /dev/null +++ b/vendor/github.com/Azure/go-ntlmssp/authenticate_message.go @@ -0,0 +1,183 @@ +package ntlmssp + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "encoding/hex" + "errors" + "strings" + "time" +) + +type authenicateMessage struct { + LmChallengeResponse []byte + NtChallengeResponse []byte + + TargetName string + UserName string + + // only set if negotiateFlag_NTLMSSP_NEGOTIATE_KEY_EXCH + EncryptedRandomSessionKey []byte + + NegotiateFlags negotiateFlags + + MIC []byte +} + +type authenticateMessageFields struct { + messageHeader + LmChallengeResponse varField + NtChallengeResponse varField + TargetName varField + UserName varField + Workstation varField + _ [8]byte + NegotiateFlags negotiateFlags +} + +func (m authenicateMessage) MarshalBinary() ([]byte, error) { + if !m.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEUNICODE) { + return nil, errors.New("Only unicode is supported") + } + + target, user := toUnicode(m.TargetName), toUnicode(m.UserName) + workstation := toUnicode("go-ntlmssp") + + ptr := binary.Size(&authenticateMessageFields{}) + f := authenticateMessageFields{ + messageHeader: newMessageHeader(3), + NegotiateFlags: m.NegotiateFlags, + LmChallengeResponse: newVarField(&ptr, len(m.LmChallengeResponse)), + NtChallengeResponse: newVarField(&ptr, len(m.NtChallengeResponse)), + TargetName: newVarField(&ptr, len(target)), + UserName: newVarField(&ptr, len(user)), + Workstation: newVarField(&ptr, len(workstation)), + } + + f.NegotiateFlags.Unset(negotiateFlagNTLMSSPNEGOTIATEVERSION) + + b := bytes.Buffer{} + if err := binary.Write(&b, binary.LittleEndian, &f); err != nil { + return nil, err + } + if err := binary.Write(&b, binary.LittleEndian, &m.LmChallengeResponse); err != nil { + return nil, err + } + if err := binary.Write(&b, binary.LittleEndian, &m.NtChallengeResponse); err != nil { + return nil, err + } + if err := binary.Write(&b, binary.LittleEndian, &target); err != nil { + return nil, err + } + if err := binary.Write(&b, binary.LittleEndian, &user); err != nil { + return nil, err + } + if err := binary.Write(&b, binary.LittleEndian, &workstation); err != nil { + return nil, err + } + + return b.Bytes(), nil +} + +//ProcessChallenge crafts an AUTHENTICATE message in response to the CHALLENGE message +//that was received from the server +func ProcessChallenge(challengeMessageData []byte, user, password string) ([]byte, error) { + if user == "" && password == "" { + return nil, errors.New("Anonymous authentication not supported") + } + + var cm challengeMessage + if err := cm.UnmarshalBinary(challengeMessageData); err != nil { + return nil, err + } + + if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATELMKEY) { + return nil, errors.New("Only NTLM v2 is supported, but server requested v1 (NTLMSSP_NEGOTIATE_LM_KEY)") + } + if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEKEYEXCH) { + return nil, errors.New("Key exchange requested but not supported (NTLMSSP_NEGOTIATE_KEY_EXCH)") + } + + am := authenicateMessage{ + UserName: user, + TargetName: cm.TargetName, + NegotiateFlags: cm.NegotiateFlags, + } + + timestamp := cm.TargetInfo[avIDMsvAvTimestamp] + if timestamp == nil { // no time sent, take current time + ft := uint64(time.Now().UnixNano()) / 100 + ft += 116444736000000000 // add time between unix & windows offset + timestamp = make([]byte, 8) + binary.LittleEndian.PutUint64(timestamp, ft) + } + + clientChallenge := make([]byte, 8) + rand.Reader.Read(clientChallenge) + + ntlmV2Hash := getNtlmV2Hash(password, user, cm.TargetName) + + am.NtChallengeResponse = computeNtlmV2Response(ntlmV2Hash, + cm.ServerChallenge[:], clientChallenge, timestamp, cm.TargetInfoRaw) + + if cm.TargetInfoRaw == nil { + am.LmChallengeResponse = computeLmV2Response(ntlmV2Hash, + cm.ServerChallenge[:], clientChallenge) + } + return am.MarshalBinary() +} + +func ProcessChallengeWithHash(challengeMessageData []byte, user, hash string) ([]byte, error) { + if user == "" && hash == "" { + return nil, errors.New("Anonymous authentication not supported") + } + + var cm challengeMessage + if err := cm.UnmarshalBinary(challengeMessageData); err != nil { + return nil, err + } + + if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATELMKEY) { + return nil, errors.New("Only NTLM v2 is supported, but server requested v1 (NTLMSSP_NEGOTIATE_LM_KEY)") + } + if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEKEYEXCH) { + return nil, errors.New("Key exchange requested but not supported (NTLMSSP_NEGOTIATE_KEY_EXCH)") + } + + am := authenicateMessage{ + UserName: user, + TargetName: cm.TargetName, + NegotiateFlags: cm.NegotiateFlags, + } + + timestamp := cm.TargetInfo[avIDMsvAvTimestamp] + if timestamp == nil { // no time sent, take current time + ft := uint64(time.Now().UnixNano()) / 100 + ft += 116444736000000000 // add time between unix & windows offset + timestamp = make([]byte, 8) + binary.LittleEndian.PutUint64(timestamp, ft) + } + + clientChallenge := make([]byte, 8) + rand.Reader.Read(clientChallenge) + + hashParts := strings.Split(hash, ":") + if len(hashParts) > 1 { + hash = hashParts[1] + } + hashBytes, err := hex.DecodeString(hash) + if err != nil { + return nil, err + } + ntlmV2Hash := hmacMd5(hashBytes, toUnicode(strings.ToUpper(user)+cm.TargetName)) + + am.NtChallengeResponse = computeNtlmV2Response(ntlmV2Hash, + cm.ServerChallenge[:], clientChallenge, timestamp, cm.TargetInfoRaw) + + if cm.TargetInfoRaw == nil { + am.LmChallengeResponse = computeLmV2Response(ntlmV2Hash, + cm.ServerChallenge[:], clientChallenge) + } + return am.MarshalBinary() +} diff --git a/vendor/github.com/Azure/go-ntlmssp/authheader.go b/vendor/github.com/Azure/go-ntlmssp/authheader.go new file mode 100644 index 00000000..aac3f77d --- /dev/null +++ b/vendor/github.com/Azure/go-ntlmssp/authheader.go @@ -0,0 +1,37 @@ +package ntlmssp + +import ( + "encoding/base64" + "strings" +) + +type authheader string + +func (h authheader) IsBasic() bool { + return strings.HasPrefix(string(h), "Basic ") +} + +func (h authheader) IsNegotiate() bool { + return strings.HasPrefix(string(h), "Negotiate") +} + +func (h authheader) IsNTLM() bool { + return strings.HasPrefix(string(h), "NTLM") +} + +func (h authheader) GetData() ([]byte, error) { + p := strings.Split(string(h), " ") + if len(p) < 2 { + return nil, nil + } + return base64.StdEncoding.DecodeString(string(p[1])) +} + +func (h authheader) GetBasicCreds() (username, password string, err error) { + d, err := h.GetData() + if err != nil { + return "", "", err + } + parts := strings.SplitN(string(d), ":", 2) + return parts[0], parts[1], nil +} diff --git a/vendor/github.com/Azure/go-ntlmssp/avids.go b/vendor/github.com/Azure/go-ntlmssp/avids.go new file mode 100644 index 00000000..196b5f13 --- /dev/null +++ b/vendor/github.com/Azure/go-ntlmssp/avids.go @@ -0,0 +1,17 @@ +package ntlmssp + +type avID uint16 + +const ( + avIDMsvAvEOL avID = iota + avIDMsvAvNbComputerName + avIDMsvAvNbDomainName + avIDMsvAvDNSComputerName + avIDMsvAvDNSDomainName + avIDMsvAvDNSTreeName + avIDMsvAvFlags + avIDMsvAvTimestamp + avIDMsvAvSingleHost + avIDMsvAvTargetName + avIDMsvChannelBindings +) diff --git a/vendor/github.com/Azure/go-ntlmssp/challenge_message.go b/vendor/github.com/Azure/go-ntlmssp/challenge_message.go new file mode 100644 index 00000000..053b55e4 --- /dev/null +++ b/vendor/github.com/Azure/go-ntlmssp/challenge_message.go @@ -0,0 +1,82 @@ +package ntlmssp + +import ( + "bytes" + "encoding/binary" + "fmt" +) + +type challengeMessageFields struct { + messageHeader + TargetName varField + NegotiateFlags negotiateFlags + ServerChallenge [8]byte + _ [8]byte + TargetInfo varField +} + +func (m challengeMessageFields) IsValid() bool { + return m.messageHeader.IsValid() && m.MessageType == 2 +} + +type challengeMessage struct { + challengeMessageFields + TargetName string + TargetInfo map[avID][]byte + TargetInfoRaw []byte +} + +func (m *challengeMessage) UnmarshalBinary(data []byte) error { + r := bytes.NewReader(data) + err := binary.Read(r, binary.LittleEndian, &m.challengeMessageFields) + if err != nil { + return err + } + if !m.challengeMessageFields.IsValid() { + return fmt.Errorf("Message is not a valid challenge message: %+v", m.challengeMessageFields.messageHeader) + } + + if m.challengeMessageFields.TargetName.Len > 0 { + m.TargetName, err = m.challengeMessageFields.TargetName.ReadStringFrom(data, m.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEUNICODE)) + if err != nil { + return err + } + } + + if m.challengeMessageFields.TargetInfo.Len > 0 { + d, err := m.challengeMessageFields.TargetInfo.ReadFrom(data) + m.TargetInfoRaw = d + if err != nil { + return err + } + m.TargetInfo = make(map[avID][]byte) + r := bytes.NewReader(d) + for { + var id avID + var l uint16 + err = binary.Read(r, binary.LittleEndian, &id) + if err != nil { + return err + } + if id == avIDMsvAvEOL { + break + } + + err = binary.Read(r, binary.LittleEndian, &l) + if err != nil { + return err + } + value := make([]byte, l) + n, err := r.Read(value) + if err != nil { + return err + } + if n != int(l) { + return fmt.Errorf("Expected to read %d bytes, got only %d", l, n) + } + m.TargetInfo[id] = value + } + } + + return nil +} diff --git a/vendor/github.com/Azure/go-ntlmssp/messageheader.go b/vendor/github.com/Azure/go-ntlmssp/messageheader.go new file mode 100644 index 00000000..247e2846 --- /dev/null +++ b/vendor/github.com/Azure/go-ntlmssp/messageheader.go @@ -0,0 +1,21 @@ +package ntlmssp + +import ( + "bytes" +) + +var signature = [8]byte{'N', 'T', 'L', 'M', 'S', 'S', 'P', 0} + +type messageHeader struct { + Signature [8]byte + MessageType uint32 +} + +func (h messageHeader) IsValid() bool { + return bytes.Equal(h.Signature[:], signature[:]) && + h.MessageType > 0 && h.MessageType < 4 +} + +func newMessageHeader(messageType uint32) messageHeader { + return messageHeader{signature, messageType} +} diff --git a/vendor/github.com/Azure/go-ntlmssp/negotiate_flags.go b/vendor/github.com/Azure/go-ntlmssp/negotiate_flags.go new file mode 100644 index 00000000..5905c023 --- /dev/null +++ b/vendor/github.com/Azure/go-ntlmssp/negotiate_flags.go @@ -0,0 +1,52 @@ +package ntlmssp + +type negotiateFlags uint32 + +const ( + /*A*/ negotiateFlagNTLMSSPNEGOTIATEUNICODE negotiateFlags = 1 << 0 + /*B*/ negotiateFlagNTLMNEGOTIATEOEM = 1 << 1 + /*C*/ negotiateFlagNTLMSSPREQUESTTARGET = 1 << 2 + + /*D*/ + negotiateFlagNTLMSSPNEGOTIATESIGN = 1 << 4 + /*E*/ negotiateFlagNTLMSSPNEGOTIATESEAL = 1 << 5 + /*F*/ negotiateFlagNTLMSSPNEGOTIATEDATAGRAM = 1 << 6 + /*G*/ negotiateFlagNTLMSSPNEGOTIATELMKEY = 1 << 7 + + /*H*/ + negotiateFlagNTLMSSPNEGOTIATENTLM = 1 << 9 + + /*J*/ + negotiateFlagANONYMOUS = 1 << 11 + /*K*/ negotiateFlagNTLMSSPNEGOTIATEOEMDOMAINSUPPLIED = 1 << 12 + /*L*/ negotiateFlagNTLMSSPNEGOTIATEOEMWORKSTATIONSUPPLIED = 1 << 13 + + /*M*/ + negotiateFlagNTLMSSPNEGOTIATEALWAYSSIGN = 1 << 15 + /*N*/ negotiateFlagNTLMSSPTARGETTYPEDOMAIN = 1 << 16 + /*O*/ negotiateFlagNTLMSSPTARGETTYPESERVER = 1 << 17 + + /*P*/ + negotiateFlagNTLMSSPNEGOTIATEEXTENDEDSESSIONSECURITY = 1 << 19 + /*Q*/ negotiateFlagNTLMSSPNEGOTIATEIDENTIFY = 1 << 20 + + /*R*/ + negotiateFlagNTLMSSPREQUESTNONNTSESSIONKEY = 1 << 22 + /*S*/ negotiateFlagNTLMSSPNEGOTIATETARGETINFO = 1 << 23 + + /*T*/ + negotiateFlagNTLMSSPNEGOTIATEVERSION = 1 << 25 + + /*U*/ + negotiateFlagNTLMSSPNEGOTIATE128 = 1 << 29 + /*V*/ negotiateFlagNTLMSSPNEGOTIATEKEYEXCH = 1 << 30 + /*W*/ negotiateFlagNTLMSSPNEGOTIATE56 = 1 << 31 +) + +func (field negotiateFlags) Has(flags negotiateFlags) bool { + return field&flags == flags +} + +func (field *negotiateFlags) Unset(flags negotiateFlags) { + *field = *field ^ (*field & flags) +} diff --git a/vendor/github.com/Azure/go-ntlmssp/negotiate_message.go b/vendor/github.com/Azure/go-ntlmssp/negotiate_message.go new file mode 100644 index 00000000..e466a986 --- /dev/null +++ b/vendor/github.com/Azure/go-ntlmssp/negotiate_message.go @@ -0,0 +1,64 @@ +package ntlmssp + +import ( + "bytes" + "encoding/binary" + "errors" + "strings" +) + +const expMsgBodyLen = 40 + +type negotiateMessageFields struct { + messageHeader + NegotiateFlags negotiateFlags + + Domain varField + Workstation varField + + Version +} + +var defaultFlags = negotiateFlagNTLMSSPNEGOTIATETARGETINFO | + negotiateFlagNTLMSSPNEGOTIATE56 | + negotiateFlagNTLMSSPNEGOTIATE128 | + negotiateFlagNTLMSSPNEGOTIATEUNICODE | + negotiateFlagNTLMSSPNEGOTIATEEXTENDEDSESSIONSECURITY + +//NewNegotiateMessage creates a new NEGOTIATE message with the +//flags that this package supports. +func NewNegotiateMessage(domainName, workstationName string) ([]byte, error) { + payloadOffset := expMsgBodyLen + flags := defaultFlags + + if domainName != "" { + flags |= negotiateFlagNTLMSSPNEGOTIATEOEMDOMAINSUPPLIED + } + + if workstationName != "" { + flags |= negotiateFlagNTLMSSPNEGOTIATEOEMWORKSTATIONSUPPLIED + } + + msg := negotiateMessageFields{ + messageHeader: newMessageHeader(1), + NegotiateFlags: flags, + Domain: newVarField(&payloadOffset, len(domainName)), + Workstation: newVarField(&payloadOffset, len(workstationName)), + Version: DefaultVersion(), + } + + b := bytes.Buffer{} + if err := binary.Write(&b, binary.LittleEndian, &msg); err != nil { + return nil, err + } + if b.Len() != expMsgBodyLen { + return nil, errors.New("incorrect body length") + } + + payload := strings.ToUpper(domainName + workstationName) + if _, err := b.WriteString(payload); err != nil { + return nil, err + } + + return b.Bytes(), nil +} diff --git a/vendor/github.com/Azure/go-ntlmssp/negotiator.go b/vendor/github.com/Azure/go-ntlmssp/negotiator.go new file mode 100644 index 00000000..7705eae4 --- /dev/null +++ b/vendor/github.com/Azure/go-ntlmssp/negotiator.go @@ -0,0 +1,144 @@ +package ntlmssp + +import ( + "bytes" + "encoding/base64" + "io" + "io/ioutil" + "net/http" + "strings" +) + +// GetDomain : parse domain name from based on slashes in the input +func GetDomain(user string) (string, string) { + domain := "" + + if strings.Contains(user, "\\") { + ucomponents := strings.SplitN(user, "\\", 2) + domain = ucomponents[0] + user = ucomponents[1] + } + return user, domain +} + +//Negotiator is a http.Roundtripper decorator that automatically +//converts basic authentication to NTLM/Negotiate authentication when appropriate. +type Negotiator struct{ http.RoundTripper } + +//RoundTrip sends the request to the server, handling any authentication +//re-sends as needed. +func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error) { + // Use default round tripper if not provided + rt := l.RoundTripper + if rt == nil { + rt = http.DefaultTransport + } + // If it is not basic auth, just round trip the request as usual + reqauth := authheader(req.Header.Get("Authorization")) + if !reqauth.IsBasic() { + return rt.RoundTrip(req) + } + // Save request body + body := bytes.Buffer{} + if req.Body != nil { + _, err = body.ReadFrom(req.Body) + if err != nil { + return nil, err + } + + req.Body.Close() + req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes())) + } + // first try anonymous, in case the server still finds us + // authenticated from previous traffic + req.Header.Del("Authorization") + res, err = rt.RoundTrip(req) + if err != nil { + return nil, err + } + if res.StatusCode != http.StatusUnauthorized { + return res, err + } + + resauth := authheader(res.Header.Get("Www-Authenticate")) + if !resauth.IsNegotiate() && !resauth.IsNTLM() { + // Unauthorized, Negotiate not requested, let's try with basic auth + req.Header.Set("Authorization", string(reqauth)) + io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes())) + + res, err = rt.RoundTrip(req) + if err != nil { + return nil, err + } + if res.StatusCode != http.StatusUnauthorized { + return res, err + } + resauth = authheader(res.Header.Get("Www-Authenticate")) + } + + if resauth.IsNegotiate() || resauth.IsNTLM() { + // 401 with request:Basic and response:Negotiate + io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + + // recycle credentials + u, p, err := reqauth.GetBasicCreds() + if err != nil { + return nil, err + } + + // get domain from username + domain := "" + u, domain = GetDomain(u) + + // send negotiate + negotiateMessage, err := NewNegotiateMessage(domain, "") + if err != nil { + return nil, err + } + if resauth.IsNTLM() { + req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(negotiateMessage)) + } else { + req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(negotiateMessage)) + } + + req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes())) + + res, err = rt.RoundTrip(req) + if err != nil { + return nil, err + } + + // receive challenge? + resauth = authheader(res.Header.Get("Www-Authenticate")) + challengeMessage, err := resauth.GetData() + if err != nil { + return nil, err + } + if !(resauth.IsNegotiate() || resauth.IsNTLM()) || len(challengeMessage) == 0 { + // Negotiation failed, let client deal with response + return res, nil + } + io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + + // send authenticate + authenticateMessage, err := ProcessChallenge(challengeMessage, u, p) + if err != nil { + return nil, err + } + if resauth.IsNTLM() { + req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(authenticateMessage)) + } else { + req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(authenticateMessage)) + } + + req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes())) + + return rt.RoundTrip(req) + } + + return res, err +} diff --git a/vendor/github.com/Azure/go-ntlmssp/nlmp.go b/vendor/github.com/Azure/go-ntlmssp/nlmp.go new file mode 100644 index 00000000..1e65abe8 --- /dev/null +++ b/vendor/github.com/Azure/go-ntlmssp/nlmp.go @@ -0,0 +1,51 @@ +// Package ntlmssp provides NTLM/Negotiate authentication over HTTP +// +// Protocol details from https://msdn.microsoft.com/en-us/library/cc236621.aspx, +// implementation hints from http://davenport.sourceforge.net/ntlm.html . +// This package only implements authentication, no key exchange or encryption. It +// only supports Unicode (UTF16LE) encoding of protocol strings, no OEM encoding. +// This package implements NTLMv2. +package ntlmssp + +import ( + "crypto/hmac" + "crypto/md5" + "golang.org/x/crypto/md4" + "strings" +) + +func getNtlmV2Hash(password, username, target string) []byte { + return hmacMd5(getNtlmHash(password), toUnicode(strings.ToUpper(username)+target)) +} + +func getNtlmHash(password string) []byte { + hash := md4.New() + hash.Write(toUnicode(password)) + return hash.Sum(nil) +} + +func computeNtlmV2Response(ntlmV2Hash, serverChallenge, clientChallenge, + timestamp, targetInfo []byte) []byte { + + temp := []byte{1, 1, 0, 0, 0, 0, 0, 0} + temp = append(temp, timestamp...) + temp = append(temp, clientChallenge...) + temp = append(temp, 0, 0, 0, 0) + temp = append(temp, targetInfo...) + temp = append(temp, 0, 0, 0, 0) + + NTProofStr := hmacMd5(ntlmV2Hash, serverChallenge, temp) + return append(NTProofStr, temp...) +} + +func computeLmV2Response(ntlmV2Hash, serverChallenge, clientChallenge []byte) []byte { + return append(hmacMd5(ntlmV2Hash, serverChallenge, clientChallenge), clientChallenge...) +} + +func hmacMd5(key []byte, data ...[]byte) []byte { + mac := hmac.New(md5.New, key) + for _, d := range data { + mac.Write(d) + } + return mac.Sum(nil) +} diff --git a/vendor/github.com/Azure/go-ntlmssp/unicode.go b/vendor/github.com/Azure/go-ntlmssp/unicode.go new file mode 100644 index 00000000..7b4f4716 --- /dev/null +++ b/vendor/github.com/Azure/go-ntlmssp/unicode.go @@ -0,0 +1,29 @@ +package ntlmssp + +import ( + "bytes" + "encoding/binary" + "errors" + "unicode/utf16" +) + +// helper func's for dealing with Windows Unicode (UTF16LE) + +func fromUnicode(d []byte) (string, error) { + if len(d)%2 > 0 { + return "", errors.New("Unicode (UTF 16 LE) specified, but uneven data length") + } + s := make([]uint16, len(d)/2) + err := binary.Read(bytes.NewReader(d), binary.LittleEndian, &s) + if err != nil { + return "", err + } + return string(utf16.Decode(s)), nil +} + +func toUnicode(s string) []byte { + uints := utf16.Encode([]rune(s)) + b := bytes.Buffer{} + binary.Write(&b, binary.LittleEndian, &uints) + return b.Bytes() +} diff --git a/vendor/github.com/Azure/go-ntlmssp/varfield.go b/vendor/github.com/Azure/go-ntlmssp/varfield.go new file mode 100644 index 00000000..15f9aa11 --- /dev/null +++ b/vendor/github.com/Azure/go-ntlmssp/varfield.go @@ -0,0 +1,40 @@ +package ntlmssp + +import ( + "errors" +) + +type varField struct { + Len uint16 + MaxLen uint16 + BufferOffset uint32 +} + +func (f varField) ReadFrom(buffer []byte) ([]byte, error) { + if len(buffer) < int(f.BufferOffset+uint32(f.Len)) { + return nil, errors.New("Error reading data, varField extends beyond buffer") + } + return buffer[f.BufferOffset : f.BufferOffset+uint32(f.Len)], nil +} + +func (f varField) ReadStringFrom(buffer []byte, unicode bool) (string, error) { + d, err := f.ReadFrom(buffer) + if err != nil { + return "", err + } + if unicode { // UTF-16LE encoding scheme + return fromUnicode(d) + } + // OEM encoding, close enough to ASCII, since no code page is specified + return string(d), err +} + +func newVarField(ptr *int, fieldsize int) varField { + f := varField{ + Len: uint16(fieldsize), + MaxLen: uint16(fieldsize), + BufferOffset: uint32(*ptr), + } + *ptr += fieldsize + return f +} diff --git a/vendor/github.com/Azure/go-ntlmssp/version.go b/vendor/github.com/Azure/go-ntlmssp/version.go new file mode 100644 index 00000000..6d848921 --- /dev/null +++ b/vendor/github.com/Azure/go-ntlmssp/version.go @@ -0,0 +1,20 @@ +package ntlmssp + +// Version is a struct representing https://msdn.microsoft.com/en-us/library/cc236654.aspx +type Version struct { + ProductMajorVersion uint8 + ProductMinorVersion uint8 + ProductBuild uint16 + _ [3]byte + NTLMRevisionCurrent uint8 +} + +// DefaultVersion returns a Version with "sensible" defaults (Windows 7) +func DefaultVersion() Version { + return Version{ + ProductMajorVersion: 6, + ProductMinorVersion: 1, + ProductBuild: 7601, + NTLMRevisionCurrent: 15, + } +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/.gitignore b/vendor/github.com/denisenkom/go-mssqldb/.gitignore new file mode 100644 index 00000000..1f8b088f --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/.gitignore @@ -0,0 +1,13 @@ +/.idea +/.connstr +.vscode +.terraform +*.tfstate* +*.log +*.swp +*~ +coverage.json +coverage.txt +coverage.xml +testresults.xml + diff --git a/vendor/github.com/denisenkom/go-mssqldb/.golangci.yml b/vendor/github.com/denisenkom/go-mssqldb/.golangci.yml new file mode 100644 index 00000000..959cd5e6 --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/.golangci.yml @@ -0,0 +1,10 @@ +linters: + enable: + # basic go linters + - gofmt + - golint + - govet + + # sql related linters + - rowserrcheck + - sqlclosecheck diff --git a/vendor/github.com/denisenkom/go-mssqldb/README.md b/vendor/github.com/denisenkom/go-mssqldb/README.md index 94d87fe0..92e74b65 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/README.md +++ b/vendor/github.com/denisenkom/go-mssqldb/README.md @@ -1,6 +1,6 @@ # A pure Go MSSQL driver for Go's database/sql package -[![GoDoc](https://godoc.org/github.com/denisenkom/go-mssqldb?status.svg)](http://godoc.org/github.com/denisenkom/go-mssqldb) +[![Go Reference](https://pkg.go.dev/badge/github.com/denisenkom/go-mssqldb.svg)](https://pkg.go.dev/github.com/denisenkom/go-mssqldb) [![Build status](https://ci.appveyor.com/api/projects/status/jrln8cs62wj9i0a2?svg=true)](https://ci.appveyor.com/project/denisenkom/go-mssqldb) [![codecov](https://codecov.io/gh/denisenkom/go-mssqldb/branch/master/graph/badge.svg)](https://codecov.io/gh/denisenkom/go-mssqldb) @@ -16,7 +16,7 @@ The recommended connection string uses a URL format: `sqlserver://username:password@host/instance?param1=value¶m2=value` Other supported formats are listed below. -### Common parameters: +### Common parameters * `user id` - enter the SQL Server Authentication user id or the Windows Authentication user id in the DOMAIN\User format. On Windows, if user id is empty or missing Single-Sign-On is used. The user domain sensitive to the case which is defined in the connection string. * `password` @@ -29,24 +29,24 @@ Other supported formats are listed below. * `true` - Data sent between client and server is encrypted. * `app name` - The application name (default is go-mssqldb) -### Connection parameters for ODBC and ADO style connection strings: +### Connection parameters for ODBC and ADO style connection strings * `server` - host or host\instance (default localhost) * `port` - used only when there is no instance in server (default 1433) -### Less common parameters: +### Less common parameters * `keepAlive` - in seconds; 0 to disable (default is 30) -* `failoverpartner` - host or host\instance (default is no partner). +* `failoverpartner` - host or host\instance (default is no partner). * `failoverport` - used only when there is no instance in failoverpartner (default 1433) * `packet size` - in bytes; 512 to 32767 (default is 4096) * Encrypted connections have a maximum packet size of 16383 bytes - * Further information on usage: https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option + * Further information on usage: * `log` - logging flags (default 0/no logging, 63 for full logging) - * 1 log errors - * 2 log messages - * 4 log rows affected - * 8 trace sql statements + * 1 log errors + * 2 log messages + * 4 log rows affected + * 8 trace sql statements * 16 log statement parameters * 32 log transaction begin/end * `TrustServerCertificate` @@ -56,72 +56,82 @@ Other supported formats are listed below. * `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host. * `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port. * `Workstation ID` - The workstation name (default is the host name) -* `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`. - -### The connection string can be specified in one of three formats: +* `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`. +### The connection string can be specified in one of three formats 1. URL: with `sqlserver` scheme. username and password appears before the host. Any instance appears as the first segment in the path. All other options are query parameters. Examples: - * `sqlserver://username:password@host/instance?param1=value¶m2=value` - * `sqlserver://username:password@host:port?param1=value¶m2=value` - * `sqlserver://sa@localhost/SQLExpress?database=master&connection+timeout=30` // `SQLExpress instance. - * `sqlserver://sa:mypass@localhost?database=master&connection+timeout=30` // username=sa, password=mypass. - * `sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30` // port 1234 on localhost. - * `sqlserver://sa:my%7Bpass@somehost?connection+timeout=30` // password is "my{pass" + * `sqlserver://username:password@host/instance?param1=value¶m2=value` + * `sqlserver://username:password@host:port?param1=value¶m2=value` + * `sqlserver://sa@localhost/SQLExpress?database=master&connection+timeout=30` // `SQLExpress instance. + * `sqlserver://sa:mypass@localhost?database=master&connection+timeout=30` // username=sa, password=mypass. + * `sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30` // port 1234 on localhost. + * `sqlserver://sa:my%7Bpass@somehost?connection+timeout=30` // password is "my{pass" + A string of this format can be constructed using the `URL` type in the `net/url` package. - A string of this format can be constructed using the `URL` type in the `net/url` package. + ```go -```go - query := url.Values{} - query.Add("app name", "MyAppName") + query := url.Values{} + query.Add("app name", "MyAppName") + + u := &url.URL{ + Scheme: "sqlserver", + User: url.UserPassword(username, password), + Host: fmt.Sprintf("%s:%d", hostname, port), + // Path: instance, // if connecting to an instance instead of a port + RawQuery: query.Encode(), + } + db, err := sql.Open("sqlserver", u.String()) - u := &url.URL{ - Scheme: "sqlserver", - User: url.UserPassword(username, password), - Host: fmt.Sprintf("%s:%d", hostname, port), - // Path: instance, // if connecting to an instance instead of a port - RawQuery: query.Encode(), - } - db, err := sql.Open("sqlserver", u.String()) -``` + ``` 2. ADO: `key=value` pairs separated by `;`. Values may not contain `;`, leading and trailing whitespace is ignored. Examples: - - * `server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName` - * `server=localhost;user id=sa;database=master;app name=MyAppName` + + * `server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName` + * `server=localhost;user id=sa;database=master;app name=MyAppName` + + ADO strings support synonyms for database, app name, user id, and server + * server <= addr, address, network address, data source + * user id <= user, uid + * database <= initial catalog + * app name <= application name 3. ODBC: Prefix with `odbc`, `key=value` pairs separated by `;`. Allow `;` by wrapping values in `{}`. Examples: - - * `odbc:server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName` - * `odbc:server=localhost;user id=sa;database=master;app name=MyAppName` - * `odbc:server=localhost;user id=sa;password={foo;bar}` // Value marked with `{}`, password is "foo;bar" - * `odbc:server=localhost;user id=sa;password={foo{bar}` // Value marked with `{}`, password is "foo{bar" - * `odbc:server=localhost;user id=sa;password={foobar }` // Value marked with `{}`, password is "foobar " - * `odbc:server=localhost;user id=sa;password=foo{bar` // Literal `{`, password is "foo{bar" - * `odbc:server=localhost;user id=sa;password=foo}bar` // Literal `}`, password is "foo}bar" - * `odbc:server=localhost;user id=sa;password={foo{bar}` // Literal `{`, password is "foo{bar" - * `odbc:server=localhost;user id=sa;password={foo}}bar}` // Escaped `} with `}}`, password is "foo}bar" + + * `odbc:server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName` + * `odbc:server=localhost;user id=sa;database=master;app name=MyAppName` + * `odbc:server=localhost;user id=sa;password={foo;bar}` // Value marked with `{}`, password is "foo;bar" + * `odbc:server=localhost;user id=sa;password={foo{bar}` // Value marked with `{}`, password is "foo{bar" + * `odbc:server=localhost;user id=sa;password={foobar }` // Value marked with `{}`, password is "foobar " + * `odbc:server=localhost;user id=sa;password=foo{bar` // Literal `{`, password is "foo{bar" + * `odbc:server=localhost;user id=sa;password=foo}bar` // Literal `}`, password is "foo}bar" + * `odbc:server=localhost;user id=sa;password={foo{bar}` // Literal `{`, password is "foo{bar" + * `odbc:server=localhost;user id=sa;password={foo}}bar}` // Escaped `} with`}}`, password is "foo}bar" ### Azure Active Directory authentication - preview The configuration of functionality might change in the future. -Azure Active Directory (AAD) access tokens are relatively short lived and need to be +Azure Active Directory (AAD) access tokens are relatively short lived and need to be valid when a new connection is made. Authentication is supported using a callback func that provides a fresh and valid token using a connector: -``` golang + +``` go + conn, err := mssql.NewAccessTokenConnector( - "Server=test.database.windows.net;Database=testdb", - tokenProvider) + "Server=test.database.windows.net;Database=testdb", + tokenProvider) if err != nil { // handle errors in DSN } db := sql.OpenDB(conn) + ``` + Where `tokenProvider` is a function that returns a fresh access token or an error. None of these statements actually trigger the retrieval of a token, this happens when the first statment is issued and a connection is created. @@ -129,18 +139,23 @@ is created. ## Executing Stored Procedures To run a stored procedure, set the query text to the procedure name: + ```go + var account = "abc" _, err := db.ExecContext(ctx, "sp_RunMe", sql.Named("ID", 123), sql.Named("Account", sql.Out{Dest: &account}), ) + ``` ## Reading Output Parameters from a Stored Procedure with Resultset To read output parameters from a stored procedure with resultset, make sure you read all the rows before reading the output parameters: + ```go + sqltextcreate := ` CREATE PROCEDURE spwithoutputandrows @bitparam BIT OUTPUT @@ -156,6 +171,7 @@ for rows.Next() { err = rows.Scan(&strrow) } fmt.Printf("bitparam is %d", bitout) + ``` ## Caveat for local temporary tables @@ -201,21 +217,25 @@ _, err := conn.ExecContext(ctx, "insert into #mytemp (x) values (@p1)", 1) To get the procedure return status, pass into the parameters a `*mssql.ReturnStatus`. For example: -``` + +```go + var rs mssql.ReturnStatus _, err := db.ExecContext(ctx, "theproc", &rs) log.Printf("status=%d", rs) + ``` or -``` +```go var rs mssql.ReturnStatus _, err := db.QueryContext(ctx, "theproc", &rs) for rows.Next() { err = rows.Scan(&val) } log.Printf("status=%d", rs) + ``` Limitation: ReturnStatus cannot be retrieved using `QueryRow`. @@ -226,7 +246,9 @@ The `sqlserver` driver uses normal MS SQL Server syntax and expects parameters i the sql query to be in the form of either `@Name` or `@p1` to `@pN` (ordinal position). ```go + db.QueryContext(ctx, `select * from t where ID = @ID and Name = @p2;`, sql.Named("ID", 6), "Bob") + ``` ### Parameter Types @@ -235,30 +257,30 @@ To pass specific types to the query parameters, say `varchar` or `date` types, you must convert the types to the type before passing in. The following types are supported: - * string -> nvarchar - * mssql.VarChar -> varchar - * time.Time -> datetimeoffset or datetime (TDS version dependent) - * mssql.DateTime1 -> datetime - * mssql.DateTimeOffset -> datetimeoffset - * "github.com/golang-sql/civil".Date -> date - * "github.com/golang-sql/civil".DateTime -> datetime2 - * "github.com/golang-sql/civil".Time -> time - * mssql.TVP -> Table Value Parameter (TDS version dependent) +* string -> nvarchar +* mssql.VarChar -> varchar +* time.Time -> datetimeoffset or datetime (TDS version dependent) +* mssql.DateTime1 -> datetime +* mssql.DateTimeOffset -> datetimeoffset +* "github.com/golang-sql/civil".Date -> date +* "github.com/golang-sql/civil".DateTime -> datetime2 +* "github.com/golang-sql/civil".Time -> time +* mssql.TVP -> Table Value Parameter (TDS version dependent) ## Important Notes - * [LastInsertId](https://golang.org/pkg/database/sql/#Result.LastInsertId) should +* [LastInsertId](https://golang.org/pkg/database/sql/#Result.LastInsertId) should not be used with this driver (or SQL Server) due to how the TDS protocol - works. Please use the [OUTPUT Clause](https://docs.microsoft.com/en-us/sql/t-sql/queries/output-clause-transact-sql) - or add a `select ID = convert(bigint, SCOPE_IDENTITY());` to the end of your - query (ref [SCOPE_IDENTITY](https://docs.microsoft.com/en-us/sql/t-sql/functions/scope-identity-transact-sql)). - This will ensure you are getting the correct ID and will prevent a network round trip. - * [NewConnector](https://godoc.org/github.com/denisenkom/go-mssqldb#NewConnector) + works. Please use the [OUTPUT Clause](https://docs.microsoft.com/en-us/sql/t-sql/queries/output-clause-transact-sql) + or add a `select ID = convert(bigint, SCOPE_IDENTITY());` to the end of your + query (ref [SCOPE_IDENTITY](https://docs.microsoft.com/en-us/sql/t-sql/functions/scope-identity-transact-sql)). + This will ensure you are getting the correct ID and will prevent a network round trip. +* [NewConnector](https://godoc.org/github.com/denisenkom/go-mssqldb#NewConnector) may be used with [OpenDB](https://golang.org/pkg/database/sql/#OpenDB). - * [Connector.SessionInitSQL](https://godoc.org/github.com/denisenkom/go-mssqldb#Connector.SessionInitSQL) - may be set to set any driver specific session settings after the session - has been reset. If empty the session will still be reset but use the database - defaults in Go1.10+. +* [Connector.SessionInitSQL](https://godoc.org/github.com/denisenkom/go-mssqldb#Connector.SessionInitSQL) + may be set to set any driver specific session settings after the session + has been reset. If empty the session will still be reset but use the database + defaults in Go1.10+. ## Features @@ -280,7 +302,9 @@ Environment variables are used to pass login information. Example: +```bash env SQLSERVER_DSN=sqlserver://user:pass@hostname/instance?database=test1 go test +``` ## Deprecated @@ -296,7 +320,7 @@ will be loosly parsed and an attempt to extract identifiers using one of * :nnn * $nnn -will be used. This is not recommended with SQL Server. +will be used. This is not recommended with SQL Server. There is at least one existing `won't fix` issue with the query parsing. Use the native "@Name" parameters instead with the "sqlserver" driver name. @@ -306,4 +330,4 @@ Use the native "@Name" parameters instead with the "sqlserver" driver name. * SQL Server 2008 and 2008 R2 engine cannot handle login records when SSL encryption is not disabled. To fix SQL Server 2008 R2 issue, install SQL Server 2008 R2 Service Pack 2. To fix SQL Server 2008 issue, install Microsoft SQL Server 2008 Service Pack 3 and Cumulative update package 3 for SQL Server 2008 SP3. -More information: http://support.microsoft.com/kb/2653857 +More information: diff --git a/vendor/github.com/denisenkom/go-mssqldb/accesstokenconnector.go b/vendor/github.com/denisenkom/go-mssqldb/accesstokenconnector.go index 8dbe5099..80213d1e 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/accesstokenconnector.go +++ b/vendor/github.com/denisenkom/go-mssqldb/accesstokenconnector.go @@ -6,19 +6,8 @@ import ( "context" "database/sql/driver" "errors" - "fmt" ) -var _ driver.Connector = &accessTokenConnector{} - -// accessTokenConnector wraps Connector and injects a -// fresh access token when connecting to the database -type accessTokenConnector struct { - Connector - - accessTokenProvider func() (string, error) -} - // NewAccessTokenConnector creates a new connector from a DSN and a token provider. // The token provider func will be called when a new connection is requested and should return a valid access token. // The returned connector may be used with sql.OpenDB. @@ -32,20 +21,11 @@ func NewAccessTokenConnector(dsn string, tokenProvider func() (string, error)) ( return nil, err } - c := &accessTokenConnector{ - Connector: *conn, - accessTokenProvider: tokenProvider, - } - return c, nil -} - -// Connect returns a new database connection -func (c *accessTokenConnector) Connect(ctx context.Context) (driver.Conn, error) { - var err error - c.Connector.params.fedAuthAccessToken, err = c.accessTokenProvider() - if err != nil { - return nil, fmt.Errorf("mssql: error retrieving access token: %+v", err) + conn.fedAuthRequired = true + conn.fedAuthLibrary = fedAuthLibrarySecurityToken + conn.securityTokenProvider = func(ctx context.Context) (string, error) { + return tokenProvider() } - return c.Connector.Connect(ctx) + return conn, nil } diff --git a/vendor/github.com/denisenkom/go-mssqldb/appveyor.yml b/vendor/github.com/denisenkom/go-mssqldb/appveyor.yml index c4d2bb06..ecb893a3 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/appveyor.yml +++ b/vendor/github.com/denisenkom/go-mssqldb/appveyor.yml @@ -1,6 +1,7 @@ version: 1.0.{build} -os: Windows Server 2012 R2 +image: + - Visual Studio 2015 clone_folder: c:\gopath\src\github.com\denisenkom\go-mssqldb @@ -9,21 +10,39 @@ environment: HOST: localhost SQLUSER: sa SQLPASSWORD: Password12! - DATABASE: test - GOVERSION: 111 + DATABASE: test + GOVERSION: 113 matrix: - GOVERSION: 18 - SQLINSTANCE: SQL2016 + SQLINSTANCE: SQL2017 - GOVERSION: 19 - SQLINSTANCE: SQL2016 + SQLINSTANCE: SQL2017 - GOVERSION: 110 - SQLINSTANCE: SQL2016 + SQLINSTANCE: SQL2017 - GOVERSION: 111 - SQLINSTANCE: SQL2016 + SQLINSTANCE: SQL2017 + - GOVERSION: 112 + SQLINSTANCE: SQL2017 + - SQLINSTANCE: SQL2017 + - SQLINSTANCE: SQL2016 - SQLINSTANCE: SQL2014 - SQLINSTANCE: SQL2012SP1 - SQLINSTANCE: SQL2008R2SP2 - + + # Go 1.14+ and SQL2019 are available on the Visual Studio 2019 image only + - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 + GOVERSION: 114 + SQLINSTANCE: SQL2019 + - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 + GOVERSION: 115 + SQLINSTANCE: SQL2019 + - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 + GOVERSION: 115 + SQLINSTANCE: SQL2017 + - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 + GOVERSION: 116 + SQLINSTANCE: SQL2017 + install: - set GOROOT=c:\go%GOVERSION% - set PATH=%GOPATH%\bin;%GOROOT%\bin;%PATH% @@ -35,15 +54,14 @@ build_script: - go build before_test: - # setup SQL Server - - ps: | + # setup SQL Server + - ps: | $instanceName = $env:SQLINSTANCE Start-Service "MSSQL`$$instanceName" Start-Service "SQLBrowser" - sqlcmd -S "(local)\%SQLINSTANCE%" -Q "Use [master]; CREATE DATABASE test;" - sqlcmd -S "(local)\%SQLINSTANCE%" -h -1 -Q "set nocount on; Select @@version" - pip install codecov - test_script: - go test -race -cpu 4 -coverprofile=coverage.txt -covermode=atomic diff --git a/vendor/github.com/denisenkom/go-mssqldb/buf.go b/vendor/github.com/denisenkom/go-mssqldb/buf.go index ba39b40f..051d7ba5 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/buf.go +++ b/vendor/github.com/denisenkom/go-mssqldb/buf.go @@ -48,8 +48,8 @@ type tdsBuffer struct { func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer { return &tdsBuffer{ packetSize: int(bufsize), - wbuf: make([]byte, 1<<16), - rbuf: make([]byte, 1<<16), + wbuf: make([]byte, bufsize), + rbuf: make([]byte, bufsize), rpos: 8, transport: transport, } @@ -137,19 +137,28 @@ func (w *tdsBuffer) FinishPacket() error { var headerSize = binary.Size(header{}) func (r *tdsBuffer) readNextPacket() error { - h := header{} - var err error - err = binary.Read(r.transport, binary.BigEndian, &h) + buf := r.rbuf[:headerSize] + _, err := io.ReadFull(r.transport, buf) if err != nil { return err } + h := header{ + PacketType: packetType(buf[0]), + Status: buf[1], + Size: binary.BigEndian.Uint16(buf[2:4]), + Spid: binary.BigEndian.Uint16(buf[4:6]), + PacketNo: buf[6], + Pad: buf[7], + } if int(h.Size) > r.packetSize { - return errors.New("Invalid packet size, it is longer than buffer size") + return errors.New("invalid packet size, it is longer than buffer size") } if headerSize > int(h.Size) { - return errors.New("Invalid packet size, it is shorter than header size") + return errors.New("invalid packet size, it is shorter than header size") } _, err = io.ReadFull(r.transport, r.rbuf[headerSize:h.Size]) + //s := base64.StdEncoding.EncodeToString(r.rbuf[headerSize:h.Size]) + //fmt.Print(s) if err != nil { return err } diff --git a/vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go b/vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go index 1d5eacb3..bbae4e69 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go +++ b/vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go @@ -7,6 +7,7 @@ import ( "fmt" "math" "reflect" + "strconv" "strings" "time" @@ -44,8 +45,9 @@ type BulkOptions struct { type DataValue interface{} const ( - sqlDateFormat = "2006-01-02" - sqlTimeFormat = "2006-01-02 15:04:05.999999999Z07:00" + sqlDateFormat = "2006-01-02" + sqlDateTimeFormat = "2006-01-02 15:04:05.999999999Z07:00" + sqlTimeFormat = "15:04:05.9999999" ) func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) { @@ -86,7 +88,7 @@ func (b *Bulk) sendBulkCommand(ctx context.Context) (err error) { b.bulkColumns = append(b.bulkColumns, *bulkCol) b.dlogf("Adding column %s %s %#x", colname, bulkCol.ColName, bulkCol.ti.TypeId) } else { - return fmt.Errorf("Column %s does not exist in destination table %s", colname, b.tablename) + return fmt.Errorf("column %s does not exist in destination table %s", colname, b.tablename) } } @@ -166,7 +168,7 @@ func (b *Bulk) AddRow(row []interface{}) (err error) { } if len(row) != len(b.bulkColumns) { - return fmt.Errorf("Row does not have the same number of columns than the destination table %d %d", + return fmt.Errorf("row does not have the same number of columns than the destination table %d %d", len(row), len(b.bulkColumns)) } @@ -215,7 +217,7 @@ func (b *Bulk) makeRowData(row []interface{}) ([]byte, error) { } func (b *Bulk) Done() (rowcount int64, err error) { - if b.headerSent == false { + if !b.headerSent { //no rows had been sent return 0, nil } @@ -233,24 +235,13 @@ func (b *Bulk) Done() (rowcount int64, err error) { buf.FinishPacket() - tokchan := make(chan tokenStruct, 5) - go processResponse(b.ctx, b.cn.sess, tokchan, nil) - - var rowCount int64 - for token := range tokchan { - switch token := token.(type) { - case doneStruct: - if token.Status&doneCount != 0 { - rowCount = int64(token.RowCount) - } - if token.isError() { - return 0, token.getError() - } - case error: - return 0, b.cn.checkBadConn(token) - } + reader := startReading(b.cn.sess, b.ctx, outputs{}) + err = reader.iterateResponse() + if err != nil { + return 0, b.cn.checkBadConn(err, false) } - return rowCount, nil + + return reader.rowCount, nil } func (b *Bulk) createColMetadata() []byte { @@ -339,6 +330,10 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) intvalue = int64(val) case int64: intvalue = val + case float32: + intvalue = int64(val) + case float64: + intvalue = int64(val) default: err = fmt.Errorf("mssql: invalid type for int column: %T", val) return @@ -383,6 +378,8 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) switch val := val.(type) { case string: res.buffer = str2ucs2(val) + case int64: + res.buffer = []byte(strconv.FormatInt(val, 10)) case []byte: res.buffer = val default: @@ -397,6 +394,8 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) res.buffer = []byte(val) case []byte: res.buffer = val + case int64: + res.buffer = []byte(strconv.FormatInt(val, 10)) default: err = fmt.Errorf("mssql: invalid type for varchar column: %T %s", val, val) return @@ -421,7 +420,7 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) res.ti.Size = len(res.buffer) case string: var t time.Time - if t, err = time.Parse(sqlTimeFormat, val); err != nil { + if t, err = time.Parse(sqlDateTimeFormat, val); err != nil { return res, fmt.Errorf("bulk: unable to convert string to date: %v", err) } res.buffer = encodeDateTime2(t, int(col.ti.Scale)) @@ -437,7 +436,7 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) res.ti.Size = len(res.buffer) case string: var t time.Time - if t, err = time.Parse(sqlTimeFormat, val); err != nil { + if t, err = time.Parse(sqlDateTimeFormat, val); err != nil { return res, fmt.Errorf("bulk: unable to convert string to date: %v", err) } res.buffer = encodeDateTimeOffset(t, int(col.ti.Scale)) @@ -468,7 +467,7 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) case time.Time: t = val case string: - if t, err = time.Parse(sqlTimeFormat, val); err != nil { + if t, err = time.Parse(sqlDateTimeFormat, val); err != nil { return res, fmt.Errorf("bulk: unable to convert string to date: %v", err) } default: @@ -485,7 +484,22 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) } else { err = fmt.Errorf("mssql: invalid size of column %d", col.ti.Size) } - + case typeTimeN: + var t time.Time + switch val := val.(type) { + case time.Time: + res.buffer = encodeTime(val.Hour(), val.Minute(), val.Second(), val.Nanosecond(), int(col.ti.Scale)) + res.ti.Size = len(res.buffer) + case string: + if t, err = time.Parse(sqlTimeFormat, val); err != nil { + return res, fmt.Errorf("bulk: unable to convert string to time: %v", err) + } + res.buffer = encodeTime(t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), int(col.ti.Scale)) + res.ti.Size = len(res.buffer) + default: + err = fmt.Errorf("mssql: invalid type for time column: %T %s", val, val) + return + } // case typeMoney, typeMoney4, typeMoneyN: case typeDecimal, typeDecimalN, typeNumeric, typeNumericN: prec := col.ti.Prec diff --git a/vendor/github.com/denisenkom/go-mssqldb/error.go b/vendor/github.com/denisenkom/go-mssqldb/error.go index 2e5bacee..59a7a1a4 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/error.go +++ b/vendor/github.com/denisenkom/go-mssqldb/error.go @@ -1,6 +1,7 @@ package mssql import ( + "database/sql/driver" "fmt" ) @@ -17,6 +18,9 @@ type Error struct { ServerName string ProcName string LineNo int32 + // All lists all errors that were received from first to last. + // This includes the last one, which is described in the other members. + All []Error } func (e Error) Error() string { @@ -65,9 +69,53 @@ func streamErrorf(format string, v ...interface{}) StreamError { } func badStreamPanic(err error) { - panic(err) + panic(streamErrorf("%v", err)) } func badStreamPanicf(format string, v ...interface{}) { panic(streamErrorf(format, v...)) } + +// ServerError is returned when the server got a fatal error +// that aborts the process and severs the connection. +// +// To get the errors returned before the process was aborted, +// unwrap this error or call errors.As with a pointer to an +// mssql.Error variable. +type ServerError struct { + sqlError Error +} + +func (e ServerError) Error() string { + return "SQL Server had internal error" +} + +func (e ServerError) Unwrap() error { + return e.sqlError +} + +// RetryableError is returned when an error was caused by a bad +// connection at the start of a query and can be safely retried +// using database/sql's automatic retry logic. +// +// In many cases database/sql's retry logic will transparently +// handle this error, the retried call will return successfully, +// and you won't even see this error. However, you may see this +// error if the retry logic cannot successfully handle the error. +// In that case you can get the underlying error by calling this +// error's UnWrap function. +type RetryableError struct { + err error +} + +func (r RetryableError) Error() string { + return r.err.Error() +} + +func (r RetryableError) Unwrap() error { + return r.err +} + +func (r RetryableError) Is(err error) bool { + return err == driver.ErrBadConn +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/fedauth.go b/vendor/github.com/denisenkom/go-mssqldb/fedauth.go new file mode 100644 index 00000000..459c6641 --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/fedauth.go @@ -0,0 +1,78 @@ +package mssql + +import ( + "context" + "errors" + + "github.com/denisenkom/go-mssqldb/msdsn" +) + +// Federated authentication library affects the login data structure and message sequence. +const ( + // fedAuthLibraryLiveIDCompactToken specifies the Microsoft Live ID Compact Token authentication scheme + fedAuthLibraryLiveIDCompactToken = 0x00 + + // fedAuthLibrarySecurityToken specifies a token-based authentication where the token is available + // without additional information provided during the login sequence. + fedAuthLibrarySecurityToken = 0x01 + + // fedAuthLibraryADAL specifies a token-based authentication where a token is obtained during the + // login sequence using the server SPN and STS URL provided by the server during login. + fedAuthLibraryADAL = 0x02 + + // fedAuthLibraryReserved is used to indicate that no federated authentication scheme applies. + fedAuthLibraryReserved = 0x7F +) + +// Federated authentication ADAL workflow affects the mechanism used to authenticate. +const ( + // fedAuthADALWorkflowPassword uses a username/password to obtain a token from Active Directory + fedAuthADALWorkflowPassword = 0x01 + + // fedAuthADALWorkflowPassword uses the Windows identity to obtain a token from Active Directory + fedAuthADALWorkflowIntegrated = 0x02 + + // fedAuthADALWorkflowMSI uses the managed identity service to obtain a token + fedAuthADALWorkflowMSI = 0x03 +) + +// newSecurityTokenConnector creates a new connector from a Config and a token provider. +// When invoked, token provider implementations should contact the security token +// service specified and obtain the appropriate token, or return an error +// to indicate why a token is not available. +// The returned connector may be used with sql.OpenDB. +func newSecurityTokenConnector(config msdsn.Config, tokenProvider func(ctx context.Context) (string, error)) (*Connector, error) { + if tokenProvider == nil { + return nil, errors.New("mssql: tokenProvider cannot be nil") + } + + conn := NewConnectorConfig(config) + conn.fedAuthRequired = true + conn.fedAuthLibrary = fedAuthLibrarySecurityToken + conn.securityTokenProvider = tokenProvider + + return conn, nil +} + +// newADALTokenConnector creates a new connector from a Config and a Active Directory token provider. +// Token provider implementations are called during federated +// authentication login sequences where the server provides a service +// principal name and security token service endpoint that should be used +// to obtain the token. Implementations should contact the security token +// service specified and obtain the appropriate token, or return an error +// to indicate why a token is not available. +// +// The returned connector may be used with sql.OpenDB. +func newActiveDirectoryTokenConnector(config msdsn.Config, adalWorkflow byte, tokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error)) (*Connector, error) { + if tokenProvider == nil { + return nil, errors.New("mssql: tokenProvider cannot be nil") + } + + conn := NewConnectorConfig(config) + conn.fedAuthRequired = true + conn.fedAuthLibrary = fedAuthLibraryADAL + conn.fedAuthADALWorkflow = adalWorkflow + conn.adalTokenProvider = tokenProvider + + return conn, nil +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/conn_str.go b/vendor/github.com/denisenkom/go-mssqldb/msdsn/conn_str.go similarity index 51% rename from vendor/github.com/denisenkom/go-mssqldb/conn_str.go rename to vendor/github.com/denisenkom/go-mssqldb/msdsn/conn_str.go index 26ac50f3..d0eddbd8 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/conn_str.go +++ b/vendor/github.com/denisenkom/go-mssqldb/msdsn/conn_str.go @@ -1,7 +1,10 @@ -package mssql +package msdsn import ( + "crypto/tls" + "crypto/x509" "fmt" + "io/ioutil" "net" "net/url" "os" @@ -11,49 +14,104 @@ import ( "unicode" ) -const defaultServerPort = 1433 +type ( + Encryption int + Log uint64 +) -type connectParams struct { - logFlags uint64 - port uint64 - host string - instance string - database string - user string - password string - dial_timeout time.Duration - conn_timeout time.Duration - keepAlive time.Duration - encrypt bool - disableEncryption bool - trustServerCertificate bool - certificate string - hostInCertificate string - hostInCertificateProvided bool - serverSPN string - workstation string - appname string - typeFlags uint8 - failOverPartner string - failOverPort uint64 - packetSize uint16 - fedAuthAccessToken string +const ( + EncryptionOff = 0 + EncryptionRequired = 1 + EncryptionDisabled = 3 +) + +const ( + LogErrors Log = 1 + LogMessages Log = 2 + LogRows Log = 4 + LogSQL Log = 8 + LogParams Log = 16 + LogTransaction Log = 32 + LogDebug Log = 64 +) + +type Config struct { + Port uint64 + Host string + Instance string + Database string + User string + Password string + Encryption Encryption + TLSConfig *tls.Config + + FailOverPartner string + FailOverPort uint64 + + // If true the TLSConfig servername should use the routed server. + HostInCertificateProvided bool + + // Read Only intent for application database. + // NOTE: This does not make queries to most databases read-only. + ReadOnlyIntent bool + + LogFlags Log + + ServerSPN string + Workstation string + AppName string + + // If true disables database/sql's automatic retry of queries + // that start on bad connections. + DisableRetry bool + + // Do not use the following. + + DialTimeout time.Duration // DialTimeout defaults to 15s. Set negative to disable. + ConnTimeout time.Duration // Use context for timeouts. + KeepAlive time.Duration // Leave at default. + PacketSize uint16 } -func parseConnectParams(dsn string) (connectParams, error) { - var p connectParams +func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate string) (*tls.Config, error) { + var config tls.Config + if certificate != "" { + pem, err := ioutil.ReadFile(certificate) + if err != nil { + return nil, fmt.Errorf("cannot read certificate %q: %v", certificate, err) + } + certs := x509.NewCertPool() + certs.AppendCertsFromPEM(pem) + config.RootCAs = certs + } + if insecureSkipVerify { + config.InsecureSkipVerify = true + } + config.ServerName = hostInCertificate + + // fix for https://github.com/denisenkom/go-mssqldb/issues/166 + // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, + // while SQL Server seems to expect one TCP segment per encrypted TDS package. + // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package + config.DynamicRecordSizingDisabled = true + + return &config, nil +} + +func Parse(dsn string) (Config, map[string]string, error) { + p := Config{} var params map[string]string if strings.HasPrefix(dsn, "odbc:") { parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):]) if err != nil { - return p, err + return p, params, err } params = parameters } else if strings.HasPrefix(dsn, "sqlserver://") { parameters, err := splitConnectionStringURL(dsn) if err != nil { - return p, err + return p, params, err } params = parameters } else { @@ -62,57 +120,55 @@ func parseConnectParams(dsn string) (connectParams, error) { strlog, ok := params["log"] if ok { - var err error - p.logFlags, err = strconv.ParseUint(strlog, 10, 64) + flags, err := strconv.ParseUint(strlog, 10, 64) if err != nil { - return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error()) + return p, params, fmt.Errorf("invalid log parameter '%s': %s", strlog, err.Error()) } + p.LogFlags = Log(flags) } server := params["server"] parts := strings.SplitN(server, `\`, 2) - p.host = parts[0] - if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" { - p.host = "localhost" + p.Host = parts[0] + if p.Host == "." || strings.ToUpper(p.Host) == "(LOCAL)" || p.Host == "" { + p.Host = "localhost" } if len(parts) > 1 { - p.instance = parts[1] + p.Instance = parts[1] } - p.database = params["database"] - p.user = params["user id"] - p.password = params["password"] + p.Database = params["database"] + p.User = params["user id"] + p.Password = params["password"] - p.port = 0 + p.Port = 0 strport, ok := params["port"] if ok { var err error - p.port, err = strconv.ParseUint(strport, 10, 16) + p.Port, err = strconv.ParseUint(strport, 10, 16) if err != nil { - f := "Invalid tcp port '%v': %v" - return p, fmt.Errorf(f, strport, err.Error()) + f := "invalid tcp port '%v': %v" + return p, params, fmt.Errorf(f, strport, err.Error()) } } - // https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option - // Default packet size remains at 4096 bytes - p.packetSize = 4096 + // https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option\ strpsize, ok := params["packet size"] if ok { var err error psize, err := strconv.ParseUint(strpsize, 0, 16) if err != nil { - f := "Invalid packet size '%v': %v" - return p, fmt.Errorf(f, strpsize, err.Error()) + f := "invalid packet size '%v': %v" + return p, params, fmt.Errorf(f, strpsize, err.Error()) } // Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes // NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request // a higher packet size, the server will respond with an ENVCHANGE request to // alter the packet size to 16383 bytes. - p.packetSize = uint16(psize) - if p.packetSize < 512 { - p.packetSize = 512 - } else if p.packetSize > 32767 { - p.packetSize = 32767 + p.PacketSize = uint16(psize) + if p.PacketSize < 512 { + p.PacketSize = 512 + } else if p.PacketSize > 32767 { + p.PacketSize = 32767 } } @@ -123,79 +179,95 @@ func parseConnectParams(dsn string) (connectParams, error) { if strconntimeout, ok := params["connection timeout"]; ok { timeout, err := strconv.ParseUint(strconntimeout, 10, 64) if err != nil { - f := "Invalid connection timeout '%v': %v" - return p, fmt.Errorf(f, strconntimeout, err.Error()) + f := "invalid connection timeout '%v': %v" + return p, params, fmt.Errorf(f, strconntimeout, err.Error()) } - p.conn_timeout = time.Duration(timeout) * time.Second + p.ConnTimeout = time.Duration(timeout) * time.Second } - p.dial_timeout = 15 * time.Second + p.DialTimeout = 15 * time.Second if strdialtimeout, ok := params["dial timeout"]; ok { timeout, err := strconv.ParseUint(strdialtimeout, 10, 64) if err != nil { - f := "Invalid dial timeout '%v': %v" - return p, fmt.Errorf(f, strdialtimeout, err.Error()) + f := "invalid dial timeout '%v': %v" + return p, params, fmt.Errorf(f, strdialtimeout, err.Error()) } - p.dial_timeout = time.Duration(timeout) * time.Second + p.DialTimeout = time.Duration(timeout) * time.Second } // default keep alive should be 30 seconds according to spec: // https://msdn.microsoft.com/en-us/library/dd341108.aspx - p.keepAlive = 30 * time.Second + p.KeepAlive = 30 * time.Second if keepAlive, ok := params["keepalive"]; ok { timeout, err := strconv.ParseUint(keepAlive, 10, 64) if err != nil { - f := "Invalid keepAlive value '%s': %s" - return p, fmt.Errorf(f, keepAlive, err.Error()) + f := "invalid keepAlive value '%s': %s" + return p, params, fmt.Errorf(f, keepAlive, err.Error()) } - p.keepAlive = time.Duration(timeout) * time.Second + p.KeepAlive = time.Duration(timeout) * time.Second } + + var ( + trustServerCert = false + certificate = "" + hostInCertificate = "" + ) encrypt, ok := params["encrypt"] if ok { if strings.EqualFold(encrypt, "DISABLE") { - p.disableEncryption = true + p.Encryption = EncryptionDisabled } else { - var err error - p.encrypt, err = strconv.ParseBool(encrypt) + e, err := strconv.ParseBool(encrypt) if err != nil { - f := "Invalid encrypt '%s': %s" - return p, fmt.Errorf(f, encrypt, err.Error()) + f := "invalid encrypt '%s': %s" + return p, params, fmt.Errorf(f, encrypt, err.Error()) + } + if e { + p.Encryption = EncryptionRequired } } } else { - p.trustServerCertificate = true + trustServerCert = true } trust, ok := params["trustservercertificate"] if ok { var err error - p.trustServerCertificate, err = strconv.ParseBool(trust) + trustServerCert, err = strconv.ParseBool(trust) if err != nil { - f := "Invalid trust server certificate '%s': %s" - return p, fmt.Errorf(f, trust, err.Error()) + f := "invalid trust server certificate '%s': %s" + return p, params, fmt.Errorf(f, trust, err.Error()) } } - p.certificate = params["certificate"] - p.hostInCertificate, ok = params["hostnameincertificate"] + certificate = params["certificate"] + hostInCertificate, ok = params["hostnameincertificate"] if ok { - p.hostInCertificateProvided = true + p.HostInCertificateProvided = true } else { - p.hostInCertificate = p.host - p.hostInCertificateProvided = false + hostInCertificate = p.Host + p.HostInCertificateProvided = false + } + + if p.Encryption != EncryptionDisabled { + var err error + p.TLSConfig, err = SetupTLS(certificate, trustServerCert, hostInCertificate) + if err != nil { + return p, params, fmt.Errorf("failed to setup TLS: %w", err) + } } serverSPN, ok := params["serverspn"] if ok { - p.serverSPN = serverSPN + p.ServerSPN = serverSPN } else { - p.serverSPN = generateSpn(p.host, resolveServerPort(p.port)) + p.ServerSPN = generateSpn(p.Host, p.Port) } workstation, ok := params["workstation id"] if ok { - p.workstation = workstation + p.Workstation = workstation } else { workstation, err := os.Hostname() if err == nil { - p.workstation = workstation + p.Workstation = workstation } } @@ -203,34 +275,86 @@ func parseConnectParams(dsn string) (connectParams, error) { if !ok { appname = "go-mssqldb" } - p.appname = appname + p.AppName = appname appintent, ok := params["applicationintent"] if ok { if appintent == "ReadOnly" { - if p.database == "" { - return p, fmt.Errorf("Database must be specified when ApplicationIntent is ReadOnly") + if p.Database == "" { + return p, params, fmt.Errorf("database must be specified when ApplicationIntent is ReadOnly") } - p.typeFlags |= fReadOnlyIntent + p.ReadOnlyIntent = true } } failOverPartner, ok := params["failoverpartner"] if ok { - p.failOverPartner = failOverPartner + p.FailOverPartner = failOverPartner } failOverPort, ok := params["failoverport"] if ok { var err error - p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16) + p.FailOverPort, err = strconv.ParseUint(failOverPort, 0, 16) if err != nil { - f := "Invalid tcp port '%v': %v" - return p, fmt.Errorf(f, failOverPort, err.Error()) + f := "invalid failover port '%v': %v" + return p, params, fmt.Errorf(f, failOverPort, err.Error()) } } - return p, nil + disableRetry, ok := params["disableretry"] + if ok { + var err error + p.DisableRetry, err = strconv.ParseBool(disableRetry) + if err != nil { + f := "invalid disableRetry '%s': %s" + return p, params, fmt.Errorf(f, disableRetry, err.Error()) + } + } else { + p.DisableRetry = disableRetryDefault + } + + return p, params, nil +} + +// convert connectionParams to url style connection string +// used mostly for testing +func (p Config) URL() *url.URL { + q := url.Values{} + if p.Database != "" { + q.Add("database", p.Database) + } + if p.LogFlags != 0 { + q.Add("log", strconv.FormatUint(uint64(p.LogFlags), 10)) + } + host := p.Host + if p.Port > 0 { + host = fmt.Sprintf("%s:%d", p.Host, p.Port) + } + q.Add("disableRetry", fmt.Sprintf("%t", p.DisableRetry)) + res := url.URL{ + Scheme: "sqlserver", + Host: host, + User: url.UserPassword(p.User, p.Password), + } + if p.Instance != "" { + res.Path = p.Instance + } + if len(q) > 0 { + res.RawQuery = q.Encode() + } + return &res +} + +var adoSynonyms = map[string]string{ + "application name": "app name", + "data source": "server", + "address": "server", + "network address": "server", + "addr": "server", + "user": "user id", + "uid": "user id", + "initial catalog": "database", } func splitConnectionString(dsn string) (res map[string]string) { @@ -249,6 +373,20 @@ func splitConnectionString(dsn string) (res map[string]string) { if len(lst) > 1 { value = strings.TrimSpace(lst[1]) } + synonym, hasSynonym := adoSynonyms[name] + if hasSynonym { + name = synonym + } + // "server" in ADO can include a protocol and a port. + // We only support tcp protocol + if name == "server" { + value = strings.TrimPrefix(value, "tcp:") + serverParts := strings.Split(value, ",") + if len(serverParts) == 2 && len(serverParts[1]) > 0 { + value = serverParts[0] + res["port"] = serverParts[1] + } + } res[name] = value } return res @@ -340,7 +478,7 @@ func splitConnectionStringOdbc(dsn string) (map[string]string, error) { case parserStateBeforeKey: switch { case c == '=': - return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i) + return res, fmt.Errorf("unexpected character = at index %d. Expected start of key or semi-colon or whitespace", i) case !unicode.IsSpace(c) && c != ';': state = parserStateKey key += string(c) @@ -419,7 +557,7 @@ func splitConnectionStringOdbc(dsn string) (map[string]string, error) { case unicode.IsSpace(c): // Ignore whitespace default: - return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i) + return res, fmt.Errorf("unexpected character %c at index %d. Expected semi-colon or whitespace", c, i) } case parserStateEndValue: @@ -429,7 +567,7 @@ func splitConnectionStringOdbc(dsn string) (map[string]string, error) { case unicode.IsSpace(c): // Ignore whitespace default: - return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i) + return res, fmt.Errorf("unexpected character %c at index %d. Expected semi-colon or whitespace", c, i) } } } @@ -444,7 +582,7 @@ func splitConnectionStringOdbc(dsn string) (map[string]string, error) { case parserStateBareValue: res[key] = strings.TrimRightFunc(value, unicode.IsSpace) case parserStateBracedValue: - return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn)) + return res, fmt.Errorf("unexpected end of braced value at index %d", len(dsn)) case parserStateBracedValueClosingBrace: // End of braced value res[key] = value case parserStateEndValue: // Okay @@ -458,14 +596,6 @@ func normalizeOdbcKey(s string) string { return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace)) } -func resolveServerPort(port uint64) uint64 { - if port == 0 { - return defaultServerPort - } - - return port -} - func generateSpn(host string, port uint64) string { return fmt.Sprintf("MSSQLSvc/%s:%d", host, port) } diff --git a/vendor/github.com/denisenkom/go-mssqldb/msdsn/conn_str_go118.go b/vendor/github.com/denisenkom/go-mssqldb/msdsn/conn_str_go118.go new file mode 100644 index 00000000..b1a1d1d4 --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/msdsn/conn_str_go118.go @@ -0,0 +1,9 @@ +// +build go1.18 + +package msdsn + +// disableRetryDefault is false for Go versions 1.18 and higher. This matches +// the behavior requested in issue #586. A query that fails at the start due to +// a bad connection is automatically retried. An error is returned only if the +// query fails all of its retries. +const disableRetryDefault bool = false diff --git a/vendor/github.com/denisenkom/go-mssqldb/msdsn/conn_str_go118pre.go b/vendor/github.com/denisenkom/go-mssqldb/msdsn/conn_str_go118pre.go new file mode 100644 index 00000000..d3ce1956 --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/msdsn/conn_str_go118pre.go @@ -0,0 +1,9 @@ +// +build !go1.18 + +package msdsn + +// disableRetryDefault is true for versions of Go less than 1.18. This matches +// the behavior requested in issue #275. A query that fails at the start due to +// a bad connection is not retried. Instead, the detailed error is immediately +// returned to the caller. +const disableRetryDefault bool = true diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql.go b/vendor/github.com/denisenkom/go-mssqldb/mssql.go index a74bc7e3..550cb91c 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/mssql.go +++ b/vendor/github.com/denisenkom/go-mssqldb/mssql.go @@ -16,6 +16,7 @@ import ( "unicode" "github.com/denisenkom/go-mssqldb/internal/querytext" + "github.com/denisenkom/go-mssqldb/msdsn" ) // ReturnStatus may be used to return the return value from a proc. @@ -31,12 +32,16 @@ var driverInstanceNoProcess = &Driver{processQueryText: false} func init() { sql.Register("mssql", driverInstance) sql.Register("sqlserver", driverInstanceNoProcess) - createDialer = func(p *connectParams) Dialer { - return netDialer{&net.Dialer{KeepAlive: p.keepAlive}} + createDialer = func(p *msdsn.Config) Dialer { + ka := p.KeepAlive + if ka == 0 { + ka = 30 * time.Second + } + return netDialer{&net.Dialer{KeepAlive: ka}} } } -var createDialer func(p *connectParams) Dialer +var createDialer func(p *msdsn.Config) Dialer type netDialer struct { nd *net.Dialer @@ -54,10 +59,11 @@ type Driver struct { // OpenConnector opens a new connector. Useful to dial with a context. func (d *Driver) OpenConnector(dsn string) (*Connector, error) { - params, err := parseConnectParams(dsn) + params, _, err := msdsn.Parse(dsn) if err != nil { return nil, err } + return &Connector{ params: params, driver: d, @@ -80,7 +86,7 @@ func (d *Driver) SetLogger(logger Logger) { // NewConnector creates a new connector from a DSN. // The returned connector may be used with sql.OpenDB. func NewConnector(dsn string) (*Connector, error) { - params, err := parseConnectParams(dsn) + params, _, err := msdsn.Parse(dsn) if err != nil { return nil, err } @@ -91,15 +97,34 @@ func NewConnector(dsn string) (*Connector, error) { return c, nil } +// NewConnectorConfig creates a new Connector for a DSN Config struct. +// The returned connector may be used with sql.OpenDB. +func NewConnectorConfig(config msdsn.Config) *Connector { + return &Connector{ + params: config, + driver: driverInstanceNoProcess, + } +} + // Connector holds the parsed DSN and is ready to make a new connection // at any time. // // In the future, settings that cannot be passed through a string DSN // may be set directly on the connector. type Connector struct { - params connectParams + params msdsn.Config driver *Driver + fedAuthRequired bool + fedAuthLibrary int + fedAuthADALWorkflow byte + + // callback that can provide a security token during login + securityTokenProvider func(ctx context.Context) (string, error) + + // callback that can provide a security token during ADAL login + adalTokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error) + // SessionInitSQL is executed after marking a given session to be reset. // When not present, the next query will still reset the session to the // database defaults. @@ -132,7 +157,7 @@ type Dialer interface { DialContext(ctx context.Context, network string, addr string) (net.Conn, error) } -func (c *Connector) getDialer(p *connectParams) Dialer { +func (c *Connector) getDialer(p *msdsn.Config) Dialer { if c != nil && c.Dialer != nil { return c.Dialer } @@ -148,34 +173,32 @@ type Conn struct { processQueryText bool connectionGood bool - outs map[string]interface{} + outs outputs +} + +type outputs struct { + params map[string]interface{} returnStatus *ReturnStatus } -func (c *Conn) setReturnStatus(s ReturnStatus) { - if c.returnStatus == nil { - return - } - *c.returnStatus = s +// IsValid satisfies the driver.Validator interface. +func (c *Conn) IsValid() bool { + return c.connectionGood } -func (c *Conn) checkBadConn(err error) error { - // this is a hack to address Issue #275 - // we set connectionGood flag to false if - // error indicates that connection is not usable - // but we return actual error instead of ErrBadConn - // this will cause connection to stay in a pool - // but next request to this connection will return ErrBadConn - - // it might be possible to revise this hack after - // https://github.com/golang/go/issues/20807 - // is implemented +// checkBadConn marks the connection as bad based on the characteristics +// of the supplied error. Bad connections will be dropped from the connection +// pool rather than reused. +// +// If bad connection retry is enabled and the error + connection state permits +// retrying, checkBadConn will return a RetryableError that allows database/sql +// to automatically retry the query with another connection. +func (c *Conn) checkBadConn(err error, mayRetry bool) error { switch err { case nil: return nil case io.EOF: c.connectionGood = false - return driver.ErrBadConn case driver.ErrBadConn: // It is an internal programming error if driver.ErrBadConn // is ever passed to this function. driver.ErrBadConn should @@ -187,34 +210,33 @@ func (c *Conn) checkBadConn(err error) error { switch err.(type) { case net.Error: c.connectionGood = false - return err case StreamError: c.connectionGood = false - return err - default: - return err + case ServerError: + c.connectionGood = false } + + if !c.connectionGood && mayRetry && !c.connector.params.DisableRetry { + return newRetryableError(err) + } + + return err } func (c *Conn) clearOuts() { - c.outs = nil + c.outs = outputs{} } func (c *Conn) simpleProcessResp(ctx context.Context) error { - tokchan := make(chan tokenStruct, 5) - go processResponse(ctx, c.sess, tokchan, c.outs) + reader := startReading(c.sess, ctx, c.outs) c.clearOuts() - for tok := range tokchan { - switch token := tok.(type) { - case doneStruct: - if token.isError() { - return c.checkBadConn(token.getError()) - } - case error: - return c.checkBadConn(token) - } + + var resultError error + err := reader.iterateResponse() + if err != nil { + return c.checkBadConn(err, false) } - return nil + return resultError } func (c *Conn) Commit() error { @@ -222,7 +244,7 @@ func (c *Conn) Commit() error { return driver.ErrBadConn } if err := c.sendCommitRequest(); err != nil { - return c.checkBadConn(err) + return c.checkBadConn(err, true) } return c.simpleProcessResp(c.transactionCtx) } @@ -239,7 +261,7 @@ func (c *Conn) sendCommitRequest() error { c.sess.log.Printf("Failed to send CommitXact with %v", err) } c.connectionGood = false - return fmt.Errorf("Faild to send CommitXact: %v", err) + return fmt.Errorf("faild to send CommitXact: %v", err) } return nil } @@ -249,7 +271,7 @@ func (c *Conn) Rollback() error { return driver.ErrBadConn } if err := c.sendRollbackRequest(); err != nil { - return c.checkBadConn(err) + return c.checkBadConn(err, true) } return c.simpleProcessResp(c.transactionCtx) } @@ -266,7 +288,7 @@ func (c *Conn) sendRollbackRequest() error { c.sess.log.Printf("Failed to send RollbackXact with %v", err) } c.connectionGood = false - return fmt.Errorf("Failed to send RollbackXact: %v", err) + return fmt.Errorf("failed to send RollbackXact: %v", err) } return nil } @@ -281,11 +303,11 @@ func (c *Conn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx, } err = c.sendBeginRequest(ctx, tdsIsolation) if err != nil { - return nil, c.checkBadConn(err) + return nil, c.checkBadConn(err, true) } tx, err = c.processBeginResponse(ctx) if err != nil { - return nil, c.checkBadConn(err) + return nil, err } return } @@ -303,7 +325,7 @@ func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) erro c.sess.log.Printf("Failed to send BeginXact with %v", err) } c.connectionGood = false - return fmt.Errorf("Failed to send BeginXact: %v", err) + return fmt.Errorf("failed to send BeginXact: %v", err) } return nil } @@ -318,25 +340,26 @@ func (c *Conn) processBeginResponse(ctx context.Context) (driver.Tx, error) { } func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) { - params, err := parseConnectParams(dsn) + params, _, err := msdsn.Parse(dsn) if err != nil { return nil, err } - return d.connect(ctx, nil, params) + c := &Connector{params: params} + return d.connect(ctx, c, params) } // connect to the server, using the provided context for dialing only. -func (d *Driver) connect(ctx context.Context, c *Connector, params connectParams) (*Conn, error) { +func (d *Driver) connect(ctx context.Context, c *Connector, params msdsn.Config) (*Conn, error) { sess, err := connect(ctx, c, d.log, params) if err != nil { // main server failed, try fail-over partner - if params.failOverPartner == "" { + if params.FailOverPartner == "" { return nil, err } - params.host = params.failOverPartner - if params.failOverPort != 0 { - params.port = params.failOverPort + params.Host = params.FailOverPartner + if params.FailOverPort != 0 { + params.Port = params.FailOverPort } sess, err = connect(ctx, c, d.log, params) @@ -447,7 +470,8 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) { reset := conn.resetSession conn.resetSession = false - if len(args) == 0 { + isProc := isProc(s.query) + if len(args) == 0 && !isProc { if err = sendSqlBatch72(conn.sess.buf, s.query, headers, reset); err != nil { if conn.sess.logFlags&logErrors != 0 { conn.sess.log.Printf("Failed to send SqlBatch with %v", err) @@ -458,7 +482,7 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) { } else { proc := sp_ExecuteSql var params []param - if isProc(s.query) { + if isProc { proc.name = s.query params, _, err = s.makeRPCParams(args, true) if err != nil { @@ -478,7 +502,7 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) { conn.sess.log.Printf("Failed to send Rpc with %v", err) } conn.connectionGood = false - return fmt.Errorf("Failed to send RPC: %v", err) + return fmt.Errorf("failed to send RPC: %v", err) } } return @@ -500,30 +524,38 @@ func isProc(s string) bool { for _, r := range s { rPrev = rn1 rn1 = r - switch r { - // No newlines or string sequences. - case '\n', '\r', '\'', ';': - return false + if st != escaped { + switch r { + // No newlines or string sequences. + case '\n', '\r', '\'', ';': + return false + } } switch st { case outside: switch { - case unicode.IsSpace(r): - return false case r == '[': st = escaped - continue case r == ']' && rPrev == ']': st = escaped - continue case unicode.IsLetter(r): st = text + case r == '_': + st = text + case r == '#': + st = text + case r == '.': + default: + return false } case text: switch { case r == '.': st = outside - continue + case r == '[': + return false + case r == '(': + return false case unicode.IsSpace(r): return false } @@ -531,7 +563,6 @@ func isProc(s string) bool { switch { case r == ']': st = outside - continue } } } @@ -558,7 +589,13 @@ func (s *Stmt) makeRPCParams(args []namedValue, isProc bool) ([]param, []string, name = fmt.Sprintf("@p%d", val.Ordinal) } params[i+offset].Name = name - decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+offset].ti)) + const outputSuffix = " output" + var output string + if isOutputValue(val.Value) { + output = outputSuffix + } + decls[i] = fmt.Sprintf("%s %s%s", name, makeDecl(params[i+offset].ti), output) + } return params, decls, nil } @@ -581,6 +618,8 @@ func convertOldArgs(args []driver.Value) []namedValue { } func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { + defer s.c.clearOuts() + return s.queryContext(context.Background(), convertOldArgs(args)) } @@ -589,48 +628,60 @@ func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver return nil, driver.ErrBadConn } if err = s.sendQuery(args); err != nil { - return nil, s.c.checkBadConn(err) + return nil, s.c.checkBadConn(err, true) } return s.processQueryResponse(ctx) } func (s *Stmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) { - tokchan := make(chan tokenStruct, 5) ctx, cancel := context.WithCancel(ctx) - go processResponse(ctx, s.c.sess, tokchan, s.c.outs) + reader := startReading(s.c.sess, ctx, s.c.outs) s.c.clearOuts() // process metadata var cols []columnStruct loop: - for tok := range tokchan { - switch token := tok.(type) { - // By ignoring DONE token we effectively - // skip empty result-sets. - // This improves results in queries like that: - // set nocount on; select 1 - // see TestIgnoreEmptyResults test - //case doneStruct: - //break loop - case []columnStruct: - cols = token - break loop - case doneStruct: - if token.isError() { - cancel() - return nil, s.c.checkBadConn(token.getError()) + for { + tok, err := reader.nextToken() + if err == nil { + if tok == nil { + break + } else { + switch token := tok.(type) { + // By ignoring DONE token we effectively + // skip empty result-sets. + // This improves results in queries like that: + // set nocount on; select 1 + // see TestIgnoreEmptyResults test + //case doneStruct: + //break loop + case []columnStruct: + cols = token + break loop + case doneStruct: + if token.isError() { + // need to cleanup cancellable context + cancel() + return nil, s.c.checkBadConn(token.getError(), false) + } + case ReturnStatus: + if reader.outs.returnStatus != nil { + *reader.outs.returnStatus = token + } + } } - case ReturnStatus: - s.c.setReturnStatus(token) - case error: + } else { + // need to cleanup cancellable context cancel() - return nil, s.c.checkBadConn(token) + return nil, s.c.checkBadConn(err, false) } } - res = &Rows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel} + res = &Rows{stmt: s, reader: reader, cols: cols, cancel: cancel} return } func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) { + defer s.c.clearOuts() + return s.exec(context.Background(), convertOldArgs(args)) } @@ -639,57 +690,55 @@ func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result, return nil, driver.ErrBadConn } if err = s.sendQuery(args); err != nil { - return nil, s.c.checkBadConn(err) + return nil, s.c.checkBadConn(err, true) } if res, err = s.processExec(ctx); err != nil { - return nil, s.c.checkBadConn(err) + return nil, err } return } func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) { - tokchan := make(chan tokenStruct, 5) - go processResponse(ctx, s.c.sess, tokchan, s.c.outs) + reader := startReading(s.c.sess, ctx, s.c.outs) s.c.clearOuts() - var rowCount int64 - for token := range tokchan { - switch token := token.(type) { - case doneInProcStruct: - if token.Status&doneCount != 0 { - rowCount += int64(token.RowCount) - } - case doneStruct: - if token.Status&doneCount != 0 { - rowCount += int64(token.RowCount) - } - if token.isError() { - return nil, token.getError() - } - case ReturnStatus: - s.c.setReturnStatus(token) - case error: - return nil, token - } + err = reader.iterateResponse() + if err != nil { + return nil, s.c.checkBadConn(err, false) } - return &Result{s.c, rowCount}, nil + return &Result{s.c, reader.rowCount}, nil } type Rows struct { - stmt *Stmt - cols []columnStruct - tokchan chan tokenStruct - + stmt *Stmt + cols []columnStruct + reader *tokenProcessor nextCols []columnStruct cancel func() } func (rc *Rows) Close() error { + // need to add a test which returns lots of rows + // and check closing after reading only few rows rc.cancel() - for _ = range rc.tokchan { + + for { + tok, err := rc.reader.nextToken() + if err == nil { + if tok == nil { + return nil + } else { + // continue consuming tokens + continue + } + } else { + if err == rc.reader.ctx.Err() { + return nil + } else { + return err + } + } } - rc.tokchan = nil - return nil } func (rc *Rows) Columns() (res []string) { @@ -707,27 +756,36 @@ func (rc *Rows) Next(dest []driver.Value) error { if rc.nextCols != nil { return io.EOF } - for tok := range rc.tokchan { - switch tokdata := tok.(type) { - case []columnStruct: - rc.nextCols = tokdata - return io.EOF - case []interface{}: - for i := range dest { - dest[i] = tokdata[i] + for { + tok, err := rc.reader.nextToken() + if err == nil { + if tok == nil { + return io.EOF + } else { + switch tokdata := tok.(type) { + case []columnStruct: + rc.nextCols = tokdata + return io.EOF + case []interface{}: + for i := range dest { + dest[i] = tokdata[i] + } + return nil + case doneStruct: + if tokdata.isError() { + return rc.stmt.c.checkBadConn(tokdata.getError(), false) + } + case ReturnStatus: + if rc.reader.outs.returnStatus != nil { + *rc.reader.outs.returnStatus = tokdata + } + } } - return nil - case doneStruct: - if tokdata.isError() { - return rc.stmt.c.checkBadConn(tokdata.getError()) - } - case ReturnStatus: - rc.stmt.c.setReturnStatus(tokdata) - case error: - return rc.stmt.c.checkBadConn(tokdata) + + } else { + return rc.stmt.c.checkBadConn(err, false) } } - return io.EOF } func (rc *Rows) HasNextResultSet() bool { @@ -895,35 +953,41 @@ func (c *Conn) Ping(ctx context.Context) error { var _ driver.ConnBeginTx = &Conn{} +func convertIsolationLevel(level sql.IsolationLevel) (isoLevel, error) { + switch level { + case sql.LevelDefault: + return isolationUseCurrent, nil + case sql.LevelReadUncommitted: + return isolationReadUncommited, nil + case sql.LevelReadCommitted: + return isolationReadCommited, nil + case sql.LevelWriteCommitted: + return isolationUseCurrent, errors.New("LevelWriteCommitted isolation level is not supported") + case sql.LevelRepeatableRead: + return isolationRepeatableRead, nil + case sql.LevelSnapshot: + return isolationSnapshot, nil + case sql.LevelSerializable: + return isolationSerializable, nil + case sql.LevelLinearizable: + return isolationUseCurrent, errors.New("LevelLinearizable isolation level is not supported") + default: + return isolationUseCurrent, errors.New("isolation level is not supported or unknown") + } +} + // BeginTx satisfies ConnBeginTx. func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { if !c.connectionGood { return nil, driver.ErrBadConn } if opts.ReadOnly { - return nil, errors.New("Read-only transactions are not supported") + return nil, errors.New("read-only transactions are not supported") } - var tdsIsolation isoLevel - switch sql.IsolationLevel(opts.Isolation) { - case sql.LevelDefault: - tdsIsolation = isolationUseCurrent - case sql.LevelReadUncommitted: - tdsIsolation = isolationReadUncommited - case sql.LevelReadCommitted: - tdsIsolation = isolationReadCommited - case sql.LevelWriteCommitted: - return nil, errors.New("LevelWriteCommitted isolation level is not supported") - case sql.LevelRepeatableRead: - tdsIsolation = isolationRepeatableRead - case sql.LevelSnapshot: - tdsIsolation = isolationSnapshot - case sql.LevelSerializable: - tdsIsolation = isolationSerializable - case sql.LevelLinearizable: - return nil, errors.New("LevelLinearizable isolation level is not supported") - default: - return nil, errors.New("Isolation level is not supported or unknown") + tdsIsolation, err := convertIsolationLevel(sql.IsolationLevel(opts.Isolation)) + if err != nil { + return nil, err } return c.begin(ctx, tdsIsolation) } @@ -940,6 +1004,8 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e } func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + defer s.c.clearOuts() + if !s.c.connectionGood { return nil, driver.ErrBadConn } @@ -951,6 +1017,8 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv } func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + defer s.c.clearOuts() + if !s.c.connectionGood { return nil, driver.ErrBadConn } diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql_go110.go b/vendor/github.com/denisenkom/go-mssqldb/mssql_go110.go index 6d76fbad..e4edc752 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/mssql_go110.go +++ b/vendor/github.com/denisenkom/go-mssqldb/mssql_go110.go @@ -48,5 +48,5 @@ func (c *Connector) Driver() driver.Driver { } func (r *Result) LastInsertId() (int64, error) { - return -1, errors.New("LastInsertId is not supported. Please use the OUTPUT clause or add `select ID = convert(bigint, SCOPE_IDENTITY())` to the end of your query.") + return -1, errors.New("LastInsertId is not supported. Please use the OUTPUT clause or add `select ID = convert(bigint, SCOPE_IDENTITY())` to the end of your query") } diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql_go118.go b/vendor/github.com/denisenkom/go-mssqldb/mssql_go118.go new file mode 100644 index 00000000..9b8014b7 --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/mssql_go118.go @@ -0,0 +1,14 @@ +// +build go1.18 + +package mssql + +// newRetryableError returns an error that allows the database/sql package +// to automatically retry the failed query. Versions of Go 1.18 and higher +// use errors.Is to determine whether or not a failed query can be retried. +// Therefore, we wrap the underlying error in a RetryableError that both +// implements errors.Is for automatic retry and maintains the error details. +func newRetryableError(err error) error { + return RetryableError{ + err: err, + } +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql_go118pre.go b/vendor/github.com/denisenkom/go-mssqldb/mssql_go118pre.go new file mode 100644 index 00000000..ac83e1d2 --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/mssql_go118pre.go @@ -0,0 +1,17 @@ +// +build !go1.18 + +package mssql + +import ( + "database/sql/driver" +) + +// newRetryableError returns an error that allows the database/sql package +// to automatically retry the failed query. Versions of Go lower than 1.18 +// compare directly to the sentinel error driver.ErrBadConn to determine +// whether or not a failed query can be retried. Therefore, we replace the +// actual error with driver.ErrBadConn, enabling retry but losing the error +// details. +func newRetryableError(err error) error { + return driver.ErrBadConn +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql_go19.go b/vendor/github.com/denisenkom/go-mssqldb/mssql_go19.go index a2bd1167..e77eebba 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/mssql_go19.go +++ b/vendor/github.com/denisenkom/go-mssqldb/mssql_go19.go @@ -66,10 +66,10 @@ func convertInputParameter(val interface{}) (interface{}, error) { func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error { switch v := nv.Value.(type) { case sql.Out: - if c.outs == nil { - c.outs = make(map[string]interface{}) + if c.outs.params == nil { + c.outs.params = make(map[string]interface{}) } - c.outs[nv.Name] = v.Dest + c.outs.params[nv.Name] = v.Dest if v.Dest == nil { return errors.New("destination is a nil pointer") @@ -110,7 +110,7 @@ func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error { return nil case *ReturnStatus: *v = 0 // By default the return value should be zero. - c.returnStatus = v + c.outs.returnStatus = v return driver.ErrRemoveArgument case TVP: return nil @@ -194,3 +194,8 @@ func (s *Stmt) makeParamExtra(val driver.Value) (res param, err error) { func scanIntoOut(name string, fromServer, scanInto interface{}) error { return convertAssign(scanInto, fromServer) } + +func isOutputValue(val driver.Value) bool { + _, out := val.(sql.Out) + return out +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql_go19pre.go b/vendor/github.com/denisenkom/go-mssqldb/mssql_go19pre.go index 9680f510..6758e0d0 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/mssql_go19pre.go +++ b/vendor/github.com/denisenkom/go-mssqldb/mssql_go19pre.go @@ -14,3 +14,7 @@ func (s *Stmt) makeParamExtra(val driver.Value) (param, error) { func scanIntoOut(name string, fromServer, scanInto interface{}) error { return fmt.Errorf("mssql: unsupported OUTPUT type, use a newer Go version") } + +func isOutputValue(val driver.Value) bool { + return false +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/net.go b/vendor/github.com/denisenkom/go-mssqldb/net.go index 94858cc7..bb7b784c 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/net.go +++ b/vendor/github.com/denisenkom/go-mssqldb/net.go @@ -7,8 +7,8 @@ import ( ) type timeoutConn struct { - c net.Conn - timeout time.Duration + c net.Conn + timeout time.Duration } func newTimeoutConn(conn net.Conn, timeout time.Duration) *timeoutConn { @@ -51,21 +51,21 @@ func (c timeoutConn) RemoteAddr() net.Addr { } func (c timeoutConn) SetDeadline(t time.Time) error { - panic("Not implemented") + return c.c.SetDeadline(t) } func (c timeoutConn) SetReadDeadline(t time.Time) error { - panic("Not implemented") + return c.c.SetReadDeadline(t) } func (c timeoutConn) SetWriteDeadline(t time.Time) error { - panic("Not implemented") + return c.c.SetWriteDeadline(t) } // this connection is used during TLS Handshake // TDS protocol requires TLS handshake messages to be sent inside TDS packets type tlsHandshakeConn struct { - buf *tdsBuffer + buf *tdsBuffer packetPending bool continueRead bool } @@ -75,7 +75,7 @@ func (c *tlsHandshakeConn) Read(b []byte) (n int, err error) { c.packetPending = false err = c.buf.FinishPacket() if err != nil { - err = fmt.Errorf("Cannot send handshake packet: %s", err.Error()) + err = fmt.Errorf("cannot send handshake packet: %s", err.Error()) return } c.continueRead = false @@ -84,7 +84,7 @@ func (c *tlsHandshakeConn) Read(b []byte) (n int, err error) { var packet packetType packet, err = c.buf.BeginRead() if err != nil { - err = fmt.Errorf("Cannot read handshake packet: %s", err.Error()) + err = fmt.Errorf("cannot read handshake packet: %s", err.Error()) return } if packet != packPrelogin { @@ -105,27 +105,27 @@ func (c *tlsHandshakeConn) Write(b []byte) (n int, err error) { } func (c *tlsHandshakeConn) Close() error { - panic("Not implemented") + return c.buf.transport.Close() } func (c *tlsHandshakeConn) LocalAddr() net.Addr { - panic("Not implemented") + return nil } func (c *tlsHandshakeConn) RemoteAddr() net.Addr { - panic("Not implemented") + return nil } -func (c *tlsHandshakeConn) SetDeadline(t time.Time) error { - panic("Not implemented") +func (c *tlsHandshakeConn) SetDeadline(_ time.Time) error { + return nil } -func (c *tlsHandshakeConn) SetReadDeadline(t time.Time) error { - panic("Not implemented") +func (c *tlsHandshakeConn) SetReadDeadline(_ time.Time) error { + return nil } -func (c *tlsHandshakeConn) SetWriteDeadline(t time.Time) error { - panic("Not implemented") +func (c *tlsHandshakeConn) SetWriteDeadline(_ time.Time) error { + return nil } // this connection just delegates all methods to it's wrapped connection @@ -148,21 +148,21 @@ func (c passthroughConn) Close() error { } func (c passthroughConn) LocalAddr() net.Addr { - panic("Not implemented") + return c.c.LocalAddr() } func (c passthroughConn) RemoteAddr() net.Addr { - panic("Not implemented") + return c.c.RemoteAddr() } func (c passthroughConn) SetDeadline(t time.Time) error { - panic("Not implemented") + return c.c.SetDeadline(t) } func (c passthroughConn) SetReadDeadline(t time.Time) error { - panic("Not implemented") + return c.c.SetReadDeadline(t) } func (c passthroughConn) SetWriteDeadline(t time.Time) error { - panic("Not implemented") + return c.c.SetWriteDeadline(t) } diff --git a/vendor/github.com/denisenkom/go-mssqldb/ntlm.go b/vendor/github.com/denisenkom/go-mssqldb/ntlm.go index ea9148ae..90adb5a0 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/ntlm.go +++ b/vendor/github.com/denisenkom/go-mssqldb/ntlm.go @@ -14,6 +14,7 @@ import ( "time" "unicode/utf16" + //lint:ignore SA1019 MD4 is used by legacy NTLM "golang.org/x/crypto/md4" ) @@ -126,18 +127,6 @@ func createDesKey(bytes, material []byte) { material[7] = (byte)(bytes[6] << 1) } -func oddParity(bytes []byte) { - for i := 0; i < len(bytes); i++ { - b := bytes[i] - needsParity := (((b >> 7) ^ (b >> 6) ^ (b >> 5) ^ (b >> 4) ^ (b >> 3) ^ (b >> 2) ^ (b >> 1)) & 0x01) == 0 - if needsParity { - bytes[i] = bytes[i] | byte(0x01) - } else { - bytes[i] = bytes[i] & byte(0xfe) - } - } -} - func encryptDes(key []byte, cleartext []byte, ciphertext []byte) { var desKey [8]byte createDesKey(key, desKey[:]) diff --git a/vendor/github.com/denisenkom/go-mssqldb/rpc.go b/vendor/github.com/denisenkom/go-mssqldb/rpc.go index 4ca22578..f7d4c00e 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/rpc.go +++ b/vendor/github.com/denisenkom/go-mssqldb/rpc.go @@ -22,12 +22,6 @@ type param struct { buffer []byte } -const ( - fWithRecomp = 1 - fNoMetaData = 2 - fReuseMetaData = 4 -) - var ( sp_Cursor = procId{1, ""} sp_CursorOpen = procId{2, ""} diff --git a/vendor/github.com/denisenkom/go-mssqldb/tds.go b/vendor/github.com/denisenkom/go-mssqldb/tds.go index 832c4fd2..30c72a7d 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/tds.go +++ b/vendor/github.com/denisenkom/go-mssqldb/tds.go @@ -3,7 +3,6 @@ package mssql import ( "context" "crypto/tls" - "crypto/x509" "encoding/binary" "errors" "fmt" @@ -13,8 +12,11 @@ import ( "sort" "strconv" "strings" + "time" "unicode/utf16" "unicode/utf8" + + "github.com/denisenkom/go-mssqldb/msdsn" ) func parseInstances(msg []byte) map[string]map[string]string { @@ -82,19 +84,20 @@ const ( // https://msdn.microsoft.com/en-us/library/dd304214.aspx const ( packSQLBatch packetType = 1 - packRPCRequest = 3 - packReply = 4 + packRPCRequest packetType = 3 + packReply packetType = 4 // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx - packAttention = 6 + packAttention packetType = 6 - packBulkLoadBCP = 7 - packTransMgrReq = 14 - packNormal = 15 - packLogin7 = 16 - packSSPIMessage = 17 - packPrelogin = 18 + packBulkLoadBCP packetType = 7 + packFedAuthToken packetType = 8 + packTransMgrReq packetType = 14 + packNormal packetType = 15 + packLogin7 packetType = 16 + packSSPIMessage packetType = 17 + packPrelogin packetType = 18 ) // prelogin fields @@ -118,6 +121,17 @@ const ( encryptReq = 3 // Encryption is required. ) +const ( + featExtSESSIONRECOVERY byte = 0x01 + featExtFEDAUTH byte = 0x02 + featExtCOLUMNENCRYPTION byte = 0x04 + featExtGLOBALTRANSACTIONS byte = 0x05 + featExtAZURESQLSUPPORT byte = 0x08 + featExtDATACLASSIFICATION byte = 0x09 + featExtUTF8SUPPORT byte = 0x0A + featExtTERMINATOR byte = 0xFF +) + type tdsSession struct { buf *tdsBuffer loginAck loginAckStruct @@ -132,13 +146,21 @@ type tdsSession struct { } const ( - logErrors = 1 - logMessages = 2 - logRows = 4 - logSQL = 8 - logParams = 16 - logTransaction = 32 - logDebug = 64 + // Default packet size for a TDS buffer. + defaultPacketSize = 4096 + + // Default port if no port given. + defaultServerPort = 1433 +) + +const ( + logErrors = uint64(msdsn.LogErrors) + logMessages = uint64(msdsn.LogMessages) + logRows = uint64(msdsn.LogRows) + logSQL = uint64(msdsn.LogSQL) + logParams = uint64(msdsn.LogParams) + logTransaction = uint64(msdsn.LogTransaction) + logDebug = uint64(msdsn.LogDebug) ) type columnStruct struct { @@ -155,13 +177,13 @@ func (p keySlice) Less(i, j int) bool { return p[i] < p[j] } func (p keySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } // http://msdn.microsoft.com/en-us/library/dd357559.aspx -func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error { +func writePrelogin(packetType packetType, w *tdsBuffer, fields map[uint8][]byte) error { var err error - w.BeginPacket(packPrelogin, false) + w.BeginPacket(packetType, false) offset := uint16(5*len(fields) + 1) keys := make(keySlice, 0, len(fields)) - for k, _ := range fields { + for k := range fields { keys = append(keys, k) } sort.Sort(keys) @@ -210,12 +232,15 @@ func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) { if err != nil { return nil, err } - if packet_type != 4 { - return nil, errors.New("Invalid respones, expected packet type 4, PRELOGIN RESPONSE") + if packet_type != packReply { + return nil, errors.New("invalid respones, expected packet type 4, PRELOGIN RESPONSE") + } + if len(struct_buf) == 0 { + return nil, errors.New("invalid empty PRELOGIN response, it must contain at least one byte") } offset := 0 results := map[uint8][]byte{} - for true { + for { rec_type := struct_buf[offset] if rec_type == preloginTERMINATOR { break @@ -240,6 +265,16 @@ const ( fIntSecurity = 0x80 ) +// OptionFlags3 +// http://msdn.microsoft.com/en-us/library/dd304019.aspx +const ( + fChangePassword = 1 + fSendYukonBinaryXML = 2 + fUserInstance = 4 + fUnknownCollationHandling = 8 + fExtension = 0x10 +) + // TypeFlags const ( // 4 bits for fSQLType @@ -247,12 +282,6 @@ const ( fReadOnlyIntent = 32 ) -// OptionFlags3 -// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac -const ( - fExtension = 0x10 -) - type login struct { TDSVersion uint32 PacketSize uint32 @@ -295,7 +324,7 @@ func (e *featureExts) Add(f featureExt) error { } id := f.featureID() if _, exists := e.features[id]; exists { - f := "Login error: Feature with ID '%v' is already present in FeatureExt block." + f := "login error: Feature with ID '%v' is already present in FeatureExt block" return fmt.Errorf(f, id) } if e.features == nil { @@ -326,37 +355,63 @@ func (e featureExts) toBytes() []byte { return d } -type featureExtFedAuthSTS struct { - FedAuthEcho bool +// featureExtFedAuth tracks federated authentication state before and during login +type featureExtFedAuth struct { + // FedAuthLibrary is populated by the federated authentication provider. + FedAuthLibrary int + + // ADALWorkflow is populated by the federated authentication provider. + ADALWorkflow byte + + // FedAuthEcho is populated from the prelogin response + FedAuthEcho bool + + // FedAuthToken is populated during login with the value from the provider. FedAuthToken string - Nonce []byte + + // Nonce is populated during login with the value from the provider. + Nonce []byte + + // Signature is populated during login with the value from the server. + Signature []byte } -func (e *featureExtFedAuthSTS) featureID() byte { - return 0x02 +func (e *featureExtFedAuth) featureID() byte { + return featExtFEDAUTH } -func (e *featureExtFedAuthSTS) toBytes() []byte { +func (e *featureExtFedAuth) toBytes() []byte { if e == nil { return nil } - options := byte(0x01) << 1 // 0x01 => STS bFedAuthLibrary 7BIT + options := byte(e.FedAuthLibrary) << 1 if e.FedAuthEcho { options |= 1 // fFedAuthEcho } - d := make([]byte, 5) - d[0] = options + // Feature extension format depends on the federated auth library. + // Options are described at + // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac + var d []byte - // looks like string in - // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/f88b63bb-b479-49e1-a87b-deda521da508 - tokenBytes := str2ucs2(e.FedAuthToken) - binary.LittleEndian.PutUint32(d[1:], uint32(len(tokenBytes))) // Should be a signed int32, but since the length is relatively small, this should work - d = append(d, tokenBytes...) + switch e.FedAuthLibrary { + case fedAuthLibrarySecurityToken: + d = make([]byte, 5) + d[0] = options - if len(e.Nonce) == 32 { - d = append(d, e.Nonce...) + // looks like string in + // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/f88b63bb-b479-49e1-a87b-deda521da508 + tokenBytes := str2ucs2(e.FedAuthToken) + binary.LittleEndian.PutUint32(d[1:], uint32(len(tokenBytes))) // Should be a signed int32, but since the length is relatively small, this should work + d = append(d, tokenBytes...) + + if len(e.Nonce) == 32 { + d = append(d, e.Nonce...) + } + + case fedAuthLibraryADAL: + d = []byte{options, e.ADALWorkflow} } return d @@ -418,7 +473,7 @@ func str2ucs2(s string) []byte { func ucs22str(s []byte) (string, error) { if len(s)%2 != 0 { - return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s)) + return "", fmt.Errorf("illegal UCS2 string length: %d", len(s)) } buf := make([]uint16, len(s)/2) for i := 0; i < len(s); i += 2 { @@ -436,7 +491,7 @@ func manglePassword(password string) []byte { } // http://msdn.microsoft.com/en-us/library/dd304019.aspx -func sendLogin(w *tdsBuffer, login login) error { +func sendLogin(w *tdsBuffer, login *login) error { w.BeginPacket(packLogin7, false) hostname := str2ucs2(login.HostName) username := str2ucs2(login.UserName) @@ -572,6 +627,36 @@ func sendLogin(w *tdsBuffer, login login) error { return w.FinishPacket() } +// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/827d9632-2957-4d54-b9ea-384530ae79d0 +func sendFedAuthInfo(w *tdsBuffer, fedAuth *featureExtFedAuth) (err error) { + fedauthtoken := str2ucs2(fedAuth.FedAuthToken) + tokenlen := len(fedauthtoken) + datalen := 4 + tokenlen + len(fedAuth.Nonce) + + w.BeginPacket(packFedAuthToken, false) + err = binary.Write(w, binary.LittleEndian, uint32(datalen)) + if err != nil { + return + } + + err = binary.Write(w, binary.LittleEndian, uint32(tokenlen)) + if err != nil { + return + } + + _, err = w.Write(fedauthtoken) + if err != nil { + return + } + + _, err = w.Write(fedAuth.Nonce) + if err != nil { + return + } + + return w.FinishPacket() +} + func readUcs2(r io.Reader, numchars int) (res string, err error) { buf := make([]byte, numchars*2) _, err = io.ReadFull(r, buf) @@ -768,26 +853,27 @@ type auth interface { // SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a // list of IP addresses. So if there is more than one, try them all and // use the first one that allows a connection. -func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn net.Conn, err error) { +func dialConnection(ctx context.Context, c *Connector, p msdsn.Config) (conn net.Conn, err error) { var ips []net.IP - ips, err = net.LookupIP(p.host) - if err != nil { - ip := net.ParseIP(p.host) - if ip == nil { - return nil, err + ip := net.ParseIP(p.Host) + if ip == nil { + ips, err = net.LookupIP(p.Host) + if err != nil { + return } + } else { ips = []net.IP{ip} } if len(ips) == 1 { d := c.getDialer(&p) - addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(resolveServerPort(p.port)))) + addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(resolveServerPort(p.Port)))) conn, err = d.DialContext(ctx, "tcp", addr) } else { //Try Dials in parallel to avoid waiting for timeouts. connChan := make(chan net.Conn, len(ips)) errChan := make(chan error, len(ips)) - portStr := strconv.Itoa(int(resolveServerPort(p.port))) + portStr := strconv.Itoa(int(resolveServerPort(p.Port))) for _, ip := range ips { go func(ip net.IP) { d := c.getDialer(&p) @@ -802,7 +888,7 @@ func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn ne } // Wait for either the *first* successful connection, or all the errors wait_loop: - for i, _ := range ips { + for i := range ips { select { case conn = <-connChan: // Got a connection to use, close any others @@ -824,66 +910,28 @@ func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn ne } // Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection if conn == nil { - f := "Unable to open tcp connection with host '%v:%v': %v" - return nil, fmt.Errorf(f, p.host, resolveServerPort(p.port), err.Error()) + f := "unable to open tcp connection with host '%v:%v': %v" + return nil, fmt.Errorf(f, p.Host, resolveServerPort(p.Port), err.Error()) } return conn, err } -func connect(ctx context.Context, c *Connector, log optionalLogger, p connectParams) (res *tdsSession, err error) { - dialCtx := ctx - if p.dial_timeout > 0 { - var cancel func() - dialCtx, cancel = context.WithTimeout(ctx, p.dial_timeout) - defer cancel() - } - // if instance is specified use instance resolution service - if p.instance != "" && p.port == 0 { - p.instance = strings.ToUpper(p.instance) - d := c.getDialer(&p) - instances, err := getInstances(dialCtx, d, p.host) - if err != nil { - f := "Unable to get instances from Sql Server Browser on host %v: %v" - return nil, fmt.Errorf(f, p.host, err.Error()) - } - strport, ok := instances[p.instance]["tcp"] - if !ok { - f := "No instance matching '%v' returned from host '%v'" - return nil, fmt.Errorf(f, p.instance, p.host) - } - port, err := strconv.ParseUint(strport, 0, 16) - if err != nil { - f := "Invalid tcp port returned from Sql Server Browser '%v': %v" - return nil, fmt.Errorf(f, strport, err.Error()) - } - p.port = port - } - -initiate_connection: - conn, err := dialConnection(dialCtx, c, p) - if err != nil { - return nil, err - } - - toconn := newTimeoutConn(conn, p.conn_timeout) - - outbuf := newTdsBuffer(p.packetSize, toconn) - sess := tdsSession{ - buf: outbuf, - log: log, - logFlags: p.logFlags, - } - - instance_buf := []byte(p.instance) +func preparePreloginFields(p msdsn.Config, fe *featureExtFedAuth) map[uint8][]byte { + instance_buf := []byte(p.Instance) instance_buf = append(instance_buf, 0) // zero terminate instance name + var encrypt byte - if p.disableEncryption { + switch p.Encryption { + default: + panic(fmt.Errorf("Unsupported Encryption Config %v", p.Encryption)) + case msdsn.EncryptionDisabled: encrypt = encryptNotSup - } else if p.encrypt { + case msdsn.EncryptionRequired: encrypt = encryptOn - } else { + case msdsn.EncryptionOff: encrypt = encryptOff } + fields := map[uint8][]byte{ preloginVERSION: {0, 0, 0, 0, 0, 0}, preloginENCRYPTION: {encrypt}, @@ -892,7 +940,182 @@ initiate_connection: preloginMARS: {0}, // MARS disabled } - err = writePrelogin(outbuf, fields) + if fe.FedAuthLibrary != fedAuthLibraryReserved { + fields[preloginFEDAUTHREQUIRED] = []byte{1} + } + + return fields +} + +func interpretPreloginResponse(p msdsn.Config, fe *featureExtFedAuth, fields map[uint8][]byte) (encrypt byte, err error) { + // If the server returns the preloginFEDAUTHREQUIRED field, then federated authentication + // is supported. The actual value may be 0 or 1, where 0 means either SSPI or federated + // authentication is allowed, while 1 means only federated authentication is allowed. + if fedAuthSupport, ok := fields[preloginFEDAUTHREQUIRED]; ok { + if len(fedAuthSupport) != 1 { + return 0, fmt.Errorf("federated authentication flag length should be 1: is %d", len(fedAuthSupport)) + } + + // We need to be able to echo the value back to the server + fe.FedAuthEcho = fedAuthSupport[0] != 0 + } else if fe.FedAuthLibrary != fedAuthLibraryReserved { + return 0, fmt.Errorf("federated authentication is not supported by the server") + } + + encryptBytes, ok := fields[preloginENCRYPTION] + if !ok { + return 0, fmt.Errorf("encrypt negotiation failed") + } + encrypt = encryptBytes[0] + if p.Encryption == msdsn.EncryptionRequired && (encrypt == encryptNotSup || encrypt == encryptOff) { + return 0, fmt.Errorf("server does not support encryption") + } + + return +} + +func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, log optionalLogger, auth auth, fe *featureExtFedAuth, packetSize uint32) (l *login, err error) { + var typeFlags uint8 + if p.ReadOnlyIntent { + typeFlags |= fReadOnlyIntent + } + l = &login{ + TDSVersion: verTDS74, + PacketSize: packetSize, + Database: p.Database, + OptionFlags2: fODBC, // to get unlimited TEXTSIZE + HostName: p.Workstation, + ServerName: p.Host, + AppName: p.AppName, + TypeFlags: typeFlags, + } + switch { + case fe.FedAuthLibrary == fedAuthLibrarySecurityToken: + if uint64(p.LogFlags)&logDebug != 0 { + log.Println("Starting federated authentication using security token") + } + + fe.FedAuthToken, err = c.securityTokenProvider(ctx) + if err != nil { + if uint64(p.LogFlags)&logDebug != 0 { + log.Printf("Failed to retrieve service principal token for federated authentication security token library: %v", err) + } + return nil, err + } + + l.FeatureExt.Add(fe) + + case fe.FedAuthLibrary == fedAuthLibraryADAL: + if uint64(p.LogFlags)&logDebug != 0 { + log.Println("Starting federated authentication using ADAL") + } + + l.FeatureExt.Add(fe) + + case auth != nil: + if uint64(p.LogFlags)&logDebug != 0 { + log.Println("Starting SSPI login") + } + + l.SSPI, err = auth.InitialBytes() + if err != nil { + return nil, err + } + + l.OptionFlags2 |= fIntSecurity + return l, nil + + default: + // Default to SQL server authentication with user and password + l.UserName = p.User + l.Password = p.Password + } + + return l, nil +} + +func connect(ctx context.Context, c *Connector, log optionalLogger, p msdsn.Config) (res *tdsSession, err error) { + dialCtx := ctx + if p.DialTimeout >= 0 { + dt := p.DialTimeout + if dt == 0 { + dt = 15 * time.Second + } + var cancel func() + dialCtx, cancel = context.WithTimeout(ctx, dt) + defer cancel() + } + // if instance is specified use instance resolution service + if len(p.Instance) > 0 && p.Port != 0 { + // both instance name and port specified + // when port is specified instance name is not used + // you should not provide instance name when you provide port + log.Println("WARN: You specified both instance name and port in the connection string, port will be used and instance name will be ignored") + } + if len(p.Instance) > 0 { + p.Instance = strings.ToUpper(p.Instance) + d := c.getDialer(&p) + instances, err := getInstances(dialCtx, d, p.Host) + if err != nil { + f := "unable to get instances from Sql Server Browser on host %v: %v" + return nil, fmt.Errorf(f, p.Host, err.Error()) + } + strport, ok := instances[p.Instance]["tcp"] + if !ok { + f := "no instance matching '%v' returned from host '%v'" + return nil, fmt.Errorf(f, p.Instance, p.Host) + } + port, err := strconv.ParseUint(strport, 0, 16) + if err != nil { + f := "invalid tcp port returned from Sql Server Browser '%v': %v" + return nil, fmt.Errorf(f, strport, err.Error()) + } + p.Port = port + } + if p.Port == 0 { + p.Port = defaultServerPort + } + + packetSize := p.PacketSize + if packetSize == 0 { + packetSize = defaultPacketSize + } + // Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes + // NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request + // a higher packet size, the server will respond with an ENVCHANGE request to + // alter the packet size to 16383 bytes. + if packetSize < 512 { + packetSize = 512 + } else if packetSize > 32767 { + packetSize = 32767 + } + +initiate_connection: + conn, err := dialConnection(dialCtx, c, p) + if err != nil { + return nil, err + } + + toconn := newTimeoutConn(conn, p.ConnTimeout) + + outbuf := newTdsBuffer(packetSize, toconn) + sess := tdsSession{ + buf: outbuf, + log: log, + logFlags: uint64(p.LogFlags), + } + + fedAuth := &featureExtFedAuth{ + FedAuthLibrary: fedAuthLibraryReserved, + } + if c.fedAuthRequired { + fedAuth.FedAuthLibrary = c.fedAuthLibrary + fedAuth.ADALWorkflow = c.fedAuthADALWorkflow + } + + fields := preparePreloginFields(p, fedAuth) + + err = writePrelogin(packPrelogin, outbuf, fields) if err != nil { return nil, err } @@ -902,39 +1125,36 @@ initiate_connection: return nil, err } - encryptBytes, ok := fields[preloginENCRYPTION] - if !ok { - return nil, fmt.Errorf("Encrypt negotiation failed") - } - encrypt = encryptBytes[0] - if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) { - return nil, fmt.Errorf("Server does not support encryption") + encrypt, err := interpretPreloginResponse(p, fedAuth, fields) + if err != nil { + return nil, err } if encrypt != encryptNotSup { - var config tls.Config - if p.certificate != "" { - pem, err := ioutil.ReadFile(p.certificate) - if err != nil { - return nil, fmt.Errorf("Cannot read certificate %q: %v", p.certificate, err) + var config *tls.Config + if pc := p.TLSConfig; pc != nil { + config = pc + if config.DynamicRecordSizingDisabled == false { + config = config.Clone() + + // fix for https://github.com/denisenkom/go-mssqldb/issues/166 + // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, + // while SQL Server seems to expect one TCP segment per encrypted TDS package. + // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package + config.DynamicRecordSizingDisabled = true } - certs := x509.NewCertPool() - certs.AppendCertsFromPEM(pem) - config.RootCAs = certs } - if p.trustServerCertificate { - config.InsecureSkipVerify = true + if config == nil { + config, err = msdsn.SetupTLS("", false, p.Host) + if err != nil { + return nil, err + } } - config.ServerName = p.hostInCertificate - // fix for https://github.com/denisenkom/go-mssqldb/issues/166 - // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, - // while SQL Server seems to expect one TCP segment per encrypted TDS package. - // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package - config.DynamicRecordSizingDisabled = true + // setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream handshakeConn := tlsHandshakeConn{buf: outbuf} passthrough := passthroughConn{c: &handshakeConn} - tlsConn := tls.Client(&passthrough, &config) + tlsConn := tls.Client(&passthrough, config) err = tlsConn.Handshake() passthrough.c = toconn outbuf.transport = tlsConn @@ -948,54 +1168,48 @@ initiate_connection: } } - login := login{ - TDSVersion: verTDS74, - PacketSize: uint32(outbuf.PackageSize()), - Database: p.database, - OptionFlags2: fODBC, // to get unlimited TEXTSIZE - HostName: p.workstation, - ServerName: p.host, - AppName: p.appname, - TypeFlags: p.typeFlags, - } - auth, authOk := getAuth(p.user, p.password, p.serverSPN, p.workstation) - switch { - case p.fedAuthAccessToken != "": // accesstoken ignores user/password - featurext := &featureExtFedAuthSTS{ - FedAuthEcho: len(fields[preloginFEDAUTHREQUIRED]) > 0 && fields[preloginFEDAUTHREQUIRED][0] == 1, - FedAuthToken: p.fedAuthAccessToken, - Nonce: fields[preloginNONCEOPT], - } - login.FeatureExt.Add(featurext) - case authOk: - login.SSPI, err = auth.InitialBytes() - if err != nil { - return nil, err - } - login.OptionFlags2 |= fIntSecurity + auth, authOk := getAuth(p.User, p.Password, p.ServerSPN, p.Workstation) + if authOk { defer auth.Free() - default: - login.UserName = p.user - login.Password = p.password + } else { + auth = nil } + + login, err := prepareLogin(ctx, c, p, log, auth, fedAuth, uint32(outbuf.PackageSize())) + if err != nil { + return nil, err + } + err = sendLogin(outbuf, login) if err != nil { return nil, err } - // processing login response - success := false - for { - tokchan := make(chan tokenStruct, 5) - go processResponse(context.Background(), &sess, tokchan, nil) - for tok := range tokchan { + // Loop until a packet containing a login acknowledgement is received. + // SSPI and federated authentication scenarios may require multiple + // packet exchanges to complete the login sequence. + for loginAck := false; !loginAck; { + reader := startReading(&sess, ctx, outputs{}) + // don't send attention or wait for cancel confirmation during login + reader.noAttn = true + + for { + tok, err := reader.nextToken() + if err != nil { + return nil, err + } + + if tok == nil { + break + } + switch token := tok.(type) { case sspiMsg: sspi_msg, err := auth.NextBytes(token) if err != nil { return nil, err } - if sspi_msg != nil && len(sspi_msg) > 0 { + if len(sspi_msg) > 0 { outbuf.BeginPacket(packSSPIMessage, false) _, err = outbuf.Write(sspi_msg) if err != nil { @@ -1007,31 +1221,59 @@ initiate_connection: } sspi_msg = nil } + // TODO: for Live ID authentication it may be necessary to + // compare fedAuth.Nonce == token.Nonce and keep track of signature + //case fedAuthAckStruct: + //fedAuth.Signature = token.Signature + case fedAuthInfoStruct: + // For ADAL workflows this contains the STS URL and server SPN. + // If received outside of an ADAL workflow, ignore. + if c == nil || c.adalTokenProvider == nil { + continue + } + + // Request the AD token given the server SPN and STS URL + fedAuth.FedAuthToken, err = c.adalTokenProvider(ctx, token.ServerSPN, token.STSURL) + if err != nil { + return nil, err + } + + // Now need to send the token as a FEDINFO packet + err = sendFedAuthInfo(outbuf, fedAuth) + if err != nil { + return nil, err + } case loginAckStruct: - success = true sess.loginAck = token - case error: - return nil, fmt.Errorf("Login error: %s", token.Error()) + loginAck = true case doneStruct: if token.isError() { - return nil, fmt.Errorf("Login error: %s", token.getError()) + tokenErr := token.getError() + tokenErr.Message = "login error: " + tokenErr.Message + return nil, tokenErr } - goto loginEnd + case error: + return nil, fmt.Errorf("login error: %s", token.Error()) } } } -loginEnd: - if !success { - return nil, fmt.Errorf("Login failed") - } + if sess.routedServer != "" { toconn.Close() - p.host = sess.routedServer - p.port = uint64(sess.routedPort) - if !p.hostInCertificateProvided { - p.hostInCertificate = sess.routedServer + p.Host = sess.routedServer + p.Port = uint64(sess.routedPort) + if !p.HostInCertificateProvided && p.TLSConfig != nil { + p.TLSConfig.ServerName = sess.routedServer } goto initiate_connection } return &sess, nil } + +func resolveServerPort(port uint64) uint64 { + if port == 0 { + return defaultServerPort + } + + return port +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/token.go b/vendor/github.com/denisenkom/go-mssqldb/token.go index 25385e89..6000fb96 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/token.go +++ b/vendor/github.com/denisenkom/go-mssqldb/token.go @@ -6,12 +6,11 @@ import ( "errors" "fmt" "io" - "net" + "io/ioutil" "strconv" - "strings" ) -//go:generate stringer -type token +//go:generate go run golang.org/x/tools/cmd/stringer -type token type token byte @@ -29,6 +28,7 @@ const ( tokenNbcRow token = 210 // 0xd2 tokenEnvChange token = 227 // 0xE3 tokenSSPI token = 237 // 0xED + tokenFedAuthInfo token = 238 // 0xEE tokenDone token = 253 // 0xFD tokenDoneProc token = 254 tokenDoneInProc token = 255 @@ -70,6 +70,11 @@ const ( envRouting = 20 ) +const ( + fedAuthInfoSTSURL = 0x01 + fedAuthInfoSPN = 0x02 +) + // COLMETADATA flags // https://msdn.microsoft.com/en-us/library/dd357363.aspx const ( @@ -96,35 +101,18 @@ func (d doneStruct) isError() bool { } func (d doneStruct) getError() Error { - if len(d.errors) > 0 { - return d.errors[len(d.errors)-1] - } else { + n := len(d.errors) + if n == 0 { return Error{Message: "Request failed but didn't provide reason"} } + err := d.errors[n-1] + err.All = make([]Error, n) + copy(err.All, d.errors) + return err } type doneInProcStruct doneStruct -var doneFlags2str = map[uint16]string{ - doneFinal: "final", - doneMore: "more", - doneError: "error", - doneInxact: "inxact", - doneCount: "count", - doneAttn: "attn", - doneSrvError: "srverror", -} - -func doneFlags2Str(flags uint16) string { - strs := make([]string, 0, len(doneFlags2str)) - for flag, tag := range doneFlags2str { - if flags&flag != 0 { - strs = append(strs, tag) - } - } - return strings.Join(strs, "|") -} - // ENVCHANGE stream // http://msdn.microsoft.com/en-us/library/dd303449.aspx func processEnvChg(sess *tdsSession) { @@ -380,9 +368,8 @@ func processEnvChg(sess *tdsSession) { default: // ignore rest of records because we don't know how to skip those sess.log.Printf("WARN: Unknown ENVCHANGE record detected with type id = %d\n", envtype) - break + return } - } } @@ -425,6 +412,78 @@ func parseSSPIMsg(r *tdsBuffer) sspiMsg { return sspiMsg(buf) } +type fedAuthInfoStruct struct { + STSURL string + ServerSPN string +} + +type fedAuthInfoOpt struct { + fedAuthInfoID byte + dataLength, dataOffset uint32 +} + +func parseFedAuthInfo(r *tdsBuffer) fedAuthInfoStruct { + size := r.uint32() + + var STSURL, SPN string + var err error + + // Each fedAuthInfoOpt is one byte to indicate the info ID, + // then a four byte offset and a four byte length. + count := r.uint32() + offset := uint32(4) + opts := make([]fedAuthInfoOpt, count) + + for i := uint32(0); i < count; i++ { + fedAuthInfoID := r.byte() + dataLength := r.uint32() + dataOffset := r.uint32() + offset += 1 + 4 + 4 + + opts[i] = fedAuthInfoOpt{ + fedAuthInfoID: fedAuthInfoID, + dataLength: dataLength, + dataOffset: dataOffset, + } + } + + data := make([]byte, size-offset) + r.ReadFull(data) + + for i := uint32(0); i < count; i++ { + if opts[i].dataOffset < offset { + badStreamPanicf("Fed auth info opt stated data offset %d is before data begins in packet at %d", + opts[i].dataOffset, offset) + // returns via panic + } + + if opts[i].dataOffset+opts[i].dataLength > size { + badStreamPanicf("Fed auth info opt stated data length %d added to stated offset exceeds size of packet %d", + opts[i].dataOffset+opts[i].dataLength, size) + // returns via panic + } + + optData := data[opts[i].dataOffset-offset : opts[i].dataOffset-offset+opts[i].dataLength] + switch opts[i].fedAuthInfoID { + case fedAuthInfoSTSURL: + STSURL, err = ucs22str(optData) + case fedAuthInfoSPN: + SPN, err = ucs22str(optData) + default: + err = fmt.Errorf("unexpected fed auth info opt ID %d", int(opts[i].fedAuthInfoID)) + } + + if err != nil { + badStreamPanic(err) + } + } + + return fedAuthInfoStruct{ + STSURL: STSURL, + ServerSPN: SPN, + } +} + type loginAckStruct struct { Interface uint8 TDSVersion uint32 @@ -449,19 +508,43 @@ func parseLoginAck(r *tdsBuffer) loginAckStruct { } // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/2eb82f8e-11f0-46dc-b42d-27302fa4701a -func parseFeatureExtAck(r *tdsBuffer) { - // at most 1 featureAck per feature in featureExt - // go-mssqldb will add at most 1 feature, the spec defines 7 different features - for i := 0; i < 8; i++ { - featureID := r.byte() // FeatureID - if featureID == 0xff { - return +type fedAuthAckStruct struct { + Nonce []byte + Signature []byte +} + +func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} { + ack := map[byte]interface{}{} + + for feature := r.byte(); feature != featExtTERMINATOR; feature = r.byte() { + length := r.uint32() + + switch feature { + case featExtFEDAUTH: + // In theory we need to know the federated authentication library to + // know how to parse, but the alternatives provide compatible structures. + fedAuthAck := fedAuthAckStruct{} + if length >= 32 { + fedAuthAck.Nonce = make([]byte, 32) + r.ReadFull(fedAuthAck.Nonce) + length -= 32 + } + if length >= 32 { + fedAuthAck.Signature = make([]byte, 32) + r.ReadFull(fedAuthAck.Signature) + length -= 32 + } + ack[feature] = fedAuthAck + + } + + // Skip unprocessed bytes + if length > 0 { + io.CopyN(ioutil.Discard, r, int64(length)) } - size := r.uint32() // FeatureAckDataLen - d := make([]byte, size) - r.ReadFull(d) } - panic("parsed more than 7 featureAck's, protocol implementation error?") + + return ack } // http://msdn.microsoft.com/en-us/library/dd357363.aspx @@ -555,7 +638,7 @@ func parseReturnValue(r *tdsBuffer) (nv namedValue) { return } -func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) { +func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs outputs) { defer func() { if err := recover(); err != nil { if sess.logFlags&logErrors != 0 { @@ -579,7 +662,7 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin } var columns []columnStruct errs := make([]Error, 0, 5) - for { + for tokens := 0; ; tokens += 1 { token := token(sess.buf.byte()) if sess.logFlags&logDebug != 0 { sess.log.Printf("got token %v", token) @@ -588,6 +671,9 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin case tokenSSPI: ch <- parseSSPIMsg(sess.buf) return + case tokenFedAuthInfo: + ch <- parseFedAuthInfo(sess.buf) + return case tokenReturnStatus: returnStatus := parseReturnStatus(sess.buf) ch <- returnStatus @@ -595,7 +681,8 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin loginAck := parseLoginAck(sess.buf) ch <- loginAck case tokenFeatureExtAck: - parseFeatureExtAck(sess.buf) + featureExtAck := parseFeatureExtAck(sess.buf) + ch <- featureExtAck case tokenOrder: order := parseOrder(sess.buf) ch <- order @@ -612,7 +699,7 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin sess.log.Printf("got DONE or DONEPROC status=%d", done.Status) } if done.Status&doneSrvError != 0 { - ch <- errors.New("SQL Server had internal error") + ch <- ServerError{done.getError()} return } if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 { @@ -656,7 +743,7 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin nv := parseReturnValue(sess.buf) if len(nv.Name) > 0 { name := nv.Name[1:] // Remove the leading "@". - if ov, has := outs[name]; has { + if ov, has := outs.params[name]; has { err = scanIntoOut(name, nv.Value, ov) if err != nil { fmt.Println("scan error", err) @@ -670,154 +757,144 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin } } -type parseRespIter byte - -const ( - parseRespIterContinue parseRespIter = iota // Continue parsing current token. - parseRespIterNext // Fetch the next token. - parseRespIterDone // Done with parsing the response. -) - -type parseRespState byte - -const ( - parseRespStateNormal parseRespState = iota // Normal response state. - parseRespStateCancel // Query is canceled, wait for server to confirm. - parseRespStateClosing // Waiting for tokens to come through. -) - -type parseResp struct { - sess *tdsSession - ctxDone <-chan struct{} - state parseRespState - cancelError error +type tokenProcessor struct { + tokChan chan tokenStruct + ctx context.Context + sess *tdsSession + outs outputs + lastRow []interface{} + rowCount int64 + firstError error + // whether to skip sending attention when ctx is done + noAttn bool } -func (ts *parseResp) sendAttention(ch chan tokenStruct) parseRespIter { - if err := sendAttention(ts.sess.buf); err != nil { - ts.dlogf("failed to send attention signal %v", err) - ch <- err - return parseRespIterDone - } - ts.state = parseRespStateCancel - return parseRespIterContinue -} - -func (ts *parseResp) dlog(msg string) { - if ts.sess.logFlags&logDebug != 0 { - ts.sess.log.Println(msg) - } -} -func (ts *parseResp) dlogf(f string, v ...interface{}) { - if ts.sess.logFlags&logDebug != 0 { - ts.sess.log.Printf(f, v...) - } -} - -func (ts *parseResp) iter(ctx context.Context, ch chan tokenStruct, tokChan chan tokenStruct) parseRespIter { - switch ts.state { - default: - panic("unknown state") - case parseRespStateNormal: - select { - case tok, ok := <-tokChan: - if !ok { - ts.dlog("response finished") - return parseRespIterDone - } - if err, ok := tok.(net.Error); ok && err.Timeout() { - ts.cancelError = err - ts.dlog("got timeout error, sending attention signal to server") - return ts.sendAttention(ch) - } - // Pass the token along. - ch <- tok - return parseRespIterContinue - - case <-ts.ctxDone: - ts.ctxDone = nil - ts.dlog("got cancel message, sending attention signal to server") - return ts.sendAttention(ch) - } - case parseRespStateCancel: // Read all responses until a DONE or error is received.Auth - select { - case tok, ok := <-tokChan: - if !ok { - ts.dlog("response finished but waiting for attention ack") - return parseRespIterNext - } - switch tok := tok.(type) { - default: - // Ignore all other tokens while waiting. - // The TDS spec says other tokens may arrive after an attention - // signal is sent. Ignore these tokens and continue looking for - // a DONE with attention confirm mark. - case doneStruct: - if tok.Status&doneAttn != 0 { - ts.dlog("got cancellation confirmation from server") - if ts.cancelError != nil { - ch <- ts.cancelError - ts.cancelError = nil - } else { - ch <- ctx.Err() - } - return parseRespIterDone - } - - // If an error happens during cancel, pass it along and just stop. - // We are uncertain to receive more tokens. - case error: - ch <- tok - ts.state = parseRespStateClosing - } - return parseRespIterContinue - case <-ts.ctxDone: - ts.ctxDone = nil - ts.state = parseRespStateClosing - return parseRespIterContinue - } - case parseRespStateClosing: // Wait for current token chan to close. - if _, ok := <-tokChan; !ok { - ts.dlog("response finished") - return parseRespIterDone - } - return parseRespIterContinue - } -} - -func processResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) { - ts := &parseResp{ +func startReading(sess *tdsSession, ctx context.Context, outs outputs) *tokenProcessor { + tokChan := make(chan tokenStruct, 5) + go processSingleResponse(sess, tokChan, outs) + return &tokenProcessor{ + tokChan: tokChan, + ctx: ctx, sess: sess, - ctxDone: ctx.Done(), + outs: outs, } - defer func() { - // Ensure any remaining error is piped through - // or the query may look like it executed when it actually failed. - if ts.cancelError != nil { - ch <- ts.cancelError - ts.cancelError = nil - } - close(ch) - }() +} - // Loop over multiple responses. +func (t *tokenProcessor) iterateResponse() error { for { - ts.dlog("initiating response reading") - - tokChan := make(chan tokenStruct) - go processSingleResponse(sess, tokChan, outs) - - // Loop over multiple tokens in response. - tokensLoop: - for { - switch ts.iter(ctx, ch, tokChan) { - case parseRespIterContinue: - // Nothing, continue to next token. - case parseRespIterNext: - break tokensLoop - case parseRespIterDone: - return + tok, err := t.nextToken() + if err == nil { + if tok == nil { + return t.firstError + } else { + switch token := tok.(type) { + case []columnStruct: + t.sess.columns = token + case []interface{}: + t.lastRow = token + case doneInProcStruct: + if token.Status&doneCount != 0 { + t.rowCount += int64(token.RowCount) + } + case doneStruct: + if token.Status&doneCount != 0 { + t.rowCount += int64(token.RowCount) + } + if token.isError() && t.firstError == nil { + t.firstError = token.getError() + } + case ReturnStatus: + if t.outs.returnStatus != nil { + *t.outs.returnStatus = token + } + /*case error: + if resultError == nil { + resultError = token + }*/ + } } + } else { + return err } } } + +func (t tokenProcessor) nextToken() (tokenStruct, error) { + // we do this separate non-blocking check on token channel to + // prioritize it over cancellation channel + select { + case tok, more := <-t.tokChan: + err, more := tok.(error) + if more { + // this is an error and not a token + return nil, err + } else { + return tok, nil + } + default: + // there are no tokens on the channel, will need to wait + } + + select { + case tok, more := <-t.tokChan: + if more { + err, ok := tok.(error) + if ok { + // this is an error and not a token + return nil, err + } else { + return tok, nil + } + } else { + // completed reading response + return nil, nil + } + case <-t.ctx.Done(): + if t.noAttn { + return nil, t.ctx.Err() + } + if err := sendAttention(t.sess.buf); err != nil { + // unable to send attention, current connection is bad + // notify caller and close channel + return nil, err + } + + // now the server should send cancellation confirmation + // it is possible that we already received full response + // just before we sent cancellation request + // in this case current response would not contain confirmation + // and we would need to read one more response + + // first lets finish reading current response and look + // for confirmation in it + if readCancelConfirmation(t.tokChan) { + // we got confirmation in current response + return nil, t.ctx.Err() + } + // we did not get cancellation confirmation in the current response + // read one more response, it must be there + t.tokChan = make(chan tokenStruct, 5) + go processSingleResponse(t.sess, t.tokChan, t.outs) + if readCancelConfirmation(t.tokChan) { + return nil, t.ctx.Err() + } + // we did not get cancellation confirmation, something is not + // right, this connection is not usable anymore + return nil, errors.New("did not get cancellation confirmation from the server") + } +} + +func readCancelConfirmation(tokChan chan tokenStruct) bool { + for tok := range tokChan { + switch tok := tok.(type) { + default: + // just skip token + case doneStruct: + if tok.Status&doneAttn != 0 { + // got cancellation confirmation, exit + return true + } + } + } + return false +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/token_string.go b/vendor/github.com/denisenkom/go-mssqldb/token_string.go index c075b23b..a473182c 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/token_string.go +++ b/vendor/github.com/denisenkom/go-mssqldb/token_string.go @@ -1,29 +1,24 @@ -// Code generated by "stringer -type token"; DO NOT EDIT +// Code generated by "stringer -type token"; DO NOT EDIT. package mssql -import "fmt" +import "strconv" const ( _token_name_0 = "tokenReturnStatus" _token_name_1 = "tokenColMetadata" - _token_name_2 = "tokenOrdertokenErrortokenInfo" - _token_name_3 = "tokenLoginAck" - _token_name_4 = "tokenRowtokenNbcRow" - _token_name_5 = "tokenEnvChange" - _token_name_6 = "tokenSSPI" - _token_name_7 = "tokenDonetokenDoneProctokenDoneInProc" + _token_name_2 = "tokenOrdertokenErrortokenInfotokenReturnValuetokenLoginAcktokenFeatureExtAck" + _token_name_3 = "tokenRowtokenNbcRow" + _token_name_4 = "tokenEnvChange" + _token_name_5 = "tokenSSPItokenFedAuthInfo" + _token_name_6 = "tokenDonetokenDoneProctokenDoneInProc" ) var ( - _token_index_0 = [...]uint8{0, 17} - _token_index_1 = [...]uint8{0, 16} - _token_index_2 = [...]uint8{0, 10, 20, 29} - _token_index_3 = [...]uint8{0, 13} - _token_index_4 = [...]uint8{0, 8, 19} - _token_index_5 = [...]uint8{0, 14} - _token_index_6 = [...]uint8{0, 9} - _token_index_7 = [...]uint8{0, 9, 22, 37} + _token_index_2 = [...]uint8{0, 10, 20, 29, 45, 58, 76} + _token_index_3 = [...]uint8{0, 8, 19} + _token_index_5 = [...]uint8{0, 9, 25} + _token_index_6 = [...]uint8{0, 9, 22, 37} ) func (i token) String() string { @@ -32,22 +27,21 @@ func (i token) String() string { return _token_name_0 case i == 129: return _token_name_1 - case 169 <= i && i <= 171: + case 169 <= i && i <= 174: i -= 169 return _token_name_2[_token_index_2[i]:_token_index_2[i+1]] - case i == 173: - return _token_name_3 case 209 <= i && i <= 210: i -= 209 - return _token_name_4[_token_index_4[i]:_token_index_4[i+1]] + return _token_name_3[_token_index_3[i]:_token_index_3[i+1]] case i == 227: - return _token_name_5 - case i == 237: - return _token_name_6 + return _token_name_4 + case 237 <= i && i <= 238: + i -= 237 + return _token_name_5[_token_index_5[i]:_token_index_5[i+1]] case 253 <= i && i <= 255: i -= 253 - return _token_name_7[_token_index_7[i]:_token_index_7[i+1]] + return _token_name_6[_token_index_6[i]:_token_index_6[i+1]] default: - return fmt.Sprintf("token(%d)", i) + return "token(" + strconv.FormatInt(int64(i), 10) + ")" } } diff --git a/vendor/github.com/denisenkom/go-mssqldb/tran.go b/vendor/github.com/denisenkom/go-mssqldb/tran.go index cb643681..9b219724 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/tran.go +++ b/vendor/github.com/denisenkom/go-mssqldb/tran.go @@ -21,11 +21,11 @@ type isoLevel uint8 const ( isolationUseCurrent isoLevel = 0 - isolationReadUncommited = 1 - isolationReadCommited = 2 - isolationRepeatableRead = 3 - isolationSerializable = 4 - isolationSnapshot = 5 + isolationReadUncommited isoLevel = 1 + isolationReadCommited isoLevel = 2 + isolationRepeatableRead isoLevel = 3 + isolationSerializable isoLevel = 4 + isolationSnapshot isoLevel = 5 ) func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation isoLevel, name string, resetSession bool) (err error) { diff --git a/vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go b/vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go index 64e5e21f..d3890af9 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go +++ b/vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go @@ -4,6 +4,7 @@ package mssql import ( "bytes" + "database/sql" "encoding/binary" "errors" "fmt" @@ -97,6 +98,9 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd for columnStrIdx, fieldIdx := range tvpFieldIndexes { field := refStr.Field(fieldIdx) tvpVal := field.Interface() + if tvp.verifyStandardTypeOnNull(buf, tvpVal) { + continue + } valOf := reflect.ValueOf(tvpVal) elemKind := field.Kind() if elemKind == reflect.Ptr && valOf.IsNil() { @@ -155,7 +159,7 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) { defaultValues = append(defaultValues, v.Interface()) continue } - defaultValues = append(defaultValues, reflect.Zero(field.Type).Interface()) + defaultValues = append(defaultValues, tvp.createZeroType(reflect.Zero(field.Type).Interface())) } if columnCount-len(tvpFieldIndexes) == columnCount { @@ -209,19 +213,23 @@ func getSchemeAndName(tvpName string) (string, string, error) { } splitVal := strings.Split(tvpName, ".") if len(splitVal) > 2 { - return "", "", errors.New("wrong tvp name") + return "", "", ErrorObjectName } + const ( + openSquareBrackets = "[" + closeSquareBrackets = "]" + ) if len(splitVal) == 2 { res := make([]string, 2) for key, value := range splitVal { - tmp := strings.Replace(value, "[", "", -1) - tmp = strings.Replace(tmp, "]", "", -1) + tmp := strings.Replace(value, openSquareBrackets, "", -1) + tmp = strings.Replace(tmp, closeSquareBrackets, "", -1) res[key] = tmp } return res[0], res[1], nil } - tmp := strings.Replace(splitVal[0], "[", "", -1) - tmp = strings.Replace(tmp, "]", "", -1) + tmp := strings.Replace(splitVal[0], openSquareBrackets, "", -1) + tmp = strings.Replace(tmp, closeSquareBrackets, "", -1) return "", tmp, nil } @@ -229,3 +237,56 @@ func getSchemeAndName(tvpName string) (string, string, error) { func getCountSQLSeparators(str string) int { return strings.Count(str, sqlSeparator) } + +// verify types https://golang.org/pkg/database/sql/ +func (tvp TVP) createZeroType(fieldVal interface{}) interface{} { + const ( + defaultBool = false + defaultFloat64 = float64(0) + defaultInt64 = int64(0) + defaultString = "" + ) + + switch fieldVal.(type) { + case sql.NullBool: + return defaultBool + case sql.NullFloat64: + return defaultFloat64 + case sql.NullInt64: + return defaultInt64 + case sql.NullString: + return defaultString + } + return fieldVal +} + +// verify types https://golang.org/pkg/database/sql/ +func (tvp TVP) verifyStandardTypeOnNull(buf *bytes.Buffer, tvpVal interface{}) bool { + const ( + defaultNull = uint8(0) + ) + + switch val := tvpVal.(type) { + case sql.NullBool: + if !val.Valid { + binary.Write(buf, binary.LittleEndian, defaultNull) + return true + } + case sql.NullFloat64: + if !val.Valid { + binary.Write(buf, binary.LittleEndian, defaultNull) + return true + } + case sql.NullInt64: + if !val.Valid { + binary.Write(buf, binary.LittleEndian, defaultNull) + return true + } + case sql.NullString: + if !val.Valid { + binary.Write(buf, binary.LittleEndian, uint64(_PLP_NULL)) + return true + } + } + return false +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/types.go b/vendor/github.com/denisenkom/go-mssqldb/types.go index b6e7fb2b..cae19924 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/types.go +++ b/vendor/github.com/denisenkom/go-mssqldb/types.go @@ -665,7 +665,7 @@ func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} { default: buf = bytes.NewBuffer(make([]byte, 0, size)) } - for true { + for { chunksize := r.uint32() if chunksize == 0 { break @@ -690,6 +690,10 @@ func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} { } func writePLPType(w io.Writer, ti typeInfo, buf []byte) (err error) { + if buf == nil { + err = binary.Write(w, binary.LittleEndian, uint64(_PLP_NULL)) + return + } if err = binary.Write(w, binary.LittleEndian, uint64(_UNKNOWN_PLP_LEN)); err != nil { return } @@ -807,7 +811,6 @@ func readVarLen(ti *typeInfo, r *tdsBuffer) { default: badStreamPanicf("Invalid type %d", ti.TypeId) } - return } func decodeMoney(buf []byte) []byte { @@ -834,8 +837,7 @@ func decodeGuid(buf []byte) []byte { } func decodeDecimal(prec uint8, scale uint8, buf []byte) []byte { - var sign uint8 - sign = buf[0] + sign := buf[0] var dec decimal.Decimal dec.SetPositive(sign != 0) dec.SetPrec(prec) @@ -1187,7 +1189,7 @@ func makeDecl(ti typeInfo) string { return fmt.Sprintf("char(%d)", ti.Size) case typeBigVarChar, typeVarChar: if ti.Size > 8000 || ti.Size == 0 { - return fmt.Sprintf("varchar(max)") + return "varchar(max)" } else { return fmt.Sprintf("varchar(%d)", ti.Size) } diff --git a/vendor/github.com/go-asn1-ber/asn1-ber/.travis.yml b/vendor/github.com/go-asn1-ber/asn1-ber/.travis.yml new file mode 100644 index 00000000..764b5418 --- /dev/null +++ b/vendor/github.com/go-asn1-ber/asn1-ber/.travis.yml @@ -0,0 +1,43 @@ +language: go + +go: + - 1.2.x + - 1.6.x + - 1.9.x + - 1.10.x + - 1.11.x + - 1.12.x + - 1.14.x + - tip + +os: + - linux + +arch: + - amd64 + - ppc64le + +dist: xenial + +env: + - GOARCH=amd64 + +jobs: + include: + - os: windows + go: 1.14.x + - os: osx + go: 1.14.x + - os: linux + go: 1.14.x + arch: arm64 + - os: linux + go: 1.14.x + env: + - GOARCH=386 + +script: + - go test -v -cover ./... || go test -v ./... +matrix: + allowfailures: + go: 1.2.x diff --git a/vendor/gopkg.in/asn1-ber.v1/LICENSE b/vendor/github.com/go-asn1-ber/asn1-ber/LICENSE similarity index 100% rename from vendor/gopkg.in/asn1-ber.v1/LICENSE rename to vendor/github.com/go-asn1-ber/asn1-ber/LICENSE diff --git a/vendor/gopkg.in/asn1-ber.v1/README.md b/vendor/github.com/go-asn1-ber/asn1-ber/README.md similarity index 100% rename from vendor/gopkg.in/asn1-ber.v1/README.md rename to vendor/github.com/go-asn1-ber/asn1-ber/README.md diff --git a/vendor/gopkg.in/asn1-ber.v1/ber.go b/vendor/github.com/go-asn1-ber/asn1-ber/ber.go similarity index 69% rename from vendor/gopkg.in/asn1-ber.v1/ber.go rename to vendor/github.com/go-asn1-ber/asn1-ber/ber.go index 6153f460..4fd7a66e 100644 --- a/vendor/gopkg.in/asn1-ber.v1/ber.go +++ b/vendor/github.com/go-asn1-ber/asn1-ber/ber.go @@ -8,6 +8,8 @@ import ( "math" "os" "reflect" + "time" + "unicode/utf8" ) // MaxPacketLengthBytes specifies the maximum allowed packet size when calling ReadPacket or DecodePacket. Set to 0 for @@ -143,42 +145,46 @@ var TypeMap = map[Type]string{ TypeConstructed: "Constructed", } -var Debug bool = false +var Debug = false func PrintBytes(out io.Writer, buf []byte, indent string) { - data_lines := make([]string, (len(buf)/30)+1) - num_lines := make([]string, (len(buf)/30)+1) + dataLines := make([]string, (len(buf)/30)+1) + numLines := make([]string, (len(buf)/30)+1) for i, b := range buf { - data_lines[i/30] += fmt.Sprintf("%02x ", b) - num_lines[i/30] += fmt.Sprintf("%02d ", (i+1)%100) + dataLines[i/30] += fmt.Sprintf("%02x ", b) + numLines[i/30] += fmt.Sprintf("%02d ", (i+1)%100) } - for i := 0; i < len(data_lines); i++ { - out.Write([]byte(indent + data_lines[i] + "\n")) - out.Write([]byte(indent + num_lines[i] + "\n\n")) + for i := 0; i < len(dataLines); i++ { + _, _ = out.Write([]byte(indent + dataLines[i] + "\n")) + _, _ = out.Write([]byte(indent + numLines[i] + "\n\n")) } } +func WritePacket(out io.Writer, p *Packet) { + printPacket(out, p, 0, false) +} + func PrintPacket(p *Packet) { printPacket(os.Stdout, p, 0, false) } func printPacket(out io.Writer, p *Packet, indent int, printBytes bool) { - indent_str := "" + indentStr := "" - for len(indent_str) != indent { - indent_str += " " + for len(indentStr) != indent { + indentStr += " " } - class_str := ClassMap[p.ClassType] + classStr := ClassMap[p.ClassType] - tagtype_str := TypeMap[p.TagType] + tagTypeStr := TypeMap[p.TagType] - tag_str := fmt.Sprintf("0x%02X", p.Tag) + tagStr := fmt.Sprintf("0x%02X", p.Tag) if p.ClassType == ClassUniversal { - tag_str = tagMap[p.Tag] + tagStr = tagMap[p.Tag] } value := fmt.Sprint(p.Value) @@ -188,10 +194,10 @@ func printPacket(out io.Writer, p *Packet, indent int, printBytes bool) { description = p.Description + ": " } - fmt.Fprintf(out, "%s%s(%s, %s, %s) Len=%d %q\n", indent_str, description, class_str, tagtype_str, tag_str, p.Data.Len(), value) + _, _ = fmt.Fprintf(out, "%s%s(%s, %s, %s) Len=%d %q\n", indentStr, description, classStr, tagTypeStr, tagStr, p.Data.Len(), value) if printBytes { - PrintBytes(out, p.Bytes(), indent_str) + PrintBytes(out, p.Bytes(), indentStr) } for _, child := range p.Children { @@ -199,7 +205,7 @@ func printPacket(out io.Writer, p *Packet, indent int, printBytes bool) { } } -// ReadPacket reads a single Packet from the reader +// ReadPacket reads a single Packet from the reader. func ReadPacket(reader io.Reader) (*Packet, error) { p, _, err := readPacket(reader) if err != nil { @@ -235,7 +241,7 @@ func encodeInteger(i int64) []byte { var j int for ; n > 0; n-- { - out[j] = (byte(i >> uint((n-1)*8))) + out[j] = byte(i >> uint((n-1)*8)) j++ } @@ -267,7 +273,7 @@ func DecodePacket(data []byte) *Packet { } // DecodePacketErr decodes the given bytes into a single Packet -// If a decode error is encountered, nil is returned +// If a decode error is encountered, nil is returned. func DecodePacketErr(data []byte) (*Packet, error) { p, _, err := readPacket(bytes.NewBuffer(data)) if err != nil { @@ -276,7 +282,7 @@ func DecodePacketErr(data []byte) (*Packet, error) { return p, nil } -// readPacket reads a single Packet from the reader, returning the number of bytes read +// readPacket reads a single Packet from the reader, returning the number of bytes read. func readPacket(reader io.Reader) (*Packet, int, error) { identifier, length, read, err := readHeader(reader) if err != nil { @@ -338,7 +344,7 @@ func readPacket(reader io.Reader) (*Packet, int, error) { if MaxPacketLengthBytes > 0 && int64(length) > MaxPacketLengthBytes { return nil, read, fmt.Errorf("length %d greater than maximum %d", length, MaxPacketLengthBytes) } - content := make([]byte, length, length) + content := make([]byte, length) if length > 0 { _, err := io.ReadFull(reader, content) if err != nil { @@ -373,22 +379,42 @@ func readPacket(reader io.Reader) (*Packet, int, error) { case TagObjectDescriptor: case TagExternal: case TagRealFloat: + p.Value, err = ParseReal(content) case TagEnumerated: p.Value, _ = ParseInt64(content) case TagEmbeddedPDV: case TagUTF8String: - p.Value = DecodeString(content) + val := DecodeString(content) + if !utf8.Valid([]byte(val)) { + err = errors.New("invalid UTF-8 string") + } else { + p.Value = val + } case TagRelativeOID: case TagSequence: case TagSet: case TagNumericString: case TagPrintableString: - p.Value = DecodeString(content) + val := DecodeString(content) + if err = isPrintableString(val); err == nil { + p.Value = val + } case TagT61String: case TagVideotexString: case TagIA5String: + val := DecodeString(content) + for i, c := range val { + if c >= 0x7F { + err = fmt.Errorf("invalid character for IA5String at pos %d: %c", i, c) + break + } + } + if err == nil { + p.Value = val + } case TagUTCTime: case TagGeneralizedTime: + p.Value, err = ParseGeneralizedTime(content) case TagGraphicString: case TagVisibleString: case TagGeneralString: @@ -400,7 +426,24 @@ func readPacket(reader io.Reader) (*Packet, int, error) { p.Data.Write(content) } - return p, read, nil + return p, read, err +} + +func isPrintableString(val string) error { + for i, c := range val { + switch { + case c >= 'a' && c <= 'z': + case c >= 'A' && c <= 'Z': + case c >= '0' && c <= '9': + default: + switch c { + case '\'', '(', ')', '+', ',', '-', '.', '=', '/', ':', '?', ' ': + default: + return fmt.Errorf("invalid character in position %d", i) + } + } + } + return nil } func (p *Packet) Bytes() []byte { @@ -418,61 +461,99 @@ func (p *Packet) AppendChild(child *Packet) { p.Children = append(p.Children, child) } -func Encode(ClassType Class, TagType Type, Tag Tag, Value interface{}, Description string) *Packet { +func Encode(classType Class, tagType Type, tag Tag, value interface{}, description string) *Packet { p := new(Packet) - p.ClassType = ClassType - p.TagType = TagType - p.Tag = Tag + p.ClassType = classType + p.TagType = tagType + p.Tag = tag p.Data = new(bytes.Buffer) p.Children = make([]*Packet, 0, 2) - p.Value = Value - p.Description = Description + p.Value = value + p.Description = description - if Value != nil { - v := reflect.ValueOf(Value) + if value != nil { + v := reflect.ValueOf(value) - if ClassType == ClassUniversal { - switch Tag { + if classType == ClassUniversal { + switch tag { case TagOctetString: sv, ok := v.Interface().(string) if ok { p.Data.Write([]byte(sv)) } + case TagEnumerated: + bv, ok := v.Interface().([]byte) + if ok { + p.Data.Write(bv) + } + case TagEmbeddedPDV: + bv, ok := v.Interface().([]byte) + if ok { + p.Data.Write(bv) + } + } + } else if classType == ClassContext { + switch tag { + case TagEnumerated: + bv, ok := v.Interface().([]byte) + if ok { + p.Data.Write(bv) + } + case TagEmbeddedPDV: + bv, ok := v.Interface().([]byte) + if ok { + p.Data.Write(bv) + } } } } - return p } -func NewSequence(Description string) *Packet { - return Encode(ClassUniversal, TypeConstructed, TagSequence, nil, Description) +func NewSequence(description string) *Packet { + return Encode(ClassUniversal, TypeConstructed, TagSequence, nil, description) } -func NewBoolean(ClassType Class, TagType Type, Tag Tag, Value bool, Description string) *Packet { +func NewBoolean(classType Class, tagType Type, tag Tag, value bool, description string) *Packet { intValue := int64(0) - if Value { + if value { intValue = 1 } - p := Encode(ClassType, TagType, Tag, nil, Description) + p := Encode(classType, tagType, tag, nil, description) - p.Value = Value + p.Value = value p.Data.Write(encodeInteger(intValue)) return p } -func NewInteger(ClassType Class, TagType Type, Tag Tag, Value interface{}, Description string) *Packet { - p := Encode(ClassType, TagType, Tag, nil, Description) +// NewLDAPBoolean returns a RFC 4511-compliant Boolean packet. +func NewLDAPBoolean(classType Class, tagType Type, tag Tag, value bool, description string) *Packet { + intValue := int64(0) - p.Value = Value - switch v := Value.(type) { + if value { + intValue = 255 + } + + p := Encode(classType, tagType, tag, nil, description) + + p.Value = value + p.Data.Write(encodeInteger(intValue)) + + return p +} + +func NewInteger(classType Class, tagType Type, tag Tag, value interface{}, description string) *Packet { + p := Encode(classType, tagType, tag, nil, description) + + p.Value = value + switch v := value.(type) { case int: p.Data.Write(encodeInteger(int64(v))) case uint: @@ -502,11 +583,38 @@ func NewInteger(ClassType Class, TagType Type, Tag Tag, Value interface{}, Descr return p } -func NewString(ClassType Class, TagType Type, Tag Tag, Value, Description string) *Packet { - p := Encode(ClassType, TagType, Tag, nil, Description) +func NewString(classType Class, tagType Type, tag Tag, value, description string) *Packet { + p := Encode(classType, tagType, tag, nil, description) - p.Value = Value - p.Data.Write([]byte(Value)) + p.Value = value + p.Data.Write([]byte(value)) return p } + +func NewGeneralizedTime(classType Class, tagType Type, tag Tag, value time.Time, description string) *Packet { + p := Encode(classType, tagType, tag, nil, description) + var s string + if value.Nanosecond() != 0 { + s = value.Format(`20060102150405.000000000Z`) + } else { + s = value.Format(`20060102150405Z`) + } + p.Value = s + p.Data.Write([]byte(s)) + return p +} + +func NewReal(classType Class, tagType Type, tag Tag, value interface{}, description string) *Packet { + p := Encode(classType, tagType, tag, nil, description) + + switch v := value.(type) { + case float64: + p.Data.Write(encodeFloat(v)) + case float32: + p.Data.Write(encodeFloat(float64(v))) + default: + panic(fmt.Sprintf("Invalid type %T, expected float{64|32}", v)) + } + return p +} diff --git a/vendor/gopkg.in/asn1-ber.v1/content_int.go b/vendor/github.com/go-asn1-ber/asn1-ber/content_int.go similarity index 87% rename from vendor/gopkg.in/asn1-ber.v1/content_int.go rename to vendor/github.com/go-asn1-ber/asn1-ber/content_int.go index 1858b74b..20b500f5 100644 --- a/vendor/gopkg.in/asn1-ber.v1/content_int.go +++ b/vendor/github.com/go-asn1-ber/asn1-ber/content_int.go @@ -6,7 +6,7 @@ func encodeUnsignedInteger(i uint64) []byte { var j int for ; n > 0; n-- { - out[j] = (byte(i >> uint((n-1)*8))) + out[j] = byte(i >> uint((n-1)*8)) j++ } diff --git a/vendor/github.com/go-asn1-ber/asn1-ber/generalizedTime.go b/vendor/github.com/go-asn1-ber/asn1-ber/generalizedTime.go new file mode 100644 index 00000000..51215f06 --- /dev/null +++ b/vendor/github.com/go-asn1-ber/asn1-ber/generalizedTime.go @@ -0,0 +1,105 @@ +package ber + +import ( + "bytes" + "errors" + "fmt" + "strconv" + "time" +) + +// ErrInvalidTimeFormat is returned when the generalizedTime string was not correct. +var ErrInvalidTimeFormat = errors.New("invalid time format") + +var zeroTime = time.Time{} + +// ParseGeneralizedTime parses a string value and if it conforms to +// GeneralizedTime[^0] format, will return a time.Time for that value. +// +// [^0]: https://www.itu.int/rec/T-REC-X.690-201508-I/en Section 11.7 +func ParseGeneralizedTime(v []byte) (time.Time, error) { + var format string + var fract time.Duration + + str := []byte(DecodeString(v)) + tzIndex := bytes.IndexAny(str, "Z+-") + if tzIndex < 0 { + return zeroTime, ErrInvalidTimeFormat + } + + dot := bytes.IndexAny(str, ".,") + switch dot { + case -1: + switch tzIndex { + case 10: + format = `2006010215Z` + case 12: + format = `200601021504Z` + case 14: + format = `20060102150405Z` + default: + return zeroTime, ErrInvalidTimeFormat + } + + case 10, 12: + if tzIndex < dot { + return zeroTime, ErrInvalidTimeFormat + } + // a "," is also allowed, but would not be parsed by time.Parse(): + str[dot] = '.' + + // If is omitted, then represents a fraction of an + // hour; otherwise, if and are omitted, then + // represents a fraction of a minute; otherwise, + // represents a fraction of a second. + + // parse as float from dot to timezone + f, err := strconv.ParseFloat(string(str[dot:tzIndex]), 64) + if err != nil { + return zeroTime, fmt.Errorf("failed to parse float: %s", err) + } + // ...and strip that part + str = append(str[:dot], str[tzIndex:]...) + tzIndex = dot + + if dot == 10 { + fract = time.Duration(int64(f * float64(time.Hour))) + format = `2006010215Z` + } else { + fract = time.Duration(int64(f * float64(time.Minute))) + format = `200601021504Z` + } + + case 14: + if tzIndex < dot { + return zeroTime, ErrInvalidTimeFormat + } + str[dot] = '.' + // no need for fractional seconds, time.Parse() handles that + format = `20060102150405Z` + + default: + return zeroTime, ErrInvalidTimeFormat + } + + l := len(str) + switch l - tzIndex { + case 1: + if str[l-1] != 'Z' { + return zeroTime, ErrInvalidTimeFormat + } + case 3: + format += `0700` + str = append(str, []byte("00")...) + case 5: + format += `0700` + default: + return zeroTime, ErrInvalidTimeFormat + } + + t, err := time.Parse(format, string(str)) + if err != nil { + return zeroTime, fmt.Errorf("%s: %s", ErrInvalidTimeFormat, err) + } + return t.Add(fract), nil +} diff --git a/vendor/gopkg.in/asn1-ber.v1/header.go b/vendor/github.com/go-asn1-ber/asn1-ber/header.go similarity index 75% rename from vendor/gopkg.in/asn1-ber.v1/header.go rename to vendor/github.com/go-asn1-ber/asn1-ber/header.go index 71615621..7dfa6b9a 100644 --- a/vendor/gopkg.in/asn1-ber.v1/header.go +++ b/vendor/github.com/go-asn1-ber/asn1-ber/header.go @@ -7,19 +7,22 @@ import ( ) func readHeader(reader io.Reader) (identifier Identifier, length int, read int, err error) { - if i, c, err := readIdentifier(reader); err != nil { - return Identifier{}, 0, read, err - } else { - identifier = i - read += c - } + var ( + c, l int + i Identifier + ) - if l, c, err := readLength(reader); err != nil { + if i, c, err = readIdentifier(reader); err != nil { return Identifier{}, 0, read, err - } else { - length = l - read += c } + identifier = i + read += c + + if l, c, err = readLength(reader); err != nil { + return Identifier{}, 0, read, err + } + length = l + read += c // Validate length type with identifier (x.600, 8.1.3.2.a) if length == LengthIndefinite && identifier.TagType == TypePrimitive { diff --git a/vendor/gopkg.in/asn1-ber.v1/identifier.go b/vendor/github.com/go-asn1-ber/asn1-ber/identifier.go similarity index 100% rename from vendor/gopkg.in/asn1-ber.v1/identifier.go rename to vendor/github.com/go-asn1-ber/asn1-ber/identifier.go diff --git a/vendor/gopkg.in/asn1-ber.v1/length.go b/vendor/github.com/go-asn1-ber/asn1-ber/length.go similarity index 85% rename from vendor/gopkg.in/asn1-ber.v1/length.go rename to vendor/github.com/go-asn1-ber/asn1-ber/length.go index 750e8f44..9cc195d0 100644 --- a/vendor/gopkg.in/asn1-ber.v1/length.go +++ b/vendor/github.com/go-asn1-ber/asn1-ber/length.go @@ -71,11 +71,11 @@ func readLength(reader io.Reader) (length int, read int, err error) { } func encodeLength(length int) []byte { - length_bytes := encodeUnsignedInteger(uint64(length)) - if length > 127 || len(length_bytes) > 1 { - longFormBytes := []byte{(LengthLongFormBitmask | byte(len(length_bytes)))} - longFormBytes = append(longFormBytes, length_bytes...) - length_bytes = longFormBytes + lengthBytes := encodeUnsignedInteger(uint64(length)) + if length > 127 || len(lengthBytes) > 1 { + longFormBytes := []byte{LengthLongFormBitmask | byte(len(lengthBytes))} + longFormBytes = append(longFormBytes, lengthBytes...) + lengthBytes = longFormBytes } - return length_bytes + return lengthBytes } diff --git a/vendor/github.com/go-asn1-ber/asn1-ber/real.go b/vendor/github.com/go-asn1-ber/asn1-ber/real.go new file mode 100644 index 00000000..610a003a --- /dev/null +++ b/vendor/github.com/go-asn1-ber/asn1-ber/real.go @@ -0,0 +1,157 @@ +package ber + +import ( + "bytes" + "errors" + "fmt" + "math" + "strconv" + "strings" +) + +func encodeFloat(v float64) []byte { + switch { + case math.IsInf(v, 1): + return []byte{0x40} + case math.IsInf(v, -1): + return []byte{0x41} + case math.IsNaN(v): + return []byte{0x42} + case v == 0.0: + if math.Signbit(v) { + return []byte{0x43} + } + return []byte{} + default: + // we take the easy part ;-) + value := []byte(strconv.FormatFloat(v, 'G', -1, 64)) + var ret []byte + if bytes.Contains(value, []byte{'E'}) { + ret = []byte{0x03} + } else { + ret = []byte{0x02} + } + ret = append(ret, value...) + return ret + } +} + +func ParseReal(v []byte) (val float64, err error) { + if len(v) == 0 { + return 0.0, nil + } + switch { + case v[0]&0x80 == 0x80: + val, err = parseBinaryFloat(v) + case v[0]&0xC0 == 0x40: + val, err = parseSpecialFloat(v) + case v[0]&0xC0 == 0x0: + val, err = parseDecimalFloat(v) + default: + return 0.0, fmt.Errorf("invalid info block") + } + if err != nil { + return 0.0, err + } + + if val == 0.0 && !math.Signbit(val) { + return 0.0, errors.New("REAL value +0 must be encoded with zero-length value block") + } + return val, nil +} + +func parseBinaryFloat(v []byte) (float64, error) { + var info byte + var buf []byte + + info, v = v[0], v[1:] + + var base int + switch info & 0x30 { + case 0x00: + base = 2 + case 0x10: + base = 8 + case 0x20: + base = 16 + case 0x30: + return 0.0, errors.New("bits 6 and 5 of information octet for REAL are equal to 11") + } + + scale := uint((info & 0x0c) >> 2) + + var expLen int + switch info & 0x03 { + case 0x00: + expLen = 1 + case 0x01: + expLen = 2 + case 0x02: + expLen = 3 + case 0x03: + expLen = int(v[0]) + if expLen > 8 { + return 0.0, errors.New("too big value of exponent") + } + v = v[1:] + } + buf, v = v[:expLen], v[expLen:] + exponent, err := ParseInt64(buf) + if err != nil { + return 0.0, err + } + + if len(v) > 8 { + return 0.0, errors.New("too big value of mantissa") + } + + mant, err := ParseInt64(v) + if err != nil { + return 0.0, err + } + mantissa := mant << scale + + if info&0x40 == 0x40 { + mantissa = -mantissa + } + + return float64(mantissa) * math.Pow(float64(base), float64(exponent)), nil +} + +func parseDecimalFloat(v []byte) (val float64, err error) { + switch v[0] & 0x3F { + case 0x01: // NR form 1 + var iVal int64 + iVal, err = strconv.ParseInt(strings.TrimLeft(string(v[1:]), " "), 10, 64) + val = float64(iVal) + case 0x02, 0x03: // NR form 2, 3 + val, err = strconv.ParseFloat(strings.Replace(strings.TrimLeft(string(v[1:]), " "), ",", ".", -1), 64) + default: + err = errors.New("incorrect NR form") + } + if err != nil { + return 0.0, err + } + + if val == 0.0 && math.Signbit(val) { + return 0.0, errors.New("REAL value -0 must be encoded as a special value") + } + return val, nil +} + +func parseSpecialFloat(v []byte) (float64, error) { + if len(v) != 1 { + return 0.0, errors.New(`encoding of "special value" must not contain exponent and mantissa`) + } + switch v[0] { + case 0x40: + return math.Inf(1), nil + case 0x41: + return math.Inf(-1), nil + case 0x42: + return math.NaN(), nil + case 0x43: + return math.Copysign(0, -1), nil + } + return 0.0, errors.New(`encoding of "special value" not from ASN.1 standard`) +} diff --git a/vendor/gopkg.in/asn1-ber.v1/util.go b/vendor/github.com/go-asn1-ber/asn1-ber/util.go similarity index 93% rename from vendor/gopkg.in/asn1-ber.v1/util.go rename to vendor/github.com/go-asn1-ber/asn1-ber/util.go index 3e56b66c..14dc87d7 100644 --- a/vendor/gopkg.in/asn1-ber.v1/util.go +++ b/vendor/github.com/go-asn1-ber/asn1-ber/util.go @@ -3,7 +3,7 @@ package ber import "io" func readByte(reader io.Reader) (byte, error) { - bytes := make([]byte, 1, 1) + bytes := make([]byte, 1) _, err := io.ReadFull(reader, bytes) if err != nil { if err == io.EOF { diff --git a/vendor/gopkg.in/ldap.v3/LICENSE b/vendor/github.com/go-ldap/ldap/v3/LICENSE similarity index 100% rename from vendor/gopkg.in/ldap.v3/LICENSE rename to vendor/github.com/go-ldap/ldap/v3/LICENSE diff --git a/vendor/gopkg.in/ldap.v3/add.go b/vendor/github.com/go-ldap/ldap/v3/add.go similarity index 91% rename from vendor/gopkg.in/ldap.v3/add.go rename to vendor/github.com/go-ldap/ldap/v3/add.go index e2cb9b06..baecd787 100644 --- a/vendor/gopkg.in/ldap.v3/add.go +++ b/vendor/github.com/go-ldap/ldap/v3/add.go @@ -1,18 +1,9 @@ -// -// https://tools.ietf.org/html/rfc4511 -// -// AddRequest ::= [APPLICATION 8] SEQUENCE { -// entry LDAPDN, -// attributes AttributeList } -// -// AttributeList ::= SEQUENCE OF attribute Attribute - package ldap import ( "log" - ber "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) // Attribute represents an LDAP attribute diff --git a/vendor/github.com/go-ldap/ldap/v3/bind.go b/vendor/github.com/go-ldap/ldap/v3/bind.go new file mode 100644 index 00000000..9bc57482 --- /dev/null +++ b/vendor/github.com/go-ldap/ldap/v3/bind.go @@ -0,0 +1,540 @@ +package ldap + +import ( + "bytes" + "crypto/md5" + enchex "encoding/hex" + "errors" + "fmt" + "io/ioutil" + "math/rand" + "strings" + + "github.com/Azure/go-ntlmssp" + ber "github.com/go-asn1-ber/asn1-ber" +) + +// SimpleBindRequest represents a username/password bind operation +type SimpleBindRequest struct { + // Username is the name of the Directory object that the client wishes to bind as + Username string + // Password is the credentials to bind with + Password string + // Controls are optional controls to send with the bind request + Controls []Control + // AllowEmptyPassword sets whether the client allows binding with an empty password + // (normally used for unauthenticated bind). + AllowEmptyPassword bool +} + +// SimpleBindResult contains the response from the server +type SimpleBindResult struct { + Controls []Control +} + +// NewSimpleBindRequest returns a bind request +func NewSimpleBindRequest(username string, password string, controls []Control) *SimpleBindRequest { + return &SimpleBindRequest{ + Username: username, + Password: password, + Controls: controls, + AllowEmptyPassword: false, + } +} + +func (req *SimpleBindRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") + pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Username, "User Name")) + pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.Password, "Password")) + + envelope.AppendChild(pkt) + if len(req.Controls) > 0 { + envelope.AppendChild(encodeControls(req.Controls)) + } + + return nil +} + +// SimpleBind performs the simple bind operation defined in the given request +func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResult, error) { + if simpleBindRequest.Password == "" && !simpleBindRequest.AllowEmptyPassword { + return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client")) + } + + msgCtx, err := l.doRequest(simpleBindRequest) + if err != nil { + return nil, err + } + defer l.finishMessage(msgCtx) + + packet, err := l.readPacket(msgCtx) + if err != nil { + return nil, err + } + + result := &SimpleBindResult{ + Controls: make([]Control, 0), + } + + if len(packet.Children) == 3 { + for _, child := range packet.Children[2].Children { + decodedChild, decodeErr := DecodeControl(child) + if decodeErr != nil { + return nil, fmt.Errorf("failed to decode child control: %s", decodeErr) + } + result.Controls = append(result.Controls, decodedChild) + } + } + + err = GetLDAPError(packet) + return result, err +} + +// Bind performs a bind with the given username and password. +// +// It does not allow unauthenticated bind (i.e. empty password). Use the UnauthenticatedBind method +// for that. +func (l *Conn) Bind(username, password string) error { + req := &SimpleBindRequest{ + Username: username, + Password: password, + AllowEmptyPassword: false, + } + _, err := l.SimpleBind(req) + return err +} + +// UnauthenticatedBind performs an unauthenticated bind. +// +// A username may be provided for trace (e.g. logging) purpose only, but it is normally not +// authenticated or otherwise validated by the LDAP server. +// +// See https://tools.ietf.org/html/rfc4513#section-5.1.2 . +// See https://tools.ietf.org/html/rfc4513#section-6.3.1 . +func (l *Conn) UnauthenticatedBind(username string) error { + req := &SimpleBindRequest{ + Username: username, + Password: "", + AllowEmptyPassword: true, + } + _, err := l.SimpleBind(req) + return err +} + +// DigestMD5BindRequest represents a digest-md5 bind operation +type DigestMD5BindRequest struct { + Host string + // Username is the name of the Directory object that the client wishes to bind as + Username string + // Password is the credentials to bind with + Password string + // Controls are optional controls to send with the bind request + Controls []Control +} + +func (req *DigestMD5BindRequest) appendTo(envelope *ber.Packet) error { + request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") + request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) + request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name")) + + auth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication") + auth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "DIGEST-MD5", "SASL Mech")) + request.AppendChild(auth) + envelope.AppendChild(request) + if len(req.Controls) > 0 { + envelope.AppendChild(encodeControls(req.Controls)) + } + return nil +} + +// DigestMD5BindResult contains the response from the server +type DigestMD5BindResult struct { + Controls []Control +} + +// MD5Bind performs a digest-md5 bind with the given host, username and password. +func (l *Conn) MD5Bind(host, username, password string) error { + req := &DigestMD5BindRequest{ + Host: host, + Username: username, + Password: password, + } + _, err := l.DigestMD5Bind(req) + return err +} + +// DigestMD5Bind performs the digest-md5 bind operation defined in the given request +func (l *Conn) DigestMD5Bind(digestMD5BindRequest *DigestMD5BindRequest) (*DigestMD5BindResult, error) { + if digestMD5BindRequest.Password == "" { + return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client")) + } + + msgCtx, err := l.doRequest(digestMD5BindRequest) + if err != nil { + return nil, err + } + defer l.finishMessage(msgCtx) + + packet, err := l.readPacket(msgCtx) + if err != nil { + return nil, err + } + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + if l.Debug { + if err = addLDAPDescriptions(packet); err != nil { + return nil, err + } + ber.PrintPacket(packet) + } + + result := &DigestMD5BindResult{ + Controls: make([]Control, 0), + } + var params map[string]string + if len(packet.Children) == 2 { + if len(packet.Children[1].Children) == 4 { + child := packet.Children[1].Children[0] + if child.Tag != ber.TagEnumerated { + return result, GetLDAPError(packet) + } + if child.Value.(int64) != 14 { + return result, GetLDAPError(packet) + } + child = packet.Children[1].Children[3] + if child.Tag != ber.TagObjectDescriptor { + return result, GetLDAPError(packet) + } + if child.Data == nil { + return result, GetLDAPError(packet) + } + data, _ := ioutil.ReadAll(child.Data) + params, err = parseParams(string(data)) + if err != nil { + return result, fmt.Errorf("parsing digest-challenge: %s", err) + } + } + } + + if params != nil { + resp := computeResponse( + params, + "ldap/"+strings.ToLower(digestMD5BindRequest.Host), + digestMD5BindRequest.Username, + digestMD5BindRequest.Password, + ) + packet = ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) + + request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") + request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) + request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name")) + + auth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication") + auth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "DIGEST-MD5", "SASL Mech")) + auth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, resp, "Credentials")) + request.AppendChild(auth) + packet.AppendChild(request) + msgCtx, err = l.sendMessage(packet) + if err != nil { + return nil, fmt.Errorf("send message: %s", err) + } + defer l.finishMessage(msgCtx) + packetResponse, ok := <-msgCtx.responses + if !ok { + return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed")) + } + packet, err = packetResponse.ReadPacket() + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + if err != nil { + return nil, fmt.Errorf("read packet: %s", err) + } + } + + err = GetLDAPError(packet) + return result, err +} + +func parseParams(str string) (map[string]string, error) { + m := make(map[string]string) + var key, value string + var state int + for i := 0; i <= len(str); i++ { + switch state { + case 0: //reading key + if i == len(str) { + return nil, fmt.Errorf("syntax error on %d", i) + } + if str[i] != '=' { + key += string(str[i]) + continue + } + state = 1 + case 1: //reading value + if i == len(str) { + m[key] = value + break + } + switch str[i] { + case ',': + m[key] = value + state = 0 + key = "" + value = "" + case '"': + if value != "" { + return nil, fmt.Errorf("syntax error on %d", i) + } + state = 2 + default: + value += string(str[i]) + } + case 2: //inside quotes + if i == len(str) { + return nil, fmt.Errorf("syntax error on %d", i) + } + if str[i] != '"' { + value += string(str[i]) + } else { + state = 1 + } + } + } + return m, nil +} + +func computeResponse(params map[string]string, uri, username, password string) string { + nc := "00000001" + qop := "auth" + cnonce := enchex.EncodeToString(randomBytes(16)) + x := username + ":" + params["realm"] + ":" + password + y := md5Hash([]byte(x)) + + a1 := bytes.NewBuffer(y) + a1.WriteString(":" + params["nonce"] + ":" + cnonce) + if len(params["authzid"]) > 0 { + a1.WriteString(":" + params["authzid"]) + } + a2 := bytes.NewBuffer([]byte("AUTHENTICATE")) + a2.WriteString(":" + uri) + ha1 := enchex.EncodeToString(md5Hash(a1.Bytes())) + ha2 := enchex.EncodeToString(md5Hash(a2.Bytes())) + + kd := ha1 + kd += ":" + params["nonce"] + kd += ":" + nc + kd += ":" + cnonce + kd += ":" + qop + kd += ":" + ha2 + resp := enchex.EncodeToString(md5Hash([]byte(kd))) + return fmt.Sprintf( + `username="%s",realm="%s",nonce="%s",cnonce="%s",nc=00000001,qop=%s,digest-uri="%s",response=%s`, + username, + params["realm"], + params["nonce"], + cnonce, + qop, + uri, + resp, + ) +} + +func md5Hash(b []byte) []byte { + hasher := md5.New() + hasher.Write(b) + return hasher.Sum(nil) +} + +func randomBytes(len int) []byte { + b := make([]byte, len) + for i := 0; i < len; i++ { + b[i] = byte(rand.Intn(256)) + } + return b +} + +var externalBindRequest = requestFunc(func(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") + pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name")) + + saslAuth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication") + saslAuth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "EXTERNAL", "SASL Mech")) + saslAuth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "SASL Cred")) + + pkt.AppendChild(saslAuth) + + envelope.AppendChild(pkt) + + return nil +}) + +// ExternalBind performs SASL/EXTERNAL authentication. +// +// Use ldap.DialURL("ldapi://") to connect to the Unix socket before ExternalBind. +// +// See https://tools.ietf.org/html/rfc4422#appendix-A +func (l *Conn) ExternalBind() error { + msgCtx, err := l.doRequest(externalBindRequest) + if err != nil { + return err + } + defer l.finishMessage(msgCtx) + + packet, err := l.readPacket(msgCtx) + if err != nil { + return err + } + + return GetLDAPError(packet) +} + +// NTLMBind performs an NTLMSSP bind leveraging https://github.com/Azure/go-ntlmssp + +// NTLMBindRequest represents an NTLMSSP bind operation +type NTLMBindRequest struct { + // Domain is the AD Domain to authenticate too. If not specified, it will be grabbed from the NTLMSSP Challenge + Domain string + // Username is the name of the Directory object that the client wishes to bind as + Username string + // Password is the credentials to bind with + Password string + // Hash is the hex NTLM hash to bind with. Password or hash must be provided + Hash string + // Controls are optional controls to send with the bind request + Controls []Control +} + +func (req *NTLMBindRequest) appendTo(envelope *ber.Packet) error { + request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") + request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) + request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name")) + + // generate an NTLMSSP Negotiation message for the specified domain (it can be blank) + negMessage, err := ntlmssp.NewNegotiateMessage(req.Domain, "") + if err != nil { + return fmt.Errorf("err creating negmessage: %s", err) + } + + // append the generated NTLMSSP message as a TagEnumerated BER value + auth := ber.Encode(ber.ClassContext, ber.TypePrimitive, ber.TagEnumerated, negMessage, "authentication") + request.AppendChild(auth) + envelope.AppendChild(request) + if len(req.Controls) > 0 { + envelope.AppendChild(encodeControls(req.Controls)) + } + return nil +} + +// NTLMBindResult contains the response from the server +type NTLMBindResult struct { + Controls []Control +} + +// NTLMBind performs an NTLMSSP Bind with the given domain, username and password +func (l *Conn) NTLMBind(domain, username, password string) error { + req := &NTLMBindRequest{ + Domain: domain, + Username: username, + Password: password, + } + _, err := l.NTLMChallengeBind(req) + return err +} + +// NTLMBindWithHash performs an NTLM Bind with an NTLM hash instead of plaintext password (pass-the-hash) +func (l *Conn) NTLMBindWithHash(domain, username, hash string) error { + req := &NTLMBindRequest{ + Domain: domain, + Username: username, + Hash: hash, + } + _, err := l.NTLMChallengeBind(req) + return err +} + +// NTLMChallengeBind performs the NTLMSSP bind operation defined in the given request +func (l *Conn) NTLMChallengeBind(ntlmBindRequest *NTLMBindRequest) (*NTLMBindResult, error) { + if ntlmBindRequest.Password == "" && ntlmBindRequest.Hash == "" { + return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client")) + } + + msgCtx, err := l.doRequest(ntlmBindRequest) + if err != nil { + return nil, err + } + defer l.finishMessage(msgCtx) + packet, err := l.readPacket(msgCtx) + if err != nil { + return nil, err + } + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + if l.Debug { + if err = addLDAPDescriptions(packet); err != nil { + return nil, err + } + ber.PrintPacket(packet) + } + result := &NTLMBindResult{ + Controls: make([]Control, 0), + } + var ntlmsspChallenge []byte + + // now find the NTLM Response Message + if len(packet.Children) == 2 { + if len(packet.Children[1].Children) == 3 { + child := packet.Children[1].Children[1] + ntlmsspChallenge = child.ByteValue + // Check to make sure we got the right message. It will always start with NTLMSSP + if len(ntlmsspChallenge) < 7 || !bytes.Equal(ntlmsspChallenge[:7], []byte("NTLMSSP")) { + return result, GetLDAPError(packet) + } + l.Debug.Printf("%d: found ntlmssp challenge", msgCtx.id) + } + } + if ntlmsspChallenge != nil { + var err error + var responseMessage []byte + // generate a response message to the challenge with the given Username/Password if password is provided + if ntlmBindRequest.Password != "" { + responseMessage, err = ntlmssp.ProcessChallenge(ntlmsspChallenge, ntlmBindRequest.Username, ntlmBindRequest.Password) + } else if ntlmBindRequest.Hash != "" { + responseMessage, err = ntlmssp.ProcessChallengeWithHash(ntlmsspChallenge, ntlmBindRequest.Username, ntlmBindRequest.Hash) + } else { + err = fmt.Errorf("need a password or hash to generate reply") + } + if err != nil { + return result, fmt.Errorf("parsing ntlm-challenge: %s", err) + } + packet = ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) + + request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") + request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) + request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name")) + + // append the challenge response message as a TagEmbeddedPDV BER value + auth := ber.Encode(ber.ClassContext, ber.TypePrimitive, ber.TagEmbeddedPDV, responseMessage, "authentication") + + request.AppendChild(auth) + packet.AppendChild(request) + msgCtx, err = l.sendMessage(packet) + if err != nil { + return nil, fmt.Errorf("send message: %s", err) + } + defer l.finishMessage(msgCtx) + packetResponse, ok := <-msgCtx.responses + if !ok { + return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed")) + } + packet, err = packetResponse.ReadPacket() + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + if err != nil { + return nil, fmt.Errorf("read packet: %s", err) + } + + } + + err = GetLDAPError(packet) + return result, err +} diff --git a/vendor/gopkg.in/ldap.v3/client.go b/vendor/github.com/go-ldap/ldap/v3/client.go similarity index 90% rename from vendor/gopkg.in/ldap.v3/client.go rename to vendor/github.com/go-ldap/ldap/v3/client.go index 619677c7..1fa4ad5a 100644 --- a/vendor/gopkg.in/ldap.v3/client.go +++ b/vendor/github.com/go-ldap/ldap/v3/client.go @@ -10,6 +10,7 @@ type Client interface { Start() StartTLS(*tls.Config) error Close() + IsClosing() bool SetTimeout(time.Duration) Bind(username, password string) error @@ -21,6 +22,7 @@ type Client interface { Del(*DelRequest) error Modify(*ModifyRequest) error ModifyDN(*ModifyDNRequest) error + ModifyWithResult(*ModifyRequest) (*ModifyResult, error) Compare(dn, attribute, value string) (bool, error) PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error) diff --git a/vendor/gopkg.in/ldap.v3/compare.go b/vendor/github.com/go-ldap/ldap/v3/compare.go similarity index 74% rename from vendor/gopkg.in/ldap.v3/compare.go rename to vendor/github.com/go-ldap/ldap/v3/compare.go index 83694d82..cd43e4c5 100644 --- a/vendor/gopkg.in/ldap.v3/compare.go +++ b/vendor/github.com/go-ldap/ldap/v3/compare.go @@ -1,28 +1,9 @@ -// File contains Compare functionality -// -// https://tools.ietf.org/html/rfc4511 -// -// CompareRequest ::= [APPLICATION 14] SEQUENCE { -// entry LDAPDN, -// ava AttributeValueAssertion } -// -// AttributeValueAssertion ::= SEQUENCE { -// attributeDesc AttributeDescription, -// assertionValue AssertionValue } -// -// AttributeDescription ::= LDAPString -// -- Constrained to -// -- [RFC4512] -// -// AttributeValue ::= OCTET STRING -// - package ldap import ( "fmt" - ber "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) // CompareRequest represents an LDAP CompareRequest operation. diff --git a/vendor/gopkg.in/ldap.v3/conn.go b/vendor/github.com/go-ldap/ldap/v3/conn.go similarity index 84% rename from vendor/gopkg.in/ldap.v3/conn.go rename to vendor/github.com/go-ldap/ldap/v3/conn.go index ab9bd4f9..ae5e19af 100644 --- a/vendor/gopkg.in/ldap.v3/conn.go +++ b/vendor/github.com/go-ldap/ldap/v3/conn.go @@ -1,6 +1,7 @@ package ldap import ( + "bufio" "crypto/tls" "errors" "fmt" @@ -11,7 +12,7 @@ import ( "sync/atomic" "time" - ber "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) const ( @@ -112,8 +113,72 @@ var _ Client = &Conn{} // multiple places will probably result in undesired behaviour. var DefaultTimeout = 60 * time.Second +// DialOpt configures DialContext. +type DialOpt func(*DialContext) + +// DialWithDialer updates net.Dialer in DialContext. +func DialWithDialer(d *net.Dialer) DialOpt { + return func(dc *DialContext) { + dc.d = d + } +} + +// DialWithTLSConfig updates tls.Config in DialContext. +func DialWithTLSConfig(tc *tls.Config) DialOpt { + return func(dc *DialContext) { + dc.tc = tc + } +} + +// DialWithTLSDialer is a wrapper for DialWithTLSConfig with the option to +// specify a net.Dialer to for example define a timeout or a custom resolver. +func DialWithTLSDialer(tlsConfig *tls.Config, dialer *net.Dialer) DialOpt { + return func(dc *DialContext) { + dc.tc = tlsConfig + dc.d = dialer + } +} + +// DialContext contains necessary parameters to dial the given ldap URL. +type DialContext struct { + d *net.Dialer + tc *tls.Config +} + +func (dc *DialContext) dial(u *url.URL) (net.Conn, error) { + if u.Scheme == "ldapi" { + if u.Path == "" || u.Path == "/" { + u.Path = "/var/run/slapd/ldapi" + } + return dc.d.Dial("unix", u.Path) + } + + host, port, err := net.SplitHostPort(u.Host) + if err != nil { + // we assume that error is due to missing port + host = u.Host + port = "" + } + + switch u.Scheme { + case "ldap": + if port == "" { + port = DefaultLdapPort + } + return dc.d.Dial("tcp", net.JoinHostPort(host, port)) + case "ldaps": + if port == "" { + port = DefaultLdapsPort + } + return tls.DialWithDialer(dc.d, "tcp", net.JoinHostPort(host, port), dc.tc) + } + + return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme) +} + // Dial connects to the given address on the given network using net.Dial // and then returns a new Conn for the connection. +// @deprecated Use DialURL instead. func Dial(network, addr string) (*Conn, error) { c, err := net.DialTimeout(network, addr, DefaultTimeout) if err != nil { @@ -126,6 +191,7 @@ func Dial(network, addr string) (*Conn, error) { // DialTLS connects to the given address on the given network using tls.Dial // and then returns a new Conn for the connection. +// @deprecated Use DialURL instead. func DialTLS(network, addr string, config *tls.Config) (*Conn, error) { c, err := tls.DialWithDialer(&net.Dialer{Timeout: DefaultTimeout}, network, addr, config) if err != nil { @@ -136,44 +202,31 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) { return conn, nil } -// DialURL connects to the given ldap URL vie TCP using tls.Dial or net.Dial if ldaps:// -// or ldap:// specified as protocol. On success a new Conn for the connection -// is returned. -func DialURL(addr string) (*Conn, error) { - lurl, err := url.Parse(addr) +// DialURL connects to the given ldap URL. +// The following schemas are supported: ldap://, ldaps://, ldapi://. +// On success a new Conn for the connection is returned. +func DialURL(addr string, opts ...DialOpt) (*Conn, error) { + u, err := url.Parse(addr) if err != nil { return nil, NewError(ErrorNetwork, err) } - host, port, err := net.SplitHostPort(lurl.Host) + var dc DialContext + for _, opt := range opts { + opt(&dc) + } + if dc.d == nil { + dc.d = &net.Dialer{Timeout: DefaultTimeout} + } + + c, err := dc.dial(u) if err != nil { - // we asume that error is due to missing port - host = lurl.Host - port = "" + return nil, NewError(ErrorNetwork, err) } - switch lurl.Scheme { - case "ldapi": - if lurl.Path == "" || lurl.Path == "/" { - lurl.Path = "/var/run/slapd/ldapi" - } - return Dial("unix", lurl.Path) - case "ldap": - if port == "" { - port = DefaultLdapPort - } - return Dial("tcp", net.JoinHostPort(host, port)) - case "ldaps": - if port == "" { - port = DefaultLdapsPort - } - tlsConf := &tls.Config{ - ServerName: host, - } - return DialTLS("tcp", net.JoinHostPort(host, port), tlsConf) - } - - return nil, NewError(ErrorNetwork, fmt.Errorf("Unknown scheme '%s'", lurl.Scheme)) + conn := NewConn(c, u.Scheme == "ldaps") + conn.Start() + return conn, nil } // NewConn returns a new Conn using conn for network I/O. @@ -191,9 +244,9 @@ func NewConn(conn net.Conn, isTLS bool) *Conn { // Start initializes goroutines to read responses and process messages func (l *Conn) Start() { + l.wgClose.Add(1) go l.reader() go l.processMessages() - l.wgClose.Add(1) } // IsClosing returns whether or not we're currently closing. @@ -278,7 +331,7 @@ func (l *Conn) StartTLS(config *tls.Config) error { l.Close() return err } - ber.PrintPacket(packet) + l.Debug.PrintPacket(packet) } if err := GetLDAPError(packet); err == nil { @@ -347,7 +400,12 @@ func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) responses: responses, }, } - l.sendProcessMessage(message) + if !l.sendProcessMessage(message) { + if l.IsClosing() { + return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed")) + } + return nil, NewError(ErrorNetwork, errors.New("ldap: could not send message for unknown reason")) + } return message.Context, nil } @@ -451,14 +509,14 @@ func (l *Conn) processMessages() { msgCtx.sendResponse(&PacketResponse{message.Packet, nil}) } else { log.Printf("Received unexpected message %d, %v", message.MessageID, l.IsClosing()) - ber.PrintPacket(message.Packet) + l.Debug.PrintPacket(message.Packet) } case MessageTimeout: // Handle the timeout by closing the channel // All reads will return immediately if msgCtx, ok := l.messageContexts[message.MessageID]; ok { l.Debug.Printf("Receiving message timeout for %d", message.MessageID) - msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")}) + msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}) delete(l.messageContexts, message.MessageID) close(msgCtx.responses) } @@ -484,12 +542,13 @@ func (l *Conn) reader() { } }() + bufConn := bufio.NewReader(l.conn) for { if cleanstop { l.Debug.Printf("reader clean stopping (without closing the connection)") return } - packet, err := ber.ReadPacket(l.conn) + packet, err := ber.ReadPacket(bufConn) if err != nil { // A read error is expected here if we are closing the connection... if !l.IsClosing() { diff --git a/vendor/gopkg.in/ldap.v3/control.go b/vendor/github.com/go-ldap/ldap/v3/control.go similarity index 86% rename from vendor/gopkg.in/ldap.v3/control.go rename to vendor/github.com/go-ldap/ldap/v3/control.go index 3f181912..64fb002a 100644 --- a/vendor/gopkg.in/ldap.v3/control.go +++ b/vendor/github.com/go-ldap/ldap/v3/control.go @@ -4,7 +4,7 @@ import ( "fmt" "strconv" - "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) const ( @@ -18,20 +18,25 @@ const ( ControlTypeVChuPasswordWarning = "2.16.840.1.113730.3.4.5" // ControlTypeManageDsaIT - https://tools.ietf.org/html/rfc3296 ControlTypeManageDsaIT = "2.16.840.1.113730.3.4.2" + // ControlTypeWhoAmI - https://tools.ietf.org/html/rfc4532 + ControlTypeWhoAmI = "1.3.6.1.4.1.4203.1.11.3" // ControlTypeMicrosoftNotification - https://msdn.microsoft.com/en-us/library/aa366983(v=vs.85).aspx ControlTypeMicrosoftNotification = "1.2.840.113556.1.4.528" // ControlTypeMicrosoftShowDeleted - https://msdn.microsoft.com/en-us/library/aa366989(v=vs.85).aspx ControlTypeMicrosoftShowDeleted = "1.2.840.113556.1.4.417" + // ControlTypeMicrosoftServerLinkTTL - https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-adts/f4f523a8-abc0-4b3a-a471-6b2fef135481?redirectedfrom=MSDN + ControlTypeMicrosoftServerLinkTTL = "1.2.840.113556.1.4.2309" ) // ControlTypeMap maps controls to text descriptions var ControlTypeMap = map[string]string{ - ControlTypePaging: "Paging", - ControlTypeBeheraPasswordPolicy: "Password Policy - Behera Draft", - ControlTypeManageDsaIT: "Manage DSA IT", - ControlTypeMicrosoftNotification: "Change Notification - Microsoft", - ControlTypeMicrosoftShowDeleted: "Show Deleted Objects - Microsoft", + ControlTypePaging: "Paging", + ControlTypeBeheraPasswordPolicy: "Password Policy - Behera Draft", + ControlTypeManageDsaIT: "Manage DSA IT", + ControlTypeMicrosoftNotification: "Change Notification - Microsoft", + ControlTypeMicrosoftShowDeleted: "Show Deleted Objects - Microsoft", + ControlTypeMicrosoftServerLinkTTL: "Return TTL-DNs for link values with associated expiry times - Microsoft", } // Control defines an interface controls provide to encode and describe themselves @@ -305,6 +310,35 @@ func NewControlMicrosoftShowDeleted() *ControlMicrosoftShowDeleted { return &ControlMicrosoftShowDeleted{} } +// ControlMicrosoftServerLinkTTL implements the control described in https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-adts/f4f523a8-abc0-4b3a-a471-6b2fef135481?redirectedfrom=MSDN +type ControlMicrosoftServerLinkTTL struct{} + +// GetControlType returns the OID +func (c *ControlMicrosoftServerLinkTTL) GetControlType() string { + return ControlTypeMicrosoftServerLinkTTL +} + +// Encode returns the ber packet representation +func (c *ControlMicrosoftServerLinkTTL) Encode() *ber.Packet { + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeMicrosoftServerLinkTTL, "Control Type ("+ControlTypeMap[ControlTypeMicrosoftServerLinkTTL]+")")) + + return packet +} + +// String returns a human-readable description +func (c *ControlMicrosoftServerLinkTTL) String() string { + return fmt.Sprintf( + "Control Type: %s (%q)", + ControlTypeMap[ControlTypeMicrosoftServerLinkTTL], + ControlTypeMicrosoftServerLinkTTL) +} + +// NewControlMicrosoftServerLinkTTL returns a ControlMicrosoftServerLinkTTL control +func NewControlMicrosoftServerLinkTTL() *ControlMicrosoftServerLinkTTL { + return &ControlMicrosoftServerLinkTTL{} +} + // FindControl returns the first control of the given type in the list, or nil func FindControl(controls []Control, controlType string) Control { for _, c := range controls { @@ -404,33 +438,26 @@ func DecodeControl(packet *ber.Packet) (Control, error) { if child.Tag == 0 { //Warning warningPacket := child.Children[0] - packet, err := ber.DecodePacketErr(warningPacket.Data.Bytes()) + val, err := ber.ParseInt64(warningPacket.Data.Bytes()) if err != nil { return nil, fmt.Errorf("failed to decode data bytes: %s", err) } - val, ok := packet.Value.(int64) - if ok { - if warningPacket.Tag == 0 { - //timeBeforeExpiration - c.Expire = val - warningPacket.Value = c.Expire - } else if warningPacket.Tag == 1 { - //graceAuthNsRemaining - c.Grace = val - warningPacket.Value = c.Grace - } + if warningPacket.Tag == 0 { + //timeBeforeExpiration + c.Expire = val + warningPacket.Value = c.Expire + } else if warningPacket.Tag == 1 { + //graceAuthNsRemaining + c.Grace = val + warningPacket.Value = c.Grace } } else if child.Tag == 1 { // Error - packet, err := ber.DecodePacketErr(child.Data.Bytes()) - if err != nil { - return nil, fmt.Errorf("failed to decode data bytes: %s", err) - } - val, ok := packet.Value.(int8) - if !ok { - // what to do? - val = -1 + bs := child.Data.Bytes() + if len(bs) != 1 || bs[0] > 8 { + return nil, fmt.Errorf("failed to decode data bytes: %s", "invalid PasswordPolicyResponse enum value") } + val := int8(bs[0]) c.Error = val child.Value = c.Error c.ErrorString = BeheraPasswordPolicyErrorMap[c.Error] @@ -456,6 +483,8 @@ func DecodeControl(packet *ber.Packet) (Control, error) { return NewControlMicrosoftNotification(), nil case ControlTypeMicrosoftShowDeleted: return NewControlMicrosoftShowDeleted(), nil + case ControlTypeMicrosoftServerLinkTTL: + return NewControlMicrosoftServerLinkTTL(), nil default: c := new(ControlString) c.ControlType = ControlType diff --git a/vendor/gopkg.in/ldap.v3/debug.go b/vendor/github.com/go-ldap/ldap/v3/debug.go similarity index 85% rename from vendor/gopkg.in/ldap.v3/debug.go rename to vendor/github.com/go-ldap/ldap/v3/debug.go index 61628e3a..d0a8fc15 100644 --- a/vendor/gopkg.in/ldap.v3/debug.go +++ b/vendor/github.com/go-ldap/ldap/v3/debug.go @@ -3,7 +3,7 @@ package ldap import ( "log" - ber "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) // debugging type @@ -25,6 +25,6 @@ func (debug debugging) Printf(format string, args ...interface{}) { // PrintPacket dumps a packet. func (debug debugging) PrintPacket(packet *ber.Packet) { if debug { - ber.PrintPacket(packet) + ber.WritePacket(log.Writer(), packet) } } diff --git a/vendor/gopkg.in/ldap.v3/del.go b/vendor/github.com/go-ldap/ldap/v3/del.go similarity index 91% rename from vendor/gopkg.in/ldap.v3/del.go rename to vendor/github.com/go-ldap/ldap/v3/del.go index 0e7767b2..6e987267 100644 --- a/vendor/gopkg.in/ldap.v3/del.go +++ b/vendor/github.com/go-ldap/ldap/v3/del.go @@ -1,14 +1,9 @@ -// -// https://tools.ietf.org/html/rfc4511 -// -// DelRequest ::= [APPLICATION 10] LDAPDN - package ldap import ( "log" - ber "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) // DelRequest implements an LDAP deletion request diff --git a/vendor/gopkg.in/ldap.v3/dn.go b/vendor/github.com/go-ldap/ldap/v3/dn.go similarity index 73% rename from vendor/gopkg.in/ldap.v3/dn.go rename to vendor/github.com/go-ldap/ldap/v3/dn.go index f89e73a9..2b4cede9 100644 --- a/vendor/gopkg.in/ldap.v3/dn.go +++ b/vendor/github.com/go-ldap/ldap/v3/dn.go @@ -1,44 +1,3 @@ -// File contains DN parsing functionality -// -// https://tools.ietf.org/html/rfc4514 -// -// distinguishedName = [ relativeDistinguishedName -// *( COMMA relativeDistinguishedName ) ] -// relativeDistinguishedName = attributeTypeAndValue -// *( PLUS attributeTypeAndValue ) -// attributeTypeAndValue = attributeType EQUALS attributeValue -// attributeType = descr / numericoid -// attributeValue = string / hexstring -// -// ; The following characters are to be escaped when they appear -// ; in the value to be encoded: ESC, one of , leading -// ; SHARP or SPACE, trailing SPACE, and NULL. -// string = [ ( leadchar / pair ) [ *( stringchar / pair ) -// ( trailchar / pair ) ] ] -// -// leadchar = LUTF1 / UTFMB -// LUTF1 = %x01-1F / %x21 / %x24-2A / %x2D-3A / -// %x3D / %x3F-5B / %x5D-7F -// -// trailchar = TUTF1 / UTFMB -// TUTF1 = %x01-1F / %x21 / %x23-2A / %x2D-3A / -// %x3D / %x3F-5B / %x5D-7F -// -// stringchar = SUTF1 / UTFMB -// SUTF1 = %x01-21 / %x23-2A / %x2D-3A / -// %x3D / %x3F-5B / %x5D-7F -// -// pair = ESC ( ESC / special / hexpair ) -// special = escaped / SPACE / SHARP / EQUALS -// escaped = DQUOTE / PLUS / COMMA / SEMI / LANGLE / RANGLE -// hexstring = SHARP 1*hexpair -// hexpair = HEX HEX -// -// where the productions , , , , -// , , , , , , , , -// , , and are defined in [RFC4512]. -// - package ldap import ( @@ -48,7 +7,7 @@ import ( "fmt" "strings" - "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) // AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514 @@ -69,7 +28,8 @@ type DN struct { RDNs []*RelativeDN } -// ParseDN returns a distinguishedName or an error +// ParseDN returns a distinguishedName or an error. +// The function respects https://tools.ietf.org/html/rfc4514 func ParseDN(str string) (*DN, error) { dn := new(DN) dn.RDNs = make([]*RelativeDN, 0) @@ -245,3 +205,66 @@ func (r *RelativeDN) hasAllAttributes(attrs []*AttributeTypeAndValue) bool { func (a *AttributeTypeAndValue) Equal(other *AttributeTypeAndValue) bool { return strings.EqualFold(a.Type, other.Type) && a.Value == other.Value } + +// Equal returns true if the DNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch). +// Returns true if they have the same number of relative distinguished names +// and corresponding relative distinguished names (by position) are the same. +// Case of the attribute type and value is not significant +func (d *DN) EqualFold(other *DN) bool { + if len(d.RDNs) != len(other.RDNs) { + return false + } + for i := range d.RDNs { + if !d.RDNs[i].EqualFold(other.RDNs[i]) { + return false + } + } + return true +} + +// AncestorOfFold returns true if the other DN consists of at least one RDN followed by all the RDNs of the current DN. +// Case of the attribute type and value is not significant +func (d *DN) AncestorOfFold(other *DN) bool { + if len(d.RDNs) >= len(other.RDNs) { + return false + } + // Take the last `len(d.RDNs)` RDNs from the other DN to compare against + otherRDNs := other.RDNs[len(other.RDNs)-len(d.RDNs):] + for i := range d.RDNs { + if !d.RDNs[i].EqualFold(otherRDNs[i]) { + return false + } + } + return true +} + +// Equal returns true if the RelativeDNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch). +// Case of the attribute type is not significant +func (r *RelativeDN) EqualFold(other *RelativeDN) bool { + if len(r.Attributes) != len(other.Attributes) { + return false + } + return r.hasAllAttributesFold(other.Attributes) && other.hasAllAttributesFold(r.Attributes) +} + +func (r *RelativeDN) hasAllAttributesFold(attrs []*AttributeTypeAndValue) bool { + for _, attr := range attrs { + found := false + for _, myattr := range r.Attributes { + if myattr.EqualFold(attr) { + found = true + break + } + } + if !found { + return false + } + } + return true +} + +// EqualFold returns true if the AttributeTypeAndValue is equivalent to the specified AttributeTypeAndValue +// Case of the attribute type and value is not significant +func (a *AttributeTypeAndValue) EqualFold(other *AttributeTypeAndValue) bool { + return strings.EqualFold(a.Type, other.Type) && strings.EqualFold(a.Value, other.Value) +} diff --git a/vendor/gopkg.in/ldap.v3/doc.go b/vendor/github.com/go-ldap/ldap/v3/doc.go similarity index 100% rename from vendor/gopkg.in/ldap.v3/doc.go rename to vendor/github.com/go-ldap/ldap/v3/doc.go diff --git a/vendor/gopkg.in/ldap.v3/error.go b/vendor/github.com/go-ldap/ldap/v3/error.go similarity index 94% rename from vendor/gopkg.in/ldap.v3/error.go rename to vendor/github.com/go-ldap/ldap/v3/error.go index 53dedb95..3cdb7b31 100644 --- a/vendor/gopkg.in/ldap.v3/error.go +++ b/vendor/github.com/go-ldap/ldap/v3/error.go @@ -3,7 +3,7 @@ package ldap import ( "fmt" - ber "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) // LDAP Result Codes @@ -184,6 +184,8 @@ type Error struct { ResultCode uint16 // MatchedDN is the matchedDN returned if any MatchedDN string + // Packet is the returned packet if any + Packet *ber.Packet } func (e *Error) Error() string { @@ -201,19 +203,23 @@ func GetLDAPError(packet *ber.Packet) error { if len(packet.Children) >= 2 { response := packet.Children[1] if response == nil { - return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty response in packet")} + return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty response in packet"), Packet: packet} } if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) >= 3 { resultCode := uint16(response.Children[0].Value.(int64)) if resultCode == 0 { // No error return nil } - return &Error{ResultCode: resultCode, MatchedDN: response.Children[1].Value.(string), - Err: fmt.Errorf("%s", response.Children[2].Value.(string))} + return &Error{ + ResultCode: resultCode, + MatchedDN: response.Children[1].Value.(string), + Err: fmt.Errorf("%s", response.Children[2].Value.(string)), + Packet: packet, + } } } - return &Error{ResultCode: ErrorNetwork, Err: fmt.Errorf("Invalid packet format")} + return &Error{ResultCode: ErrorNetwork, Err: fmt.Errorf("Invalid packet format"), Packet: packet} } // NewError creates an LDAP error with the given code and underlying error @@ -221,8 +227,8 @@ func NewError(resultCode uint16, err error) error { return &Error{ResultCode: resultCode, Err: err} } -// IsErrorWithCode returns true if the given error is an LDAP error with the given result code -func IsErrorWithCode(err error, desiredResultCode uint16) bool { +// IsErrorAnyOf returns true if the given error is an LDAP error with any one of the given result codes +func IsErrorAnyOf(err error, codes ...uint16) bool { if err == nil { return false } @@ -232,5 +238,16 @@ func IsErrorWithCode(err error, desiredResultCode uint16) bool { return false } - return serverError.ResultCode == desiredResultCode + for _, code := range codes { + if serverError.ResultCode == code { + return true + } + } + + return false +} + +// IsErrorWithCode returns true if the given error is an LDAP error with the given result code +func IsErrorWithCode(err error, desiredResultCode uint16) bool { + return IsErrorAnyOf(err, desiredResultCode) } diff --git a/vendor/gopkg.in/ldap.v3/filter.go b/vendor/github.com/go-ldap/ldap/v3/filter.go similarity index 75% rename from vendor/gopkg.in/ldap.v3/filter.go rename to vendor/github.com/go-ldap/ldap/v3/filter.go index 4cc4207b..73505e79 100644 --- a/vendor/gopkg.in/ldap.v3/filter.go +++ b/vendor/github.com/go-ldap/ldap/v3/filter.go @@ -5,10 +5,12 @@ import ( hexpac "encoding/hex" "errors" "fmt" + "io" "strings" + "unicode" "unicode/utf8" - "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) // Filter choices @@ -69,6 +71,8 @@ var MatchingRuleAssertionMap = map[uint64]string{ MatchingRuleAssertionDNAttributes: "Matching Rule Assertion DN Attributes", } +var _SymbolAny = []byte{'*'} + // CompileFilter converts a string representation of a filter into a BER-encoded packet func CompileFilter(filter string) (*ber.Packet, error) { if len(filter) == 0 || filter[0] != '(' { @@ -88,74 +92,75 @@ func CompileFilter(filter string) (*ber.Packet, error) { } // DecompileFilter converts a packet representation of a filter into a string representation -func DecompileFilter(packet *ber.Packet) (ret string, err error) { +func DecompileFilter(packet *ber.Packet) (_ string, err error) { defer func() { if r := recover(); r != nil { err = NewError(ErrorFilterDecompile, errors.New("ldap: error decompiling filter")) } }() - ret = "(" - err = nil + + buf := bytes.NewBuffer(nil) + buf.WriteByte('(') childStr := "" switch packet.Tag { case FilterAnd: - ret += "&" + buf.WriteByte('&') for _, child := range packet.Children { childStr, err = DecompileFilter(child) if err != nil { return } - ret += childStr + buf.WriteString(childStr) } case FilterOr: - ret += "|" + buf.WriteByte('|') for _, child := range packet.Children { childStr, err = DecompileFilter(child) if err != nil { return } - ret += childStr + buf.WriteString(childStr) } case FilterNot: - ret += "!" + buf.WriteByte('!') childStr, err = DecompileFilter(packet.Children[0]) if err != nil { return } - ret += childStr + buf.WriteString(childStr) case FilterSubstrings: - ret += ber.DecodeString(packet.Children[0].Data.Bytes()) - ret += "=" + buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes())) + buf.WriteByte('=') for i, child := range packet.Children[1].Children { if i == 0 && child.Tag != FilterSubstringsInitial { - ret += "*" + buf.Write(_SymbolAny) } - ret += EscapeFilter(ber.DecodeString(child.Data.Bytes())) + buf.WriteString(EscapeFilter(ber.DecodeString(child.Data.Bytes()))) if child.Tag != FilterSubstringsFinal { - ret += "*" + buf.Write(_SymbolAny) } } case FilterEqualityMatch: - ret += ber.DecodeString(packet.Children[0].Data.Bytes()) - ret += "=" - ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())) + buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes())) + buf.WriteByte('=') + buf.WriteString(EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))) case FilterGreaterOrEqual: - ret += ber.DecodeString(packet.Children[0].Data.Bytes()) - ret += ">=" - ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())) + buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes())) + buf.WriteString(">=") + buf.WriteString(EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))) case FilterLessOrEqual: - ret += ber.DecodeString(packet.Children[0].Data.Bytes()) - ret += "<=" - ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())) + buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes())) + buf.WriteString("<=") + buf.WriteString(EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))) case FilterPresent: - ret += ber.DecodeString(packet.Data.Bytes()) - ret += "=*" + buf.WriteString(ber.DecodeString(packet.Data.Bytes())) + buf.WriteString("=*") case FilterApproxMatch: - ret += ber.DecodeString(packet.Children[0].Data.Bytes()) - ret += "~=" - ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())) + buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes())) + buf.WriteString("~=") + buf.WriteString(EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))) case FilterExtensibleMatch: attr := "" dnAttributes := false @@ -176,21 +181,22 @@ func DecompileFilter(packet *ber.Packet) (ret string, err error) { } if len(attr) > 0 { - ret += attr + buf.WriteString(attr) } if dnAttributes { - ret += ":dn" + buf.WriteString(":dn") } if len(matchingRule) > 0 { - ret += ":" - ret += matchingRule + buf.WriteString(":") + buf.WriteString(matchingRule) } - ret += ":=" - ret += EscapeFilter(value) + buf.WriteString(":=") + buf.WriteString(EscapeFilter(value)) } - ret += ")" - return + buf.WriteByte(')') + + return buf.String(), nil } func compileFilterSet(filter string, pos int, parent *ber.Packet) (int, error) { @@ -253,11 +259,10 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { ) state := stateReadingAttr - - attribute := "" + attribute := bytes.NewBuffer(nil) extensibleDNAttributes := false - extensibleMatchingRule := "" - condition := "" + extensibleMatchingRule := bytes.NewBuffer(nil) + condition := bytes.NewBuffer(nil) for newPos < len(filter) { remainingFilter := filter[newPos:] @@ -324,7 +329,7 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { // Still reading the attribute name default: - attribute += fmt.Sprintf("%c", currentRune) + attribute.WriteRune(currentRune) newPos += currentWidth } @@ -338,13 +343,13 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { // Still reading the matching rule oid default: - extensibleMatchingRule += fmt.Sprintf("%c", currentRune) + extensibleMatchingRule.WriteRune(currentRune) newPos += currentWidth } case stateReadingCondition: // append to the condition - condition += fmt.Sprintf("%c", currentRune) + condition.WriteRune(currentRune) newPos += currentWidth } } @@ -368,17 +373,17 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { // } // Include the matching rule oid, if specified - if len(extensibleMatchingRule) > 0 { - packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionMatchingRule, extensibleMatchingRule, MatchingRuleAssertionMap[MatchingRuleAssertionMatchingRule])) + if extensibleMatchingRule.Len() > 0 { + packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionMatchingRule, extensibleMatchingRule.String(), MatchingRuleAssertionMap[MatchingRuleAssertionMatchingRule])) } // Include the attribute, if specified - if len(attribute) > 0 { - packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionType, attribute, MatchingRuleAssertionMap[MatchingRuleAssertionType])) + if attribute.Len() > 0 { + packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionType, attribute.String(), MatchingRuleAssertionMap[MatchingRuleAssertionType])) } // Add the value (only required child) - encodedString, encodeErr := escapedStringToEncodedBytes(condition) + encodedString, encodeErr := decodeEscapedSymbols(condition.Bytes()) if encodeErr != nil { return packet, newPos, encodeErr } @@ -389,16 +394,16 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { packet.AppendChild(ber.NewBoolean(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionDNAttributes, extensibleDNAttributes, MatchingRuleAssertionMap[MatchingRuleAssertionDNAttributes])) } - case packet.Tag == FilterEqualityMatch && condition == "*": - packet = ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterPresent, attribute, FilterMap[FilterPresent]) - case packet.Tag == FilterEqualityMatch && strings.Contains(condition, "*"): - packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute")) + case packet.Tag == FilterEqualityMatch && bytes.Equal(condition.Bytes(), _SymbolAny): + packet = ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterPresent, attribute.String(), FilterMap[FilterPresent]) + case packet.Tag == FilterEqualityMatch && bytes.Index(condition.Bytes(), _SymbolAny) > -1: + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute.String(), "Attribute")) packet.Tag = FilterSubstrings packet.Description = FilterMap[uint64(packet.Tag)] seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings") - parts := strings.Split(condition, "*") + parts := bytes.Split(condition.Bytes(), _SymbolAny) for i, part := range parts { - if part == "" { + if len(part) == 0 { continue } var tag ber.Tag @@ -410,7 +415,7 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { default: tag = FilterSubstringsAny } - encodedString, encodeErr := escapedStringToEncodedBytes(part) + encodedString, encodeErr := decodeEscapedSymbols(part) if encodeErr != nil { return packet, newPos, encodeErr } @@ -418,11 +423,11 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { } packet.AppendChild(seq) default: - encodedString, encodeErr := escapedStringToEncodedBytes(condition) + encodedString, encodeErr := decodeEscapedSymbols(condition.Bytes()) if encodeErr != nil { return packet, newPos, encodeErr } - packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute")) + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute.String(), "Attribute")) packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, encodedString, "Condition")) } @@ -432,34 +437,51 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { } // Convert from "ABC\xx\xx\xx" form to literal bytes for transport -func escapedStringToEncodedBytes(escapedString string) (string, error) { - var buffer bytes.Buffer - i := 0 - for i < len(escapedString) { - currentRune, currentWidth := utf8.DecodeRuneInString(escapedString[i:]) - if currentRune == utf8.RuneError { - return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", i)) +func decodeEscapedSymbols(src []byte) (string, error) { + + var ( + buffer bytes.Buffer + offset int + reader = bytes.NewReader(src) + byteHex []byte + byteVal []byte + ) + + for { + runeVal, runeSize, err := reader.ReadRune() + if err == io.EOF { + return buffer.String(), nil + } else if err != nil { + return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: failed to read filter: %v", err)) + } else if runeVal == unicode.ReplacementChar { + return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", offset)) } - // Check for escaped hex characters and convert them to their literal value for transport. - if currentRune == '\\' { + if runeVal == '\\' { // http://tools.ietf.org/search/rfc4515 // \ (%x5C) is not a valid character unless it is followed by two HEX characters due to not // being a member of UTF1SUBSET. - if i+2 > len(escapedString) { - return "", NewError(ErrorFilterCompile, errors.New("ldap: missing characters for escape in filter")) + if byteHex == nil { + byteHex = make([]byte, 2) + byteVal = make([]byte, 1) } - escByte, decodeErr := hexpac.DecodeString(escapedString[i+1 : i+3]) - if decodeErr != nil { - return "", NewError(ErrorFilterCompile, errors.New("ldap: invalid characters for escape in filter")) + + if _, err := io.ReadFull(reader, byteHex); err != nil { + if err == io.ErrUnexpectedEOF { + return "", NewError(ErrorFilterCompile, errors.New("ldap: missing characters for escape in filter")) + } + return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: invalid characters for escape in filter: %v", err)) } - buffer.WriteByte(escByte[0]) - i += 2 // +1 from end of loop, so 3 total for \xx. + + if _, err := hexpac.Decode(byteVal, byteHex); err != nil { + return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: invalid characters for escape in filter: %v", err)) + } + + buffer.Write(byteVal) } else { - buffer.WriteRune(currentRune) + buffer.WriteRune(runeVal) } - i += currentWidth + offset += runeSize } - return buffer.String(), nil } diff --git a/vendor/gopkg.in/ldap.v3/ldap.go b/vendor/github.com/go-ldap/ldap/v3/ldap.go similarity index 91% rename from vendor/gopkg.in/ldap.v3/ldap.go rename to vendor/github.com/go-ldap/ldap/v3/ldap.go index 5b694bf3..7ae6dfe2 100644 --- a/vendor/gopkg.in/ldap.v3/ldap.go +++ b/vendor/github.com/go-ldap/ldap/v3/ldap.go @@ -5,7 +5,7 @@ import ( "io/ioutil" "os" - ber "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) // LDAP Application Codes @@ -223,32 +223,26 @@ func addControlDescriptions(packet *ber.Packet) error { if child.Tag == 0 { //Warning warningPacket := child.Children[0] - packet, err := ber.DecodePacketErr(warningPacket.Data.Bytes()) + val, err := ber.ParseInt64(warningPacket.Data.Bytes()) if err != nil { return fmt.Errorf("failed to decode data bytes: %s", err) } - val, ok := packet.Value.(int64) - if ok { - if warningPacket.Tag == 0 { - //timeBeforeExpiration - value.Description += " (TimeBeforeExpiration)" - warningPacket.Value = val - } else if warningPacket.Tag == 1 { - //graceAuthNsRemaining - value.Description += " (GraceAuthNsRemaining)" - warningPacket.Value = val - } + if warningPacket.Tag == 0 { + //timeBeforeExpiration + value.Description += " (TimeBeforeExpiration)" + warningPacket.Value = val + } else if warningPacket.Tag == 1 { + //graceAuthNsRemaining + value.Description += " (GraceAuthNsRemaining)" + warningPacket.Value = val } } else if child.Tag == 1 { // Error - packet, err := ber.DecodePacketErr(child.Data.Bytes()) - if err != nil { - return fmt.Errorf("failed to decode data bytes: %s", err) - } - val, ok := packet.Value.(int8) - if !ok { - val = -1 + bs := child.Data.Bytes() + if len(bs) != 1 || bs[0] > 8 { + return fmt.Errorf("failed to decode data bytes: %s", "invalid PasswordPolicyResponse enum value") } + val := int8(bs[0]) child.Description = "Error" child.Value = val } @@ -269,13 +263,18 @@ func addRequestDescriptions(packet *ber.Packet) error { } func addDefaultLDAPResponseDescriptions(packet *ber.Packet) error { - err := GetLDAPError(packet) - if err == nil { - return nil + resultCode := uint16(LDAPResultSuccess) + matchedDN := "" + description := "Success" + if err := GetLDAPError(packet); err != nil { + resultCode = err.(*Error).ResultCode + matchedDN = err.(*Error).MatchedDN + description = "Error Message" } - packet.Children[1].Children[0].Description = "Result Code (" + LDAPResultCodeMap[err.(*Error).ResultCode] + ")" - packet.Children[1].Children[1].Description = "Matched DN (" + err.(*Error).MatchedDN + ")" - packet.Children[1].Children[2].Description = "Error Message" + + packet.Children[1].Children[0].Description = "Result Code (" + LDAPResultCodeMap[resultCode] + ")" + packet.Children[1].Children[1].Description = "Matched DN (" + matchedDN + ")" + packet.Children[1].Children[2].Description = description if len(packet.Children[1].Children) > 3 { packet.Children[1].Children[3].Description = "Referral" } diff --git a/vendor/gopkg.in/ldap.v3/moddn.go b/vendor/github.com/go-ldap/ldap/v3/moddn.go similarity index 68% rename from vendor/gopkg.in/ldap.v3/moddn.go rename to vendor/github.com/go-ldap/ldap/v3/moddn.go index 889a82ac..71cdcd0b 100644 --- a/vendor/gopkg.in/ldap.v3/moddn.go +++ b/vendor/github.com/go-ldap/ldap/v3/moddn.go @@ -1,19 +1,9 @@ -// Package ldap - moddn.go contains ModifyDN functionality -// -// https://tools.ietf.org/html/rfc4511 -// ModifyDNRequest ::= [APPLICATION 12] SEQUENCE { -// entry LDAPDN, -// newrdn RelativeLDAPDN, -// deleteoldrdn BOOLEAN, -// newSuperior [0] LDAPDN OPTIONAL } -// -// package ldap import ( "log" - ber "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) // ModifyDNRequest holds the request to modify a DN @@ -22,6 +12,8 @@ type ModifyDNRequest struct { NewRDN string DeleteOldRDN bool NewSuperior string + // Controls hold optional controls to send with the request + Controls []Control } // NewModifyDNRequest creates a new request which can be passed to ModifyDN(). @@ -45,16 +37,39 @@ func NewModifyDNRequest(dn string, rdn string, delOld bool, newSup string) *Modi } } +// NewModifyDNWithControlsRequest creates a new request which can be passed to ModifyDN() +// and also allows setting LDAP request controls. +// +// Refer NewModifyDNRequest for other parameters +func NewModifyDNWithControlsRequest(dn string, rdn string, delOld bool, + newSup string, controls []Control) *ModifyDNRequest { + return &ModifyDNRequest{ + DN: dn, + NewRDN: rdn, + DeleteOldRDN: delOld, + NewSuperior: newSup, + Controls: controls, + } +} + func (req *ModifyDNRequest) appendTo(envelope *ber.Packet) error { pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyDNRequest, nil, "Modify DN Request") pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN")) pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.NewRDN, "New RDN")) - pkt.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, req.DeleteOldRDN, "Delete old RDN")) + if req.DeleteOldRDN { + buf := []byte{0xff} + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, string(buf), "Delete old RDN")) + } else { + pkt.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, req.DeleteOldRDN, "Delete old RDN")) + } if req.NewSuperior != "" { pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.NewSuperior, "New Superior")) } envelope.AppendChild(pkt) + if len(req.Controls) > 0 { + envelope.AppendChild(encodeControls(req.Controls)) + } return nil } diff --git a/vendor/gopkg.in/ldap.v3/modify.go b/vendor/github.com/go-ldap/ldap/v3/modify.go similarity index 71% rename from vendor/gopkg.in/ldap.v3/modify.go rename to vendor/github.com/go-ldap/ldap/v3/modify.go index 7e09b507..1821413d 100644 --- a/vendor/gopkg.in/ldap.v3/modify.go +++ b/vendor/github.com/go-ldap/ldap/v3/modify.go @@ -1,41 +1,18 @@ -// File contains Modify functionality -// -// https://tools.ietf.org/html/rfc4511 -// -// ModifyRequest ::= [APPLICATION 6] SEQUENCE { -// object LDAPDN, -// changes SEQUENCE OF change SEQUENCE { -// operation ENUMERATED { -// add (0), -// delete (1), -// replace (2), -// ... }, -// modification PartialAttribute } } -// -// PartialAttribute ::= SEQUENCE { -// type AttributeDescription, -// vals SET OF value AttributeValue } -// -// AttributeDescription ::= LDAPString -// -- Constrained to -// -- [RFC4512] -// -// AttributeValue ::= OCTET STRING -// - package ldap import ( + "errors" "log" - ber "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) // Change operation choices const ( - AddAttribute = 0 - DeleteAttribute = 1 - ReplaceAttribute = 2 + AddAttribute = 0 + DeleteAttribute = 1 + ReplaceAttribute = 2 + IncrementAttribute = 3 // (https://tools.ietf.org/html/rfc4525) ) // PartialAttribute for a ModifyRequest as defined in https://tools.ietf.org/html/rfc4511 @@ -97,6 +74,11 @@ func (req *ModifyRequest) Replace(attrType string, attrVals []string) { req.appendChange(ReplaceAttribute, attrType, attrVals) } +// Increment appends the given attribute to the list of changes to be made +func (req *ModifyRequest) Increment(attrType string, attrVal string) { + req.appendChange(IncrementAttribute, attrType, []string{attrVal}) +} + func (req *ModifyRequest) appendChange(operation uint, attrType string, attrVals []string) { req.Changes = append(req.Changes, Change{operation, PartialAttribute{Type: attrType, Vals: attrVals}}) } @@ -149,3 +131,47 @@ func (l *Conn) Modify(modifyRequest *ModifyRequest) error { } return nil } + +// ModifyResult holds the server's response to a modify request +type ModifyResult struct { + // Controls are the returned controls + Controls []Control +} + +// ModifyWithResult performs the ModifyRequest and returns the result +func (l *Conn) ModifyWithResult(modifyRequest *ModifyRequest) (*ModifyResult, error) { + msgCtx, err := l.doRequest(modifyRequest) + if err != nil { + return nil, err + } + defer l.finishMessage(msgCtx) + + result := &ModifyResult{ + Controls: make([]Control, 0), + } + + l.Debug.Printf("%d: waiting for response", msgCtx.id) + packet, err := l.readPacket(msgCtx) + if err != nil { + return nil, err + } + + switch packet.Children[1].Tag { + case ApplicationModifyResponse: + err := GetLDAPError(packet) + if err != nil { + return nil, err + } + if len(packet.Children) == 3 { + for _, child := range packet.Children[2].Children { + decodedChild, err := DecodeControl(child) + if err != nil { + return nil, errors.New("failed to decode child control: " + err.Error()) + } + result.Controls = append(result.Controls, decodedChild) + } + } + } + l.Debug.Printf("%d: returning", msgCtx.id) + return result, nil +} diff --git a/vendor/gopkg.in/ldap.v3/passwdmodify.go b/vendor/github.com/go-ldap/ldap/v3/passwdmodify.go similarity index 96% rename from vendor/gopkg.in/ldap.v3/passwdmodify.go rename to vendor/github.com/go-ldap/ldap/v3/passwdmodify.go index bfaceff3..62a11084 100644 --- a/vendor/gopkg.in/ldap.v3/passwdmodify.go +++ b/vendor/github.com/go-ldap/ldap/v3/passwdmodify.go @@ -1,14 +1,9 @@ -// This file contains the password modify extended operation as specified in rfc 3062 -// -// https://tools.ietf.org/html/rfc3062 -// - package ldap import ( "fmt" - ber "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) const ( @@ -61,7 +56,7 @@ func (req *PasswordModifyRequest) appendTo(envelope *ber.Packet) error { // NewPasswordModifyRequest creates a new PasswordModifyRequest // -// According to the RFC 3602: +// According to the RFC 3602 (https://tools.ietf.org/html/rfc3062): // userIdentity is a string representing the user associated with the request. // This string may or may not be an LDAPDN (RFC 2253). // If userIdentity is empty then the operation will act on the user associated diff --git a/vendor/gopkg.in/ldap.v3/request.go b/vendor/github.com/go-ldap/ldap/v3/request.go similarity index 85% rename from vendor/gopkg.in/ldap.v3/request.go rename to vendor/github.com/go-ldap/ldap/v3/request.go index 814e29fe..4ea31e90 100644 --- a/vendor/gopkg.in/ldap.v3/request.go +++ b/vendor/github.com/go-ldap/ldap/v3/request.go @@ -3,12 +3,13 @@ package ldap import ( "errors" - ber "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) var ( errRespChanClosed = errors.New("ldap: response channel closed") errCouldNotRetMsg = errors.New("ldap: could not retrieve message") + ErrNilConnection = errors.New("ldap: conn is nil, expected net.Conn") ) type request interface { @@ -22,6 +23,10 @@ func (f requestFunc) appendTo(p *ber.Packet) error { } func (l *Conn) doRequest(req request) (*messageContext, error) { + if l == nil || l.conn == nil { + return nil, ErrNilConnection + } + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) if err := req.appendTo(packet); err != nil { @@ -29,7 +34,7 @@ func (l *Conn) doRequest(req request) (*messageContext, error) { } if l.Debug { - ber.PrintPacket(packet) + l.Debug.PrintPacket(packet) } msgCtx, err := l.sendMessage(packet) @@ -60,7 +65,7 @@ func (l *Conn) readPacket(msgCtx *messageContext) (*ber.Packet, error) { if err = addLDAPDescriptions(packet); err != nil { return nil, err } - ber.PrintPacket(packet) + l.Debug.PrintPacket(packet) } return packet, nil } diff --git a/vendor/gopkg.in/ldap.v3/search.go b/vendor/github.com/go-ldap/ldap/v3/search.go similarity index 78% rename from vendor/gopkg.in/ldap.v3/search.go rename to vendor/github.com/go-ldap/ldap/v3/search.go index 51eb7dc6..915e4203 100644 --- a/vendor/gopkg.in/ldap.v3/search.go +++ b/vendor/github.com/go-ldap/ldap/v3/search.go @@ -1,58 +1,3 @@ -// File contains Search functionality -// -// https://tools.ietf.org/html/rfc4511 -// -// SearchRequest ::= [APPLICATION 3] SEQUENCE { -// baseObject LDAPDN, -// scope ENUMERATED { -// baseObject (0), -// singleLevel (1), -// wholeSubtree (2), -// ... }, -// derefAliases ENUMERATED { -// neverDerefAliases (0), -// derefInSearching (1), -// derefFindingBaseObj (2), -// derefAlways (3) }, -// sizeLimit INTEGER (0 .. maxInt), -// timeLimit INTEGER (0 .. maxInt), -// typesOnly BOOLEAN, -// filter Filter, -// attributes AttributeSelection } -// -// AttributeSelection ::= SEQUENCE OF selector LDAPString -// -- The LDAPString is constrained to -// -- in Section 4.5.1.8 -// -// Filter ::= CHOICE { -// and [0] SET SIZE (1..MAX) OF filter Filter, -// or [1] SET SIZE (1..MAX) OF filter Filter, -// not [2] Filter, -// equalityMatch [3] AttributeValueAssertion, -// substrings [4] SubstringFilter, -// greaterOrEqual [5] AttributeValueAssertion, -// lessOrEqual [6] AttributeValueAssertion, -// present [7] AttributeDescription, -// approxMatch [8] AttributeValueAssertion, -// extensibleMatch [9] MatchingRuleAssertion, -// ... } -// -// SubstringFilter ::= SEQUENCE { -// type AttributeDescription, -// substrings SEQUENCE SIZE (1..MAX) OF substring CHOICE { -// initial [0] AssertionValue, -- can occur at most once -// any [1] AssertionValue, -// final [2] AssertionValue } -- can occur at most once -// } -// -// MatchingRuleAssertion ::= SEQUENCE { -// matchingRule [1] MatchingRuleId OPTIONAL, -// type [2] AttributeDescription OPTIONAL, -// matchValue [3] AssertionValue, -// dnAttributes [4] BOOLEAN DEFAULT FALSE } -// -// - package ldap import ( @@ -61,7 +6,7 @@ import ( "sort" "strings" - ber "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) // scope choices @@ -132,6 +77,17 @@ func (e *Entry) GetAttributeValues(attribute string) []string { return []string{} } +// GetEqualFoldAttributeValues returns the values for the named attribute, or an +// empty list. Attribute matching is done with strings.EqualFold. +func (e *Entry) GetEqualFoldAttributeValues(attribute string) []string { + for _, attr := range e.Attributes { + if strings.EqualFold(attribute, attr.Name) { + return attr.Values + } + } + return []string{} +} + // GetRawAttributeValues returns the byte values for the named attribute, or an empty list func (e *Entry) GetRawAttributeValues(attribute string) [][]byte { for _, attr := range e.Attributes { @@ -142,6 +98,16 @@ func (e *Entry) GetRawAttributeValues(attribute string) [][]byte { return [][]byte{} } +// GetEqualFoldRawAttributeValues returns the byte values for the named attribute, or an empty list +func (e *Entry) GetEqualFoldRawAttributeValues(attribute string) [][]byte { + for _, attr := range e.Attributes { + if strings.EqualFold(attr.Name, attribute) { + return attr.ByteValues + } + } + return [][]byte{} +} + // GetAttributeValue returns the first value for the named attribute, or "" func (e *Entry) GetAttributeValue(attribute string) string { values := e.GetAttributeValues(attribute) @@ -151,6 +117,16 @@ func (e *Entry) GetAttributeValue(attribute string) string { return values[0] } +// GetEqualFoldAttributeValue returns the first value for the named attribute, or "". +// Attribute comparison is done with strings.EqualFold. +func (e *Entry) GetEqualFoldAttributeValue(attribute string) string { + values := e.GetEqualFoldAttributeValues(attribute) + if len(values) == 0 { + return "" + } + return values[0] +} + // GetRawAttributeValue returns the first value for the named attribute, or an empty slice func (e *Entry) GetRawAttributeValue(attribute string) []byte { values := e.GetRawAttributeValues(attribute) @@ -160,6 +136,15 @@ func (e *Entry) GetRawAttributeValue(attribute string) []byte { return values[0] } +// GetEqualFoldRawAttributeValue returns the first value for the named attribute, or an empty slice +func (e *Entry) GetEqualFoldRawAttributeValue(attribute string) []byte { + values := e.GetEqualFoldRawAttributeValues(attribute) + if len(values) == 0 { + return []byte{} + } + return values[0] +} + // Print outputs a human-readable description func (e *Entry) Print() { fmt.Printf("DN: %s\n", e.DN) @@ -386,33 +371,26 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { for { packet, err := l.readPacket(msgCtx) if err != nil { - return nil, err + return result, err } switch packet.Children[1].Tag { case 4: - entry := new(Entry) - entry.DN = packet.Children[1].Children[0].Value.(string) - for _, child := range packet.Children[1].Children[1].Children { - attr := new(EntryAttribute) - attr.Name = child.Children[0].Value.(string) - for _, value := range child.Children[1].Children { - attr.Values = append(attr.Values, value.Value.(string)) - attr.ByteValues = append(attr.ByteValues, value.ByteValue) - } - entry.Attributes = append(entry.Attributes, attr) + entry := &Entry{ + DN: packet.Children[1].Children[0].Value.(string), + Attributes: unpackAttributes(packet.Children[1].Children[1].Children), } result.Entries = append(result.Entries, entry) case 5: err := GetLDAPError(packet) if err != nil { - return nil, err + return result, err } if len(packet.Children) == 3 { for _, child := range packet.Children[2].Children { decodedChild, err := DecodeControl(child) if err != nil { - return nil, fmt.Errorf("failed to decode child control: %s", err) + return result, fmt.Errorf("failed to decode child control: %s", err) } result.Controls = append(result.Controls, decodedChild) } @@ -423,3 +401,27 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { } } } + +// unpackAttributes will extract all given LDAP attributes and it's values +// from the ber.Packet +func unpackAttributes(children []*ber.Packet) []*EntryAttribute { + entries := make([]*EntryAttribute, len(children)) + for i, child := range children { + length := len(child.Children[1].Children) + entry := &EntryAttribute{ + Name: child.Children[0].Value.(string), + // pre-allocate the slice since we can determine + // the number of attributes at this point + Values: make([]string, length), + ByteValues: make([][]byte, length), + } + + for i, value := range child.Children[1].Children { + entry.ByteValues[i] = value.ByteValue + entry.Values[i] = value.Value.(string) + } + entries[i] = entry + } + + return entries +} diff --git a/vendor/github.com/go-ldap/ldap/v3/unbind.go b/vendor/github.com/go-ldap/ldap/v3/unbind.go new file mode 100644 index 00000000..6c411cd1 --- /dev/null +++ b/vendor/github.com/go-ldap/ldap/v3/unbind.go @@ -0,0 +1,37 @@ +package ldap + +import ( + "errors" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +var ErrConnUnbound = NewError(ErrorNetwork, errors.New("ldap: connection is closed")) + +type unbindRequest struct{} + +func (unbindRequest) appendTo(envelope *ber.Packet) error { + envelope.AppendChild(ber.Encode(ber.ClassApplication, ber.TypePrimitive, ApplicationUnbindRequest, nil, ApplicationMap[ApplicationUnbindRequest])) + return nil +} + +// Unbind will perform an unbind request. The Unbind operation +// should be thought of as the "quit" operation. +// See https://datatracker.ietf.org/doc/html/rfc4511#section-4.3 +func (l *Conn) Unbind() error { + if l.IsClosing() { + return ErrConnUnbound + } + + _, err := l.doRequest(unbindRequest{}) + if err != nil { + return err + } + + // Sending an unbindRequest will make the connection unusable. + // Pending requests will fail with: + // LDAP Result Code 200 "Network Error": ldap: response channel closed + l.Close() + + return nil +} diff --git a/vendor/github.com/go-ldap/ldap/v3/whoami.go b/vendor/github.com/go-ldap/ldap/v3/whoami.go new file mode 100644 index 00000000..10c523d0 --- /dev/null +++ b/vendor/github.com/go-ldap/ldap/v3/whoami.go @@ -0,0 +1,91 @@ +package ldap + +// This file contains the "Who Am I?" extended operation as specified in rfc 4532 +// +// https://tools.ietf.org/html/rfc4532 + +import ( + "errors" + "fmt" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +type whoAmIRequest bool + +// WhoAmIResult is returned by the WhoAmI() call +type WhoAmIResult struct { + AuthzID string +} + +func (r whoAmIRequest) encode() (*ber.Packet, error) { + request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Who Am I? Extended Operation") + request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, ControlTypeWhoAmI, "Extended Request Name: Who Am I? OID")) + return request, nil +} + +// WhoAmI returns the authzId the server thinks we are, you may pass controls +// like a Proxied Authorization control +func (l *Conn) WhoAmI(controls []Control) (*WhoAmIResult, error) { + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) + req := whoAmIRequest(true) + encodedWhoAmIRequest, err := req.encode() + if err != nil { + return nil, err + } + packet.AppendChild(encodedWhoAmIRequest) + + if len(controls) != 0 { + packet.AppendChild(encodeControls(controls)) + } + + l.Debug.PrintPacket(packet) + + msgCtx, err := l.sendMessage(packet) + if err != nil { + return nil, err + } + defer l.finishMessage(msgCtx) + + result := &WhoAmIResult{} + + l.Debug.Printf("%d: waiting for response", msgCtx.id) + packetResponse, ok := <-msgCtx.responses + if !ok { + return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed")) + } + packet, err = packetResponse.ReadPacket() + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + if err != nil { + return nil, err + } + + if packet == nil { + return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve message")) + } + + if l.Debug { + if err := addLDAPDescriptions(packet); err != nil { + return nil, err + } + ber.PrintPacket(packet) + } + + if packet.Children[1].Tag == ApplicationExtendedResponse { + if err := GetLDAPError(packet); err != nil { + return nil, err + } + } else { + return nil, NewError(ErrorUnexpectedResponse, fmt.Errorf("Unexpected Response: %d", packet.Children[1].Tag)) + } + + extendedResponse := packet.Children[1] + for _, child := range extendedResponse.Children { + if child.Tag == 11 { + result.AuthzID = ber.DecodeString(child.Data.Bytes()) + } + } + + return result, nil +} diff --git a/vendor/github.com/go-sql-driver/mysql/.travis.yml b/vendor/github.com/go-sql-driver/mysql/.travis.yml deleted file mode 100644 index 56fcf25f..00000000 --- a/vendor/github.com/go-sql-driver/mysql/.travis.yml +++ /dev/null @@ -1,129 +0,0 @@ -sudo: false -language: go -go: - - 1.10.x - - 1.11.x - - 1.12.x - - 1.13.x - - master - -before_install: - - go get golang.org/x/tools/cmd/cover - - go get github.com/mattn/goveralls - -before_script: - - echo -e "[server]\ninnodb_log_file_size=256MB\ninnodb_buffer_pool_size=512MB\nmax_allowed_packet=16MB" | sudo tee -a /etc/mysql/my.cnf - - sudo service mysql restart - - .travis/wait_mysql.sh - - mysql -e 'create database gotest;' - -matrix: - include: - - env: DB=MYSQL8 - sudo: required - dist: trusty - go: 1.10.x - services: - - docker - before_install: - - go get golang.org/x/tools/cmd/cover - - go get github.com/mattn/goveralls - - docker pull mysql:8.0 - - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret - mysql:8.0 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 - - cp .travis/docker.cnf ~/.my.cnf - - .travis/wait_mysql.sh - before_script: - - export MYSQL_TEST_USER=gotest - - export MYSQL_TEST_PASS=secret - - export MYSQL_TEST_ADDR=127.0.0.1:3307 - - export MYSQL_TEST_CONCURRENT=1 - - - env: DB=MYSQL57 - sudo: required - dist: trusty - go: 1.10.x - services: - - docker - before_install: - - go get golang.org/x/tools/cmd/cover - - go get github.com/mattn/goveralls - - docker pull mysql:5.7 - - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret - mysql:5.7 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 - - cp .travis/docker.cnf ~/.my.cnf - - .travis/wait_mysql.sh - before_script: - - export MYSQL_TEST_USER=gotest - - export MYSQL_TEST_PASS=secret - - export MYSQL_TEST_ADDR=127.0.0.1:3307 - - export MYSQL_TEST_CONCURRENT=1 - - - env: DB=MARIA55 - sudo: required - dist: trusty - go: 1.10.x - services: - - docker - before_install: - - go get golang.org/x/tools/cmd/cover - - go get github.com/mattn/goveralls - - docker pull mariadb:5.5 - - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret - mariadb:5.5 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 - - cp .travis/docker.cnf ~/.my.cnf - - .travis/wait_mysql.sh - before_script: - - export MYSQL_TEST_USER=gotest - - export MYSQL_TEST_PASS=secret - - export MYSQL_TEST_ADDR=127.0.0.1:3307 - - export MYSQL_TEST_CONCURRENT=1 - - - env: DB=MARIA10_1 - sudo: required - dist: trusty - go: 1.10.x - services: - - docker - before_install: - - go get golang.org/x/tools/cmd/cover - - go get github.com/mattn/goveralls - - docker pull mariadb:10.1 - - docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret - mariadb:10.1 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1 - - cp .travis/docker.cnf ~/.my.cnf - - .travis/wait_mysql.sh - before_script: - - export MYSQL_TEST_USER=gotest - - export MYSQL_TEST_PASS=secret - - export MYSQL_TEST_ADDR=127.0.0.1:3307 - - export MYSQL_TEST_CONCURRENT=1 - - - os: osx - osx_image: xcode10.1 - addons: - homebrew: - packages: - - mysql - update: true - go: 1.12.x - before_install: - - go get golang.org/x/tools/cmd/cover - - go get github.com/mattn/goveralls - before_script: - - echo -e "[server]\ninnodb_log_file_size=256MB\ninnodb_buffer_pool_size=512MB\nmax_allowed_packet=16MB\nlocal_infile=1" >> /usr/local/etc/my.cnf - - mysql.server start - - mysql -uroot -e 'CREATE USER gotest IDENTIFIED BY "secret"' - - mysql -uroot -e 'GRANT ALL ON *.* TO gotest' - - mysql -uroot -e 'create database gotest;' - - export MYSQL_TEST_USER=gotest - - export MYSQL_TEST_PASS=secret - - export MYSQL_TEST_ADDR=127.0.0.1:3306 - - export MYSQL_TEST_CONCURRENT=1 - -script: - - go test -v -covermode=count -coverprofile=coverage.out - - go vet ./... - - .travis/gofmt.sh -after_script: - - $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci diff --git a/vendor/github.com/go-sql-driver/mysql/AUTHORS b/vendor/github.com/go-sql-driver/mysql/AUTHORS index ad598980..50afa2c8 100644 --- a/vendor/github.com/go-sql-driver/mysql/AUTHORS +++ b/vendor/github.com/go-sql-driver/mysql/AUTHORS @@ -13,11 +13,15 @@ Aaron Hopkins Achille Roussel +Alex Snast Alexey Palazhchenko Andrew Reid +Animesh Ray Arne Hormann +Ariel Mashraki Asta Xie Bulat Gaifullin +Caine Jette Carlos Nieto Chris Moos Craig Wilson @@ -52,6 +56,7 @@ Julien Schmidt Justin Li Justin Nuß Kamil Dziedzic +Kei Kamikawa Kevin Malachowski Kieron Woodhouse Lennart Rudolph @@ -74,20 +79,26 @@ Reed Allman Richard Wilkes Robert Russell Runrioter Wung +Sho Iizuka +Sho Ikeda Shuode Li Simon J Mudd Soroush Pour Stan Putrya Stanley Gunawan Steven Hartland +Tan Jinhua <312841925 at qq.com> Thomas Wodarek Tim Ruffles Tom Jenkinson Vladimir Kovpak +Vladyslav Zhelezniak Xiangyu Hu Xiaobing Jiang Xiuming Chen +Xuehong Chan Zhenye Xie +Zhixin Wen # Organizations @@ -103,3 +114,4 @@ Multiplay Ltd. Percona LLC Pivotal Inc. Stripe Inc. +Zendesk Inc. diff --git a/vendor/github.com/go-sql-driver/mysql/CHANGELOG.md b/vendor/github.com/go-sql-driver/mysql/CHANGELOG.md index 9cb97b38..72a738ed 100644 --- a/vendor/github.com/go-sql-driver/mysql/CHANGELOG.md +++ b/vendor/github.com/go-sql-driver/mysql/CHANGELOG.md @@ -1,3 +1,29 @@ +## Version 1.6 (2021-04-01) + +Changes: + + - Migrate the CI service from travis-ci to GitHub Actions (#1176, #1183, #1190) + - `NullTime` is deprecated (#960, #1144) + - Reduce allocations when building SET command (#1111) + - Performance improvement for time formatting (#1118) + - Performance improvement for time parsing (#1098, #1113) + +New Features: + + - Implement `driver.Validator` interface (#1106, #1174) + - Support returning `uint64` from `Valuer` in `ConvertValue` (#1143) + - Add `json.RawMessage` for converter and prepared statement (#1059) + - Interpolate `json.RawMessage` as `string` (#1058) + - Implements `CheckNamedValue` (#1090) + +Bugfixes: + + - Stop rounding times (#1121, #1172) + - Put zero filler into the SSL handshake packet (#1066) + - Fix checking cancelled connections back into the connection pool (#1095) + - Fix remove last 0 byte for mysql_old_password when password is empty (#1133) + + ## Version 1.5 (2020-01-07) Changes: diff --git a/vendor/github.com/go-sql-driver/mysql/README.md b/vendor/github.com/go-sql-driver/mysql/README.md index d2627a41..0b13154f 100644 --- a/vendor/github.com/go-sql-driver/mysql/README.md +++ b/vendor/github.com/go-sql-driver/mysql/README.md @@ -35,7 +35,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac * Supports queries larger than 16MB * Full [`sql.RawBytes`](https://golang.org/pkg/database/sql/#RawBytes) support. * Intelligent `LONG DATA` handling in prepared statements - * Secure `LOAD DATA LOCAL INFILE` support with file Whitelisting and `io.Reader` support + * Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support * Optional `time.Time` parsing * Optional placeholder interpolation @@ -56,15 +56,37 @@ Make sure [Git is installed](https://git-scm.com/downloads) on your machine and _Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface. You only need to import the driver and can use the full [`database/sql`](https://golang.org/pkg/database/sql/) API then. Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`: + ```go -import "database/sql" -import _ "github.com/go-sql-driver/mysql" +import ( + "database/sql" + "time" + + _ "github.com/go-sql-driver/mysql" +) + +// ... db, err := sql.Open("mysql", "user:password@/dbname") +if err != nil { + panic(err) +} +// See "Important settings" section. +db.SetConnMaxLifetime(time.Minute * 3) +db.SetMaxOpenConns(10) +db.SetMaxIdleConns(10) ``` [Examples are available in our Wiki](https://github.com/go-sql-driver/mysql/wiki/Examples "Go-MySQL-Driver Examples"). +### Important settings + +`db.SetConnMaxLifetime()` is required to ensure connections are closed by the driver safely before connection is closed by MySQL server, OS, or other middlewares. Since some middlewares close idle connections by 5 minutes, we recommend timeout shorter than 5 minutes. This setting helps load balancing and changing system variables too. + +`db.SetMaxOpenConns()` is highly recommended to limit the number of connection used by the application. There is no recommended limit number because it depends on application and MySQL server. + +`db.SetMaxIdleConns()` is recommended to be set same to (or greater than) `db.SetMaxOpenConns()`. When it is smaller than `SetMaxOpenConns()`, connections can be opened and closed very frequently than you expect. Idle connections can be closed by the `db.SetConnMaxLifetime()`. If you want to close idle connections more rapidly, you can use `db.SetConnMaxIdleTime()` since Go 1.15. + ### DSN (Data Source Name) @@ -122,7 +144,7 @@ Valid Values: true, false Default: false ``` -`allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. +`allowAllFiles=true` disables the file allowlist for `LOAD DATA LOCAL INFILE` and allows *all* files. [*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html) ##### `allowCleartextPasswords` @@ -133,7 +155,7 @@ Valid Values: true, false Default: false ``` -`allowCleartextPasswords=true` allows using the [cleartext client side plugin](http://dev.mysql.com/doc/en/cleartext-authentication-plugin.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network. +`allowCleartextPasswords=true` allows using the [cleartext client side plugin](https://dev.mysql.com/doc/en/cleartext-pluggable-authentication.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network. ##### `allowNativePasswords` @@ -230,7 +252,7 @@ Default: false If `interpolateParams` is true, placeholders (`?`) in calls to `db.Query()` and `db.Exec()` are interpolated into a single query string with given parameters. This reduces the number of roundtrips, since the driver has to prepare a statement, execute it with given parameters and close the statement again with `interpolateParams=false`. -*This can not be used together with the multibyte encodings BIG5, CP932, GB2312, GBK or SJIS. These are blacklisted as they may [introduce a SQL injection vulnerability](http://stackoverflow.com/a/12118602/3430118)!* +*This can not be used together with the multibyte encodings BIG5, CP932, GB2312, GBK or SJIS. These are rejected as they may [introduce a SQL injection vulnerability](http://stackoverflow.com/a/12118602/3430118)!* ##### `loc` @@ -376,7 +398,7 @@ Rules: Examples: * `autocommit=1`: `SET autocommit=1` * [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'` - * [`tx_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation): `SET tx_isolation='REPEATABLE-READ'` + * [`transaction_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_transaction_isolation): `SET transaction_isolation='REPEATABLE-READ'` #### Examples @@ -445,7 +467,7 @@ For this feature you need direct access to the package. Therefore you must chang import "github.com/go-sql-driver/mysql" ``` -Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the Whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)). +Files must be explicitly allowed by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the allowlist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)). To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::` then. Choose different names for different handlers and `DeregisterReaderHandler` when you don't need it anymore. @@ -459,8 +481,6 @@ However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` v **Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes). -Alternatively you can use the [`NullTime`](https://godoc.org/github.com/go-sql-driver/mysql#NullTime) type as the scan destination, which works with both `time.Time` and `string` / `[]byte`. - ### Unicode support Since version 1.5 Go-MySQL-Driver automatically uses the collation ` utf8mb4_general_ci` by default. @@ -477,7 +497,7 @@ To run the driver tests you may need to adjust the configuration. See the [Testi Go-MySQL-Driver is not feature-complete yet. Your help is very appreciated. If you want to contribute, you can work on an [open issue](https://github.com/go-sql-driver/mysql/issues?state=open) or review a [pull request](https://github.com/go-sql-driver/mysql/pulls). -See the [Contribution Guidelines](https://github.com/go-sql-driver/mysql/blob/master/CONTRIBUTING.md) for details. +See the [Contribution Guidelines](https://github.com/go-sql-driver/mysql/blob/master/.github/CONTRIBUTING.md) for details. --------------------------------------- @@ -498,4 +518,3 @@ Please read the [MPL 2.0 FAQ](https://www.mozilla.org/en-US/MPL/2.0/FAQ/) if you You can read the full terms here: [LICENSE](https://raw.github.com/go-sql-driver/mysql/master/LICENSE). ![Go Gopher and MySQL Dolphin](https://raw.github.com/wiki/go-sql-driver/mysql/go-mysql-driver_m.jpg "Golang Gopher transporting the MySQL Dolphin in a wheelbarrow") - diff --git a/vendor/github.com/go-sql-driver/mysql/auth.go b/vendor/github.com/go-sql-driver/mysql/auth.go index fec7040d..b2f19e8f 100644 --- a/vendor/github.com/go-sql-driver/mysql/auth.go +++ b/vendor/github.com/go-sql-driver/mysql/auth.go @@ -15,6 +15,7 @@ import ( "crypto/sha256" "crypto/x509" "encoding/pem" + "fmt" "sync" ) @@ -136,10 +137,6 @@ func pwHash(password []byte) (result [2]uint32) { // Hash password using insecure pre 4.1 method func scrambleOldPassword(scramble []byte, password string) []byte { - if len(password) == 0 { - return nil - } - scramble = scramble[:8] hashPw := pwHash([]byte(password)) @@ -247,6 +244,9 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { if !mc.cfg.AllowOldPasswords { return nil, ErrOldPassword } + if len(mc.cfg.Passwd) == 0 { + return nil, nil + } // Note: there are edge cases where this should work but doesn't; // this is currently "wontfix": // https://github.com/go-sql-driver/mysql/issues/184 @@ -372,7 +372,10 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { return err } - block, _ := pem.Decode(data[1:]) + block, rest := pem.Decode(data[1:]) + if block == nil { + return fmt.Errorf("No Pem data found, data: %s", rest) + } pkix, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { return err diff --git a/vendor/github.com/go-sql-driver/mysql/collations.go b/vendor/github.com/go-sql-driver/mysql/collations.go index 8d2b5567..326a9f7f 100644 --- a/vendor/github.com/go-sql-driver/mysql/collations.go +++ b/vendor/github.com/go-sql-driver/mysql/collations.go @@ -247,7 +247,7 @@ var collations = map[string]byte{ "utf8mb4_0900_ai_ci": 255, } -// A blacklist of collations which is unsafe to interpolate parameters. +// A denylist of collations which is unsafe to interpolate parameters. // These multibyte encodings may contains 0x5c (`\`) in their trailing bytes. var unsafeCollations = map[string]bool{ "big5_chinese_ci": true, diff --git a/vendor/github.com/go-sql-driver/mysql/connection.go b/vendor/github.com/go-sql-driver/mysql/connection.go index e4bb59e6..835f8972 100644 --- a/vendor/github.com/go-sql-driver/mysql/connection.go +++ b/vendor/github.com/go-sql-driver/mysql/connection.go @@ -12,6 +12,7 @@ import ( "context" "database/sql" "database/sql/driver" + "encoding/json" "io" "net" "strconv" @@ -46,9 +47,10 @@ type mysqlConn struct { // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { + var cmdSet strings.Builder for param, val := range mc.cfg.Params { switch param { - // Charset + // Charset: character_set_connection, character_set_client, character_set_results case "charset": charsets := strings.Split(val, ",") for i := range charsets { @@ -62,12 +64,25 @@ func (mc *mysqlConn) handleParams() (err error) { return } - // System Vars + // Other system vars accumulated in a single SET command default: - err = mc.exec("SET " + param + "=" + val + "") - if err != nil { - return + if cmdSet.Len() == 0 { + // Heuristic: 29 chars for each other key=value to reduce reallocations + cmdSet.Grow(4 + len(param) + 1 + len(val) + 30*(len(mc.cfg.Params)-1)) + cmdSet.WriteString("SET ") + } else { + cmdSet.WriteByte(',') } + cmdSet.WriteString(param) + cmdSet.WriteByte('=') + cmdSet.WriteString(val) + } + } + + if cmdSet.Len() > 0 { + err = mc.exec(cmdSet.String()) + if err != nil { + return } } @@ -230,47 +245,21 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin if v.IsZero() { buf = append(buf, "'0000-00-00'"...) } else { - v := v.In(mc.cfg.Loc) - v = v.Add(time.Nanosecond * 500) // To round under microsecond - year := v.Year() - year100 := year / 100 - year1 := year % 100 - month := v.Month() - day := v.Day() - hour := v.Hour() - minute := v.Minute() - second := v.Second() - micro := v.Nanosecond() / 1000 - - buf = append(buf, []byte{ - '\'', - digits10[year100], digits01[year100], - digits10[year1], digits01[year1], - '-', - digits10[month], digits01[month], - '-', - digits10[day], digits01[day], - ' ', - digits10[hour], digits01[hour], - ':', - digits10[minute], digits01[minute], - ':', - digits10[second], digits01[second], - }...) - - if micro != 0 { - micro10000 := micro / 10000 - micro100 := micro / 100 % 100 - micro1 := micro % 100 - buf = append(buf, []byte{ - '.', - digits10[micro10000], digits01[micro10000], - digits10[micro100], digits01[micro100], - digits10[micro1], digits01[micro1], - }...) + buf = append(buf, '\'') + buf, err = appendDateTime(buf, v.In(mc.cfg.Loc)) + if err != nil { + return "", err } buf = append(buf, '\'') } + case json.RawMessage: + buf = append(buf, '\'') + if mc.status&statusNoBackslashEscapes == 0 { + buf = escapeBytesBackslash(buf, v) + } else { + buf = escapeBytesQuotes(buf, v) + } + buf = append(buf, '\'') case []byte: if v == nil { buf = append(buf, "NULL"...) @@ -480,6 +469,10 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) { // BeginTx implements driver.ConnBeginTx interface func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if mc.closed.IsSet() { + return nil, driver.ErrBadConn + } + if err := mc.watchCancel(ctx); err != nil { return nil, err } @@ -649,3 +642,9 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error { mc.reset = true return nil } + +// IsValid implements driver.Validator interface +// (From Go 1.15) +func (mc *mysqlConn) IsValid() bool { + return !mc.closed.IsSet() +} diff --git a/vendor/github.com/go-sql-driver/mysql/dsn.go b/vendor/github.com/go-sql-driver/mysql/dsn.go index 75c8c248..93f3548c 100644 --- a/vendor/github.com/go-sql-driver/mysql/dsn.go +++ b/vendor/github.com/go-sql-driver/mysql/dsn.go @@ -375,7 +375,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { // cfg params switch value := param[1]; param[0] { - // Disable INFILE whitelist / enable all files + // Disable INFILE allowlist / enable all files case "allowAllFiles": var isBool bool cfg.AllowAllFiles, isBool = readBool(value) diff --git a/vendor/github.com/go-sql-driver/mysql/fields.go b/vendor/github.com/go-sql-driver/mysql/fields.go index e1e2ece4..ed6c7a37 100644 --- a/vendor/github.com/go-sql-driver/mysql/fields.go +++ b/vendor/github.com/go-sql-driver/mysql/fields.go @@ -106,7 +106,7 @@ var ( scanTypeInt64 = reflect.TypeOf(int64(0)) scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) - scanTypeNullTime = reflect.TypeOf(NullTime{}) + scanTypeNullTime = reflect.TypeOf(nullTime{}) scanTypeUint8 = reflect.TypeOf(uint8(0)) scanTypeUint16 = reflect.TypeOf(uint16(0)) scanTypeUint32 = reflect.TypeOf(uint32(0)) diff --git a/vendor/github.com/go-sql-driver/mysql/fuzz.go b/vendor/github.com/go-sql-driver/mysql/fuzz.go new file mode 100644 index 00000000..fa75adf6 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/fuzz.go @@ -0,0 +1,24 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package. +// +// Copyright 2020 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build gofuzz + +package mysql + +import ( + "database/sql" +) + +func Fuzz(data []byte) int { + db, err := sql.Open("mysql", string(data)) + if err != nil { + return 0 + } + db.Close() + return 1 +} diff --git a/vendor/github.com/go-sql-driver/mysql/infile.go b/vendor/github.com/go-sql-driver/mysql/infile.go index 273cb0ba..60effdfc 100644 --- a/vendor/github.com/go-sql-driver/mysql/infile.go +++ b/vendor/github.com/go-sql-driver/mysql/infile.go @@ -23,7 +23,7 @@ var ( readerRegisterLock sync.RWMutex ) -// RegisterLocalFile adds the given file to the file whitelist, +// RegisterLocalFile adds the given file to the file allowlist, // so that it can be used by "LOAD DATA LOCAL INFILE ". // Alternatively you can allow the use of all local files with // the DSN parameter 'allowAllFiles=true' @@ -45,7 +45,7 @@ func RegisterLocalFile(filePath string) { fileRegisterLock.Unlock() } -// DeregisterLocalFile removes the given filepath from the whitelist. +// DeregisterLocalFile removes the given filepath from the allowlist. func DeregisterLocalFile(filePath string) { fileRegisterLock.Lock() delete(fileRegister, strings.Trim(filePath, `"`)) diff --git a/vendor/github.com/go-sql-driver/mysql/nulltime.go b/vendor/github.com/go-sql-driver/mysql/nulltime.go index afa8a89e..651723a9 100644 --- a/vendor/github.com/go-sql-driver/mysql/nulltime.go +++ b/vendor/github.com/go-sql-driver/mysql/nulltime.go @@ -28,11 +28,11 @@ func (nt *NullTime) Scan(value interface{}) (err error) { nt.Time, nt.Valid = v, true return case []byte: - nt.Time, err = parseDateTime(string(v), time.UTC) + nt.Time, err = parseDateTime(v, time.UTC) nt.Valid = (err == nil) return case string: - nt.Time, err = parseDateTime(v, time.UTC) + nt.Time, err = parseDateTime([]byte(v), time.UTC) nt.Valid = (err == nil) return } diff --git a/vendor/github.com/go-sql-driver/mysql/nulltime_go113.go b/vendor/github.com/go-sql-driver/mysql/nulltime_go113.go index c392594d..453b4b39 100644 --- a/vendor/github.com/go-sql-driver/mysql/nulltime_go113.go +++ b/vendor/github.com/go-sql-driver/mysql/nulltime_go113.go @@ -28,4 +28,13 @@ import ( // } // // This NullTime implementation is not driver-specific +// +// Deprecated: NullTime doesn't honor the loc DSN parameter. +// NullTime.Scan interprets a time as UTC, not the loc DSN parameter. +// Use sql.NullTime instead. type NullTime sql.NullTime + +// for internal use. +// the mysql package uses sql.NullTime if it is available. +// if not, the package uses mysql.NullTime. +type nullTime = sql.NullTime // sql.NullTime is available diff --git a/vendor/github.com/go-sql-driver/mysql/nulltime_legacy.go b/vendor/github.com/go-sql-driver/mysql/nulltime_legacy.go index 86d159d4..9f7ae27a 100644 --- a/vendor/github.com/go-sql-driver/mysql/nulltime_legacy.go +++ b/vendor/github.com/go-sql-driver/mysql/nulltime_legacy.go @@ -32,3 +32,8 @@ type NullTime struct { Time time.Time Valid bool // Valid is true if Time is not NULL } + +// for internal use. +// the mysql package uses sql.NullTime if it is available. +// if not, the package uses mysql.NullTime. +type nullTime = NullTime // sql.NullTime is not available diff --git a/vendor/github.com/go-sql-driver/mysql/packets.go b/vendor/github.com/go-sql-driver/mysql/packets.go index 82ad7a20..6664e5ae 100644 --- a/vendor/github.com/go-sql-driver/mysql/packets.go +++ b/vendor/github.com/go-sql-driver/mysql/packets.go @@ -13,6 +13,7 @@ import ( "crypto/tls" "database/sql/driver" "encoding/binary" + "encoding/json" "errors" "fmt" "io" @@ -348,6 +349,12 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string return errors.New("unknown collation") } + // Filler [23 bytes] (all 0x00) + pos := 13 + for ; pos < 13+23; pos++ { + data[pos] = 0 + } + // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest if mc.cfg.tls != nil { @@ -366,12 +373,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string mc.buf.nc = tlsConn } - // Filler [23 bytes] (all 0x00) - pos := 13 - for ; pos < 13+23; pos++ { - data[pos] = 0 - } - // User [null terminated string] if len(mc.cfg.User) > 0 { pos += copy(data[pos:], mc.cfg.User) @@ -777,7 +778,7 @@ func (rows *textRows) readRow(dest []driver.Value) error { case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeDate, fieldTypeNewDate: dest[i], err = parseDateTime( - string(dest[i].([]byte)), + dest[i].([]byte), mc.cfg.Loc, ) if err == nil { @@ -1003,6 +1004,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { continue } + if v, ok := arg.(json.RawMessage); ok { + arg = []byte(v) + } // cache types and values switch v := arg.(type) { case int64: @@ -1112,7 +1116,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if v.IsZero() { b = append(b, "0000-00-00"...) } else { - b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat) + b, err = appendDateTime(b, v.In(mc.cfg.Loc)) + if err != nil { + return err + } } paramValues = appendLengthEncodedInteger(paramValues, diff --git a/vendor/github.com/go-sql-driver/mysql/statement.go b/vendor/github.com/go-sql-driver/mysql/statement.go index f7e37093..18a3ae49 100644 --- a/vendor/github.com/go-sql-driver/mysql/statement.go +++ b/vendor/github.com/go-sql-driver/mysql/statement.go @@ -10,6 +10,7 @@ package mysql import ( "database/sql/driver" + "encoding/json" "fmt" "io" "reflect" @@ -43,6 +44,11 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { return converter{} } +func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = converter{}.ConvertValue(nv.Value) + return +} + func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { if stmt.mc.closed.IsSet() { errLog.Print(ErrInvalidConn) @@ -129,6 +135,8 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { return rows, err } +var jsonType = reflect.TypeOf(json.RawMessage{}) + type converter struct{} // ConvertValue mirrors the reference/default converter in database/sql/driver @@ -146,12 +154,17 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { if err != nil { return nil, err } - if !driver.IsValue(sv) { - return nil, fmt.Errorf("non-Value type %T returned from Value", sv) + if driver.IsValue(sv) { + return sv, nil } - return sv, nil + // A value returend from the Valuer interface can be "a type handled by + // a database driver's NamedValueChecker interface" so we should accept + // uint64 here as well. + if u, ok := sv.(uint64); ok { + return u, nil + } + return nil, fmt.Errorf("non-Value type %T returned from Value", sv) } - rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.Ptr: @@ -170,11 +183,14 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { case reflect.Bool: return rv.Bool(), nil case reflect.Slice: - ek := rv.Type().Elem().Kind() - if ek == reflect.Uint8 { + switch t := rv.Type(); { + case t == jsonType: + return v, nil + case t.Elem().Kind() == reflect.Uint8: return rv.Bytes(), nil + default: + return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, t.Elem().Kind()) } - return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) case reflect.String: return rv.String(), nil } diff --git a/vendor/github.com/go-sql-driver/mysql/utils.go b/vendor/github.com/go-sql-driver/mysql/utils.go index 9552e80b..d6545f5b 100644 --- a/vendor/github.com/go-sql-driver/mysql/utils.go +++ b/vendor/github.com/go-sql-driver/mysql/utils.go @@ -106,27 +106,136 @@ func readBool(input string) (value bool, valid bool) { * Time related utils * ******************************************************************************/ -func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { - base := "0000-00-00 00:00:00.0000000" - switch len(str) { +func parseDateTime(b []byte, loc *time.Location) (time.Time, error) { + const base = "0000-00-00 00:00:00.000000" + switch len(b) { case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM" - if str == base[:len(str)] { - return + if string(b) == base[:len(b)] { + return time.Time{}, nil } - t, err = time.Parse(timeFormat[:len(str)], str) + + year, err := parseByteYear(b) + if err != nil { + return time.Time{}, err + } + if year <= 0 { + year = 1 + } + + if b[4] != '-' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[4]) + } + + m, err := parseByte2Digits(b[5], b[6]) + if err != nil { + return time.Time{}, err + } + if m <= 0 { + m = 1 + } + month := time.Month(m) + + if b[7] != '-' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[7]) + } + + day, err := parseByte2Digits(b[8], b[9]) + if err != nil { + return time.Time{}, err + } + if day <= 0 { + day = 1 + } + if len(b) == 10 { + return time.Date(year, month, day, 0, 0, 0, 0, loc), nil + } + + if b[10] != ' ' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[10]) + } + + hour, err := parseByte2Digits(b[11], b[12]) + if err != nil { + return time.Time{}, err + } + if b[13] != ':' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[13]) + } + + min, err := parseByte2Digits(b[14], b[15]) + if err != nil { + return time.Time{}, err + } + if b[16] != ':' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[16]) + } + + sec, err := parseByte2Digits(b[17], b[18]) + if err != nil { + return time.Time{}, err + } + if len(b) == 19 { + return time.Date(year, month, day, hour, min, sec, 0, loc), nil + } + + if b[19] != '.' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[19]) + } + nsec, err := parseByteNanoSec(b[20:]) + if err != nil { + return time.Time{}, err + } + return time.Date(year, month, day, hour, min, sec, nsec, loc), nil default: - err = fmt.Errorf("invalid time string: %s", str) - return + return time.Time{}, fmt.Errorf("invalid time bytes: %s", b) } +} - // Adjust location - if err == nil && loc != time.UTC { - y, mo, d := t.Date() - h, mi, s := t.Clock() - t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil +func parseByteYear(b []byte) (int, error) { + year, n := 0, 1000 + for i := 0; i < 4; i++ { + v, err := bToi(b[i]) + if err != nil { + return 0, err + } + year += v * n + n = n / 10 } + return year, nil +} - return +func parseByte2Digits(b1, b2 byte) (int, error) { + d1, err := bToi(b1) + if err != nil { + return 0, err + } + d2, err := bToi(b2) + if err != nil { + return 0, err + } + return d1*10 + d2, nil +} + +func parseByteNanoSec(b []byte) (int, error) { + ns, digit := 0, 100000 // max is 6-digits + for i := 0; i < len(b); i++ { + v, err := bToi(b[i]) + if err != nil { + return 0, err + } + ns += v * digit + digit /= 10 + } + // nanoseconds has 10-digits. (needs to scale digits) + // 10 - 6 = 4, so we have to multiple 1000. + return ns * 1000, nil +} + +func bToi(b byte) (int, error) { + if b < '0' || b > '9' { + return 0, errors.New("not [0-9]") + } + return int(b - '0'), nil } func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) { @@ -167,6 +276,64 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va return nil, fmt.Errorf("invalid DATETIME packet length %d", num) } +func appendDateTime(buf []byte, t time.Time) ([]byte, error) { + year, month, day := t.Date() + hour, min, sec := t.Clock() + nsec := t.Nanosecond() + + if year < 1 || year > 9999 { + return buf, errors.New("year is not in the range [1, 9999]: " + strconv.Itoa(year)) // use errors.New instead of fmt.Errorf to avoid year escape to heap + } + year100 := year / 100 + year1 := year % 100 + + var localBuf [len("2006-01-02T15:04:05.999999999")]byte // does not escape + localBuf[0], localBuf[1], localBuf[2], localBuf[3] = digits10[year100], digits01[year100], digits10[year1], digits01[year1] + localBuf[4] = '-' + localBuf[5], localBuf[6] = digits10[month], digits01[month] + localBuf[7] = '-' + localBuf[8], localBuf[9] = digits10[day], digits01[day] + + if hour == 0 && min == 0 && sec == 0 && nsec == 0 { + return append(buf, localBuf[:10]...), nil + } + + localBuf[10] = ' ' + localBuf[11], localBuf[12] = digits10[hour], digits01[hour] + localBuf[13] = ':' + localBuf[14], localBuf[15] = digits10[min], digits01[min] + localBuf[16] = ':' + localBuf[17], localBuf[18] = digits10[sec], digits01[sec] + + if nsec == 0 { + return append(buf, localBuf[:19]...), nil + } + nsec100000000 := nsec / 100000000 + nsec1000000 := (nsec / 1000000) % 100 + nsec10000 := (nsec / 10000) % 100 + nsec100 := (nsec / 100) % 100 + nsec1 := nsec % 100 + localBuf[19] = '.' + + // milli second + localBuf[20], localBuf[21], localBuf[22] = + digits01[nsec100000000], digits10[nsec1000000], digits01[nsec1000000] + // micro second + localBuf[23], localBuf[24], localBuf[25] = + digits10[nsec10000], digits01[nsec10000], digits10[nsec100] + // nano second + localBuf[26], localBuf[27], localBuf[28] = + digits01[nsec100], digits10[nsec1], digits01[nsec1] + + // trim trailing zeros + n := len(localBuf) + for n > 0 && localBuf[n-1] == '0' { + n-- + } + + return append(buf, localBuf[:n]...), nil +} + // zeroDateTime is used in formatBinaryDateTime to avoid an allocation // if the DATE or DATETIME has the zero value. // It must never be changed. diff --git a/vendor/github.com/lib/pq/.gitignore b/vendor/github.com/lib/pq/.gitignore index 0f1d00e1..3243952a 100644 --- a/vendor/github.com/lib/pq/.gitignore +++ b/vendor/github.com/lib/pq/.gitignore @@ -2,3 +2,5 @@ *.test *~ *.swp +.idea +.vscode \ No newline at end of file diff --git a/vendor/github.com/lib/pq/.travis.sh b/vendor/github.com/lib/pq/.travis.sh index ebf44703..15607b50 100644 --- a/vendor/github.com/lib/pq/.travis.sh +++ b/vendor/github.com/lib/pq/.travis.sh @@ -1,17 +1,15 @@ #!/bin/bash -set -eu +set -eux client_configure() { sudo chmod 600 $PQSSLCERTTEST_PATH/postgresql.key } pgdg_repository() { - local sourcelist='sources.list.d/postgresql.list' - curl -sS 'https://www.postgresql.org/media/keys/ACCC4CF8.asc' | sudo apt-key add - - echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION | sudo tee "/etc/apt/$sourcelist" - sudo apt-get -o Dir::Etc::sourcelist="$sourcelist" -o Dir::Etc::sourceparts='-' -o APT::Get::List-Cleanup='0' update + echo "deb http://apt.postgresql.org/pub/repos/apt/ `lsb_release -cs`-pgdg main" | sudo tee /etc/apt/sources.list.d/pgdg.list + sudo apt-get update } postgresql_configure() { @@ -51,10 +49,10 @@ postgresql_configure() { } postgresql_install() { - xargs sudo apt-get -y -o Dpkg::Options::='--force-confdef' -o Dpkg::Options::='--force-confnew' install <<-packages + xargs sudo apt-get -y install <<-packages postgresql-$PGVERSION + postgresql-client-$PGVERSION postgresql-server-dev-$PGVERSION - postgresql-contrib-$PGVERSION packages } diff --git a/vendor/github.com/lib/pq/.travis.yml b/vendor/github.com/lib/pq/.travis.yml index 3498c53d..283f35f2 100644 --- a/vendor/github.com/lib/pq/.travis.yml +++ b/vendor/github.com/lib/pq/.travis.yml @@ -1,9 +1,9 @@ language: go go: - - 1.13.x - 1.14.x - - master + - 1.15.x + - 1.16.x sudo: true @@ -13,6 +13,7 @@ env: - PQGOSSLTESTS=1 - PQSSLCERTTEST_PATH=$PWD/certs - PGHOST=127.0.0.1 + - GODEBUG=x509ignoreCN=0 matrix: - PGVERSION=10 - PGVERSION=9.6 diff --git a/vendor/github.com/lib/pq/README.md b/vendor/github.com/lib/pq/README.md index 16fc31cd..c972a86a 100644 --- a/vendor/github.com/lib/pq/README.md +++ b/vendor/github.com/lib/pq/README.md @@ -19,6 +19,7 @@ * Unix socket support * Notifications: `LISTEN`/`NOTIFY` * pgpass support +* GSS (Kerberos) auth ## Tests diff --git a/vendor/github.com/lib/pq/array.go b/vendor/github.com/lib/pq/array.go index e4933e22..7806a31f 100644 --- a/vendor/github.com/lib/pq/array.go +++ b/vendor/github.com/lib/pq/array.go @@ -22,7 +22,7 @@ var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() // db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) // // var x []sql.NullInt64 -// db.QueryRow('SELECT ARRAY[235, 401]').Scan(pq.Array(&x)) +// db.QueryRow(`SELECT ARRAY[235, 401]`).Scan(pq.Array(&x)) // // Scanning multi-dimensional arrays is not supported. Arrays where the lower // bound is not one (such as `[0:0]={1}') are not supported. @@ -35,19 +35,31 @@ func Array(a interface{}) interface { return (*BoolArray)(&a) case []float64: return (*Float64Array)(&a) + case []float32: + return (*Float32Array)(&a) case []int64: return (*Int64Array)(&a) + case []int32: + return (*Int32Array)(&a) case []string: return (*StringArray)(&a) + case [][]byte: + return (*ByteaArray)(&a) case *[]bool: return (*BoolArray)(a) case *[]float64: return (*Float64Array)(a) + case *[]float32: + return (*Float32Array)(a) case *[]int64: return (*Int64Array)(a) + case *[]int32: + return (*Int32Array)(a) case *[]string: return (*StringArray)(a) + case *[][]byte: + return (*ByteaArray)(a) } return GenericArray{a} @@ -267,6 +279,70 @@ func (a Float64Array) Value() (driver.Value, error) { return "{}", nil } +// Float32Array represents a one-dimensional array of the PostgreSQL double +// precision type. +type Float32Array []float32 + +// Scan implements the sql.Scanner interface. +func (a *Float32Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Float32Array", src) +} + +func (a *Float32Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Float32Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Float32Array, len(elems)) + for i, v := range elems { + var x float64 + if x, err = strconv.ParseFloat(string(v), 32); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + b[i] = float32(x) + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Float32Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendFloat(b, float64(a[0]), 'f', -1, 32) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendFloat(b, float64(a[i]), 'f', -1, 32) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + // GenericArray implements the driver.Valuer and sql.Scanner interfaces for // an array or slice of any dimension. type GenericArray struct{ A interface{} } @@ -483,6 +559,69 @@ func (a Int64Array) Value() (driver.Value, error) { return "{}", nil } +// Int32Array represents a one-dimensional array of the PostgreSQL integer types. +type Int32Array []int32 + +// Scan implements the sql.Scanner interface. +func (a *Int32Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Int32Array", src) +} + +func (a *Int32Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Int32Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Int32Array, len(elems)) + for i, v := range elems { + var x int + if x, err = strconv.Atoi(string(v)); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + b[i] = int32(x) + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Int32Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendInt(b, int64(a[0]), 10) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendInt(b, int64(a[i]), 10) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + // StringArray represents a one-dimensional array of the PostgreSQL character types. type StringArray []string diff --git a/vendor/github.com/lib/pq/conn.go b/vendor/github.com/lib/pq/conn.go index 77c36134..b09a1704 100644 --- a/vendor/github.com/lib/pq/conn.go +++ b/vendor/github.com/lib/pq/conn.go @@ -18,6 +18,7 @@ import ( "path/filepath" "strconv" "strings" + "sync/atomic" "time" "unicode" @@ -38,13 +39,18 @@ var ( errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") ) +// Compile time validation that our types implement the expected interfaces +var ( + _ driver.Driver = Driver{} +) + // Driver is the Postgres database driver. type Driver struct{} // Open opens a new connection to the database. name is a connection string. // Most users should only use it through database/sql package from the standard // library. -func (d *Driver) Open(name string) (driver.Conn, error) { +func (d Driver) Open(name string) (driver.Conn, error) { return Open(name) } @@ -136,7 +142,7 @@ type conn struct { // If true, this connection is bad and all public-facing functions should // return ErrBadConn. - bad bool + bad *atomic.Value // If set, this connection should never use the binary format when // receiving query results from prepared statements. Only provided for @@ -155,6 +161,9 @@ type conn struct { // If not nil, notifications will be synchronously sent here notificationHandler func(*Notification) + + // GSSAPI context + gss GSS } // Handle driver-side settings in parsed connection string. @@ -289,11 +298,20 @@ func (c *Connector) open(ctx context.Context) (cn *conn, err error) { // the user. defer errRecoverNoErrBadConn(&err) - o := c.opts + // Create a new values map (copy). This makes it so maps in different + // connections do not reference the same underlying data structure, so it + // is safe for multiple connections to concurrently write to their opts. + o := make(values) + for k, v := range c.opts { + o[k] = v + } + bad := &atomic.Value{} + bad.Store(false) cn = &conn{ opts: o, dialer: c.dialer, + bad: bad, } err = cn.handleDriverSettings(o) if err != nil { @@ -335,10 +353,6 @@ func (c *Connector) open(ctx context.Context) (cn *conn, err error) { func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) { network, address := network(o) - // SSL is not necessary or supported over UNIX domain sockets - if network == "unix" { - o["sslmode"] = "disable" - } // Zero or not specified means wait indefinitely. if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { @@ -502,9 +516,22 @@ func (cn *conn) isInTransaction() bool { cn.txnStatus == txnStatusInFailedTransaction } +func (cn *conn) setBad() { + if cn.bad != nil { + cn.bad.Store(true) + } +} + +func (cn *conn) getBad() bool { + if cn.bad != nil { + return cn.bad.Load().(bool) + } + return false +} + func (cn *conn) checkIsInTransaction(intxn bool) { if cn.isInTransaction() != intxn { - cn.bad = true + cn.setBad() errorf("unexpected transaction status %v", cn.txnStatus) } } @@ -514,7 +541,7 @@ func (cn *conn) Begin() (_ driver.Tx, err error) { } func (cn *conn) begin(mode string) (_ driver.Tx, err error) { - if cn.bad { + if cn.getBad() { return nil, driver.ErrBadConn } defer cn.errRecover(&err) @@ -525,11 +552,11 @@ func (cn *conn) begin(mode string) (_ driver.Tx, err error) { return nil, err } if commandTag != "BEGIN" { - cn.bad = true + cn.setBad() return nil, fmt.Errorf("unexpected command tag %s", commandTag) } if cn.txnStatus != txnStatusIdleInTransaction { - cn.bad = true + cn.setBad() return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus) } return cn, nil @@ -543,7 +570,7 @@ func (cn *conn) closeTxn() { func (cn *conn) Commit() (err error) { defer cn.closeTxn() - if cn.bad { + if cn.getBad() { return driver.ErrBadConn } defer cn.errRecover(&err) @@ -565,12 +592,12 @@ func (cn *conn) Commit() (err error) { _, commandTag, err := cn.simpleExec("COMMIT") if err != nil { if cn.isInTransaction() { - cn.bad = true + cn.setBad() } return err } if commandTag != "COMMIT" { - cn.bad = true + cn.setBad() return fmt.Errorf("unexpected command tag %s", commandTag) } cn.checkIsInTransaction(false) @@ -579,7 +606,7 @@ func (cn *conn) Commit() (err error) { func (cn *conn) Rollback() (err error) { defer cn.closeTxn() - if cn.bad { + if cn.getBad() { return driver.ErrBadConn } defer cn.errRecover(&err) @@ -591,7 +618,7 @@ func (cn *conn) rollback() (err error) { _, commandTag, err := cn.simpleExec("ROLLBACK") if err != nil { if cn.isInTransaction() { - cn.bad = true + cn.setBad() } return err } @@ -631,7 +658,7 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err case 'T', 'D': // ignore any results default: - cn.bad = true + cn.setBad() errorf("unknown response for simple query: %q", t) } } @@ -653,7 +680,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { // the user can close, though, to avoid connections from being // leaked. A "rows" with done=true works fine for that purpose. if err != nil { - cn.bad = true + cn.setBad() errorf("unexpected message %q in simple query execution", t) } if res == nil { @@ -664,8 +691,11 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { // Set the result and tag to the last command complete if there wasn't a // query already run. Although queries usually return from here and cede // control to Next, a query with zero results does not. - if t == 'C' && res.colNames == nil { + if t == 'C' { res.result, res.tag = cn.parseComplete(r.string()) + if res.colNames != nil { + return + } } res.done = true case 'Z': @@ -677,7 +707,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { err = parseError(r) case 'D': if res == nil { - cn.bad = true + cn.setBad() errorf("unexpected DataRow in simple query execution") } // the query didn't fail; kick off to Next @@ -692,7 +722,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { // To work around a bug in QueryRow in Go 1.2 and earlier, wait // until the first DataRow has been received. default: - cn.bad = true + cn.setBad() errorf("unknown response for simple query: %q", t) } } @@ -785,7 +815,7 @@ func (cn *conn) prepareTo(q, stmtName string) *stmt { } func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { - if cn.bad { + if cn.getBad() { return nil, driver.ErrBadConn } defer cn.errRecover(&err) @@ -824,7 +854,7 @@ func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { } func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { - if cn.bad { + if cn.getBad() { return nil, driver.ErrBadConn } if cn.inCopy { @@ -858,7 +888,7 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { // Implement the optional "Execer" interface for one-shot queries func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) { - if cn.bad { + if cn.getBad() { return nil, driver.ErrBadConn } defer cn.errRecover(&err) @@ -892,9 +922,20 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err return r, err } +type safeRetryError struct { + Err error +} + +func (se *safeRetryError) Error() string { + return se.Err.Error() +} + func (cn *conn) send(m *writeBuf) { - _, err := cn.c.Write(m.wrap()) + n, err := cn.c.Write(m.wrap()) if err != nil { + if n == 0 { + err = &safeRetryError{Err: err} + } panic(err) } } @@ -919,7 +960,7 @@ func (cn *conn) sendSimpleMessage(typ byte) (err error) { // the message yourself. func (cn *conn) saveMessage(typ byte, buf *readBuf) { if cn.saveMessageType != 0 { - cn.bad = true + cn.setBad() errorf("unexpected saveMessageType %d", cn.saveMessageType) } cn.saveMessageType = typ @@ -1065,7 +1106,7 @@ func isDriverSetting(key string) bool { return true case "password": return true - case "sslmode", "sslcert", "sslkey", "sslrootcert": + case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline": return true case "fallback_application_name": return true @@ -1075,7 +1116,10 @@ func isDriverSetting(key string) bool { return true case "binary_parameters": return true - + case "krbsrvname": + return true + case "krbspn": + return true default: return false } @@ -1155,6 +1199,59 @@ func (cn *conn) auth(r *readBuf, o values) { if r.int32() != 0 { errorf("unexpected authentication response: %q", t) } + case 7: // GSSAPI, startup + if newGss == nil { + errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)") + } + cli, err := newGss() + if err != nil { + errorf("kerberos error: %s", err.Error()) + } + + var token []byte + + if spn, ok := o["krbspn"]; ok { + // Use the supplied SPN if provided.. + token, err = cli.GetInitTokenFromSpn(spn) + } else { + // Allow the kerberos service name to be overridden + service := "postgres" + if val, ok := o["krbsrvname"]; ok { + service = val + } + + token, err = cli.GetInitToken(o["host"], service) + } + + if err != nil { + errorf("failed to get Kerberos ticket: %q", err) + } + + w := cn.writeBuf('p') + w.bytes(token) + cn.send(w) + + // Store for GSSAPI continue message + cn.gss = cli + + case 8: // GSSAPI continue + + if cn.gss == nil { + errorf("GSSAPI protocol error") + } + + b := []byte(*r) + + done, tokOut, err := cn.gss.Continue(b) + if err == nil && !done { + w := cn.writeBuf('p') + w.bytes(tokOut) + cn.send(w) + } + + // Errors fall through and read the more detailed message + // from the server.. + case 10: sc := scram.NewClient(sha256.New, o["user"], o["password"]) sc.Step(nil) @@ -1233,7 +1330,7 @@ func (st *stmt) Close() (err error) { if st.closed { return nil } - if st.cn.bad { + if st.cn.getBad() { return driver.ErrBadConn } defer st.cn.errRecover(&err) @@ -1247,14 +1344,14 @@ func (st *stmt) Close() (err error) { t, _ := st.cn.recv1() if t != '3' { - st.cn.bad = true + st.cn.setBad() errorf("unexpected close response: %q", t) } st.closed = true t, r := st.cn.recv1() if t != 'Z' { - st.cn.bad = true + st.cn.setBad() errorf("expected ready for query, but got: %q", t) } st.cn.processReadyForQuery(r) @@ -1263,7 +1360,7 @@ func (st *stmt) Close() (err error) { } func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { - if st.cn.bad { + if st.cn.getBad() { return nil, driver.ErrBadConn } defer st.cn.errRecover(&err) @@ -1276,7 +1373,7 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { } func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { - if st.cn.bad { + if st.cn.getBad() { return nil, driver.ErrBadConn } defer st.cn.errRecover(&err) @@ -1363,7 +1460,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") { parts := strings.Split(commandTag, " ") if len(parts) != 3 { - cn.bad = true + cn.setBad() errorf("unexpected INSERT command tag %s", commandTag) } affectedRows = &parts[len(parts)-1] @@ -1375,7 +1472,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { } n, err := strconv.ParseInt(*affectedRows, 10, 64) if err != nil { - cn.bad = true + cn.setBad() errorf("could not parse commandTag: %s", err) } return driver.RowsAffected(n), commandTag @@ -1442,7 +1539,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) { } conn := rs.cn - if conn.bad { + if conn.getBad() { return driver.ErrBadConn } defer conn.errRecover(&err) @@ -1467,7 +1564,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) { case 'D': n := rs.rb.int16() if err != nil { - conn.bad = true + conn.setBad() errorf("unexpected DataRow after error %s", err) } if n < len(dest) { @@ -1634,10 +1731,9 @@ func (cn *conn) processParameterStatus(r *readBuf) { case "server_version": var major1 int var major2 int - var minor int - _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor) + _, err = fmt.Sscanf(r.string(), "%d.%d", &major1, &major2) if err == nil { - cn.parameterStatus.serverVersion = major1*10000 + major2*100 + minor + cn.parameterStatus.serverVersion = major1*10000 + major2*100 } case "TimeZone": @@ -1662,7 +1758,7 @@ func (cn *conn) readReadyForQuery() { cn.processReadyForQuery(r) return default: - cn.bad = true + cn.setBad() errorf("unexpected message %q; expected ReadyForQuery", t) } } @@ -1682,7 +1778,7 @@ func (cn *conn) readParseResponse() { cn.readReadyForQuery() panic(err) default: - cn.bad = true + cn.setBad() errorf("unexpected Parse response %q", t) } } @@ -1707,7 +1803,7 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [ cn.readReadyForQuery() panic(err) default: - cn.bad = true + cn.setBad() errorf("unexpected Describe statement response %q", t) } } @@ -1725,7 +1821,7 @@ func (cn *conn) readPortalDescribeResponse() rowsHeader { cn.readReadyForQuery() panic(err) default: - cn.bad = true + cn.setBad() errorf("unexpected Describe response %q", t) } panic("not reached") @@ -1741,7 +1837,7 @@ func (cn *conn) readBindResponse() { cn.readReadyForQuery() panic(err) default: - cn.bad = true + cn.setBad() errorf("unexpected Bind response %q", t) } } @@ -1768,7 +1864,7 @@ func (cn *conn) postExecuteWorkaround() { cn.saveMessage(t, r) return default: - cn.bad = true + cn.setBad() errorf("unexpected message during extended query execution: %q", t) } } @@ -1781,7 +1877,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co switch t { case 'C': if err != nil { - cn.bad = true + cn.setBad() errorf("unexpected CommandComplete after error %s", err) } res, commandTag = cn.parseComplete(r.string()) @@ -1795,7 +1891,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co err = parseError(r) case 'T', 'D', 'I': if err != nil { - cn.bad = true + cn.setBad() errorf("unexpected %q after error %s", t, err) } if t == 'I' { @@ -1803,7 +1899,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co } // ignore any results default: - cn.bad = true + cn.setBad() errorf("unknown %s response: %q", protocolState, t) } } diff --git a/vendor/github.com/lib/pq/conn_go18.go b/vendor/github.com/lib/pq/conn_go18.go index 09e2ea46..2b9a9599 100644 --- a/vendor/github.com/lib/pq/conn_go18.go +++ b/vendor/github.com/lib/pq/conn_go18.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "io/ioutil" + "sync/atomic" "time" ) @@ -89,10 +90,21 @@ func (cn *conn) Ping(ctx context.Context) error { func (cn *conn) watchCancel(ctx context.Context) func() { if done := ctx.Done(); done != nil { - finished := make(chan struct{}) + finished := make(chan struct{}, 1) go func() { select { case <-done: + select { + case finished <- struct{}{}: + default: + // We raced with the finish func, let the next query handle this with the + // context. + return + } + + // Set the connection state to bad so it does not get reused. + cn.setBad() + // At this point the function level context is canceled, // so it must not be used for the additional network // request to cancel the query. @@ -101,13 +113,14 @@ func (cn *conn) watchCancel(ctx context.Context) func() { defer cancel() _ = cn.cancel(ctxCancel) - finished <- struct{}{} case <-finished: } }() return func() { select { case <-finished: + cn.setBad() + cn.Close() case finished <- struct{}{}: } } @@ -116,17 +129,29 @@ func (cn *conn) watchCancel(ctx context.Context) func() { } func (cn *conn) cancel(ctx context.Context) error { - c, err := dial(ctx, cn.dialer, cn.opts) + // Create a new values map (copy). This makes sure the connection created + // in this method cannot write to the same underlying data, which could + // cause a concurrent map write panic. This is necessary because cancel + // is called from a goroutine in watchCancel. + o := make(values) + for k, v := range cn.opts { + o[k] = v + } + + c, err := dial(ctx, cn.dialer, o) if err != nil { return err } defer c.Close() { + bad := &atomic.Value{} + bad.Store(false) can := conn{ - c: c, + c: c, + bad: bad, } - err = can.ssl(cn.opts) + err = can.ssl(o) if err != nil { return err } diff --git a/vendor/github.com/lib/pq/connector.go b/vendor/github.com/lib/pq/connector.go index 2f8ced67..d7d47261 100644 --- a/vendor/github.com/lib/pq/connector.go +++ b/vendor/github.com/lib/pq/connector.go @@ -27,7 +27,7 @@ func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { return c.open(ctx) } -// Driver returnst the underlying driver of this Connector. +// Driver returns the underlying driver of this Connector. func (c *Connector) Driver() driver.Driver { return &Driver{} } @@ -106,5 +106,10 @@ func NewConnector(dsn string) (*Connector, error) { o["user"] = u } + // SSL is not necessary or supported over UNIX domain sockets + if network, _ := network(o); network == "unix" { + o["sslmode"] = "disable" + } + return &Connector{opts: o, dialer: defaultDialer{}}, nil } diff --git a/vendor/github.com/lib/pq/copy.go b/vendor/github.com/lib/pq/copy.go index d3bc1edd..bb3cbd7b 100644 --- a/vendor/github.com/lib/pq/copy.go +++ b/vendor/github.com/lib/pq/copy.go @@ -49,6 +49,7 @@ type copyin struct { buffer []byte rowData chan []byte done chan bool + driver.Result closed bool @@ -151,6 +152,8 @@ func (ci *copyin) resploop() { switch t { case 'C': // complete + res, _ := ci.cn.parseComplete(r.string()) + ci.setResult(res) case 'N': if n := ci.cn.noticeHandler; n != nil { n(parseError(&r)) @@ -173,13 +176,13 @@ func (ci *copyin) resploop() { func (ci *copyin) setBad() { ci.Lock() - ci.cn.bad = true + ci.cn.setBad() ci.Unlock() } func (ci *copyin) isBad() bool { ci.Lock() - b := ci.cn.bad + b := ci.cn.getBad() ci.Unlock() return b } @@ -201,6 +204,22 @@ func (ci *copyin) setError(err error) { ci.Unlock() } +func (ci *copyin) setResult(result driver.Result) { + ci.Lock() + ci.Result = result + ci.Unlock() +} + +func (ci *copyin) getResult() driver.Result { + ci.Lock() + result := ci.Result + ci.Unlock() + if result == nil { + return driver.RowsAffected(0) + } + return result +} + func (ci *copyin) NumInput() int { return -1 } @@ -231,7 +250,11 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { } if len(v) == 0 { - return driver.RowsAffected(0), ci.Close() + if err := ci.Close(); err != nil { + return driver.RowsAffected(0), err + } + + return ci.getResult(), nil } numValues := len(v) diff --git a/vendor/github.com/lib/pq/doc.go b/vendor/github.com/lib/pq/doc.go index 2a60054e..b5718480 100644 --- a/vendor/github.com/lib/pq/doc.go +++ b/vendor/github.com/lib/pq/doc.go @@ -241,5 +241,28 @@ bytes by the PostgreSQL server. You can find a complete, working example of Listener usage at https://godoc.org/github.com/lib/pq/example/listen. + +Kerberos Support + + +If you need support for Kerberos authentication, add the following to your main +package: + + import "github.com/lib/pq/auth/kerberos" + + func init() { + pq.RegisterGSSProvider(func() (pq.Gss, error) { return kerberos.NewGSS() }) + } + +This package is in a separate module so that users who don't need Kerberos +don't have to download unnecessary dependencies. + +When imported, additional connection string parameters are supported: + + * krbsrvname - GSS (Kerberos) service name when constructing the + SPN (default is `postgres`). This will be combined with the host + to form the full SPN: `krbsrvname/host`. + * krbspn - GSS (Kerberos) SPN. This takes priority over + `krbsrvname` if present. */ package pq diff --git a/vendor/github.com/lib/pq/encode.go b/vendor/github.com/lib/pq/encode.go index c4dafe27..51c143ee 100644 --- a/vendor/github.com/lib/pq/encode.go +++ b/vendor/github.com/lib/pq/encode.go @@ -200,11 +200,17 @@ func appendEscapedText(buf []byte, text string) []byte { func mustParse(f string, typ oid.Oid, s []byte) time.Time { str := string(s) - // check for a 30-minute-offset timezone - if (typ == oid.T_timestamptz || typ == oid.T_timetz) && - str[len(str)-3] == ':' { - f += ":00" + // Check for a minute and second offset in the timezone. + if typ == oid.T_timestamptz || typ == oid.T_timetz { + for i := 3; i <= 6; i += 3 { + if str[len(str)-i] == ':' { + f += ":00" + continue + } + break + } } + // Special case for 24:00 time. // Unfortunately, golang does not parse 24:00 as a proper time. // In this case, we want to try "round to the next day", to differentiate. diff --git a/vendor/github.com/lib/pq/error.go b/vendor/github.com/lib/pq/error.go index 3d66ba7c..c19c349f 100644 --- a/vendor/github.com/lib/pq/error.go +++ b/vendor/github.com/lib/pq/error.go @@ -484,7 +484,7 @@ func (cn *conn) errRecover(err *error) { case nil: // Do nothing case runtime.Error: - cn.bad = true + cn.setBad() panic(v) case *Error: if v.Fatal() { @@ -493,8 +493,11 @@ func (cn *conn) errRecover(err *error) { *err = v } case *net.OpError: - cn.bad = true + cn.setBad() *err = v + case *safeRetryError: + cn.setBad() + *err = driver.ErrBadConn case error: if v == io.EOF || v.(error).Error() == "remote error: handshake failure" { *err = driver.ErrBadConn @@ -503,13 +506,13 @@ func (cn *conn) errRecover(err *error) { } default: - cn.bad = true + cn.setBad() panic(fmt.Sprintf("unknown error: %#v", e)) } // Any time we return ErrBadConn, we need to remember it since *Tx doesn't // mark the connection bad in database/sql. if *err == driver.ErrBadConn { - cn.bad = true + cn.setBad() } } diff --git a/vendor/github.com/lib/pq/krb.go b/vendor/github.com/lib/pq/krb.go new file mode 100644 index 00000000..408ec01f --- /dev/null +++ b/vendor/github.com/lib/pq/krb.go @@ -0,0 +1,27 @@ +package pq + +// NewGSSFunc creates a GSS authentication provider, for use with +// RegisterGSSProvider. +type NewGSSFunc func() (GSS, error) + +var newGss NewGSSFunc + +// RegisterGSSProvider registers a GSS authentication provider. For example, if +// you need to use Kerberos to authenticate with your server, add this to your +// main package: +// +// import "github.com/lib/pq/auth/kerberos" +// +// func init() { +// pq.RegisterGSSProvider(func() (pq.GSS, error) { return kerberos.NewGSS() }) +// } +func RegisterGSSProvider(newGssArg NewGSSFunc) { + newGss = newGssArg +} + +// GSS provides GSSAPI authentication (e.g., Kerberos). +type GSS interface { + GetInitToken(host string, service string) ([]byte, error) + GetInitTokenFromSpn(spn string) ([]byte, error) + Continue(inToken []byte) (done bool, outToken []byte, err error) +} diff --git a/vendor/github.com/lib/pq/ssl.go b/vendor/github.com/lib/pq/ssl.go index d9020845..e5eb9289 100644 --- a/vendor/github.com/lib/pq/ssl.go +++ b/vendor/github.com/lib/pq/ssl.go @@ -83,6 +83,16 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { // in the user's home directory. The configured files must exist and have // the correct permissions. func sslClientCertificates(tlsConf *tls.Config, o values) error { + sslinline := o["sslinline"] + if sslinline == "true" { + cert, err := tls.X509KeyPair([]byte(o["sslcert"]), []byte(o["sslkey"])) + if err != nil { + return err + } + tlsConf.Certificates = []tls.Certificate{cert} + return nil + } + // user.Current() might fail when cross-compiling. We have to ignore the // error and continue without home directory defaults, since we wouldn't // know from where to load them. @@ -137,9 +147,17 @@ func sslCertificateAuthority(tlsConf *tls.Config, o values) error { if sslrootcert := o["sslrootcert"]; len(sslrootcert) > 0 { tlsConf.RootCAs = x509.NewCertPool() - cert, err := ioutil.ReadFile(sslrootcert) - if err != nil { - return err + sslinline := o["sslinline"] + + var cert []byte + if sslinline == "true" { + cert = []byte(sslrootcert) + } else { + var err error + cert, err = ioutil.ReadFile(sslrootcert) + if err != nil { + return err + } } if !tlsConf.RootCAs.AppendCertsFromPEM(cert) { diff --git a/vendor/github.com/lib/pq/url.go b/vendor/github.com/lib/pq/url.go index f4d8a7c2..aec6e95b 100644 --- a/vendor/github.com/lib/pq/url.go +++ b/vendor/github.com/lib/pq/url.go @@ -40,10 +40,10 @@ func ParseURL(url string) (string, error) { } var kvs []string - escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) + escaper := strings.NewReplacer(`'`, `\'`, `\`, `\\`) accrue := func(k, v string) { if v != "" { - kvs = append(kvs, k+"="+escaper.Replace(v)) + kvs = append(kvs, k+"='"+escaper.Replace(v)+"'") } } diff --git a/vendor/github.com/lib/pq/user_other.go b/vendor/github.com/lib/pq/user_other.go new file mode 100644 index 00000000..f1c33134 --- /dev/null +++ b/vendor/github.com/lib/pq/user_other.go @@ -0,0 +1,9 @@ +// Package pq is a pure Go Postgres driver for the database/sql package. + +// +build js android hurd illumos zos + +package pq + +func userCurrent() (string, error) { + return "", ErrCouldNotDetectUsername +} diff --git a/vendor/gopkg.in/asn1-ber.v1/.travis.yml b/vendor/gopkg.in/asn1-ber.v1/.travis.yml deleted file mode 100644 index ecf41325..00000000 --- a/vendor/gopkg.in/asn1-ber.v1/.travis.yml +++ /dev/null @@ -1,36 +0,0 @@ -language: go -matrix: - include: - - go: 1.2.x - env: GOOS=linux GOARCH=amd64 - - go: 1.2.x - env: GOOS=linux GOARCH=386 - - go: 1.2.x - env: GOOS=windows GOARCH=amd64 - - go: 1.2.x - env: GOOS=windows GOARCH=386 - - go: 1.3.x - - go: 1.4.x - - go: 1.5.x - - go: 1.6.x - - go: 1.7.x - - go: 1.8.x - - go: 1.9.x - - go: 1.10.x - - go: 1.11.x - env: GOOS=linux GOARCH=amd64 - - go: 1.11.x - env: GOOS=linux GOARCH=386 - - go: 1.11.x - env: GOOS=windows GOARCH=amd64 - - go: 1.11.x - env: GOOS=windows GOARCH=386 - - go: tip -go_import_path: gopkg.in/asn-ber.v1 -install: - - go list -f '{{range .Imports}}{{.}} {{end}}' ./... | xargs go get -v - - go list -f '{{range .TestImports}}{{.}} {{end}}' ./... | xargs go get -v - - go get code.google.com/p/go.tools/cmd/cover || go get golang.org/x/tools/cmd/cover - - go build -v ./... -script: - - go test -v -cover ./... || go test -v ./... diff --git a/vendor/gopkg.in/ldap.v3/.gitignore b/vendor/gopkg.in/ldap.v3/.gitignore deleted file mode 100644 index e69de29b..00000000 diff --git a/vendor/gopkg.in/ldap.v3/.travis.yml b/vendor/gopkg.in/ldap.v3/.travis.yml deleted file mode 100644 index 107aa786..00000000 --- a/vendor/gopkg.in/ldap.v3/.travis.yml +++ /dev/null @@ -1,32 +0,0 @@ -sudo: false -language: go -go: - - "1.5.x" - - "1.6.x" - - "1.7.x" - - "1.8.x" - - "1.9.x" - - "1.10.x" - - "1.11.x" - - "1.12.x" - - "1.13.x" - - tip - -git: - depth: 1 - -matrix: - fast_finish: true - allow_failures: - - go: tip -go_import_path: gopkg.in/ldap.v3 -install: - - go get gopkg.in/asn1-ber.v1 - - go get code.google.com/p/go.tools/cmd/cover || go get golang.org/x/tools/cmd/cover - - go get github.com/golang/lint/golint || go get golang.org/x/lint/golint || true - - go build -v ./... -script: - - make test - - make fmt - - make vet - - make lint diff --git a/vendor/gopkg.in/ldap.v3/CONTRIBUTING.md b/vendor/gopkg.in/ldap.v3/CONTRIBUTING.md deleted file mode 100644 index a7885231..00000000 --- a/vendor/gopkg.in/ldap.v3/CONTRIBUTING.md +++ /dev/null @@ -1,12 +0,0 @@ -# Contribution Guidelines - -We welcome contribution and improvements. - -## Guiding Principles - -To begin with here is a draft from an email exchange: - - * take compatibility seriously (our semvers, compatibility with older go versions, etc) - * don't tag untested code for release - * beware of baking in implicit behavior based on other libraries/tools choices - * be as high-fidelity as possible in plumbing through LDAP data (don't mask errors or reduce power of someone using the library) diff --git a/vendor/gopkg.in/ldap.v3/Makefile b/vendor/gopkg.in/ldap.v3/Makefile deleted file mode 100644 index c4966472..00000000 --- a/vendor/gopkg.in/ldap.v3/Makefile +++ /dev/null @@ -1,82 +0,0 @@ -.PHONY: default install build test quicktest fmt vet lint - -# List of all release tags "supported" by our current Go version -# E.g. ":go1.1:go1.2:go1.3:go1.4:go1.5:go1.6:go1.7:go1.8:go1.9:go1.10:go1.11:go1.12:" -GO_RELEASE_TAGS := $(shell go list -f ':{{join (context.ReleaseTags) ":"}}:' runtime) - -# Only use the `-race` flag on newer versions of Go (version 1.3 and newer) -ifeq (,$(findstring :go1.3:,$(GO_RELEASE_TAGS))) - RACE_FLAG := -else - RACE_FLAG := -race -cpu 1,2,4 -endif - -# Run `go vet` on Go 1.12 and newer. For Go 1.5-1.11, use `go tool vet` -ifneq (,$(findstring :go1.12:,$(GO_RELEASE_TAGS))) - GO_VET := go vet \ - -atomic \ - -bool \ - -copylocks \ - -nilfunc \ - -printf \ - -rangeloops \ - -unreachable \ - -unsafeptr \ - -unusedresult \ - . -else ifneq (,$(findstring :go1.5:,$(GO_RELEASE_TAGS))) - GO_VET := go tool vet \ - -atomic \ - -bool \ - -copylocks \ - -nilfunc \ - -printf \ - -shadow \ - -rangeloops \ - -unreachable \ - -unsafeptr \ - -unusedresult \ - . -else - GO_VET := @echo "go vet skipped -- not supported on this version of Go" -endif - -default: fmt vet lint build quicktest - -install: - go get -t -v ./... - -build: - go build -v ./... - -test: - go test -v $(RACE_FLAG) -cover ./... - -quicktest: - go test ./... - -# Capture output and force failure when there is non-empty output -fmt: - @echo gofmt -l . - @OUTPUT=`gofmt -l . 2>&1`; \ - if [ "$$OUTPUT" ]; then \ - echo "gofmt must be run on the following files:"; \ - echo "$$OUTPUT"; \ - exit 1; \ - fi - -vet: - $(GO_VET) - -# https://github.com/golang/lint -# go get github.com/golang/lint/golint -# Capture output and force failure when there is non-empty output -# Only run on go1.5+ -lint: - @echo golint ./... - @OUTPUT=`command -v golint >/dev/null 2>&1 && golint ./... 2>&1`; \ - if [ "$$OUTPUT" ]; then \ - echo "golint errors:"; \ - echo "$$OUTPUT"; \ - exit 1; \ - fi diff --git a/vendor/gopkg.in/ldap.v3/README.md b/vendor/gopkg.in/ldap.v3/README.md deleted file mode 100644 index 25cf730b..00000000 --- a/vendor/gopkg.in/ldap.v3/README.md +++ /dev/null @@ -1,54 +0,0 @@ -[![GoDoc](https://godoc.org/gopkg.in/ldap.v3?status.svg)](https://godoc.org/gopkg.in/ldap.v3) -[![Build Status](https://travis-ci.org/go-ldap/ldap.svg)](https://travis-ci.org/go-ldap/ldap) - -# Basic LDAP v3 functionality for the GO programming language. - -## Install - -For the latest version use: - - go get gopkg.in/ldap.v3 - -Import the latest version with: - - import "gopkg.in/ldap.v3" - -## Required Libraries: - - - gopkg.in/asn1-ber.v1 - -## Features: - - - Connecting to LDAP server (non-TLS, TLS, STARTTLS) - - Binding to LDAP server - - Searching for entries - - Filter Compile / Decompile - - Paging Search Results - - Modify Requests / Responses - - Add Requests / Responses - - Delete Requests / Responses - - Modify DN Requests / Responses - -## Examples: - - - search - - modify - -## Contributing: - -Bug reports and pull requests are welcome! - -Before submitting a pull request, please make sure tests and verification scripts pass: -``` -make all -``` - -To set up a pre-push hook to run the tests and verify scripts before pushing: -``` -ln -s ../../.githooks/pre-push .git/hooks/pre-push -``` - ---- -The Go gopher was designed by Renee French. (http://reneefrench.blogspot.com/) -The design is licensed under the Creative Commons 3.0 Attributions license. -Read this article for more details: http://blog.golang.org/gopher diff --git a/vendor/gopkg.in/ldap.v3/bind.go b/vendor/gopkg.in/ldap.v3/bind.go deleted file mode 100644 index 7b5e657a..00000000 --- a/vendor/gopkg.in/ldap.v3/bind.go +++ /dev/null @@ -1,152 +0,0 @@ -package ldap - -import ( - "errors" - "fmt" - - ber "gopkg.in/asn1-ber.v1" -) - -// SimpleBindRequest represents a username/password bind operation -type SimpleBindRequest struct { - // Username is the name of the Directory object that the client wishes to bind as - Username string - // Password is the credentials to bind with - Password string - // Controls are optional controls to send with the bind request - Controls []Control - // AllowEmptyPassword sets whether the client allows binding with an empty password - // (normally used for unauthenticated bind). - AllowEmptyPassword bool -} - -// SimpleBindResult contains the response from the server -type SimpleBindResult struct { - Controls []Control -} - -// NewSimpleBindRequest returns a bind request -func NewSimpleBindRequest(username string, password string, controls []Control) *SimpleBindRequest { - return &SimpleBindRequest{ - Username: username, - Password: password, - Controls: controls, - AllowEmptyPassword: false, - } -} - -func (req *SimpleBindRequest) appendTo(envelope *ber.Packet) error { - pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") - pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) - pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Username, "User Name")) - pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.Password, "Password")) - - envelope.AppendChild(pkt) - if len(req.Controls) > 0 { - envelope.AppendChild(encodeControls(req.Controls)) - } - - return nil -} - -// SimpleBind performs the simple bind operation defined in the given request -func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResult, error) { - if simpleBindRequest.Password == "" && !simpleBindRequest.AllowEmptyPassword { - return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client")) - } - - msgCtx, err := l.doRequest(simpleBindRequest) - if err != nil { - return nil, err - } - defer l.finishMessage(msgCtx) - - packet, err := l.readPacket(msgCtx) - if err != nil { - return nil, err - } - - result := &SimpleBindResult{ - Controls: make([]Control, 0), - } - - if len(packet.Children) == 3 { - for _, child := range packet.Children[2].Children { - decodedChild, decodeErr := DecodeControl(child) - if decodeErr != nil { - return nil, fmt.Errorf("failed to decode child control: %s", decodeErr) - } - result.Controls = append(result.Controls, decodedChild) - } - } - - err = GetLDAPError(packet) - return result, err -} - -// Bind performs a bind with the given username and password. -// -// It does not allow unauthenticated bind (i.e. empty password). Use the UnauthenticatedBind method -// for that. -func (l *Conn) Bind(username, password string) error { - req := &SimpleBindRequest{ - Username: username, - Password: password, - AllowEmptyPassword: false, - } - _, err := l.SimpleBind(req) - return err -} - -// UnauthenticatedBind performs an unauthenticated bind. -// -// A username may be provided for trace (e.g. logging) purpose only, but it is normally not -// authenticated or otherwise validated by the LDAP server. -// -// See https://tools.ietf.org/html/rfc4513#section-5.1.2 . -// See https://tools.ietf.org/html/rfc4513#section-6.3.1 . -func (l *Conn) UnauthenticatedBind(username string) error { - req := &SimpleBindRequest{ - Username: username, - Password: "", - AllowEmptyPassword: true, - } - _, err := l.SimpleBind(req) - return err -} - -var externalBindRequest = requestFunc(func(envelope *ber.Packet) error { - pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") - pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) - pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name")) - - saslAuth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication") - saslAuth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "EXTERNAL", "SASL Mech")) - saslAuth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "SASL Cred")) - - pkt.AppendChild(saslAuth) - - envelope.AppendChild(pkt) - - return nil -}) - -// ExternalBind performs SASL/EXTERNAL authentication. -// -// Use ldap.DialURL("ldapi://") to connect to the Unix socket before ExternalBind. -// -// See https://tools.ietf.org/html/rfc4422#appendix-A -func (l *Conn) ExternalBind() error { - msgCtx, err := l.doRequest(externalBindRequest) - if err != nil { - return err - } - defer l.finishMessage(msgCtx) - - packet, err := l.readPacket(msgCtx) - if err != nil { - return err - } - - return GetLDAPError(packet) -} diff --git a/vendor/modules.txt b/vendor/modules.txt index df1b54ae..13f621ca 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -1,3 +1,6 @@ +# github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c +## explicit +github.com/Azure/go-ntlmssp # github.com/BurntSushi/toml v0.3.1 ## explicit github.com/BurntSushi/toml @@ -7,12 +10,13 @@ github.com/andygrunwald/go-jira # github.com/codegangsta/negroni v1.0.0 ## explicit github.com/codegangsta/negroni -# github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc +# github.com/denisenkom/go-mssqldb v0.10.1-0.20210728001037-ee2fbc25fd8f ## explicit; go 1.11 github.com/denisenkom/go-mssqldb github.com/denisenkom/go-mssqldb/internal/cp github.com/denisenkom/go-mssqldb/internal/decimal github.com/denisenkom/go-mssqldb/internal/querytext +github.com/denisenkom/go-mssqldb/msdsn # github.com/dgrijalva/jwt-go v3.2.0+incompatible ## explicit github.com/dgrijalva/jwt-go @@ -31,7 +35,13 @@ github.com/documize/slug # github.com/fatih/structs v1.0.0 ## explicit github.com/fatih/structs -# github.com/go-sql-driver/mysql v1.5.0 +# github.com/go-asn1-ber/asn1-ber v1.5.3 +## explicit; go 1.13 +github.com/go-asn1-ber/asn1-ber +# github.com/go-ldap/ldap/v3 v3.4.1 +## explicit; go 1.13 +github.com/go-ldap/ldap/v3 +# github.com/go-sql-driver/mysql v1.6.0 ## explicit; go 1.10 github.com/go-sql-driver/mysql # github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe @@ -61,8 +71,8 @@ github.com/jmoiron/sqlx github.com/jmoiron/sqlx/reflectx # github.com/kr/pretty v0.2.0 ## explicit; go 1.12 -# github.com/lib/pq v1.5.2 -## explicit +# github.com/lib/pq v1.10.2 +## explicit; go 1.13 github.com/lib/pq github.com/lib/pq/oid github.com/lib/pq/scram @@ -88,7 +98,7 @@ github.com/shurcooL/sanitized_anchor_name ## explicit github.com/trivago/tgo/tcontainer github.com/trivago/tgo/treflect -# golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 +# golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 ## explicit; go 1.11 golang.org/x/crypto/bcrypt golang.org/x/crypto/blowfish @@ -146,17 +156,11 @@ google.golang.org/protobuf/runtime/protoimpl # gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc ## explicit gopkg.in/alexcesaro/quotedprintable.v3 -# gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d -## explicit -gopkg.in/asn1-ber.v1 # gopkg.in/cas.v2 v2.1.0 ## explicit gopkg.in/cas.v2 # gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 ## explicit -# gopkg.in/ldap.v3 v3.1.0 -## explicit -gopkg.in/ldap.v3 # gopkg.in/yaml.v2 v2.2.2 ## explicit gopkg.in/yaml.v2