mirror of
https://github.com/documize/community.git
synced 2025-07-18 20:59:43 +02:00
Sync with Community
This commit is contained in:
parent
df8f650319
commit
989b7cd62c
123 changed files with 5054 additions and 2015 deletions
|
@ -20,8 +20,8 @@ import (
|
|||
"github.com/documize/community/core/stringutil"
|
||||
lm "github.com/documize/community/model/auth"
|
||||
"github.com/documize/community/model/user"
|
||||
ld "github.com/go-ldap/ldap/v3"
|
||||
"github.com/pkg/errors"
|
||||
ld "gopkg.in/ldap.v3"
|
||||
)
|
||||
|
||||
// Connect establishes connection to LDAP server.
|
||||
|
|
13
go.mod
13
go.mod
|
@ -6,13 +6,14 @@ require (
|
|||
github.com/BurntSushi/toml v0.3.1
|
||||
github.com/andygrunwald/go-jira v1.12.0
|
||||
github.com/codegangsta/negroni v1.0.0
|
||||
github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc
|
||||
github.com/denisenkom/go-mssqldb v0.10.1-0.20210728001037-ee2fbc25fd8f
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible
|
||||
github.com/documize/blackfriday v2.0.0+incompatible
|
||||
github.com/documize/glick v0.0.0-20160503134043-a8ccbef88237
|
||||
github.com/documize/html-diff v0.0.0-20160503140253-f61c192c7796
|
||||
github.com/documize/slug v1.1.1
|
||||
github.com/go-sql-driver/mysql v1.5.0
|
||||
github.com/go-ldap/ldap/v3 v3.4.1
|
||||
github.com/go-sql-driver/mysql v1.6.0
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b // indirect
|
||||
github.com/golang/protobuf v1.4.0 // indirect
|
||||
github.com/google/go-github v17.0.0+incompatible
|
||||
|
@ -21,29 +22,29 @@ require (
|
|||
github.com/gorilla/mux v1.7.4
|
||||
github.com/jmoiron/sqlx v1.2.0
|
||||
github.com/kr/pretty v0.2.0 // indirect
|
||||
github.com/lib/pq v1.5.2
|
||||
github.com/lib/pq v1.10.2
|
||||
github.com/mb0/diff v0.0.0-20131118162322-d8d9a906c24d // indirect
|
||||
github.com/microcosm-cc/bluemonday v1.0.2
|
||||
github.com/nu7hatch/gouuid v0.0.0-20131221200532-179d4d0c4d8d
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/rainycape/unidecode v0.0.0-20150907023854-cb7f23ec59be // indirect
|
||||
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
|
||||
golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37
|
||||
golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9
|
||||
golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d
|
||||
google.golang.org/appengine v1.6.6 // indirect
|
||||
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc
|
||||
gopkg.in/cas.v2 v2.1.0
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
|
||||
gopkg.in/ldap.v3 v3.1.0
|
||||
gopkg.in/yaml.v2 v2.2.2 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c // indirect
|
||||
github.com/fatih/structs v1.0.0 // indirect
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.3 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe // indirect
|
||||
github.com/trivago/tgo v1.0.1 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
google.golang.org/protobuf v1.21.0 // indirect
|
||||
gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d // indirect
|
||||
)
|
||||
|
|
27
go.sum
27
go.sum
|
@ -1,12 +1,14 @@
|
|||
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c h1:/IBSNwUN8+eKzUzbJPqhK839ygXJ82sde8x3ogr6R28=
|
||||
github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU=
|
||||
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/andygrunwald/go-jira v1.12.0 h1:JJi2cEDmDxVtTXxC8ruLDbtOU6pA4OLeL0niyfNcoWw=
|
||||
github.com/andygrunwald/go-jira v1.12.0/go.mod h1:jYi4kFDbRPZTJdJOVJO4mpMMIwdB+rcZwSO58DzPd2I=
|
||||
github.com/codegangsta/negroni v1.0.0 h1:+aYywywx4bnKXWvoWtRfJ91vC59NbEhEY03sZjQhbVY=
|
||||
github.com/codegangsta/negroni v1.0.0/go.mod h1:v0y3T5G7Y1UlFfyxFn/QLRU4a2EuNau2iZY63YTKWo0=
|
||||
github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc h1:VRRKCwnzqk8QCaRC4os14xoKDdbHqqlJtJA0oc1ZAjg=
|
||||
github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
|
||||
github.com/denisenkom/go-mssqldb v0.10.1-0.20210728001037-ee2fbc25fd8f h1:3UtVZFKTqZwLZi65UbfSIqYR75aUTP8FYUAEQnMXSJs=
|
||||
github.com/denisenkom/go-mssqldb v0.10.1-0.20210728001037-ee2fbc25fd8f/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
|
||||
github.com/documize/blackfriday v2.0.0+incompatible h1:qjRGAIVwZlHBtA/b9u0LtseYM3v3WpIXofPCwNjcUsE=
|
||||
|
@ -19,9 +21,14 @@ github.com/documize/slug v1.1.1 h1:OCJRbWxbOgrgiBYSbVzuFwxb9wVu4oy1LxvLJOC2s8Y=
|
|||
github.com/documize/slug v1.1.1/go.mod h1:Vi7fQ5PzeOpXAiIrk1WCEDRihjTfU/bf4eWUPSD7tkU=
|
||||
github.com/fatih/structs v1.0.0 h1:BrX964Rv5uQ3wwS+KRUAJCBBw5PQmgJfJ6v4yly5QwU=
|
||||
github.com/fatih/structs v1.0.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M=
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.1/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.3 h1:u7utq56RUFiynqUzgVMFDymapcOtQ/MZkh3H4QYkxag=
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.3/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
|
||||
github.com/go-ldap/ldap/v3 v3.4.1 h1:fU/0xli6HY02ocbMuozHAYsaHLcnkLjvho2r5a34BUU=
|
||||
github.com/go-ldap/ldap/v3 v3.4.1/go.mod h1:iYS1MdmrmceOJ1QOTnRXrIs7i3kloqtmGQjRvjKpyMg=
|
||||
github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
|
||||
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
|
||||
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
|
||||
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY=
|
||||
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58=
|
||||
|
@ -55,8 +62,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
|||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/lib/pq v1.5.2 h1:yTSXVswvWUOQ3k1sd7vJfDrbSl8lKuscqFJRqjC0ifw=
|
||||
github.com/lib/pq v1.5.2/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8=
|
||||
github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/mattn/go-sqlite3 v1.9.0 h1:pDRiWfl+++eC2FEFRy6jXmQlvp4Yh3z1MJKg4UeYM/4=
|
||||
github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
github.com/mb0/diff v0.0.0-20131118162322-d8d9a906c24d h1:eAS2t2Vy+6psf9LZ4T5WXWsbkBt3Tu5PWekJy5AGyEU=
|
||||
|
@ -77,8 +84,8 @@ github.com/trivago/tgo v1.0.1/go.mod h1:w4dpD+3tzNIIiIfkWWa85w5/B77tlvdZckQ+6PkF
|
|||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 h1:cg5LA/zNPRzIXIWSCxQW10Rvpy94aQh3LT/ShoCpkHw=
|
||||
golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 h1:vEg9joUBmeBcK9iSJftGNf3coIG4HqZElCPehJsfAYM=
|
||||
golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
|
@ -109,14 +116,10 @@ google.golang.org/protobuf v1.21.0 h1:qdOKuR/EIArgaWNjetjgTzgVTAZ+S/WXVrq9HW9zim
|
|||
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
|
||||
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk=
|
||||
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk=
|
||||
gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d h1:TxyelI5cVkbREznMhfzycHdkp5cLA7DpE+GKjSslYhM=
|
||||
gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d/go.mod h1:cuepJuh7vyXfUyUwEgHQXw849cJrilpS5NeIjOWESAw=
|
||||
gopkg.in/cas.v2 v2.1.0 h1:sbYBMWtpanwLH75GAWjIp5JnON9wa3NodLZhouu0G9I=
|
||||
gopkg.in/cas.v2 v2.1.0/go.mod h1:M291I/o/u3eeMl9SkXMPYpWasHp7weFY9G/pM5DbB+g=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/ldap.v3 v3.1.0 h1:DIDWEjI7vQWREh0S8X5/NFPCZ3MCVd55LmXKPW4XLGE=
|
||||
gopkg.in/ldap.v3 v3.1.0/go.mod h1:dQjCc0R0kfyFjIlWNMH1DORwUASZyDxo2Ry1B51dXaQ=
|
||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
/* eslint-disable ember/no-actions-hash */
|
||||
/* eslint-disable ember/no-classic-classes */
|
||||
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
|
||||
//
|
||||
// This software (Documize Community Edition) is licensed under
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
/* eslint-disable ember/no-classic-classes */
|
||||
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
|
||||
//
|
||||
// This software (Documize Community Edition) is licensed under
|
||||
|
|
|
@ -186,7 +186,8 @@ export default Router.map(function () {
|
|||
path: 'updates'
|
||||
});
|
||||
|
||||
this.route('not-found', {
|
||||
this.route('auth/login', {
|
||||
path: '/*wildcard'
|
||||
// path: '/*wildcard'
|
||||
});
|
||||
});
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
/* eslint-disable ember/no-classic-classes */
|
||||
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
|
||||
//
|
||||
// This software (Documize Community Edition) is licensed under
|
||||
|
|
|
@ -43,12 +43,11 @@ func (m *middleware) cors(w http.ResponseWriter, r *http.Request, next http.Hand
|
|||
w.Header().Set("Access-Control-Allow-Headers", "host, content-type, accept, authorization, origin, referer, user-agent, cache-control, x-requested-with, range")
|
||||
w.Header().Set("Access-Control-Expose-Headers", "x-documize-version, x-documize-status, x-documize-filename, x-documize-subscription, Content-Disposition, Content-Length")
|
||||
|
||||
w.Header().Add("X-Documize-Version", m.Runtime.Product.Version)
|
||||
w.Header().Add("Cache-Control", "no-cache")
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.Header().Add("X-Documize-Version", m.Runtime.Product.Version)
|
||||
w.Header().Add("Cache-Control", "no-cache")
|
||||
|
||||
w.Write([]byte(""))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
17
vendor/github.com/Azure/go-ntlmssp/.travis.yml
generated
vendored
Normal file
17
vendor/github.com/Azure/go-ntlmssp/.travis.yml
generated
vendored
Normal file
|
@ -0,0 +1,17 @@
|
|||
sudo: false
|
||||
|
||||
language: go
|
||||
|
||||
before_script:
|
||||
- go get -u golang.org/x/lint/golint
|
||||
|
||||
go:
|
||||
- 1.10.x
|
||||
- master
|
||||
|
||||
script:
|
||||
- test -z "$(gofmt -s -l . | tee /dev/stderr)"
|
||||
- test -z "$(golint ./... | tee /dev/stderr)"
|
||||
- go vet ./...
|
||||
- go build -v ./...
|
||||
- go test -v ./...
|
21
vendor/github.com/Azure/go-ntlmssp/LICENSE
generated
vendored
Normal file
21
vendor/github.com/Azure/go-ntlmssp/LICENSE
generated
vendored
Normal file
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2016 Microsoft
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
29
vendor/github.com/Azure/go-ntlmssp/README.md
generated
vendored
Normal file
29
vendor/github.com/Azure/go-ntlmssp/README.md
generated
vendored
Normal file
|
@ -0,0 +1,29 @@
|
|||
# go-ntlmssp
|
||||
Golang package that provides NTLM/Negotiate authentication over HTTP
|
||||
|
||||
[](https://godoc.org/github.com/Azure/go-ntlmssp) [](https://travis-ci.org/Azure/go-ntlmssp)
|
||||
|
||||
Protocol details from https://msdn.microsoft.com/en-us/library/cc236621.aspx
|
||||
Implementation hints from http://davenport.sourceforge.net/ntlm.html
|
||||
|
||||
This package only implements authentication, no key exchange or encryption. It
|
||||
only supports Unicode (UTF16LE) encoding of protocol strings, no OEM encoding.
|
||||
This package implements NTLMv2.
|
||||
|
||||
# Usage
|
||||
|
||||
```
|
||||
url, user, password := "http://www.example.com/secrets", "robpike", "pw123"
|
||||
client := &http.Client{
|
||||
Transport: ntlmssp.Negotiator{
|
||||
RoundTripper:&http.Transport{},
|
||||
},
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("GET", url, nil)
|
||||
req.SetBasicAuth(user, password)
|
||||
res, _ := client.Do(req)
|
||||
```
|
||||
|
||||
-----
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
183
vendor/github.com/Azure/go-ntlmssp/authenticate_message.go
generated
vendored
Normal file
183
vendor/github.com/Azure/go-ntlmssp/authenticate_message.go
generated
vendored
Normal file
|
@ -0,0 +1,183 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type authenicateMessage struct {
|
||||
LmChallengeResponse []byte
|
||||
NtChallengeResponse []byte
|
||||
|
||||
TargetName string
|
||||
UserName string
|
||||
|
||||
// only set if negotiateFlag_NTLMSSP_NEGOTIATE_KEY_EXCH
|
||||
EncryptedRandomSessionKey []byte
|
||||
|
||||
NegotiateFlags negotiateFlags
|
||||
|
||||
MIC []byte
|
||||
}
|
||||
|
||||
type authenticateMessageFields struct {
|
||||
messageHeader
|
||||
LmChallengeResponse varField
|
||||
NtChallengeResponse varField
|
||||
TargetName varField
|
||||
UserName varField
|
||||
Workstation varField
|
||||
_ [8]byte
|
||||
NegotiateFlags negotiateFlags
|
||||
}
|
||||
|
||||
func (m authenicateMessage) MarshalBinary() ([]byte, error) {
|
||||
if !m.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEUNICODE) {
|
||||
return nil, errors.New("Only unicode is supported")
|
||||
}
|
||||
|
||||
target, user := toUnicode(m.TargetName), toUnicode(m.UserName)
|
||||
workstation := toUnicode("go-ntlmssp")
|
||||
|
||||
ptr := binary.Size(&authenticateMessageFields{})
|
||||
f := authenticateMessageFields{
|
||||
messageHeader: newMessageHeader(3),
|
||||
NegotiateFlags: m.NegotiateFlags,
|
||||
LmChallengeResponse: newVarField(&ptr, len(m.LmChallengeResponse)),
|
||||
NtChallengeResponse: newVarField(&ptr, len(m.NtChallengeResponse)),
|
||||
TargetName: newVarField(&ptr, len(target)),
|
||||
UserName: newVarField(&ptr, len(user)),
|
||||
Workstation: newVarField(&ptr, len(workstation)),
|
||||
}
|
||||
|
||||
f.NegotiateFlags.Unset(negotiateFlagNTLMSSPNEGOTIATEVERSION)
|
||||
|
||||
b := bytes.Buffer{}
|
||||
if err := binary.Write(&b, binary.LittleEndian, &f); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Write(&b, binary.LittleEndian, &m.LmChallengeResponse); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Write(&b, binary.LittleEndian, &m.NtChallengeResponse); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Write(&b, binary.LittleEndian, &target); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Write(&b, binary.LittleEndian, &user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Write(&b, binary.LittleEndian, &workstation); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
//ProcessChallenge crafts an AUTHENTICATE message in response to the CHALLENGE message
|
||||
//that was received from the server
|
||||
func ProcessChallenge(challengeMessageData []byte, user, password string) ([]byte, error) {
|
||||
if user == "" && password == "" {
|
||||
return nil, errors.New("Anonymous authentication not supported")
|
||||
}
|
||||
|
||||
var cm challengeMessage
|
||||
if err := cm.UnmarshalBinary(challengeMessageData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATELMKEY) {
|
||||
return nil, errors.New("Only NTLM v2 is supported, but server requested v1 (NTLMSSP_NEGOTIATE_LM_KEY)")
|
||||
}
|
||||
if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEKEYEXCH) {
|
||||
return nil, errors.New("Key exchange requested but not supported (NTLMSSP_NEGOTIATE_KEY_EXCH)")
|
||||
}
|
||||
|
||||
am := authenicateMessage{
|
||||
UserName: user,
|
||||
TargetName: cm.TargetName,
|
||||
NegotiateFlags: cm.NegotiateFlags,
|
||||
}
|
||||
|
||||
timestamp := cm.TargetInfo[avIDMsvAvTimestamp]
|
||||
if timestamp == nil { // no time sent, take current time
|
||||
ft := uint64(time.Now().UnixNano()) / 100
|
||||
ft += 116444736000000000 // add time between unix & windows offset
|
||||
timestamp = make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(timestamp, ft)
|
||||
}
|
||||
|
||||
clientChallenge := make([]byte, 8)
|
||||
rand.Reader.Read(clientChallenge)
|
||||
|
||||
ntlmV2Hash := getNtlmV2Hash(password, user, cm.TargetName)
|
||||
|
||||
am.NtChallengeResponse = computeNtlmV2Response(ntlmV2Hash,
|
||||
cm.ServerChallenge[:], clientChallenge, timestamp, cm.TargetInfoRaw)
|
||||
|
||||
if cm.TargetInfoRaw == nil {
|
||||
am.LmChallengeResponse = computeLmV2Response(ntlmV2Hash,
|
||||
cm.ServerChallenge[:], clientChallenge)
|
||||
}
|
||||
return am.MarshalBinary()
|
||||
}
|
||||
|
||||
func ProcessChallengeWithHash(challengeMessageData []byte, user, hash string) ([]byte, error) {
|
||||
if user == "" && hash == "" {
|
||||
return nil, errors.New("Anonymous authentication not supported")
|
||||
}
|
||||
|
||||
var cm challengeMessage
|
||||
if err := cm.UnmarshalBinary(challengeMessageData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATELMKEY) {
|
||||
return nil, errors.New("Only NTLM v2 is supported, but server requested v1 (NTLMSSP_NEGOTIATE_LM_KEY)")
|
||||
}
|
||||
if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEKEYEXCH) {
|
||||
return nil, errors.New("Key exchange requested but not supported (NTLMSSP_NEGOTIATE_KEY_EXCH)")
|
||||
}
|
||||
|
||||
am := authenicateMessage{
|
||||
UserName: user,
|
||||
TargetName: cm.TargetName,
|
||||
NegotiateFlags: cm.NegotiateFlags,
|
||||
}
|
||||
|
||||
timestamp := cm.TargetInfo[avIDMsvAvTimestamp]
|
||||
if timestamp == nil { // no time sent, take current time
|
||||
ft := uint64(time.Now().UnixNano()) / 100
|
||||
ft += 116444736000000000 // add time between unix & windows offset
|
||||
timestamp = make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(timestamp, ft)
|
||||
}
|
||||
|
||||
clientChallenge := make([]byte, 8)
|
||||
rand.Reader.Read(clientChallenge)
|
||||
|
||||
hashParts := strings.Split(hash, ":")
|
||||
if len(hashParts) > 1 {
|
||||
hash = hashParts[1]
|
||||
}
|
||||
hashBytes, err := hex.DecodeString(hash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ntlmV2Hash := hmacMd5(hashBytes, toUnicode(strings.ToUpper(user)+cm.TargetName))
|
||||
|
||||
am.NtChallengeResponse = computeNtlmV2Response(ntlmV2Hash,
|
||||
cm.ServerChallenge[:], clientChallenge, timestamp, cm.TargetInfoRaw)
|
||||
|
||||
if cm.TargetInfoRaw == nil {
|
||||
am.LmChallengeResponse = computeLmV2Response(ntlmV2Hash,
|
||||
cm.ServerChallenge[:], clientChallenge)
|
||||
}
|
||||
return am.MarshalBinary()
|
||||
}
|
37
vendor/github.com/Azure/go-ntlmssp/authheader.go
generated
vendored
Normal file
37
vendor/github.com/Azure/go-ntlmssp/authheader.go
generated
vendored
Normal file
|
@ -0,0 +1,37 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type authheader string
|
||||
|
||||
func (h authheader) IsBasic() bool {
|
||||
return strings.HasPrefix(string(h), "Basic ")
|
||||
}
|
||||
|
||||
func (h authheader) IsNegotiate() bool {
|
||||
return strings.HasPrefix(string(h), "Negotiate")
|
||||
}
|
||||
|
||||
func (h authheader) IsNTLM() bool {
|
||||
return strings.HasPrefix(string(h), "NTLM")
|
||||
}
|
||||
|
||||
func (h authheader) GetData() ([]byte, error) {
|
||||
p := strings.Split(string(h), " ")
|
||||
if len(p) < 2 {
|
||||
return nil, nil
|
||||
}
|
||||
return base64.StdEncoding.DecodeString(string(p[1]))
|
||||
}
|
||||
|
||||
func (h authheader) GetBasicCreds() (username, password string, err error) {
|
||||
d, err := h.GetData()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
parts := strings.SplitN(string(d), ":", 2)
|
||||
return parts[0], parts[1], nil
|
||||
}
|
17
vendor/github.com/Azure/go-ntlmssp/avids.go
generated
vendored
Normal file
17
vendor/github.com/Azure/go-ntlmssp/avids.go
generated
vendored
Normal file
|
@ -0,0 +1,17 @@
|
|||
package ntlmssp
|
||||
|
||||
type avID uint16
|
||||
|
||||
const (
|
||||
avIDMsvAvEOL avID = iota
|
||||
avIDMsvAvNbComputerName
|
||||
avIDMsvAvNbDomainName
|
||||
avIDMsvAvDNSComputerName
|
||||
avIDMsvAvDNSDomainName
|
||||
avIDMsvAvDNSTreeName
|
||||
avIDMsvAvFlags
|
||||
avIDMsvAvTimestamp
|
||||
avIDMsvAvSingleHost
|
||||
avIDMsvAvTargetName
|
||||
avIDMsvChannelBindings
|
||||
)
|
82
vendor/github.com/Azure/go-ntlmssp/challenge_message.go
generated
vendored
Normal file
82
vendor/github.com/Azure/go-ntlmssp/challenge_message.go
generated
vendored
Normal file
|
@ -0,0 +1,82 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type challengeMessageFields struct {
|
||||
messageHeader
|
||||
TargetName varField
|
||||
NegotiateFlags negotiateFlags
|
||||
ServerChallenge [8]byte
|
||||
_ [8]byte
|
||||
TargetInfo varField
|
||||
}
|
||||
|
||||
func (m challengeMessageFields) IsValid() bool {
|
||||
return m.messageHeader.IsValid() && m.MessageType == 2
|
||||
}
|
||||
|
||||
type challengeMessage struct {
|
||||
challengeMessageFields
|
||||
TargetName string
|
||||
TargetInfo map[avID][]byte
|
||||
TargetInfoRaw []byte
|
||||
}
|
||||
|
||||
func (m *challengeMessage) UnmarshalBinary(data []byte) error {
|
||||
r := bytes.NewReader(data)
|
||||
err := binary.Read(r, binary.LittleEndian, &m.challengeMessageFields)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !m.challengeMessageFields.IsValid() {
|
||||
return fmt.Errorf("Message is not a valid challenge message: %+v", m.challengeMessageFields.messageHeader)
|
||||
}
|
||||
|
||||
if m.challengeMessageFields.TargetName.Len > 0 {
|
||||
m.TargetName, err = m.challengeMessageFields.TargetName.ReadStringFrom(data, m.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEUNICODE))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if m.challengeMessageFields.TargetInfo.Len > 0 {
|
||||
d, err := m.challengeMessageFields.TargetInfo.ReadFrom(data)
|
||||
m.TargetInfoRaw = d
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.TargetInfo = make(map[avID][]byte)
|
||||
r := bytes.NewReader(d)
|
||||
for {
|
||||
var id avID
|
||||
var l uint16
|
||||
err = binary.Read(r, binary.LittleEndian, &id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if id == avIDMsvAvEOL {
|
||||
break
|
||||
}
|
||||
|
||||
err = binary.Read(r, binary.LittleEndian, &l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value := make([]byte, l)
|
||||
n, err := r.Read(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n != int(l) {
|
||||
return fmt.Errorf("Expected to read %d bytes, got only %d", l, n)
|
||||
}
|
||||
m.TargetInfo[id] = value
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
21
vendor/github.com/Azure/go-ntlmssp/messageheader.go
generated
vendored
Normal file
21
vendor/github.com/Azure/go-ntlmssp/messageheader.go
generated
vendored
Normal file
|
@ -0,0 +1,21 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
)
|
||||
|
||||
var signature = [8]byte{'N', 'T', 'L', 'M', 'S', 'S', 'P', 0}
|
||||
|
||||
type messageHeader struct {
|
||||
Signature [8]byte
|
||||
MessageType uint32
|
||||
}
|
||||
|
||||
func (h messageHeader) IsValid() bool {
|
||||
return bytes.Equal(h.Signature[:], signature[:]) &&
|
||||
h.MessageType > 0 && h.MessageType < 4
|
||||
}
|
||||
|
||||
func newMessageHeader(messageType uint32) messageHeader {
|
||||
return messageHeader{signature, messageType}
|
||||
}
|
52
vendor/github.com/Azure/go-ntlmssp/negotiate_flags.go
generated
vendored
Normal file
52
vendor/github.com/Azure/go-ntlmssp/negotiate_flags.go
generated
vendored
Normal file
|
@ -0,0 +1,52 @@
|
|||
package ntlmssp
|
||||
|
||||
type negotiateFlags uint32
|
||||
|
||||
const (
|
||||
/*A*/ negotiateFlagNTLMSSPNEGOTIATEUNICODE negotiateFlags = 1 << 0
|
||||
/*B*/ negotiateFlagNTLMNEGOTIATEOEM = 1 << 1
|
||||
/*C*/ negotiateFlagNTLMSSPREQUESTTARGET = 1 << 2
|
||||
|
||||
/*D*/
|
||||
negotiateFlagNTLMSSPNEGOTIATESIGN = 1 << 4
|
||||
/*E*/ negotiateFlagNTLMSSPNEGOTIATESEAL = 1 << 5
|
||||
/*F*/ negotiateFlagNTLMSSPNEGOTIATEDATAGRAM = 1 << 6
|
||||
/*G*/ negotiateFlagNTLMSSPNEGOTIATELMKEY = 1 << 7
|
||||
|
||||
/*H*/
|
||||
negotiateFlagNTLMSSPNEGOTIATENTLM = 1 << 9
|
||||
|
||||
/*J*/
|
||||
negotiateFlagANONYMOUS = 1 << 11
|
||||
/*K*/ negotiateFlagNTLMSSPNEGOTIATEOEMDOMAINSUPPLIED = 1 << 12
|
||||
/*L*/ negotiateFlagNTLMSSPNEGOTIATEOEMWORKSTATIONSUPPLIED = 1 << 13
|
||||
|
||||
/*M*/
|
||||
negotiateFlagNTLMSSPNEGOTIATEALWAYSSIGN = 1 << 15
|
||||
/*N*/ negotiateFlagNTLMSSPTARGETTYPEDOMAIN = 1 << 16
|
||||
/*O*/ negotiateFlagNTLMSSPTARGETTYPESERVER = 1 << 17
|
||||
|
||||
/*P*/
|
||||
negotiateFlagNTLMSSPNEGOTIATEEXTENDEDSESSIONSECURITY = 1 << 19
|
||||
/*Q*/ negotiateFlagNTLMSSPNEGOTIATEIDENTIFY = 1 << 20
|
||||
|
||||
/*R*/
|
||||
negotiateFlagNTLMSSPREQUESTNONNTSESSIONKEY = 1 << 22
|
||||
/*S*/ negotiateFlagNTLMSSPNEGOTIATETARGETINFO = 1 << 23
|
||||
|
||||
/*T*/
|
||||
negotiateFlagNTLMSSPNEGOTIATEVERSION = 1 << 25
|
||||
|
||||
/*U*/
|
||||
negotiateFlagNTLMSSPNEGOTIATE128 = 1 << 29
|
||||
/*V*/ negotiateFlagNTLMSSPNEGOTIATEKEYEXCH = 1 << 30
|
||||
/*W*/ negotiateFlagNTLMSSPNEGOTIATE56 = 1 << 31
|
||||
)
|
||||
|
||||
func (field negotiateFlags) Has(flags negotiateFlags) bool {
|
||||
return field&flags == flags
|
||||
}
|
||||
|
||||
func (field *negotiateFlags) Unset(flags negotiateFlags) {
|
||||
*field = *field ^ (*field & flags)
|
||||
}
|
64
vendor/github.com/Azure/go-ntlmssp/negotiate_message.go
generated
vendored
Normal file
64
vendor/github.com/Azure/go-ntlmssp/negotiate_message.go
generated
vendored
Normal file
|
@ -0,0 +1,64 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const expMsgBodyLen = 40
|
||||
|
||||
type negotiateMessageFields struct {
|
||||
messageHeader
|
||||
NegotiateFlags negotiateFlags
|
||||
|
||||
Domain varField
|
||||
Workstation varField
|
||||
|
||||
Version
|
||||
}
|
||||
|
||||
var defaultFlags = negotiateFlagNTLMSSPNEGOTIATETARGETINFO |
|
||||
negotiateFlagNTLMSSPNEGOTIATE56 |
|
||||
negotiateFlagNTLMSSPNEGOTIATE128 |
|
||||
negotiateFlagNTLMSSPNEGOTIATEUNICODE |
|
||||
negotiateFlagNTLMSSPNEGOTIATEEXTENDEDSESSIONSECURITY
|
||||
|
||||
//NewNegotiateMessage creates a new NEGOTIATE message with the
|
||||
//flags that this package supports.
|
||||
func NewNegotiateMessage(domainName, workstationName string) ([]byte, error) {
|
||||
payloadOffset := expMsgBodyLen
|
||||
flags := defaultFlags
|
||||
|
||||
if domainName != "" {
|
||||
flags |= negotiateFlagNTLMSSPNEGOTIATEOEMDOMAINSUPPLIED
|
||||
}
|
||||
|
||||
if workstationName != "" {
|
||||
flags |= negotiateFlagNTLMSSPNEGOTIATEOEMWORKSTATIONSUPPLIED
|
||||
}
|
||||
|
||||
msg := negotiateMessageFields{
|
||||
messageHeader: newMessageHeader(1),
|
||||
NegotiateFlags: flags,
|
||||
Domain: newVarField(&payloadOffset, len(domainName)),
|
||||
Workstation: newVarField(&payloadOffset, len(workstationName)),
|
||||
Version: DefaultVersion(),
|
||||
}
|
||||
|
||||
b := bytes.Buffer{}
|
||||
if err := binary.Write(&b, binary.LittleEndian, &msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if b.Len() != expMsgBodyLen {
|
||||
return nil, errors.New("incorrect body length")
|
||||
}
|
||||
|
||||
payload := strings.ToUpper(domainName + workstationName)
|
||||
if _, err := b.WriteString(payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b.Bytes(), nil
|
||||
}
|
144
vendor/github.com/Azure/go-ntlmssp/negotiator.go
generated
vendored
Normal file
144
vendor/github.com/Azure/go-ntlmssp/negotiator.go
generated
vendored
Normal file
|
@ -0,0 +1,144 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GetDomain : parse domain name from based on slashes in the input
|
||||
func GetDomain(user string) (string, string) {
|
||||
domain := ""
|
||||
|
||||
if strings.Contains(user, "\\") {
|
||||
ucomponents := strings.SplitN(user, "\\", 2)
|
||||
domain = ucomponents[0]
|
||||
user = ucomponents[1]
|
||||
}
|
||||
return user, domain
|
||||
}
|
||||
|
||||
//Negotiator is a http.Roundtripper decorator that automatically
|
||||
//converts basic authentication to NTLM/Negotiate authentication when appropriate.
|
||||
type Negotiator struct{ http.RoundTripper }
|
||||
|
||||
//RoundTrip sends the request to the server, handling any authentication
|
||||
//re-sends as needed.
|
||||
func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error) {
|
||||
// Use default round tripper if not provided
|
||||
rt := l.RoundTripper
|
||||
if rt == nil {
|
||||
rt = http.DefaultTransport
|
||||
}
|
||||
// If it is not basic auth, just round trip the request as usual
|
||||
reqauth := authheader(req.Header.Get("Authorization"))
|
||||
if !reqauth.IsBasic() {
|
||||
return rt.RoundTrip(req)
|
||||
}
|
||||
// Save request body
|
||||
body := bytes.Buffer{}
|
||||
if req.Body != nil {
|
||||
_, err = body.ReadFrom(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Body.Close()
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
|
||||
}
|
||||
// first try anonymous, in case the server still finds us
|
||||
// authenticated from previous traffic
|
||||
req.Header.Del("Authorization")
|
||||
res, err = rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res.StatusCode != http.StatusUnauthorized {
|
||||
return res, err
|
||||
}
|
||||
|
||||
resauth := authheader(res.Header.Get("Www-Authenticate"))
|
||||
if !resauth.IsNegotiate() && !resauth.IsNTLM() {
|
||||
// Unauthorized, Negotiate not requested, let's try with basic auth
|
||||
req.Header.Set("Authorization", string(reqauth))
|
||||
io.Copy(ioutil.Discard, res.Body)
|
||||
res.Body.Close()
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
|
||||
|
||||
res, err = rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res.StatusCode != http.StatusUnauthorized {
|
||||
return res, err
|
||||
}
|
||||
resauth = authheader(res.Header.Get("Www-Authenticate"))
|
||||
}
|
||||
|
||||
if resauth.IsNegotiate() || resauth.IsNTLM() {
|
||||
// 401 with request:Basic and response:Negotiate
|
||||
io.Copy(ioutil.Discard, res.Body)
|
||||
res.Body.Close()
|
||||
|
||||
// recycle credentials
|
||||
u, p, err := reqauth.GetBasicCreds()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// get domain from username
|
||||
domain := ""
|
||||
u, domain = GetDomain(u)
|
||||
|
||||
// send negotiate
|
||||
negotiateMessage, err := NewNegotiateMessage(domain, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resauth.IsNTLM() {
|
||||
req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(negotiateMessage))
|
||||
} else {
|
||||
req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(negotiateMessage))
|
||||
}
|
||||
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
|
||||
|
||||
res, err = rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// receive challenge?
|
||||
resauth = authheader(res.Header.Get("Www-Authenticate"))
|
||||
challengeMessage, err := resauth.GetData()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !(resauth.IsNegotiate() || resauth.IsNTLM()) || len(challengeMessage) == 0 {
|
||||
// Negotiation failed, let client deal with response
|
||||
return res, nil
|
||||
}
|
||||
io.Copy(ioutil.Discard, res.Body)
|
||||
res.Body.Close()
|
||||
|
||||
// send authenticate
|
||||
authenticateMessage, err := ProcessChallenge(challengeMessage, u, p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resauth.IsNTLM() {
|
||||
req.Header.Set("Authorization", "NTLM "+base64.StdEncoding.EncodeToString(authenticateMessage))
|
||||
} else {
|
||||
req.Header.Set("Authorization", "Negotiate "+base64.StdEncoding.EncodeToString(authenticateMessage))
|
||||
}
|
||||
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
|
||||
|
||||
return rt.RoundTrip(req)
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
51
vendor/github.com/Azure/go-ntlmssp/nlmp.go
generated
vendored
Normal file
51
vendor/github.com/Azure/go-ntlmssp/nlmp.go
generated
vendored
Normal file
|
@ -0,0 +1,51 @@
|
|||
// Package ntlmssp provides NTLM/Negotiate authentication over HTTP
|
||||
//
|
||||
// Protocol details from https://msdn.microsoft.com/en-us/library/cc236621.aspx,
|
||||
// implementation hints from http://davenport.sourceforge.net/ntlm.html .
|
||||
// This package only implements authentication, no key exchange or encryption. It
|
||||
// only supports Unicode (UTF16LE) encoding of protocol strings, no OEM encoding.
|
||||
// This package implements NTLMv2.
|
||||
package ntlmssp
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"golang.org/x/crypto/md4"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func getNtlmV2Hash(password, username, target string) []byte {
|
||||
return hmacMd5(getNtlmHash(password), toUnicode(strings.ToUpper(username)+target))
|
||||
}
|
||||
|
||||
func getNtlmHash(password string) []byte {
|
||||
hash := md4.New()
|
||||
hash.Write(toUnicode(password))
|
||||
return hash.Sum(nil)
|
||||
}
|
||||
|
||||
func computeNtlmV2Response(ntlmV2Hash, serverChallenge, clientChallenge,
|
||||
timestamp, targetInfo []byte) []byte {
|
||||
|
||||
temp := []byte{1, 1, 0, 0, 0, 0, 0, 0}
|
||||
temp = append(temp, timestamp...)
|
||||
temp = append(temp, clientChallenge...)
|
||||
temp = append(temp, 0, 0, 0, 0)
|
||||
temp = append(temp, targetInfo...)
|
||||
temp = append(temp, 0, 0, 0, 0)
|
||||
|
||||
NTProofStr := hmacMd5(ntlmV2Hash, serverChallenge, temp)
|
||||
return append(NTProofStr, temp...)
|
||||
}
|
||||
|
||||
func computeLmV2Response(ntlmV2Hash, serverChallenge, clientChallenge []byte) []byte {
|
||||
return append(hmacMd5(ntlmV2Hash, serverChallenge, clientChallenge), clientChallenge...)
|
||||
}
|
||||
|
||||
func hmacMd5(key []byte, data ...[]byte) []byte {
|
||||
mac := hmac.New(md5.New, key)
|
||||
for _, d := range data {
|
||||
mac.Write(d)
|
||||
}
|
||||
return mac.Sum(nil)
|
||||
}
|
29
vendor/github.com/Azure/go-ntlmssp/unicode.go
generated
vendored
Normal file
29
vendor/github.com/Azure/go-ntlmssp/unicode.go
generated
vendored
Normal file
|
@ -0,0 +1,29 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"unicode/utf16"
|
||||
)
|
||||
|
||||
// helper func's for dealing with Windows Unicode (UTF16LE)
|
||||
|
||||
func fromUnicode(d []byte) (string, error) {
|
||||
if len(d)%2 > 0 {
|
||||
return "", errors.New("Unicode (UTF 16 LE) specified, but uneven data length")
|
||||
}
|
||||
s := make([]uint16, len(d)/2)
|
||||
err := binary.Read(bytes.NewReader(d), binary.LittleEndian, &s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(utf16.Decode(s)), nil
|
||||
}
|
||||
|
||||
func toUnicode(s string) []byte {
|
||||
uints := utf16.Encode([]rune(s))
|
||||
b := bytes.Buffer{}
|
||||
binary.Write(&b, binary.LittleEndian, &uints)
|
||||
return b.Bytes()
|
||||
}
|
40
vendor/github.com/Azure/go-ntlmssp/varfield.go
generated
vendored
Normal file
40
vendor/github.com/Azure/go-ntlmssp/varfield.go
generated
vendored
Normal file
|
@ -0,0 +1,40 @@
|
|||
package ntlmssp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
type varField struct {
|
||||
Len uint16
|
||||
MaxLen uint16
|
||||
BufferOffset uint32
|
||||
}
|
||||
|
||||
func (f varField) ReadFrom(buffer []byte) ([]byte, error) {
|
||||
if len(buffer) < int(f.BufferOffset+uint32(f.Len)) {
|
||||
return nil, errors.New("Error reading data, varField extends beyond buffer")
|
||||
}
|
||||
return buffer[f.BufferOffset : f.BufferOffset+uint32(f.Len)], nil
|
||||
}
|
||||
|
||||
func (f varField) ReadStringFrom(buffer []byte, unicode bool) (string, error) {
|
||||
d, err := f.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if unicode { // UTF-16LE encoding scheme
|
||||
return fromUnicode(d)
|
||||
}
|
||||
// OEM encoding, close enough to ASCII, since no code page is specified
|
||||
return string(d), err
|
||||
}
|
||||
|
||||
func newVarField(ptr *int, fieldsize int) varField {
|
||||
f := varField{
|
||||
Len: uint16(fieldsize),
|
||||
MaxLen: uint16(fieldsize),
|
||||
BufferOffset: uint32(*ptr),
|
||||
}
|
||||
*ptr += fieldsize
|
||||
return f
|
||||
}
|
20
vendor/github.com/Azure/go-ntlmssp/version.go
generated
vendored
Normal file
20
vendor/github.com/Azure/go-ntlmssp/version.go
generated
vendored
Normal file
|
@ -0,0 +1,20 @@
|
|||
package ntlmssp
|
||||
|
||||
// Version is a struct representing https://msdn.microsoft.com/en-us/library/cc236654.aspx
|
||||
type Version struct {
|
||||
ProductMajorVersion uint8
|
||||
ProductMinorVersion uint8
|
||||
ProductBuild uint16
|
||||
_ [3]byte
|
||||
NTLMRevisionCurrent uint8
|
||||
}
|
||||
|
||||
// DefaultVersion returns a Version with "sensible" defaults (Windows 7)
|
||||
func DefaultVersion() Version {
|
||||
return Version{
|
||||
ProductMajorVersion: 6,
|
||||
ProductMinorVersion: 1,
|
||||
ProductBuild: 7601,
|
||||
NTLMRevisionCurrent: 15,
|
||||
}
|
||||
}
|
13
vendor/github.com/denisenkom/go-mssqldb/.gitignore
generated
vendored
Normal file
13
vendor/github.com/denisenkom/go-mssqldb/.gitignore
generated
vendored
Normal file
|
@ -0,0 +1,13 @@
|
|||
/.idea
|
||||
/.connstr
|
||||
.vscode
|
||||
.terraform
|
||||
*.tfstate*
|
||||
*.log
|
||||
*.swp
|
||||
*~
|
||||
coverage.json
|
||||
coverage.txt
|
||||
coverage.xml
|
||||
testresults.xml
|
||||
|
10
vendor/github.com/denisenkom/go-mssqldb/.golangci.yml
generated
vendored
Normal file
10
vendor/github.com/denisenkom/go-mssqldb/.golangci.yml
generated
vendored
Normal file
|
@ -0,0 +1,10 @@
|
|||
linters:
|
||||
enable:
|
||||
# basic go linters
|
||||
- gofmt
|
||||
- golint
|
||||
- govet
|
||||
|
||||
# sql related linters
|
||||
- rowserrcheck
|
||||
- sqlclosecheck
|
168
vendor/github.com/denisenkom/go-mssqldb/README.md
generated
vendored
168
vendor/github.com/denisenkom/go-mssqldb/README.md
generated
vendored
|
@ -1,6 +1,6 @@
|
|||
# A pure Go MSSQL driver for Go's database/sql package
|
||||
|
||||
[](http://godoc.org/github.com/denisenkom/go-mssqldb)
|
||||
[](https://pkg.go.dev/github.com/denisenkom/go-mssqldb)
|
||||
[](https://ci.appveyor.com/project/denisenkom/go-mssqldb)
|
||||
[](https://codecov.io/gh/denisenkom/go-mssqldb)
|
||||
|
||||
|
@ -16,7 +16,7 @@ The recommended connection string uses a URL format:
|
|||
`sqlserver://username:password@host/instance?param1=value¶m2=value`
|
||||
Other supported formats are listed below.
|
||||
|
||||
### Common parameters:
|
||||
### Common parameters
|
||||
|
||||
* `user id` - enter the SQL Server Authentication user id or the Windows Authentication user id in the DOMAIN\User format. On Windows, if user id is empty or missing Single-Sign-On is used. The user domain sensitive to the case which is defined in the connection string.
|
||||
* `password`
|
||||
|
@ -29,24 +29,24 @@ Other supported formats are listed below.
|
|||
* `true` - Data sent between client and server is encrypted.
|
||||
* `app name` - The application name (default is go-mssqldb)
|
||||
|
||||
### Connection parameters for ODBC and ADO style connection strings:
|
||||
### Connection parameters for ODBC and ADO style connection strings
|
||||
|
||||
* `server` - host or host\instance (default localhost)
|
||||
* `port` - used only when there is no instance in server (default 1433)
|
||||
|
||||
### Less common parameters:
|
||||
### Less common parameters
|
||||
|
||||
* `keepAlive` - in seconds; 0 to disable (default is 30)
|
||||
* `failoverpartner` - host or host\instance (default is no partner).
|
||||
* `failoverpartner` - host or host\instance (default is no partner).
|
||||
* `failoverport` - used only when there is no instance in failoverpartner (default 1433)
|
||||
* `packet size` - in bytes; 512 to 32767 (default is 4096)
|
||||
* Encrypted connections have a maximum packet size of 16383 bytes
|
||||
* Further information on usage: https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option
|
||||
* Further information on usage: <https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option>
|
||||
* `log` - logging flags (default 0/no logging, 63 for full logging)
|
||||
* 1 log errors
|
||||
* 2 log messages
|
||||
* 4 log rows affected
|
||||
* 8 trace sql statements
|
||||
* 1 log errors
|
||||
* 2 log messages
|
||||
* 4 log rows affected
|
||||
* 8 trace sql statements
|
||||
* 16 log statement parameters
|
||||
* 32 log transaction begin/end
|
||||
* `TrustServerCertificate`
|
||||
|
@ -56,72 +56,82 @@ Other supported formats are listed below.
|
|||
* `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host.
|
||||
* `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port.
|
||||
* `Workstation ID` - The workstation name (default is the host name)
|
||||
* `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`.
|
||||
|
||||
### The connection string can be specified in one of three formats:
|
||||
* `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`.
|
||||
|
||||
### The connection string can be specified in one of three formats
|
||||
|
||||
1. URL: with `sqlserver` scheme. username and password appears before the host. Any instance appears as
|
||||
the first segment in the path. All other options are query parameters. Examples:
|
||||
|
||||
* `sqlserver://username:password@host/instance?param1=value¶m2=value`
|
||||
* `sqlserver://username:password@host:port?param1=value¶m2=value`
|
||||
* `sqlserver://sa@localhost/SQLExpress?database=master&connection+timeout=30` // `SQLExpress instance.
|
||||
* `sqlserver://sa:mypass@localhost?database=master&connection+timeout=30` // username=sa, password=mypass.
|
||||
* `sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30` // port 1234 on localhost.
|
||||
* `sqlserver://sa:my%7Bpass@somehost?connection+timeout=30` // password is "my{pass"
|
||||
* `sqlserver://username:password@host/instance?param1=value¶m2=value`
|
||||
* `sqlserver://username:password@host:port?param1=value¶m2=value`
|
||||
* `sqlserver://sa@localhost/SQLExpress?database=master&connection+timeout=30` // `SQLExpress instance.
|
||||
* `sqlserver://sa:mypass@localhost?database=master&connection+timeout=30` // username=sa, password=mypass.
|
||||
* `sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30` // port 1234 on localhost.
|
||||
* `sqlserver://sa:my%7Bpass@somehost?connection+timeout=30` // password is "my{pass"
|
||||
A string of this format can be constructed using the `URL` type in the `net/url` package.
|
||||
|
||||
A string of this format can be constructed using the `URL` type in the `net/url` package.
|
||||
```go
|
||||
|
||||
```go
|
||||
query := url.Values{}
|
||||
query.Add("app name", "MyAppName")
|
||||
query := url.Values{}
|
||||
query.Add("app name", "MyAppName")
|
||||
|
||||
u := &url.URL{
|
||||
Scheme: "sqlserver",
|
||||
User: url.UserPassword(username, password),
|
||||
Host: fmt.Sprintf("%s:%d", hostname, port),
|
||||
// Path: instance, // if connecting to an instance instead of a port
|
||||
RawQuery: query.Encode(),
|
||||
}
|
||||
db, err := sql.Open("sqlserver", u.String())
|
||||
|
||||
u := &url.URL{
|
||||
Scheme: "sqlserver",
|
||||
User: url.UserPassword(username, password),
|
||||
Host: fmt.Sprintf("%s:%d", hostname, port),
|
||||
// Path: instance, // if connecting to an instance instead of a port
|
||||
RawQuery: query.Encode(),
|
||||
}
|
||||
db, err := sql.Open("sqlserver", u.String())
|
||||
```
|
||||
```
|
||||
|
||||
2. ADO: `key=value` pairs separated by `;`. Values may not contain `;`, leading and trailing whitespace is ignored.
|
||||
Examples:
|
||||
|
||||
* `server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName`
|
||||
* `server=localhost;user id=sa;database=master;app name=MyAppName`
|
||||
|
||||
* `server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName`
|
||||
* `server=localhost;user id=sa;database=master;app name=MyAppName`
|
||||
|
||||
ADO strings support synonyms for database, app name, user id, and server
|
||||
* server <= addr, address, network address, data source
|
||||
* user id <= user, uid
|
||||
* database <= initial catalog
|
||||
* app name <= application name
|
||||
|
||||
3. ODBC: Prefix with `odbc`, `key=value` pairs separated by `;`. Allow `;` by wrapping
|
||||
values in `{}`. Examples:
|
||||
|
||||
* `odbc:server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName`
|
||||
* `odbc:server=localhost;user id=sa;database=master;app name=MyAppName`
|
||||
* `odbc:server=localhost;user id=sa;password={foo;bar}` // Value marked with `{}`, password is "foo;bar"
|
||||
* `odbc:server=localhost;user id=sa;password={foo{bar}` // Value marked with `{}`, password is "foo{bar"
|
||||
* `odbc:server=localhost;user id=sa;password={foobar }` // Value marked with `{}`, password is "foobar "
|
||||
* `odbc:server=localhost;user id=sa;password=foo{bar` // Literal `{`, password is "foo{bar"
|
||||
* `odbc:server=localhost;user id=sa;password=foo}bar` // Literal `}`, password is "foo}bar"
|
||||
* `odbc:server=localhost;user id=sa;password={foo{bar}` // Literal `{`, password is "foo{bar"
|
||||
* `odbc:server=localhost;user id=sa;password={foo}}bar}` // Escaped `} with `}}`, password is "foo}bar"
|
||||
|
||||
* `odbc:server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName`
|
||||
* `odbc:server=localhost;user id=sa;database=master;app name=MyAppName`
|
||||
* `odbc:server=localhost;user id=sa;password={foo;bar}` // Value marked with `{}`, password is "foo;bar"
|
||||
* `odbc:server=localhost;user id=sa;password={foo{bar}` // Value marked with `{}`, password is "foo{bar"
|
||||
* `odbc:server=localhost;user id=sa;password={foobar }` // Value marked with `{}`, password is "foobar "
|
||||
* `odbc:server=localhost;user id=sa;password=foo{bar` // Literal `{`, password is "foo{bar"
|
||||
* `odbc:server=localhost;user id=sa;password=foo}bar` // Literal `}`, password is "foo}bar"
|
||||
* `odbc:server=localhost;user id=sa;password={foo{bar}` // Literal `{`, password is "foo{bar"
|
||||
* `odbc:server=localhost;user id=sa;password={foo}}bar}` // Escaped `} with`}}`, password is "foo}bar"
|
||||
|
||||
### Azure Active Directory authentication - preview
|
||||
|
||||
The configuration of functionality might change in the future.
|
||||
|
||||
Azure Active Directory (AAD) access tokens are relatively short lived and need to be
|
||||
Azure Active Directory (AAD) access tokens are relatively short lived and need to be
|
||||
valid when a new connection is made. Authentication is supported using a callback func that
|
||||
provides a fresh and valid token using a connector:
|
||||
``` golang
|
||||
|
||||
``` go
|
||||
|
||||
conn, err := mssql.NewAccessTokenConnector(
|
||||
"Server=test.database.windows.net;Database=testdb",
|
||||
tokenProvider)
|
||||
"Server=test.database.windows.net;Database=testdb",
|
||||
tokenProvider)
|
||||
if err != nil {
|
||||
// handle errors in DSN
|
||||
}
|
||||
db := sql.OpenDB(conn)
|
||||
|
||||
```
|
||||
|
||||
Where `tokenProvider` is a function that returns a fresh access token or an error. None of these statements
|
||||
actually trigger the retrieval of a token, this happens when the first statment is issued and a connection
|
||||
is created.
|
||||
|
@ -129,18 +139,23 @@ is created.
|
|||
## Executing Stored Procedures
|
||||
|
||||
To run a stored procedure, set the query text to the procedure name:
|
||||
|
||||
```go
|
||||
|
||||
var account = "abc"
|
||||
_, err := db.ExecContext(ctx, "sp_RunMe",
|
||||
sql.Named("ID", 123),
|
||||
sql.Named("Account", sql.Out{Dest: &account}),
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
## Reading Output Parameters from a Stored Procedure with Resultset
|
||||
|
||||
To read output parameters from a stored procedure with resultset, make sure you read all the rows before reading the output parameters:
|
||||
|
||||
```go
|
||||
|
||||
sqltextcreate := `
|
||||
CREATE PROCEDURE spwithoutputandrows
|
||||
@bitparam BIT OUTPUT
|
||||
|
@ -156,6 +171,7 @@ for rows.Next() {
|
|||
err = rows.Scan(&strrow)
|
||||
}
|
||||
fmt.Printf("bitparam is %d", bitout)
|
||||
|
||||
```
|
||||
|
||||
## Caveat for local temporary tables
|
||||
|
@ -201,21 +217,25 @@ _, err := conn.ExecContext(ctx, "insert into #mytemp (x) values (@p1)", 1)
|
|||
|
||||
To get the procedure return status, pass into the parameters a
|
||||
`*mssql.ReturnStatus`. For example:
|
||||
```
|
||||
|
||||
```go
|
||||
|
||||
var rs mssql.ReturnStatus
|
||||
_, err := db.ExecContext(ctx, "theproc", &rs)
|
||||
log.Printf("status=%d", rs)
|
||||
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```
|
||||
```go
|
||||
var rs mssql.ReturnStatus
|
||||
_, err := db.QueryContext(ctx, "theproc", &rs)
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&val)
|
||||
}
|
||||
log.Printf("status=%d", rs)
|
||||
|
||||
```
|
||||
|
||||
Limitation: ReturnStatus cannot be retrieved using `QueryRow`.
|
||||
|
@ -226,7 +246,9 @@ The `sqlserver` driver uses normal MS SQL Server syntax and expects parameters i
|
|||
the sql query to be in the form of either `@Name` or `@p1` to `@pN` (ordinal position).
|
||||
|
||||
```go
|
||||
|
||||
db.QueryContext(ctx, `select * from t where ID = @ID and Name = @p2;`, sql.Named("ID", 6), "Bob")
|
||||
|
||||
```
|
||||
|
||||
### Parameter Types
|
||||
|
@ -235,30 +257,30 @@ To pass specific types to the query parameters, say `varchar` or `date` types,
|
|||
you must convert the types to the type before passing in. The following types
|
||||
are supported:
|
||||
|
||||
* string -> nvarchar
|
||||
* mssql.VarChar -> varchar
|
||||
* time.Time -> datetimeoffset or datetime (TDS version dependent)
|
||||
* mssql.DateTime1 -> datetime
|
||||
* mssql.DateTimeOffset -> datetimeoffset
|
||||
* "github.com/golang-sql/civil".Date -> date
|
||||
* "github.com/golang-sql/civil".DateTime -> datetime2
|
||||
* "github.com/golang-sql/civil".Time -> time
|
||||
* mssql.TVP -> Table Value Parameter (TDS version dependent)
|
||||
* string -> nvarchar
|
||||
* mssql.VarChar -> varchar
|
||||
* time.Time -> datetimeoffset or datetime (TDS version dependent)
|
||||
* mssql.DateTime1 -> datetime
|
||||
* mssql.DateTimeOffset -> datetimeoffset
|
||||
* "github.com/golang-sql/civil".Date -> date
|
||||
* "github.com/golang-sql/civil".DateTime -> datetime2
|
||||
* "github.com/golang-sql/civil".Time -> time
|
||||
* mssql.TVP -> Table Value Parameter (TDS version dependent)
|
||||
|
||||
## Important Notes
|
||||
|
||||
* [LastInsertId](https://golang.org/pkg/database/sql/#Result.LastInsertId) should
|
||||
* [LastInsertId](https://golang.org/pkg/database/sql/#Result.LastInsertId) should
|
||||
not be used with this driver (or SQL Server) due to how the TDS protocol
|
||||
works. Please use the [OUTPUT Clause](https://docs.microsoft.com/en-us/sql/t-sql/queries/output-clause-transact-sql)
|
||||
or add a `select ID = convert(bigint, SCOPE_IDENTITY());` to the end of your
|
||||
query (ref [SCOPE_IDENTITY](https://docs.microsoft.com/en-us/sql/t-sql/functions/scope-identity-transact-sql)).
|
||||
This will ensure you are getting the correct ID and will prevent a network round trip.
|
||||
* [NewConnector](https://godoc.org/github.com/denisenkom/go-mssqldb#NewConnector)
|
||||
works. Please use the [OUTPUT Clause](https://docs.microsoft.com/en-us/sql/t-sql/queries/output-clause-transact-sql)
|
||||
or add a `select ID = convert(bigint, SCOPE_IDENTITY());` to the end of your
|
||||
query (ref [SCOPE_IDENTITY](https://docs.microsoft.com/en-us/sql/t-sql/functions/scope-identity-transact-sql)).
|
||||
This will ensure you are getting the correct ID and will prevent a network round trip.
|
||||
* [NewConnector](https://godoc.org/github.com/denisenkom/go-mssqldb#NewConnector)
|
||||
may be used with [OpenDB](https://golang.org/pkg/database/sql/#OpenDB).
|
||||
* [Connector.SessionInitSQL](https://godoc.org/github.com/denisenkom/go-mssqldb#Connector.SessionInitSQL)
|
||||
may be set to set any driver specific session settings after the session
|
||||
has been reset. If empty the session will still be reset but use the database
|
||||
defaults in Go1.10+.
|
||||
* [Connector.SessionInitSQL](https://godoc.org/github.com/denisenkom/go-mssqldb#Connector.SessionInitSQL)
|
||||
may be set to set any driver specific session settings after the session
|
||||
has been reset. If empty the session will still be reset but use the database
|
||||
defaults in Go1.10+.
|
||||
|
||||
## Features
|
||||
|
||||
|
@ -280,7 +302,9 @@ Environment variables are used to pass login information.
|
|||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
env SQLSERVER_DSN=sqlserver://user:pass@hostname/instance?database=test1 go test
|
||||
```
|
||||
|
||||
## Deprecated
|
||||
|
||||
|
@ -296,7 +320,7 @@ will be loosly parsed and an attempt to extract identifiers using one of
|
|||
* :nnn
|
||||
* $nnn
|
||||
|
||||
will be used. This is not recommended with SQL Server.
|
||||
will be used. This is not recommended with SQL Server.
|
||||
There is at least one existing `won't fix` issue with the query parsing.
|
||||
|
||||
Use the native "@Name" parameters instead with the "sqlserver" driver name.
|
||||
|
@ -306,4 +330,4 @@ Use the native "@Name" parameters instead with the "sqlserver" driver name.
|
|||
* SQL Server 2008 and 2008 R2 engine cannot handle login records when SSL encryption is not disabled.
|
||||
To fix SQL Server 2008 R2 issue, install SQL Server 2008 R2 Service Pack 2.
|
||||
To fix SQL Server 2008 issue, install Microsoft SQL Server 2008 Service Pack 3 and Cumulative update package 3 for SQL Server 2008 SP3.
|
||||
More information: http://support.microsoft.com/kb/2653857
|
||||
More information: <http://support.microsoft.com/kb/2653857>
|
||||
|
|
30
vendor/github.com/denisenkom/go-mssqldb/accesstokenconnector.go
generated
vendored
30
vendor/github.com/denisenkom/go-mssqldb/accesstokenconnector.go
generated
vendored
|
@ -6,19 +6,8 @@ import (
|
|||
"context"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var _ driver.Connector = &accessTokenConnector{}
|
||||
|
||||
// accessTokenConnector wraps Connector and injects a
|
||||
// fresh access token when connecting to the database
|
||||
type accessTokenConnector struct {
|
||||
Connector
|
||||
|
||||
accessTokenProvider func() (string, error)
|
||||
}
|
||||
|
||||
// NewAccessTokenConnector creates a new connector from a DSN and a token provider.
|
||||
// The token provider func will be called when a new connection is requested and should return a valid access token.
|
||||
// The returned connector may be used with sql.OpenDB.
|
||||
|
@ -32,20 +21,11 @@ func NewAccessTokenConnector(dsn string, tokenProvider func() (string, error)) (
|
|||
return nil, err
|
||||
}
|
||||
|
||||
c := &accessTokenConnector{
|
||||
Connector: *conn,
|
||||
accessTokenProvider: tokenProvider,
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Connect returns a new database connection
|
||||
func (c *accessTokenConnector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||
var err error
|
||||
c.Connector.params.fedAuthAccessToken, err = c.accessTokenProvider()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mssql: error retrieving access token: %+v", err)
|
||||
conn.fedAuthRequired = true
|
||||
conn.fedAuthLibrary = fedAuthLibrarySecurityToken
|
||||
conn.securityTokenProvider = func(ctx context.Context) (string, error) {
|
||||
return tokenProvider()
|
||||
}
|
||||
|
||||
return c.Connector.Connect(ctx)
|
||||
return conn, nil
|
||||
}
|
||||
|
|
40
vendor/github.com/denisenkom/go-mssqldb/appveyor.yml
generated
vendored
40
vendor/github.com/denisenkom/go-mssqldb/appveyor.yml
generated
vendored
|
@ -1,6 +1,7 @@
|
|||
version: 1.0.{build}
|
||||
|
||||
os: Windows Server 2012 R2
|
||||
image:
|
||||
- Visual Studio 2015
|
||||
|
||||
clone_folder: c:\gopath\src\github.com\denisenkom\go-mssqldb
|
||||
|
||||
|
@ -9,21 +10,39 @@ environment:
|
|||
HOST: localhost
|
||||
SQLUSER: sa
|
||||
SQLPASSWORD: Password12!
|
||||
DATABASE: test
|
||||
GOVERSION: 111
|
||||
DATABASE: test
|
||||
GOVERSION: 113
|
||||
matrix:
|
||||
- GOVERSION: 18
|
||||
SQLINSTANCE: SQL2016
|
||||
SQLINSTANCE: SQL2017
|
||||
- GOVERSION: 19
|
||||
SQLINSTANCE: SQL2016
|
||||
SQLINSTANCE: SQL2017
|
||||
- GOVERSION: 110
|
||||
SQLINSTANCE: SQL2016
|
||||
SQLINSTANCE: SQL2017
|
||||
- GOVERSION: 111
|
||||
SQLINSTANCE: SQL2016
|
||||
SQLINSTANCE: SQL2017
|
||||
- GOVERSION: 112
|
||||
SQLINSTANCE: SQL2017
|
||||
- SQLINSTANCE: SQL2017
|
||||
- SQLINSTANCE: SQL2016
|
||||
- SQLINSTANCE: SQL2014
|
||||
- SQLINSTANCE: SQL2012SP1
|
||||
- SQLINSTANCE: SQL2008R2SP2
|
||||
|
||||
|
||||
# Go 1.14+ and SQL2019 are available on the Visual Studio 2019 image only
|
||||
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
|
||||
GOVERSION: 114
|
||||
SQLINSTANCE: SQL2019
|
||||
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
|
||||
GOVERSION: 115
|
||||
SQLINSTANCE: SQL2019
|
||||
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
|
||||
GOVERSION: 115
|
||||
SQLINSTANCE: SQL2017
|
||||
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
|
||||
GOVERSION: 116
|
||||
SQLINSTANCE: SQL2017
|
||||
|
||||
install:
|
||||
- set GOROOT=c:\go%GOVERSION%
|
||||
- set PATH=%GOPATH%\bin;%GOROOT%\bin;%PATH%
|
||||
|
@ -35,15 +54,14 @@ build_script:
|
|||
- go build
|
||||
|
||||
before_test:
|
||||
# setup SQL Server
|
||||
- ps: |
|
||||
# setup SQL Server
|
||||
- ps: |
|
||||
$instanceName = $env:SQLINSTANCE
|
||||
Start-Service "MSSQL`$$instanceName"
|
||||
Start-Service "SQLBrowser"
|
||||
- sqlcmd -S "(local)\%SQLINSTANCE%" -Q "Use [master]; CREATE DATABASE test;"
|
||||
- sqlcmd -S "(local)\%SQLINSTANCE%" -h -1 -Q "set nocount on; Select @@version"
|
||||
- pip install codecov
|
||||
|
||||
|
||||
test_script:
|
||||
- go test -race -cpu 4 -coverprofile=coverage.txt -covermode=atomic
|
||||
|
|
23
vendor/github.com/denisenkom/go-mssqldb/buf.go
generated
vendored
23
vendor/github.com/denisenkom/go-mssqldb/buf.go
generated
vendored
|
@ -48,8 +48,8 @@ type tdsBuffer struct {
|
|||
func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer {
|
||||
return &tdsBuffer{
|
||||
packetSize: int(bufsize),
|
||||
wbuf: make([]byte, 1<<16),
|
||||
rbuf: make([]byte, 1<<16),
|
||||
wbuf: make([]byte, bufsize),
|
||||
rbuf: make([]byte, bufsize),
|
||||
rpos: 8,
|
||||
transport: transport,
|
||||
}
|
||||
|
@ -137,19 +137,28 @@ func (w *tdsBuffer) FinishPacket() error {
|
|||
var headerSize = binary.Size(header{})
|
||||
|
||||
func (r *tdsBuffer) readNextPacket() error {
|
||||
h := header{}
|
||||
var err error
|
||||
err = binary.Read(r.transport, binary.BigEndian, &h)
|
||||
buf := r.rbuf[:headerSize]
|
||||
_, err := io.ReadFull(r.transport, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h := header{
|
||||
PacketType: packetType(buf[0]),
|
||||
Status: buf[1],
|
||||
Size: binary.BigEndian.Uint16(buf[2:4]),
|
||||
Spid: binary.BigEndian.Uint16(buf[4:6]),
|
||||
PacketNo: buf[6],
|
||||
Pad: buf[7],
|
||||
}
|
||||
if int(h.Size) > r.packetSize {
|
||||
return errors.New("Invalid packet size, it is longer than buffer size")
|
||||
return errors.New("invalid packet size, it is longer than buffer size")
|
||||
}
|
||||
if headerSize > int(h.Size) {
|
||||
return errors.New("Invalid packet size, it is shorter than header size")
|
||||
return errors.New("invalid packet size, it is shorter than header size")
|
||||
}
|
||||
_, err = io.ReadFull(r.transport, r.rbuf[headerSize:h.Size])
|
||||
//s := base64.StdEncoding.EncodeToString(r.rbuf[headerSize:h.Size])
|
||||
//fmt.Print(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
66
vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go
generated
vendored
66
vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go
generated
vendored
|
@ -7,6 +7,7 @@ import (
|
|||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -44,8 +45,9 @@ type BulkOptions struct {
|
|||
type DataValue interface{}
|
||||
|
||||
const (
|
||||
sqlDateFormat = "2006-01-02"
|
||||
sqlTimeFormat = "2006-01-02 15:04:05.999999999Z07:00"
|
||||
sqlDateFormat = "2006-01-02"
|
||||
sqlDateTimeFormat = "2006-01-02 15:04:05.999999999Z07:00"
|
||||
sqlTimeFormat = "15:04:05.9999999"
|
||||
)
|
||||
|
||||
func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) {
|
||||
|
@ -86,7 +88,7 @@ func (b *Bulk) sendBulkCommand(ctx context.Context) (err error) {
|
|||
b.bulkColumns = append(b.bulkColumns, *bulkCol)
|
||||
b.dlogf("Adding column %s %s %#x", colname, bulkCol.ColName, bulkCol.ti.TypeId)
|
||||
} else {
|
||||
return fmt.Errorf("Column %s does not exist in destination table %s", colname, b.tablename)
|
||||
return fmt.Errorf("column %s does not exist in destination table %s", colname, b.tablename)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -166,7 +168,7 @@ func (b *Bulk) AddRow(row []interface{}) (err error) {
|
|||
}
|
||||
|
||||
if len(row) != len(b.bulkColumns) {
|
||||
return fmt.Errorf("Row does not have the same number of columns than the destination table %d %d",
|
||||
return fmt.Errorf("row does not have the same number of columns than the destination table %d %d",
|
||||
len(row), len(b.bulkColumns))
|
||||
}
|
||||
|
||||
|
@ -215,7 +217,7 @@ func (b *Bulk) makeRowData(row []interface{}) ([]byte, error) {
|
|||
}
|
||||
|
||||
func (b *Bulk) Done() (rowcount int64, err error) {
|
||||
if b.headerSent == false {
|
||||
if !b.headerSent {
|
||||
//no rows had been sent
|
||||
return 0, nil
|
||||
}
|
||||
|
@ -233,24 +235,13 @@ func (b *Bulk) Done() (rowcount int64, err error) {
|
|||
|
||||
buf.FinishPacket()
|
||||
|
||||
tokchan := make(chan tokenStruct, 5)
|
||||
go processResponse(b.ctx, b.cn.sess, tokchan, nil)
|
||||
|
||||
var rowCount int64
|
||||
for token := range tokchan {
|
||||
switch token := token.(type) {
|
||||
case doneStruct:
|
||||
if token.Status&doneCount != 0 {
|
||||
rowCount = int64(token.RowCount)
|
||||
}
|
||||
if token.isError() {
|
||||
return 0, token.getError()
|
||||
}
|
||||
case error:
|
||||
return 0, b.cn.checkBadConn(token)
|
||||
}
|
||||
reader := startReading(b.cn.sess, b.ctx, outputs{})
|
||||
err = reader.iterateResponse()
|
||||
if err != nil {
|
||||
return 0, b.cn.checkBadConn(err, false)
|
||||
}
|
||||
return rowCount, nil
|
||||
|
||||
return reader.rowCount, nil
|
||||
}
|
||||
|
||||
func (b *Bulk) createColMetadata() []byte {
|
||||
|
@ -339,6 +330,10 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
|
|||
intvalue = int64(val)
|
||||
case int64:
|
||||
intvalue = val
|
||||
case float32:
|
||||
intvalue = int64(val)
|
||||
case float64:
|
||||
intvalue = int64(val)
|
||||
default:
|
||||
err = fmt.Errorf("mssql: invalid type for int column: %T", val)
|
||||
return
|
||||
|
@ -383,6 +378,8 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
|
|||
switch val := val.(type) {
|
||||
case string:
|
||||
res.buffer = str2ucs2(val)
|
||||
case int64:
|
||||
res.buffer = []byte(strconv.FormatInt(val, 10))
|
||||
case []byte:
|
||||
res.buffer = val
|
||||
default:
|
||||
|
@ -397,6 +394,8 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
|
|||
res.buffer = []byte(val)
|
||||
case []byte:
|
||||
res.buffer = val
|
||||
case int64:
|
||||
res.buffer = []byte(strconv.FormatInt(val, 10))
|
||||
default:
|
||||
err = fmt.Errorf("mssql: invalid type for varchar column: %T %s", val, val)
|
||||
return
|
||||
|
@ -421,7 +420,7 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
|
|||
res.ti.Size = len(res.buffer)
|
||||
case string:
|
||||
var t time.Time
|
||||
if t, err = time.Parse(sqlTimeFormat, val); err != nil {
|
||||
if t, err = time.Parse(sqlDateTimeFormat, val); err != nil {
|
||||
return res, fmt.Errorf("bulk: unable to convert string to date: %v", err)
|
||||
}
|
||||
res.buffer = encodeDateTime2(t, int(col.ti.Scale))
|
||||
|
@ -437,7 +436,7 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
|
|||
res.ti.Size = len(res.buffer)
|
||||
case string:
|
||||
var t time.Time
|
||||
if t, err = time.Parse(sqlTimeFormat, val); err != nil {
|
||||
if t, err = time.Parse(sqlDateTimeFormat, val); err != nil {
|
||||
return res, fmt.Errorf("bulk: unable to convert string to date: %v", err)
|
||||
}
|
||||
res.buffer = encodeDateTimeOffset(t, int(col.ti.Scale))
|
||||
|
@ -468,7 +467,7 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
|
|||
case time.Time:
|
||||
t = val
|
||||
case string:
|
||||
if t, err = time.Parse(sqlTimeFormat, val); err != nil {
|
||||
if t, err = time.Parse(sqlDateTimeFormat, val); err != nil {
|
||||
return res, fmt.Errorf("bulk: unable to convert string to date: %v", err)
|
||||
}
|
||||
default:
|
||||
|
@ -485,7 +484,22 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
|
|||
} else {
|
||||
err = fmt.Errorf("mssql: invalid size of column %d", col.ti.Size)
|
||||
}
|
||||
|
||||
case typeTimeN:
|
||||
var t time.Time
|
||||
switch val := val.(type) {
|
||||
case time.Time:
|
||||
res.buffer = encodeTime(val.Hour(), val.Minute(), val.Second(), val.Nanosecond(), int(col.ti.Scale))
|
||||
res.ti.Size = len(res.buffer)
|
||||
case string:
|
||||
if t, err = time.Parse(sqlTimeFormat, val); err != nil {
|
||||
return res, fmt.Errorf("bulk: unable to convert string to time: %v", err)
|
||||
}
|
||||
res.buffer = encodeTime(t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), int(col.ti.Scale))
|
||||
res.ti.Size = len(res.buffer)
|
||||
default:
|
||||
err = fmt.Errorf("mssql: invalid type for time column: %T %s", val, val)
|
||||
return
|
||||
}
|
||||
// case typeMoney, typeMoney4, typeMoneyN:
|
||||
case typeDecimal, typeDecimalN, typeNumeric, typeNumericN:
|
||||
prec := col.ti.Prec
|
||||
|
|
50
vendor/github.com/denisenkom/go-mssqldb/error.go
generated
vendored
50
vendor/github.com/denisenkom/go-mssqldb/error.go
generated
vendored
|
@ -1,6 +1,7 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
|
@ -17,6 +18,9 @@ type Error struct {
|
|||
ServerName string
|
||||
ProcName string
|
||||
LineNo int32
|
||||
// All lists all errors that were received from first to last.
|
||||
// This includes the last one, which is described in the other members.
|
||||
All []Error
|
||||
}
|
||||
|
||||
func (e Error) Error() string {
|
||||
|
@ -65,9 +69,53 @@ func streamErrorf(format string, v ...interface{}) StreamError {
|
|||
}
|
||||
|
||||
func badStreamPanic(err error) {
|
||||
panic(err)
|
||||
panic(streamErrorf("%v", err))
|
||||
}
|
||||
|
||||
func badStreamPanicf(format string, v ...interface{}) {
|
||||
panic(streamErrorf(format, v...))
|
||||
}
|
||||
|
||||
// ServerError is returned when the server got a fatal error
|
||||
// that aborts the process and severs the connection.
|
||||
//
|
||||
// To get the errors returned before the process was aborted,
|
||||
// unwrap this error or call errors.As with a pointer to an
|
||||
// mssql.Error variable.
|
||||
type ServerError struct {
|
||||
sqlError Error
|
||||
}
|
||||
|
||||
func (e ServerError) Error() string {
|
||||
return "SQL Server had internal error"
|
||||
}
|
||||
|
||||
func (e ServerError) Unwrap() error {
|
||||
return e.sqlError
|
||||
}
|
||||
|
||||
// RetryableError is returned when an error was caused by a bad
|
||||
// connection at the start of a query and can be safely retried
|
||||
// using database/sql's automatic retry logic.
|
||||
//
|
||||
// In many cases database/sql's retry logic will transparently
|
||||
// handle this error, the retried call will return successfully,
|
||||
// and you won't even see this error. However, you may see this
|
||||
// error if the retry logic cannot successfully handle the error.
|
||||
// In that case you can get the underlying error by calling this
|
||||
// error's UnWrap function.
|
||||
type RetryableError struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (r RetryableError) Error() string {
|
||||
return r.err.Error()
|
||||
}
|
||||
|
||||
func (r RetryableError) Unwrap() error {
|
||||
return r.err
|
||||
}
|
||||
|
||||
func (r RetryableError) Is(err error) bool {
|
||||
return err == driver.ErrBadConn
|
||||
}
|
||||
|
|
78
vendor/github.com/denisenkom/go-mssqldb/fedauth.go
generated
vendored
Normal file
78
vendor/github.com/denisenkom/go-mssqldb/fedauth.go
generated
vendored
Normal file
|
@ -0,0 +1,78 @@
|
|||
package mssql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/denisenkom/go-mssqldb/msdsn"
|
||||
)
|
||||
|
||||
// Federated authentication library affects the login data structure and message sequence.
|
||||
const (
|
||||
// fedAuthLibraryLiveIDCompactToken specifies the Microsoft Live ID Compact Token authentication scheme
|
||||
fedAuthLibraryLiveIDCompactToken = 0x00
|
||||
|
||||
// fedAuthLibrarySecurityToken specifies a token-based authentication where the token is available
|
||||
// without additional information provided during the login sequence.
|
||||
fedAuthLibrarySecurityToken = 0x01
|
||||
|
||||
// fedAuthLibraryADAL specifies a token-based authentication where a token is obtained during the
|
||||
// login sequence using the server SPN and STS URL provided by the server during login.
|
||||
fedAuthLibraryADAL = 0x02
|
||||
|
||||
// fedAuthLibraryReserved is used to indicate that no federated authentication scheme applies.
|
||||
fedAuthLibraryReserved = 0x7F
|
||||
)
|
||||
|
||||
// Federated authentication ADAL workflow affects the mechanism used to authenticate.
|
||||
const (
|
||||
// fedAuthADALWorkflowPassword uses a username/password to obtain a token from Active Directory
|
||||
fedAuthADALWorkflowPassword = 0x01
|
||||
|
||||
// fedAuthADALWorkflowPassword uses the Windows identity to obtain a token from Active Directory
|
||||
fedAuthADALWorkflowIntegrated = 0x02
|
||||
|
||||
// fedAuthADALWorkflowMSI uses the managed identity service to obtain a token
|
||||
fedAuthADALWorkflowMSI = 0x03
|
||||
)
|
||||
|
||||
// newSecurityTokenConnector creates a new connector from a Config and a token provider.
|
||||
// When invoked, token provider implementations should contact the security token
|
||||
// service specified and obtain the appropriate token, or return an error
|
||||
// to indicate why a token is not available.
|
||||
// The returned connector may be used with sql.OpenDB.
|
||||
func newSecurityTokenConnector(config msdsn.Config, tokenProvider func(ctx context.Context) (string, error)) (*Connector, error) {
|
||||
if tokenProvider == nil {
|
||||
return nil, errors.New("mssql: tokenProvider cannot be nil")
|
||||
}
|
||||
|
||||
conn := NewConnectorConfig(config)
|
||||
conn.fedAuthRequired = true
|
||||
conn.fedAuthLibrary = fedAuthLibrarySecurityToken
|
||||
conn.securityTokenProvider = tokenProvider
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// newADALTokenConnector creates a new connector from a Config and a Active Directory token provider.
|
||||
// Token provider implementations are called during federated
|
||||
// authentication login sequences where the server provides a service
|
||||
// principal name and security token service endpoint that should be used
|
||||
// to obtain the token. Implementations should contact the security token
|
||||
// service specified and obtain the appropriate token, or return an error
|
||||
// to indicate why a token is not available.
|
||||
//
|
||||
// The returned connector may be used with sql.OpenDB.
|
||||
func newActiveDirectoryTokenConnector(config msdsn.Config, adalWorkflow byte, tokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error)) (*Connector, error) {
|
||||
if tokenProvider == nil {
|
||||
return nil, errors.New("mssql: tokenProvider cannot be nil")
|
||||
}
|
||||
|
||||
conn := NewConnectorConfig(config)
|
||||
conn.fedAuthRequired = true
|
||||
conn.fedAuthLibrary = fedAuthLibraryADAL
|
||||
conn.fedAuthADALWorkflow = adalWorkflow
|
||||
conn.adalTokenProvider = tokenProvider
|
||||
|
||||
return conn, nil
|
||||
}
|
|
@ -1,7 +1,10 @@
|
|||
package mssql
|
||||
package msdsn
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
|
@ -11,49 +14,104 @@ import (
|
|||
"unicode"
|
||||
)
|
||||
|
||||
const defaultServerPort = 1433
|
||||
type (
|
||||
Encryption int
|
||||
Log uint64
|
||||
)
|
||||
|
||||
type connectParams struct {
|
||||
logFlags uint64
|
||||
port uint64
|
||||
host string
|
||||
instance string
|
||||
database string
|
||||
user string
|
||||
password string
|
||||
dial_timeout time.Duration
|
||||
conn_timeout time.Duration
|
||||
keepAlive time.Duration
|
||||
encrypt bool
|
||||
disableEncryption bool
|
||||
trustServerCertificate bool
|
||||
certificate string
|
||||
hostInCertificate string
|
||||
hostInCertificateProvided bool
|
||||
serverSPN string
|
||||
workstation string
|
||||
appname string
|
||||
typeFlags uint8
|
||||
failOverPartner string
|
||||
failOverPort uint64
|
||||
packetSize uint16
|
||||
fedAuthAccessToken string
|
||||
const (
|
||||
EncryptionOff = 0
|
||||
EncryptionRequired = 1
|
||||
EncryptionDisabled = 3
|
||||
)
|
||||
|
||||
const (
|
||||
LogErrors Log = 1
|
||||
LogMessages Log = 2
|
||||
LogRows Log = 4
|
||||
LogSQL Log = 8
|
||||
LogParams Log = 16
|
||||
LogTransaction Log = 32
|
||||
LogDebug Log = 64
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Port uint64
|
||||
Host string
|
||||
Instance string
|
||||
Database string
|
||||
User string
|
||||
Password string
|
||||
Encryption Encryption
|
||||
TLSConfig *tls.Config
|
||||
|
||||
FailOverPartner string
|
||||
FailOverPort uint64
|
||||
|
||||
// If true the TLSConfig servername should use the routed server.
|
||||
HostInCertificateProvided bool
|
||||
|
||||
// Read Only intent for application database.
|
||||
// NOTE: This does not make queries to most databases read-only.
|
||||
ReadOnlyIntent bool
|
||||
|
||||
LogFlags Log
|
||||
|
||||
ServerSPN string
|
||||
Workstation string
|
||||
AppName string
|
||||
|
||||
// If true disables database/sql's automatic retry of queries
|
||||
// that start on bad connections.
|
||||
DisableRetry bool
|
||||
|
||||
// Do not use the following.
|
||||
|
||||
DialTimeout time.Duration // DialTimeout defaults to 15s. Set negative to disable.
|
||||
ConnTimeout time.Duration // Use context for timeouts.
|
||||
KeepAlive time.Duration // Leave at default.
|
||||
PacketSize uint16
|
||||
}
|
||||
|
||||
func parseConnectParams(dsn string) (connectParams, error) {
|
||||
var p connectParams
|
||||
func SetupTLS(certificate string, insecureSkipVerify bool, hostInCertificate string) (*tls.Config, error) {
|
||||
var config tls.Config
|
||||
if certificate != "" {
|
||||
pem, err := ioutil.ReadFile(certificate)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot read certificate %q: %v", certificate, err)
|
||||
}
|
||||
certs := x509.NewCertPool()
|
||||
certs.AppendCertsFromPEM(pem)
|
||||
config.RootCAs = certs
|
||||
}
|
||||
if insecureSkipVerify {
|
||||
config.InsecureSkipVerify = true
|
||||
}
|
||||
config.ServerName = hostInCertificate
|
||||
|
||||
// fix for https://github.com/denisenkom/go-mssqldb/issues/166
|
||||
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
|
||||
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
|
||||
// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
|
||||
config.DynamicRecordSizingDisabled = true
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func Parse(dsn string) (Config, map[string]string, error) {
|
||||
p := Config{}
|
||||
|
||||
var params map[string]string
|
||||
if strings.HasPrefix(dsn, "odbc:") {
|
||||
parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):])
|
||||
if err != nil {
|
||||
return p, err
|
||||
return p, params, err
|
||||
}
|
||||
params = parameters
|
||||
} else if strings.HasPrefix(dsn, "sqlserver://") {
|
||||
parameters, err := splitConnectionStringURL(dsn)
|
||||
if err != nil {
|
||||
return p, err
|
||||
return p, params, err
|
||||
}
|
||||
params = parameters
|
||||
} else {
|
||||
|
@ -62,57 +120,55 @@ func parseConnectParams(dsn string) (connectParams, error) {
|
|||
|
||||
strlog, ok := params["log"]
|
||||
if ok {
|
||||
var err error
|
||||
p.logFlags, err = strconv.ParseUint(strlog, 10, 64)
|
||||
flags, err := strconv.ParseUint(strlog, 10, 64)
|
||||
if err != nil {
|
||||
return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error())
|
||||
return p, params, fmt.Errorf("invalid log parameter '%s': %s", strlog, err.Error())
|
||||
}
|
||||
p.LogFlags = Log(flags)
|
||||
}
|
||||
server := params["server"]
|
||||
parts := strings.SplitN(server, `\`, 2)
|
||||
p.host = parts[0]
|
||||
if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" {
|
||||
p.host = "localhost"
|
||||
p.Host = parts[0]
|
||||
if p.Host == "." || strings.ToUpper(p.Host) == "(LOCAL)" || p.Host == "" {
|
||||
p.Host = "localhost"
|
||||
}
|
||||
if len(parts) > 1 {
|
||||
p.instance = parts[1]
|
||||
p.Instance = parts[1]
|
||||
}
|
||||
p.database = params["database"]
|
||||
p.user = params["user id"]
|
||||
p.password = params["password"]
|
||||
p.Database = params["database"]
|
||||
p.User = params["user id"]
|
||||
p.Password = params["password"]
|
||||
|
||||
p.port = 0
|
||||
p.Port = 0
|
||||
strport, ok := params["port"]
|
||||
if ok {
|
||||
var err error
|
||||
p.port, err = strconv.ParseUint(strport, 10, 16)
|
||||
p.Port, err = strconv.ParseUint(strport, 10, 16)
|
||||
if err != nil {
|
||||
f := "Invalid tcp port '%v': %v"
|
||||
return p, fmt.Errorf(f, strport, err.Error())
|
||||
f := "invalid tcp port '%v': %v"
|
||||
return p, params, fmt.Errorf(f, strport, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option
|
||||
// Default packet size remains at 4096 bytes
|
||||
p.packetSize = 4096
|
||||
// https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option\
|
||||
strpsize, ok := params["packet size"]
|
||||
if ok {
|
||||
var err error
|
||||
psize, err := strconv.ParseUint(strpsize, 0, 16)
|
||||
if err != nil {
|
||||
f := "Invalid packet size '%v': %v"
|
||||
return p, fmt.Errorf(f, strpsize, err.Error())
|
||||
f := "invalid packet size '%v': %v"
|
||||
return p, params, fmt.Errorf(f, strpsize, err.Error())
|
||||
}
|
||||
|
||||
// Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes
|
||||
// NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request
|
||||
// a higher packet size, the server will respond with an ENVCHANGE request to
|
||||
// alter the packet size to 16383 bytes.
|
||||
p.packetSize = uint16(psize)
|
||||
if p.packetSize < 512 {
|
||||
p.packetSize = 512
|
||||
} else if p.packetSize > 32767 {
|
||||
p.packetSize = 32767
|
||||
p.PacketSize = uint16(psize)
|
||||
if p.PacketSize < 512 {
|
||||
p.PacketSize = 512
|
||||
} else if p.PacketSize > 32767 {
|
||||
p.PacketSize = 32767
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -123,79 +179,95 @@ func parseConnectParams(dsn string) (connectParams, error) {
|
|||
if strconntimeout, ok := params["connection timeout"]; ok {
|
||||
timeout, err := strconv.ParseUint(strconntimeout, 10, 64)
|
||||
if err != nil {
|
||||
f := "Invalid connection timeout '%v': %v"
|
||||
return p, fmt.Errorf(f, strconntimeout, err.Error())
|
||||
f := "invalid connection timeout '%v': %v"
|
||||
return p, params, fmt.Errorf(f, strconntimeout, err.Error())
|
||||
}
|
||||
p.conn_timeout = time.Duration(timeout) * time.Second
|
||||
p.ConnTimeout = time.Duration(timeout) * time.Second
|
||||
}
|
||||
p.dial_timeout = 15 * time.Second
|
||||
p.DialTimeout = 15 * time.Second
|
||||
if strdialtimeout, ok := params["dial timeout"]; ok {
|
||||
timeout, err := strconv.ParseUint(strdialtimeout, 10, 64)
|
||||
if err != nil {
|
||||
f := "Invalid dial timeout '%v': %v"
|
||||
return p, fmt.Errorf(f, strdialtimeout, err.Error())
|
||||
f := "invalid dial timeout '%v': %v"
|
||||
return p, params, fmt.Errorf(f, strdialtimeout, err.Error())
|
||||
}
|
||||
p.dial_timeout = time.Duration(timeout) * time.Second
|
||||
p.DialTimeout = time.Duration(timeout) * time.Second
|
||||
}
|
||||
|
||||
// default keep alive should be 30 seconds according to spec:
|
||||
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
|
||||
p.keepAlive = 30 * time.Second
|
||||
p.KeepAlive = 30 * time.Second
|
||||
if keepAlive, ok := params["keepalive"]; ok {
|
||||
timeout, err := strconv.ParseUint(keepAlive, 10, 64)
|
||||
if err != nil {
|
||||
f := "Invalid keepAlive value '%s': %s"
|
||||
return p, fmt.Errorf(f, keepAlive, err.Error())
|
||||
f := "invalid keepAlive value '%s': %s"
|
||||
return p, params, fmt.Errorf(f, keepAlive, err.Error())
|
||||
}
|
||||
p.keepAlive = time.Duration(timeout) * time.Second
|
||||
p.KeepAlive = time.Duration(timeout) * time.Second
|
||||
}
|
||||
|
||||
var (
|
||||
trustServerCert = false
|
||||
certificate = ""
|
||||
hostInCertificate = ""
|
||||
)
|
||||
encrypt, ok := params["encrypt"]
|
||||
if ok {
|
||||
if strings.EqualFold(encrypt, "DISABLE") {
|
||||
p.disableEncryption = true
|
||||
p.Encryption = EncryptionDisabled
|
||||
} else {
|
||||
var err error
|
||||
p.encrypt, err = strconv.ParseBool(encrypt)
|
||||
e, err := strconv.ParseBool(encrypt)
|
||||
if err != nil {
|
||||
f := "Invalid encrypt '%s': %s"
|
||||
return p, fmt.Errorf(f, encrypt, err.Error())
|
||||
f := "invalid encrypt '%s': %s"
|
||||
return p, params, fmt.Errorf(f, encrypt, err.Error())
|
||||
}
|
||||
if e {
|
||||
p.Encryption = EncryptionRequired
|
||||
}
|
||||
}
|
||||
} else {
|
||||
p.trustServerCertificate = true
|
||||
trustServerCert = true
|
||||
}
|
||||
trust, ok := params["trustservercertificate"]
|
||||
if ok {
|
||||
var err error
|
||||
p.trustServerCertificate, err = strconv.ParseBool(trust)
|
||||
trustServerCert, err = strconv.ParseBool(trust)
|
||||
if err != nil {
|
||||
f := "Invalid trust server certificate '%s': %s"
|
||||
return p, fmt.Errorf(f, trust, err.Error())
|
||||
f := "invalid trust server certificate '%s': %s"
|
||||
return p, params, fmt.Errorf(f, trust, err.Error())
|
||||
}
|
||||
}
|
||||
p.certificate = params["certificate"]
|
||||
p.hostInCertificate, ok = params["hostnameincertificate"]
|
||||
certificate = params["certificate"]
|
||||
hostInCertificate, ok = params["hostnameincertificate"]
|
||||
if ok {
|
||||
p.hostInCertificateProvided = true
|
||||
p.HostInCertificateProvided = true
|
||||
} else {
|
||||
p.hostInCertificate = p.host
|
||||
p.hostInCertificateProvided = false
|
||||
hostInCertificate = p.Host
|
||||
p.HostInCertificateProvided = false
|
||||
}
|
||||
|
||||
if p.Encryption != EncryptionDisabled {
|
||||
var err error
|
||||
p.TLSConfig, err = SetupTLS(certificate, trustServerCert, hostInCertificate)
|
||||
if err != nil {
|
||||
return p, params, fmt.Errorf("failed to setup TLS: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
serverSPN, ok := params["serverspn"]
|
||||
if ok {
|
||||
p.serverSPN = serverSPN
|
||||
p.ServerSPN = serverSPN
|
||||
} else {
|
||||
p.serverSPN = generateSpn(p.host, resolveServerPort(p.port))
|
||||
p.ServerSPN = generateSpn(p.Host, p.Port)
|
||||
}
|
||||
|
||||
workstation, ok := params["workstation id"]
|
||||
if ok {
|
||||
p.workstation = workstation
|
||||
p.Workstation = workstation
|
||||
} else {
|
||||
workstation, err := os.Hostname()
|
||||
if err == nil {
|
||||
p.workstation = workstation
|
||||
p.Workstation = workstation
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -203,34 +275,86 @@ func parseConnectParams(dsn string) (connectParams, error) {
|
|||
if !ok {
|
||||
appname = "go-mssqldb"
|
||||
}
|
||||
p.appname = appname
|
||||
p.AppName = appname
|
||||
|
||||
appintent, ok := params["applicationintent"]
|
||||
if ok {
|
||||
if appintent == "ReadOnly" {
|
||||
if p.database == "" {
|
||||
return p, fmt.Errorf("Database must be specified when ApplicationIntent is ReadOnly")
|
||||
if p.Database == "" {
|
||||
return p, params, fmt.Errorf("database must be specified when ApplicationIntent is ReadOnly")
|
||||
}
|
||||
p.typeFlags |= fReadOnlyIntent
|
||||
p.ReadOnlyIntent = true
|
||||
}
|
||||
}
|
||||
|
||||
failOverPartner, ok := params["failoverpartner"]
|
||||
if ok {
|
||||
p.failOverPartner = failOverPartner
|
||||
p.FailOverPartner = failOverPartner
|
||||
}
|
||||
|
||||
failOverPort, ok := params["failoverport"]
|
||||
if ok {
|
||||
var err error
|
||||
p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
|
||||
p.FailOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
|
||||
if err != nil {
|
||||
f := "Invalid tcp port '%v': %v"
|
||||
return p, fmt.Errorf(f, failOverPort, err.Error())
|
||||
f := "invalid failover port '%v': %v"
|
||||
return p, params, fmt.Errorf(f, failOverPort, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return p, nil
|
||||
disableRetry, ok := params["disableretry"]
|
||||
if ok {
|
||||
var err error
|
||||
p.DisableRetry, err = strconv.ParseBool(disableRetry)
|
||||
if err != nil {
|
||||
f := "invalid disableRetry '%s': %s"
|
||||
return p, params, fmt.Errorf(f, disableRetry, err.Error())
|
||||
}
|
||||
} else {
|
||||
p.DisableRetry = disableRetryDefault
|
||||
}
|
||||
|
||||
return p, params, nil
|
||||
}
|
||||
|
||||
// convert connectionParams to url style connection string
|
||||
// used mostly for testing
|
||||
func (p Config) URL() *url.URL {
|
||||
q := url.Values{}
|
||||
if p.Database != "" {
|
||||
q.Add("database", p.Database)
|
||||
}
|
||||
if p.LogFlags != 0 {
|
||||
q.Add("log", strconv.FormatUint(uint64(p.LogFlags), 10))
|
||||
}
|
||||
host := p.Host
|
||||
if p.Port > 0 {
|
||||
host = fmt.Sprintf("%s:%d", p.Host, p.Port)
|
||||
}
|
||||
q.Add("disableRetry", fmt.Sprintf("%t", p.DisableRetry))
|
||||
res := url.URL{
|
||||
Scheme: "sqlserver",
|
||||
Host: host,
|
||||
User: url.UserPassword(p.User, p.Password),
|
||||
}
|
||||
if p.Instance != "" {
|
||||
res.Path = p.Instance
|
||||
}
|
||||
if len(q) > 0 {
|
||||
res.RawQuery = q.Encode()
|
||||
}
|
||||
return &res
|
||||
}
|
||||
|
||||
var adoSynonyms = map[string]string{
|
||||
"application name": "app name",
|
||||
"data source": "server",
|
||||
"address": "server",
|
||||
"network address": "server",
|
||||
"addr": "server",
|
||||
"user": "user id",
|
||||
"uid": "user id",
|
||||
"initial catalog": "database",
|
||||
}
|
||||
|
||||
func splitConnectionString(dsn string) (res map[string]string) {
|
||||
|
@ -249,6 +373,20 @@ func splitConnectionString(dsn string) (res map[string]string) {
|
|||
if len(lst) > 1 {
|
||||
value = strings.TrimSpace(lst[1])
|
||||
}
|
||||
synonym, hasSynonym := adoSynonyms[name]
|
||||
if hasSynonym {
|
||||
name = synonym
|
||||
}
|
||||
// "server" in ADO can include a protocol and a port.
|
||||
// We only support tcp protocol
|
||||
if name == "server" {
|
||||
value = strings.TrimPrefix(value, "tcp:")
|
||||
serverParts := strings.Split(value, ",")
|
||||
if len(serverParts) == 2 && len(serverParts[1]) > 0 {
|
||||
value = serverParts[0]
|
||||
res["port"] = serverParts[1]
|
||||
}
|
||||
}
|
||||
res[name] = value
|
||||
}
|
||||
return res
|
||||
|
@ -340,7 +478,7 @@ func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
|
|||
case parserStateBeforeKey:
|
||||
switch {
|
||||
case c == '=':
|
||||
return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i)
|
||||
return res, fmt.Errorf("unexpected character = at index %d. Expected start of key or semi-colon or whitespace", i)
|
||||
case !unicode.IsSpace(c) && c != ';':
|
||||
state = parserStateKey
|
||||
key += string(c)
|
||||
|
@ -419,7 +557,7 @@ func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
|
|||
case unicode.IsSpace(c):
|
||||
// Ignore whitespace
|
||||
default:
|
||||
return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
|
||||
return res, fmt.Errorf("unexpected character %c at index %d. Expected semi-colon or whitespace", c, i)
|
||||
}
|
||||
|
||||
case parserStateEndValue:
|
||||
|
@ -429,7 +567,7 @@ func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
|
|||
case unicode.IsSpace(c):
|
||||
// Ignore whitespace
|
||||
default:
|
||||
return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
|
||||
return res, fmt.Errorf("unexpected character %c at index %d. Expected semi-colon or whitespace", c, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -444,7 +582,7 @@ func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
|
|||
case parserStateBareValue:
|
||||
res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
|
||||
case parserStateBracedValue:
|
||||
return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn))
|
||||
return res, fmt.Errorf("unexpected end of braced value at index %d", len(dsn))
|
||||
case parserStateBracedValueClosingBrace: // End of braced value
|
||||
res[key] = value
|
||||
case parserStateEndValue: // Okay
|
||||
|
@ -458,14 +596,6 @@ func normalizeOdbcKey(s string) string {
|
|||
return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace))
|
||||
}
|
||||
|
||||
func resolveServerPort(port uint64) uint64 {
|
||||
if port == 0 {
|
||||
return defaultServerPort
|
||||
}
|
||||
|
||||
return port
|
||||
}
|
||||
|
||||
func generateSpn(host string, port uint64) string {
|
||||
return fmt.Sprintf("MSSQLSvc/%s:%d", host, port)
|
||||
}
|
9
vendor/github.com/denisenkom/go-mssqldb/msdsn/conn_str_go118.go
generated
vendored
Normal file
9
vendor/github.com/denisenkom/go-mssqldb/msdsn/conn_str_go118.go
generated
vendored
Normal file
|
@ -0,0 +1,9 @@
|
|||
// +build go1.18
|
||||
|
||||
package msdsn
|
||||
|
||||
// disableRetryDefault is false for Go versions 1.18 and higher. This matches
|
||||
// the behavior requested in issue #586. A query that fails at the start due to
|
||||
// a bad connection is automatically retried. An error is returned only if the
|
||||
// query fails all of its retries.
|
||||
const disableRetryDefault bool = false
|
9
vendor/github.com/denisenkom/go-mssqldb/msdsn/conn_str_go118pre.go
generated
vendored
Normal file
9
vendor/github.com/denisenkom/go-mssqldb/msdsn/conn_str_go118pre.go
generated
vendored
Normal file
|
@ -0,0 +1,9 @@
|
|||
// +build !go1.18
|
||||
|
||||
package msdsn
|
||||
|
||||
// disableRetryDefault is true for versions of Go less than 1.18. This matches
|
||||
// the behavior requested in issue #275. A query that fails at the start due to
|
||||
// a bad connection is not retried. Instead, the detailed error is immediately
|
||||
// returned to the caller.
|
||||
const disableRetryDefault bool = true
|
396
vendor/github.com/denisenkom/go-mssqldb/mssql.go
generated
vendored
396
vendor/github.com/denisenkom/go-mssqldb/mssql.go
generated
vendored
|
@ -16,6 +16,7 @@ import (
|
|||
"unicode"
|
||||
|
||||
"github.com/denisenkom/go-mssqldb/internal/querytext"
|
||||
"github.com/denisenkom/go-mssqldb/msdsn"
|
||||
)
|
||||
|
||||
// ReturnStatus may be used to return the return value from a proc.
|
||||
|
@ -31,12 +32,16 @@ var driverInstanceNoProcess = &Driver{processQueryText: false}
|
|||
func init() {
|
||||
sql.Register("mssql", driverInstance)
|
||||
sql.Register("sqlserver", driverInstanceNoProcess)
|
||||
createDialer = func(p *connectParams) Dialer {
|
||||
return netDialer{&net.Dialer{KeepAlive: p.keepAlive}}
|
||||
createDialer = func(p *msdsn.Config) Dialer {
|
||||
ka := p.KeepAlive
|
||||
if ka == 0 {
|
||||
ka = 30 * time.Second
|
||||
}
|
||||
return netDialer{&net.Dialer{KeepAlive: ka}}
|
||||
}
|
||||
}
|
||||
|
||||
var createDialer func(p *connectParams) Dialer
|
||||
var createDialer func(p *msdsn.Config) Dialer
|
||||
|
||||
type netDialer struct {
|
||||
nd *net.Dialer
|
||||
|
@ -54,10 +59,11 @@ type Driver struct {
|
|||
|
||||
// OpenConnector opens a new connector. Useful to dial with a context.
|
||||
func (d *Driver) OpenConnector(dsn string) (*Connector, error) {
|
||||
params, err := parseConnectParams(dsn)
|
||||
params, _, err := msdsn.Parse(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Connector{
|
||||
params: params,
|
||||
driver: d,
|
||||
|
@ -80,7 +86,7 @@ func (d *Driver) SetLogger(logger Logger) {
|
|||
// NewConnector creates a new connector from a DSN.
|
||||
// The returned connector may be used with sql.OpenDB.
|
||||
func NewConnector(dsn string) (*Connector, error) {
|
||||
params, err := parseConnectParams(dsn)
|
||||
params, _, err := msdsn.Parse(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -91,15 +97,34 @@ func NewConnector(dsn string) (*Connector, error) {
|
|||
return c, nil
|
||||
}
|
||||
|
||||
// NewConnectorConfig creates a new Connector for a DSN Config struct.
|
||||
// The returned connector may be used with sql.OpenDB.
|
||||
func NewConnectorConfig(config msdsn.Config) *Connector {
|
||||
return &Connector{
|
||||
params: config,
|
||||
driver: driverInstanceNoProcess,
|
||||
}
|
||||
}
|
||||
|
||||
// Connector holds the parsed DSN and is ready to make a new connection
|
||||
// at any time.
|
||||
//
|
||||
// In the future, settings that cannot be passed through a string DSN
|
||||
// may be set directly on the connector.
|
||||
type Connector struct {
|
||||
params connectParams
|
||||
params msdsn.Config
|
||||
driver *Driver
|
||||
|
||||
fedAuthRequired bool
|
||||
fedAuthLibrary int
|
||||
fedAuthADALWorkflow byte
|
||||
|
||||
// callback that can provide a security token during login
|
||||
securityTokenProvider func(ctx context.Context) (string, error)
|
||||
|
||||
// callback that can provide a security token during ADAL login
|
||||
adalTokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error)
|
||||
|
||||
// SessionInitSQL is executed after marking a given session to be reset.
|
||||
// When not present, the next query will still reset the session to the
|
||||
// database defaults.
|
||||
|
@ -132,7 +157,7 @@ type Dialer interface {
|
|||
DialContext(ctx context.Context, network string, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
func (c *Connector) getDialer(p *connectParams) Dialer {
|
||||
func (c *Connector) getDialer(p *msdsn.Config) Dialer {
|
||||
if c != nil && c.Dialer != nil {
|
||||
return c.Dialer
|
||||
}
|
||||
|
@ -148,34 +173,32 @@ type Conn struct {
|
|||
processQueryText bool
|
||||
connectionGood bool
|
||||
|
||||
outs map[string]interface{}
|
||||
outs outputs
|
||||
}
|
||||
|
||||
type outputs struct {
|
||||
params map[string]interface{}
|
||||
returnStatus *ReturnStatus
|
||||
}
|
||||
|
||||
func (c *Conn) setReturnStatus(s ReturnStatus) {
|
||||
if c.returnStatus == nil {
|
||||
return
|
||||
}
|
||||
*c.returnStatus = s
|
||||
// IsValid satisfies the driver.Validator interface.
|
||||
func (c *Conn) IsValid() bool {
|
||||
return c.connectionGood
|
||||
}
|
||||
|
||||
func (c *Conn) checkBadConn(err error) error {
|
||||
// this is a hack to address Issue #275
|
||||
// we set connectionGood flag to false if
|
||||
// error indicates that connection is not usable
|
||||
// but we return actual error instead of ErrBadConn
|
||||
// this will cause connection to stay in a pool
|
||||
// but next request to this connection will return ErrBadConn
|
||||
|
||||
// it might be possible to revise this hack after
|
||||
// https://github.com/golang/go/issues/20807
|
||||
// is implemented
|
||||
// checkBadConn marks the connection as bad based on the characteristics
|
||||
// of the supplied error. Bad connections will be dropped from the connection
|
||||
// pool rather than reused.
|
||||
//
|
||||
// If bad connection retry is enabled and the error + connection state permits
|
||||
// retrying, checkBadConn will return a RetryableError that allows database/sql
|
||||
// to automatically retry the query with another connection.
|
||||
func (c *Conn) checkBadConn(err error, mayRetry bool) error {
|
||||
switch err {
|
||||
case nil:
|
||||
return nil
|
||||
case io.EOF:
|
||||
c.connectionGood = false
|
||||
return driver.ErrBadConn
|
||||
case driver.ErrBadConn:
|
||||
// It is an internal programming error if driver.ErrBadConn
|
||||
// is ever passed to this function. driver.ErrBadConn should
|
||||
|
@ -187,34 +210,33 @@ func (c *Conn) checkBadConn(err error) error {
|
|||
switch err.(type) {
|
||||
case net.Error:
|
||||
c.connectionGood = false
|
||||
return err
|
||||
case StreamError:
|
||||
c.connectionGood = false
|
||||
return err
|
||||
default:
|
||||
return err
|
||||
case ServerError:
|
||||
c.connectionGood = false
|
||||
}
|
||||
|
||||
if !c.connectionGood && mayRetry && !c.connector.params.DisableRetry {
|
||||
return newRetryableError(err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) clearOuts() {
|
||||
c.outs = nil
|
||||
c.outs = outputs{}
|
||||
}
|
||||
|
||||
func (c *Conn) simpleProcessResp(ctx context.Context) error {
|
||||
tokchan := make(chan tokenStruct, 5)
|
||||
go processResponse(ctx, c.sess, tokchan, c.outs)
|
||||
reader := startReading(c.sess, ctx, c.outs)
|
||||
c.clearOuts()
|
||||
for tok := range tokchan {
|
||||
switch token := tok.(type) {
|
||||
case doneStruct:
|
||||
if token.isError() {
|
||||
return c.checkBadConn(token.getError())
|
||||
}
|
||||
case error:
|
||||
return c.checkBadConn(token)
|
||||
}
|
||||
|
||||
var resultError error
|
||||
err := reader.iterateResponse()
|
||||
if err != nil {
|
||||
return c.checkBadConn(err, false)
|
||||
}
|
||||
return nil
|
||||
return resultError
|
||||
}
|
||||
|
||||
func (c *Conn) Commit() error {
|
||||
|
@ -222,7 +244,7 @@ func (c *Conn) Commit() error {
|
|||
return driver.ErrBadConn
|
||||
}
|
||||
if err := c.sendCommitRequest(); err != nil {
|
||||
return c.checkBadConn(err)
|
||||
return c.checkBadConn(err, true)
|
||||
}
|
||||
return c.simpleProcessResp(c.transactionCtx)
|
||||
}
|
||||
|
@ -239,7 +261,7 @@ func (c *Conn) sendCommitRequest() error {
|
|||
c.sess.log.Printf("Failed to send CommitXact with %v", err)
|
||||
}
|
||||
c.connectionGood = false
|
||||
return fmt.Errorf("Faild to send CommitXact: %v", err)
|
||||
return fmt.Errorf("faild to send CommitXact: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -249,7 +271,7 @@ func (c *Conn) Rollback() error {
|
|||
return driver.ErrBadConn
|
||||
}
|
||||
if err := c.sendRollbackRequest(); err != nil {
|
||||
return c.checkBadConn(err)
|
||||
return c.checkBadConn(err, true)
|
||||
}
|
||||
return c.simpleProcessResp(c.transactionCtx)
|
||||
}
|
||||
|
@ -266,7 +288,7 @@ func (c *Conn) sendRollbackRequest() error {
|
|||
c.sess.log.Printf("Failed to send RollbackXact with %v", err)
|
||||
}
|
||||
c.connectionGood = false
|
||||
return fmt.Errorf("Failed to send RollbackXact: %v", err)
|
||||
return fmt.Errorf("failed to send RollbackXact: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -281,11 +303,11 @@ func (c *Conn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx,
|
|||
}
|
||||
err = c.sendBeginRequest(ctx, tdsIsolation)
|
||||
if err != nil {
|
||||
return nil, c.checkBadConn(err)
|
||||
return nil, c.checkBadConn(err, true)
|
||||
}
|
||||
tx, err = c.processBeginResponse(ctx)
|
||||
if err != nil {
|
||||
return nil, c.checkBadConn(err)
|
||||
return nil, err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -303,7 +325,7 @@ func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) erro
|
|||
c.sess.log.Printf("Failed to send BeginXact with %v", err)
|
||||
}
|
||||
c.connectionGood = false
|
||||
return fmt.Errorf("Failed to send BeginXact: %v", err)
|
||||
return fmt.Errorf("failed to send BeginXact: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -318,25 +340,26 @@ func (c *Conn) processBeginResponse(ctx context.Context) (driver.Tx, error) {
|
|||
}
|
||||
|
||||
func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) {
|
||||
params, err := parseConnectParams(dsn)
|
||||
params, _, err := msdsn.Parse(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.connect(ctx, nil, params)
|
||||
c := &Connector{params: params}
|
||||
return d.connect(ctx, c, params)
|
||||
}
|
||||
|
||||
// connect to the server, using the provided context for dialing only.
|
||||
func (d *Driver) connect(ctx context.Context, c *Connector, params connectParams) (*Conn, error) {
|
||||
func (d *Driver) connect(ctx context.Context, c *Connector, params msdsn.Config) (*Conn, error) {
|
||||
sess, err := connect(ctx, c, d.log, params)
|
||||
if err != nil {
|
||||
// main server failed, try fail-over partner
|
||||
if params.failOverPartner == "" {
|
||||
if params.FailOverPartner == "" {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params.host = params.failOverPartner
|
||||
if params.failOverPort != 0 {
|
||||
params.port = params.failOverPort
|
||||
params.Host = params.FailOverPartner
|
||||
if params.FailOverPort != 0 {
|
||||
params.Port = params.FailOverPort
|
||||
}
|
||||
|
||||
sess, err = connect(ctx, c, d.log, params)
|
||||
|
@ -447,7 +470,8 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) {
|
|||
|
||||
reset := conn.resetSession
|
||||
conn.resetSession = false
|
||||
if len(args) == 0 {
|
||||
isProc := isProc(s.query)
|
||||
if len(args) == 0 && !isProc {
|
||||
if err = sendSqlBatch72(conn.sess.buf, s.query, headers, reset); err != nil {
|
||||
if conn.sess.logFlags&logErrors != 0 {
|
||||
conn.sess.log.Printf("Failed to send SqlBatch with %v", err)
|
||||
|
@ -458,7 +482,7 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) {
|
|||
} else {
|
||||
proc := sp_ExecuteSql
|
||||
var params []param
|
||||
if isProc(s.query) {
|
||||
if isProc {
|
||||
proc.name = s.query
|
||||
params, _, err = s.makeRPCParams(args, true)
|
||||
if err != nil {
|
||||
|
@ -478,7 +502,7 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) {
|
|||
conn.sess.log.Printf("Failed to send Rpc with %v", err)
|
||||
}
|
||||
conn.connectionGood = false
|
||||
return fmt.Errorf("Failed to send RPC: %v", err)
|
||||
return fmt.Errorf("failed to send RPC: %v", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
@ -500,30 +524,38 @@ func isProc(s string) bool {
|
|||
for _, r := range s {
|
||||
rPrev = rn1
|
||||
rn1 = r
|
||||
switch r {
|
||||
// No newlines or string sequences.
|
||||
case '\n', '\r', '\'', ';':
|
||||
return false
|
||||
if st != escaped {
|
||||
switch r {
|
||||
// No newlines or string sequences.
|
||||
case '\n', '\r', '\'', ';':
|
||||
return false
|
||||
}
|
||||
}
|
||||
switch st {
|
||||
case outside:
|
||||
switch {
|
||||
case unicode.IsSpace(r):
|
||||
return false
|
||||
case r == '[':
|
||||
st = escaped
|
||||
continue
|
||||
case r == ']' && rPrev == ']':
|
||||
st = escaped
|
||||
continue
|
||||
case unicode.IsLetter(r):
|
||||
st = text
|
||||
case r == '_':
|
||||
st = text
|
||||
case r == '#':
|
||||
st = text
|
||||
case r == '.':
|
||||
default:
|
||||
return false
|
||||
}
|
||||
case text:
|
||||
switch {
|
||||
case r == '.':
|
||||
st = outside
|
||||
continue
|
||||
case r == '[':
|
||||
return false
|
||||
case r == '(':
|
||||
return false
|
||||
case unicode.IsSpace(r):
|
||||
return false
|
||||
}
|
||||
|
@ -531,7 +563,6 @@ func isProc(s string) bool {
|
|||
switch {
|
||||
case r == ']':
|
||||
st = outside
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -558,7 +589,13 @@ func (s *Stmt) makeRPCParams(args []namedValue, isProc bool) ([]param, []string,
|
|||
name = fmt.Sprintf("@p%d", val.Ordinal)
|
||||
}
|
||||
params[i+offset].Name = name
|
||||
decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+offset].ti))
|
||||
const outputSuffix = " output"
|
||||
var output string
|
||||
if isOutputValue(val.Value) {
|
||||
output = outputSuffix
|
||||
}
|
||||
decls[i] = fmt.Sprintf("%s %s%s", name, makeDecl(params[i+offset].ti), output)
|
||||
|
||||
}
|
||||
return params, decls, nil
|
||||
}
|
||||
|
@ -581,6 +618,8 @@ func convertOldArgs(args []driver.Value) []namedValue {
|
|||
}
|
||||
|
||||
func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
defer s.c.clearOuts()
|
||||
|
||||
return s.queryContext(context.Background(), convertOldArgs(args))
|
||||
}
|
||||
|
||||
|
@ -589,48 +628,60 @@ func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver
|
|||
return nil, driver.ErrBadConn
|
||||
}
|
||||
if err = s.sendQuery(args); err != nil {
|
||||
return nil, s.c.checkBadConn(err)
|
||||
return nil, s.c.checkBadConn(err, true)
|
||||
}
|
||||
return s.processQueryResponse(ctx)
|
||||
}
|
||||
|
||||
func (s *Stmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {
|
||||
tokchan := make(chan tokenStruct, 5)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
|
||||
reader := startReading(s.c.sess, ctx, s.c.outs)
|
||||
s.c.clearOuts()
|
||||
// process metadata
|
||||
var cols []columnStruct
|
||||
loop:
|
||||
for tok := range tokchan {
|
||||
switch token := tok.(type) {
|
||||
// By ignoring DONE token we effectively
|
||||
// skip empty result-sets.
|
||||
// This improves results in queries like that:
|
||||
// set nocount on; select 1
|
||||
// see TestIgnoreEmptyResults test
|
||||
//case doneStruct:
|
||||
//break loop
|
||||
case []columnStruct:
|
||||
cols = token
|
||||
break loop
|
||||
case doneStruct:
|
||||
if token.isError() {
|
||||
cancel()
|
||||
return nil, s.c.checkBadConn(token.getError())
|
||||
for {
|
||||
tok, err := reader.nextToken()
|
||||
if err == nil {
|
||||
if tok == nil {
|
||||
break
|
||||
} else {
|
||||
switch token := tok.(type) {
|
||||
// By ignoring DONE token we effectively
|
||||
// skip empty result-sets.
|
||||
// This improves results in queries like that:
|
||||
// set nocount on; select 1
|
||||
// see TestIgnoreEmptyResults test
|
||||
//case doneStruct:
|
||||
//break loop
|
||||
case []columnStruct:
|
||||
cols = token
|
||||
break loop
|
||||
case doneStruct:
|
||||
if token.isError() {
|
||||
// need to cleanup cancellable context
|
||||
cancel()
|
||||
return nil, s.c.checkBadConn(token.getError(), false)
|
||||
}
|
||||
case ReturnStatus:
|
||||
if reader.outs.returnStatus != nil {
|
||||
*reader.outs.returnStatus = token
|
||||
}
|
||||
}
|
||||
}
|
||||
case ReturnStatus:
|
||||
s.c.setReturnStatus(token)
|
||||
case error:
|
||||
} else {
|
||||
// need to cleanup cancellable context
|
||||
cancel()
|
||||
return nil, s.c.checkBadConn(token)
|
||||
return nil, s.c.checkBadConn(err, false)
|
||||
}
|
||||
}
|
||||
res = &Rows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel}
|
||||
res = &Rows{stmt: s, reader: reader, cols: cols, cancel: cancel}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
defer s.c.clearOuts()
|
||||
|
||||
return s.exec(context.Background(), convertOldArgs(args))
|
||||
}
|
||||
|
||||
|
@ -639,57 +690,55 @@ func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result,
|
|||
return nil, driver.ErrBadConn
|
||||
}
|
||||
if err = s.sendQuery(args); err != nil {
|
||||
return nil, s.c.checkBadConn(err)
|
||||
return nil, s.c.checkBadConn(err, true)
|
||||
}
|
||||
if res, err = s.processExec(ctx); err != nil {
|
||||
return nil, s.c.checkBadConn(err)
|
||||
return nil, err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) {
|
||||
tokchan := make(chan tokenStruct, 5)
|
||||
go processResponse(ctx, s.c.sess, tokchan, s.c.outs)
|
||||
reader := startReading(s.c.sess, ctx, s.c.outs)
|
||||
s.c.clearOuts()
|
||||
var rowCount int64
|
||||
for token := range tokchan {
|
||||
switch token := token.(type) {
|
||||
case doneInProcStruct:
|
||||
if token.Status&doneCount != 0 {
|
||||
rowCount += int64(token.RowCount)
|
||||
}
|
||||
case doneStruct:
|
||||
if token.Status&doneCount != 0 {
|
||||
rowCount += int64(token.RowCount)
|
||||
}
|
||||
if token.isError() {
|
||||
return nil, token.getError()
|
||||
}
|
||||
case ReturnStatus:
|
||||
s.c.setReturnStatus(token)
|
||||
case error:
|
||||
return nil, token
|
||||
}
|
||||
err = reader.iterateResponse()
|
||||
if err != nil {
|
||||
return nil, s.c.checkBadConn(err, false)
|
||||
}
|
||||
return &Result{s.c, rowCount}, nil
|
||||
return &Result{s.c, reader.rowCount}, nil
|
||||
}
|
||||
|
||||
type Rows struct {
|
||||
stmt *Stmt
|
||||
cols []columnStruct
|
||||
tokchan chan tokenStruct
|
||||
|
||||
stmt *Stmt
|
||||
cols []columnStruct
|
||||
reader *tokenProcessor
|
||||
nextCols []columnStruct
|
||||
|
||||
cancel func()
|
||||
}
|
||||
|
||||
func (rc *Rows) Close() error {
|
||||
// need to add a test which returns lots of rows
|
||||
// and check closing after reading only few rows
|
||||
rc.cancel()
|
||||
for _ = range rc.tokchan {
|
||||
|
||||
for {
|
||||
tok, err := rc.reader.nextToken()
|
||||
if err == nil {
|
||||
if tok == nil {
|
||||
return nil
|
||||
} else {
|
||||
// continue consuming tokens
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if err == rc.reader.ctx.Err() {
|
||||
return nil
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
rc.tokchan = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rc *Rows) Columns() (res []string) {
|
||||
|
@ -707,27 +756,36 @@ func (rc *Rows) Next(dest []driver.Value) error {
|
|||
if rc.nextCols != nil {
|
||||
return io.EOF
|
||||
}
|
||||
for tok := range rc.tokchan {
|
||||
switch tokdata := tok.(type) {
|
||||
case []columnStruct:
|
||||
rc.nextCols = tokdata
|
||||
return io.EOF
|
||||
case []interface{}:
|
||||
for i := range dest {
|
||||
dest[i] = tokdata[i]
|
||||
for {
|
||||
tok, err := rc.reader.nextToken()
|
||||
if err == nil {
|
||||
if tok == nil {
|
||||
return io.EOF
|
||||
} else {
|
||||
switch tokdata := tok.(type) {
|
||||
case []columnStruct:
|
||||
rc.nextCols = tokdata
|
||||
return io.EOF
|
||||
case []interface{}:
|
||||
for i := range dest {
|
||||
dest[i] = tokdata[i]
|
||||
}
|
||||
return nil
|
||||
case doneStruct:
|
||||
if tokdata.isError() {
|
||||
return rc.stmt.c.checkBadConn(tokdata.getError(), false)
|
||||
}
|
||||
case ReturnStatus:
|
||||
if rc.reader.outs.returnStatus != nil {
|
||||
*rc.reader.outs.returnStatus = tokdata
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case doneStruct:
|
||||
if tokdata.isError() {
|
||||
return rc.stmt.c.checkBadConn(tokdata.getError())
|
||||
}
|
||||
case ReturnStatus:
|
||||
rc.stmt.c.setReturnStatus(tokdata)
|
||||
case error:
|
||||
return rc.stmt.c.checkBadConn(tokdata)
|
||||
|
||||
} else {
|
||||
return rc.stmt.c.checkBadConn(err, false)
|
||||
}
|
||||
}
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
func (rc *Rows) HasNextResultSet() bool {
|
||||
|
@ -895,35 +953,41 @@ func (c *Conn) Ping(ctx context.Context) error {
|
|||
|
||||
var _ driver.ConnBeginTx = &Conn{}
|
||||
|
||||
func convertIsolationLevel(level sql.IsolationLevel) (isoLevel, error) {
|
||||
switch level {
|
||||
case sql.LevelDefault:
|
||||
return isolationUseCurrent, nil
|
||||
case sql.LevelReadUncommitted:
|
||||
return isolationReadUncommited, nil
|
||||
case sql.LevelReadCommitted:
|
||||
return isolationReadCommited, nil
|
||||
case sql.LevelWriteCommitted:
|
||||
return isolationUseCurrent, errors.New("LevelWriteCommitted isolation level is not supported")
|
||||
case sql.LevelRepeatableRead:
|
||||
return isolationRepeatableRead, nil
|
||||
case sql.LevelSnapshot:
|
||||
return isolationSnapshot, nil
|
||||
case sql.LevelSerializable:
|
||||
return isolationSerializable, nil
|
||||
case sql.LevelLinearizable:
|
||||
return isolationUseCurrent, errors.New("LevelLinearizable isolation level is not supported")
|
||||
default:
|
||||
return isolationUseCurrent, errors.New("isolation level is not supported or unknown")
|
||||
}
|
||||
}
|
||||
|
||||
// BeginTx satisfies ConnBeginTx.
|
||||
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
if !c.connectionGood {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
if opts.ReadOnly {
|
||||
return nil, errors.New("Read-only transactions are not supported")
|
||||
return nil, errors.New("read-only transactions are not supported")
|
||||
}
|
||||
|
||||
var tdsIsolation isoLevel
|
||||
switch sql.IsolationLevel(opts.Isolation) {
|
||||
case sql.LevelDefault:
|
||||
tdsIsolation = isolationUseCurrent
|
||||
case sql.LevelReadUncommitted:
|
||||
tdsIsolation = isolationReadUncommited
|
||||
case sql.LevelReadCommitted:
|
||||
tdsIsolation = isolationReadCommited
|
||||
case sql.LevelWriteCommitted:
|
||||
return nil, errors.New("LevelWriteCommitted isolation level is not supported")
|
||||
case sql.LevelRepeatableRead:
|
||||
tdsIsolation = isolationRepeatableRead
|
||||
case sql.LevelSnapshot:
|
||||
tdsIsolation = isolationSnapshot
|
||||
case sql.LevelSerializable:
|
||||
tdsIsolation = isolationSerializable
|
||||
case sql.LevelLinearizable:
|
||||
return nil, errors.New("LevelLinearizable isolation level is not supported")
|
||||
default:
|
||||
return nil, errors.New("Isolation level is not supported or unknown")
|
||||
tdsIsolation, err := convertIsolationLevel(sql.IsolationLevel(opts.Isolation))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.begin(ctx, tdsIsolation)
|
||||
}
|
||||
|
@ -940,6 +1004,8 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
|
|||
}
|
||||
|
||||
func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
defer s.c.clearOuts()
|
||||
|
||||
if !s.c.connectionGood {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
@ -951,6 +1017,8 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
|
|||
}
|
||||
|
||||
func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
defer s.c.clearOuts()
|
||||
|
||||
if !s.c.connectionGood {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
|
2
vendor/github.com/denisenkom/go-mssqldb/mssql_go110.go
generated
vendored
2
vendor/github.com/denisenkom/go-mssqldb/mssql_go110.go
generated
vendored
|
@ -48,5 +48,5 @@ func (c *Connector) Driver() driver.Driver {
|
|||
}
|
||||
|
||||
func (r *Result) LastInsertId() (int64, error) {
|
||||
return -1, errors.New("LastInsertId is not supported. Please use the OUTPUT clause or add `select ID = convert(bigint, SCOPE_IDENTITY())` to the end of your query.")
|
||||
return -1, errors.New("LastInsertId is not supported. Please use the OUTPUT clause or add `select ID = convert(bigint, SCOPE_IDENTITY())` to the end of your query")
|
||||
}
|
||||
|
|
14
vendor/github.com/denisenkom/go-mssqldb/mssql_go118.go
generated
vendored
Normal file
14
vendor/github.com/denisenkom/go-mssqldb/mssql_go118.go
generated
vendored
Normal file
|
@ -0,0 +1,14 @@
|
|||
// +build go1.18
|
||||
|
||||
package mssql
|
||||
|
||||
// newRetryableError returns an error that allows the database/sql package
|
||||
// to automatically retry the failed query. Versions of Go 1.18 and higher
|
||||
// use errors.Is to determine whether or not a failed query can be retried.
|
||||
// Therefore, we wrap the underlying error in a RetryableError that both
|
||||
// implements errors.Is for automatic retry and maintains the error details.
|
||||
func newRetryableError(err error) error {
|
||||
return RetryableError{
|
||||
err: err,
|
||||
}
|
||||
}
|
17
vendor/github.com/denisenkom/go-mssqldb/mssql_go118pre.go
generated
vendored
Normal file
17
vendor/github.com/denisenkom/go-mssqldb/mssql_go118pre.go
generated
vendored
Normal file
|
@ -0,0 +1,17 @@
|
|||
// +build !go1.18
|
||||
|
||||
package mssql
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
)
|
||||
|
||||
// newRetryableError returns an error that allows the database/sql package
|
||||
// to automatically retry the failed query. Versions of Go lower than 1.18
|
||||
// compare directly to the sentinel error driver.ErrBadConn to determine
|
||||
// whether or not a failed query can be retried. Therefore, we replace the
|
||||
// actual error with driver.ErrBadConn, enabling retry but losing the error
|
||||
// details.
|
||||
func newRetryableError(err error) error {
|
||||
return driver.ErrBadConn
|
||||
}
|
13
vendor/github.com/denisenkom/go-mssqldb/mssql_go19.go
generated
vendored
13
vendor/github.com/denisenkom/go-mssqldb/mssql_go19.go
generated
vendored
|
@ -66,10 +66,10 @@ func convertInputParameter(val interface{}) (interface{}, error) {
|
|||
func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error {
|
||||
switch v := nv.Value.(type) {
|
||||
case sql.Out:
|
||||
if c.outs == nil {
|
||||
c.outs = make(map[string]interface{})
|
||||
if c.outs.params == nil {
|
||||
c.outs.params = make(map[string]interface{})
|
||||
}
|
||||
c.outs[nv.Name] = v.Dest
|
||||
c.outs.params[nv.Name] = v.Dest
|
||||
|
||||
if v.Dest == nil {
|
||||
return errors.New("destination is a nil pointer")
|
||||
|
@ -110,7 +110,7 @@ func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error {
|
|||
return nil
|
||||
case *ReturnStatus:
|
||||
*v = 0 // By default the return value should be zero.
|
||||
c.returnStatus = v
|
||||
c.outs.returnStatus = v
|
||||
return driver.ErrRemoveArgument
|
||||
case TVP:
|
||||
return nil
|
||||
|
@ -194,3 +194,8 @@ func (s *Stmt) makeParamExtra(val driver.Value) (res param, err error) {
|
|||
func scanIntoOut(name string, fromServer, scanInto interface{}) error {
|
||||
return convertAssign(scanInto, fromServer)
|
||||
}
|
||||
|
||||
func isOutputValue(val driver.Value) bool {
|
||||
_, out := val.(sql.Out)
|
||||
return out
|
||||
}
|
||||
|
|
4
vendor/github.com/denisenkom/go-mssqldb/mssql_go19pre.go
generated
vendored
4
vendor/github.com/denisenkom/go-mssqldb/mssql_go19pre.go
generated
vendored
|
@ -14,3 +14,7 @@ func (s *Stmt) makeParamExtra(val driver.Value) (param, error) {
|
|||
func scanIntoOut(name string, fromServer, scanInto interface{}) error {
|
||||
return fmt.Errorf("mssql: unsupported OUTPUT type, use a newer Go version")
|
||||
}
|
||||
|
||||
func isOutputValue(val driver.Value) bool {
|
||||
return false
|
||||
}
|
||||
|
|
44
vendor/github.com/denisenkom/go-mssqldb/net.go
generated
vendored
44
vendor/github.com/denisenkom/go-mssqldb/net.go
generated
vendored
|
@ -7,8 +7,8 @@ import (
|
|||
)
|
||||
|
||||
type timeoutConn struct {
|
||||
c net.Conn
|
||||
timeout time.Duration
|
||||
c net.Conn
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func newTimeoutConn(conn net.Conn, timeout time.Duration) *timeoutConn {
|
||||
|
@ -51,21 +51,21 @@ func (c timeoutConn) RemoteAddr() net.Addr {
|
|||
}
|
||||
|
||||
func (c timeoutConn) SetDeadline(t time.Time) error {
|
||||
panic("Not implemented")
|
||||
return c.c.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c timeoutConn) SetReadDeadline(t time.Time) error {
|
||||
panic("Not implemented")
|
||||
return c.c.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c timeoutConn) SetWriteDeadline(t time.Time) error {
|
||||
panic("Not implemented")
|
||||
return c.c.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
// this connection is used during TLS Handshake
|
||||
// TDS protocol requires TLS handshake messages to be sent inside TDS packets
|
||||
type tlsHandshakeConn struct {
|
||||
buf *tdsBuffer
|
||||
buf *tdsBuffer
|
||||
packetPending bool
|
||||
continueRead bool
|
||||
}
|
||||
|
@ -75,7 +75,7 @@ func (c *tlsHandshakeConn) Read(b []byte) (n int, err error) {
|
|||
c.packetPending = false
|
||||
err = c.buf.FinishPacket()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Cannot send handshake packet: %s", err.Error())
|
||||
err = fmt.Errorf("cannot send handshake packet: %s", err.Error())
|
||||
return
|
||||
}
|
||||
c.continueRead = false
|
||||
|
@ -84,7 +84,7 @@ func (c *tlsHandshakeConn) Read(b []byte) (n int, err error) {
|
|||
var packet packetType
|
||||
packet, err = c.buf.BeginRead()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Cannot read handshake packet: %s", err.Error())
|
||||
err = fmt.Errorf("cannot read handshake packet: %s", err.Error())
|
||||
return
|
||||
}
|
||||
if packet != packPrelogin {
|
||||
|
@ -105,27 +105,27 @@ func (c *tlsHandshakeConn) Write(b []byte) (n int, err error) {
|
|||
}
|
||||
|
||||
func (c *tlsHandshakeConn) Close() error {
|
||||
panic("Not implemented")
|
||||
return c.buf.transport.Close()
|
||||
}
|
||||
|
||||
func (c *tlsHandshakeConn) LocalAddr() net.Addr {
|
||||
panic("Not implemented")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *tlsHandshakeConn) RemoteAddr() net.Addr {
|
||||
panic("Not implemented")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *tlsHandshakeConn) SetDeadline(t time.Time) error {
|
||||
panic("Not implemented")
|
||||
func (c *tlsHandshakeConn) SetDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *tlsHandshakeConn) SetReadDeadline(t time.Time) error {
|
||||
panic("Not implemented")
|
||||
func (c *tlsHandshakeConn) SetReadDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *tlsHandshakeConn) SetWriteDeadline(t time.Time) error {
|
||||
panic("Not implemented")
|
||||
func (c *tlsHandshakeConn) SetWriteDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// this connection just delegates all methods to it's wrapped connection
|
||||
|
@ -148,21 +148,21 @@ func (c passthroughConn) Close() error {
|
|||
}
|
||||
|
||||
func (c passthroughConn) LocalAddr() net.Addr {
|
||||
panic("Not implemented")
|
||||
return c.c.LocalAddr()
|
||||
}
|
||||
|
||||
func (c passthroughConn) RemoteAddr() net.Addr {
|
||||
panic("Not implemented")
|
||||
return c.c.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c passthroughConn) SetDeadline(t time.Time) error {
|
||||
panic("Not implemented")
|
||||
return c.c.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c passthroughConn) SetReadDeadline(t time.Time) error {
|
||||
panic("Not implemented")
|
||||
return c.c.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c passthroughConn) SetWriteDeadline(t time.Time) error {
|
||||
panic("Not implemented")
|
||||
return c.c.SetWriteDeadline(t)
|
||||
}
|
||||
|
|
13
vendor/github.com/denisenkom/go-mssqldb/ntlm.go
generated
vendored
13
vendor/github.com/denisenkom/go-mssqldb/ntlm.go
generated
vendored
|
@ -14,6 +14,7 @@ import (
|
|||
"time"
|
||||
"unicode/utf16"
|
||||
|
||||
//lint:ignore SA1019 MD4 is used by legacy NTLM
|
||||
"golang.org/x/crypto/md4"
|
||||
)
|
||||
|
||||
|
@ -126,18 +127,6 @@ func createDesKey(bytes, material []byte) {
|
|||
material[7] = (byte)(bytes[6] << 1)
|
||||
}
|
||||
|
||||
func oddParity(bytes []byte) {
|
||||
for i := 0; i < len(bytes); i++ {
|
||||
b := bytes[i]
|
||||
needsParity := (((b >> 7) ^ (b >> 6) ^ (b >> 5) ^ (b >> 4) ^ (b >> 3) ^ (b >> 2) ^ (b >> 1)) & 0x01) == 0
|
||||
if needsParity {
|
||||
bytes[i] = bytes[i] | byte(0x01)
|
||||
} else {
|
||||
bytes[i] = bytes[i] & byte(0xfe)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func encryptDes(key []byte, cleartext []byte, ciphertext []byte) {
|
||||
var desKey [8]byte
|
||||
createDesKey(key, desKey[:])
|
||||
|
|
6
vendor/github.com/denisenkom/go-mssqldb/rpc.go
generated
vendored
6
vendor/github.com/denisenkom/go-mssqldb/rpc.go
generated
vendored
|
@ -22,12 +22,6 @@ type param struct {
|
|||
buffer []byte
|
||||
}
|
||||
|
||||
const (
|
||||
fWithRecomp = 1
|
||||
fNoMetaData = 2
|
||||
fReuseMetaData = 4
|
||||
)
|
||||
|
||||
var (
|
||||
sp_Cursor = procId{1, ""}
|
||||
sp_CursorOpen = procId{2, ""}
|
||||
|
|
602
vendor/github.com/denisenkom/go-mssqldb/tds.go
generated
vendored
602
vendor/github.com/denisenkom/go-mssqldb/tds.go
generated
vendored
|
@ -3,7 +3,6 @@ package mssql
|
|||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -13,8 +12,11 @@ import (
|
|||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf16"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/denisenkom/go-mssqldb/msdsn"
|
||||
)
|
||||
|
||||
func parseInstances(msg []byte) map[string]map[string]string {
|
||||
|
@ -82,19 +84,20 @@ const (
|
|||
// https://msdn.microsoft.com/en-us/library/dd304214.aspx
|
||||
const (
|
||||
packSQLBatch packetType = 1
|
||||
packRPCRequest = 3
|
||||
packReply = 4
|
||||
packRPCRequest packetType = 3
|
||||
packReply packetType = 4
|
||||
|
||||
// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
|
||||
// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
|
||||
packAttention = 6
|
||||
packAttention packetType = 6
|
||||
|
||||
packBulkLoadBCP = 7
|
||||
packTransMgrReq = 14
|
||||
packNormal = 15
|
||||
packLogin7 = 16
|
||||
packSSPIMessage = 17
|
||||
packPrelogin = 18
|
||||
packBulkLoadBCP packetType = 7
|
||||
packFedAuthToken packetType = 8
|
||||
packTransMgrReq packetType = 14
|
||||
packNormal packetType = 15
|
||||
packLogin7 packetType = 16
|
||||
packSSPIMessage packetType = 17
|
||||
packPrelogin packetType = 18
|
||||
)
|
||||
|
||||
// prelogin fields
|
||||
|
@ -118,6 +121,17 @@ const (
|
|||
encryptReq = 3 // Encryption is required.
|
||||
)
|
||||
|
||||
const (
|
||||
featExtSESSIONRECOVERY byte = 0x01
|
||||
featExtFEDAUTH byte = 0x02
|
||||
featExtCOLUMNENCRYPTION byte = 0x04
|
||||
featExtGLOBALTRANSACTIONS byte = 0x05
|
||||
featExtAZURESQLSUPPORT byte = 0x08
|
||||
featExtDATACLASSIFICATION byte = 0x09
|
||||
featExtUTF8SUPPORT byte = 0x0A
|
||||
featExtTERMINATOR byte = 0xFF
|
||||
)
|
||||
|
||||
type tdsSession struct {
|
||||
buf *tdsBuffer
|
||||
loginAck loginAckStruct
|
||||
|
@ -132,13 +146,21 @@ type tdsSession struct {
|
|||
}
|
||||
|
||||
const (
|
||||
logErrors = 1
|
||||
logMessages = 2
|
||||
logRows = 4
|
||||
logSQL = 8
|
||||
logParams = 16
|
||||
logTransaction = 32
|
||||
logDebug = 64
|
||||
// Default packet size for a TDS buffer.
|
||||
defaultPacketSize = 4096
|
||||
|
||||
// Default port if no port given.
|
||||
defaultServerPort = 1433
|
||||
)
|
||||
|
||||
const (
|
||||
logErrors = uint64(msdsn.LogErrors)
|
||||
logMessages = uint64(msdsn.LogMessages)
|
||||
logRows = uint64(msdsn.LogRows)
|
||||
logSQL = uint64(msdsn.LogSQL)
|
||||
logParams = uint64(msdsn.LogParams)
|
||||
logTransaction = uint64(msdsn.LogTransaction)
|
||||
logDebug = uint64(msdsn.LogDebug)
|
||||
)
|
||||
|
||||
type columnStruct struct {
|
||||
|
@ -155,13 +177,13 @@ func (p keySlice) Less(i, j int) bool { return p[i] < p[j] }
|
|||
func (p keySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||
|
||||
// http://msdn.microsoft.com/en-us/library/dd357559.aspx
|
||||
func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error {
|
||||
func writePrelogin(packetType packetType, w *tdsBuffer, fields map[uint8][]byte) error {
|
||||
var err error
|
||||
|
||||
w.BeginPacket(packPrelogin, false)
|
||||
w.BeginPacket(packetType, false)
|
||||
offset := uint16(5*len(fields) + 1)
|
||||
keys := make(keySlice, 0, len(fields))
|
||||
for k, _ := range fields {
|
||||
for k := range fields {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Sort(keys)
|
||||
|
@ -210,12 +232,15 @@ func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if packet_type != 4 {
|
||||
return nil, errors.New("Invalid respones, expected packet type 4, PRELOGIN RESPONSE")
|
||||
if packet_type != packReply {
|
||||
return nil, errors.New("invalid respones, expected packet type 4, PRELOGIN RESPONSE")
|
||||
}
|
||||
if len(struct_buf) == 0 {
|
||||
return nil, errors.New("invalid empty PRELOGIN response, it must contain at least one byte")
|
||||
}
|
||||
offset := 0
|
||||
results := map[uint8][]byte{}
|
||||
for true {
|
||||
for {
|
||||
rec_type := struct_buf[offset]
|
||||
if rec_type == preloginTERMINATOR {
|
||||
break
|
||||
|
@ -240,6 +265,16 @@ const (
|
|||
fIntSecurity = 0x80
|
||||
)
|
||||
|
||||
// OptionFlags3
|
||||
// http://msdn.microsoft.com/en-us/library/dd304019.aspx
|
||||
const (
|
||||
fChangePassword = 1
|
||||
fSendYukonBinaryXML = 2
|
||||
fUserInstance = 4
|
||||
fUnknownCollationHandling = 8
|
||||
fExtension = 0x10
|
||||
)
|
||||
|
||||
// TypeFlags
|
||||
const (
|
||||
// 4 bits for fSQLType
|
||||
|
@ -247,12 +282,6 @@ const (
|
|||
fReadOnlyIntent = 32
|
||||
)
|
||||
|
||||
// OptionFlags3
|
||||
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac
|
||||
const (
|
||||
fExtension = 0x10
|
||||
)
|
||||
|
||||
type login struct {
|
||||
TDSVersion uint32
|
||||
PacketSize uint32
|
||||
|
@ -295,7 +324,7 @@ func (e *featureExts) Add(f featureExt) error {
|
|||
}
|
||||
id := f.featureID()
|
||||
if _, exists := e.features[id]; exists {
|
||||
f := "Login error: Feature with ID '%v' is already present in FeatureExt block."
|
||||
f := "login error: Feature with ID '%v' is already present in FeatureExt block"
|
||||
return fmt.Errorf(f, id)
|
||||
}
|
||||
if e.features == nil {
|
||||
|
@ -326,37 +355,63 @@ func (e featureExts) toBytes() []byte {
|
|||
return d
|
||||
}
|
||||
|
||||
type featureExtFedAuthSTS struct {
|
||||
FedAuthEcho bool
|
||||
// featureExtFedAuth tracks federated authentication state before and during login
|
||||
type featureExtFedAuth struct {
|
||||
// FedAuthLibrary is populated by the federated authentication provider.
|
||||
FedAuthLibrary int
|
||||
|
||||
// ADALWorkflow is populated by the federated authentication provider.
|
||||
ADALWorkflow byte
|
||||
|
||||
// FedAuthEcho is populated from the prelogin response
|
||||
FedAuthEcho bool
|
||||
|
||||
// FedAuthToken is populated during login with the value from the provider.
|
||||
FedAuthToken string
|
||||
Nonce []byte
|
||||
|
||||
// Nonce is populated during login with the value from the provider.
|
||||
Nonce []byte
|
||||
|
||||
// Signature is populated during login with the value from the server.
|
||||
Signature []byte
|
||||
}
|
||||
|
||||
func (e *featureExtFedAuthSTS) featureID() byte {
|
||||
return 0x02
|
||||
func (e *featureExtFedAuth) featureID() byte {
|
||||
return featExtFEDAUTH
|
||||
}
|
||||
|
||||
func (e *featureExtFedAuthSTS) toBytes() []byte {
|
||||
func (e *featureExtFedAuth) toBytes() []byte {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
options := byte(0x01) << 1 // 0x01 => STS bFedAuthLibrary 7BIT
|
||||
options := byte(e.FedAuthLibrary) << 1
|
||||
if e.FedAuthEcho {
|
||||
options |= 1 // fFedAuthEcho
|
||||
}
|
||||
|
||||
d := make([]byte, 5)
|
||||
d[0] = options
|
||||
// Feature extension format depends on the federated auth library.
|
||||
// Options are described at
|
||||
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac
|
||||
var d []byte
|
||||
|
||||
// looks like string in
|
||||
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/f88b63bb-b479-49e1-a87b-deda521da508
|
||||
tokenBytes := str2ucs2(e.FedAuthToken)
|
||||
binary.LittleEndian.PutUint32(d[1:], uint32(len(tokenBytes))) // Should be a signed int32, but since the length is relatively small, this should work
|
||||
d = append(d, tokenBytes...)
|
||||
switch e.FedAuthLibrary {
|
||||
case fedAuthLibrarySecurityToken:
|
||||
d = make([]byte, 5)
|
||||
d[0] = options
|
||||
|
||||
if len(e.Nonce) == 32 {
|
||||
d = append(d, e.Nonce...)
|
||||
// looks like string in
|
||||
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/f88b63bb-b479-49e1-a87b-deda521da508
|
||||
tokenBytes := str2ucs2(e.FedAuthToken)
|
||||
binary.LittleEndian.PutUint32(d[1:], uint32(len(tokenBytes))) // Should be a signed int32, but since the length is relatively small, this should work
|
||||
d = append(d, tokenBytes...)
|
||||
|
||||
if len(e.Nonce) == 32 {
|
||||
d = append(d, e.Nonce...)
|
||||
}
|
||||
|
||||
case fedAuthLibraryADAL:
|
||||
d = []byte{options, e.ADALWorkflow}
|
||||
}
|
||||
|
||||
return d
|
||||
|
@ -418,7 +473,7 @@ func str2ucs2(s string) []byte {
|
|||
|
||||
func ucs22str(s []byte) (string, error) {
|
||||
if len(s)%2 != 0 {
|
||||
return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s))
|
||||
return "", fmt.Errorf("illegal UCS2 string length: %d", len(s))
|
||||
}
|
||||
buf := make([]uint16, len(s)/2)
|
||||
for i := 0; i < len(s); i += 2 {
|
||||
|
@ -436,7 +491,7 @@ func manglePassword(password string) []byte {
|
|||
}
|
||||
|
||||
// http://msdn.microsoft.com/en-us/library/dd304019.aspx
|
||||
func sendLogin(w *tdsBuffer, login login) error {
|
||||
func sendLogin(w *tdsBuffer, login *login) error {
|
||||
w.BeginPacket(packLogin7, false)
|
||||
hostname := str2ucs2(login.HostName)
|
||||
username := str2ucs2(login.UserName)
|
||||
|
@ -572,6 +627,36 @@ func sendLogin(w *tdsBuffer, login login) error {
|
|||
return w.FinishPacket()
|
||||
}
|
||||
|
||||
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/827d9632-2957-4d54-b9ea-384530ae79d0
|
||||
func sendFedAuthInfo(w *tdsBuffer, fedAuth *featureExtFedAuth) (err error) {
|
||||
fedauthtoken := str2ucs2(fedAuth.FedAuthToken)
|
||||
tokenlen := len(fedauthtoken)
|
||||
datalen := 4 + tokenlen + len(fedAuth.Nonce)
|
||||
|
||||
w.BeginPacket(packFedAuthToken, false)
|
||||
err = binary.Write(w, binary.LittleEndian, uint32(datalen))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = binary.Write(w, binary.LittleEndian, uint32(tokenlen))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = w.Write(fedauthtoken)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = w.Write(fedAuth.Nonce)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return w.FinishPacket()
|
||||
}
|
||||
|
||||
func readUcs2(r io.Reader, numchars int) (res string, err error) {
|
||||
buf := make([]byte, numchars*2)
|
||||
_, err = io.ReadFull(r, buf)
|
||||
|
@ -768,26 +853,27 @@ type auth interface {
|
|||
// SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
|
||||
// list of IP addresses. So if there is more than one, try them all and
|
||||
// use the first one that allows a connection.
|
||||
func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn net.Conn, err error) {
|
||||
func dialConnection(ctx context.Context, c *Connector, p msdsn.Config) (conn net.Conn, err error) {
|
||||
var ips []net.IP
|
||||
ips, err = net.LookupIP(p.host)
|
||||
if err != nil {
|
||||
ip := net.ParseIP(p.host)
|
||||
if ip == nil {
|
||||
return nil, err
|
||||
ip := net.ParseIP(p.Host)
|
||||
if ip == nil {
|
||||
ips, err = net.LookupIP(p.Host)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
ips = []net.IP{ip}
|
||||
}
|
||||
if len(ips) == 1 {
|
||||
d := c.getDialer(&p)
|
||||
addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(resolveServerPort(p.port))))
|
||||
addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(resolveServerPort(p.Port))))
|
||||
conn, err = d.DialContext(ctx, "tcp", addr)
|
||||
|
||||
} else {
|
||||
//Try Dials in parallel to avoid waiting for timeouts.
|
||||
connChan := make(chan net.Conn, len(ips))
|
||||
errChan := make(chan error, len(ips))
|
||||
portStr := strconv.Itoa(int(resolveServerPort(p.port)))
|
||||
portStr := strconv.Itoa(int(resolveServerPort(p.Port)))
|
||||
for _, ip := range ips {
|
||||
go func(ip net.IP) {
|
||||
d := c.getDialer(&p)
|
||||
|
@ -802,7 +888,7 @@ func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn ne
|
|||
}
|
||||
// Wait for either the *first* successful connection, or all the errors
|
||||
wait_loop:
|
||||
for i, _ := range ips {
|
||||
for i := range ips {
|
||||
select {
|
||||
case conn = <-connChan:
|
||||
// Got a connection to use, close any others
|
||||
|
@ -824,66 +910,28 @@ func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn ne
|
|||
}
|
||||
// Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection
|
||||
if conn == nil {
|
||||
f := "Unable to open tcp connection with host '%v:%v': %v"
|
||||
return nil, fmt.Errorf(f, p.host, resolveServerPort(p.port), err.Error())
|
||||
f := "unable to open tcp connection with host '%v:%v': %v"
|
||||
return nil, fmt.Errorf(f, p.Host, resolveServerPort(p.Port), err.Error())
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
||||
func connect(ctx context.Context, c *Connector, log optionalLogger, p connectParams) (res *tdsSession, err error) {
|
||||
dialCtx := ctx
|
||||
if p.dial_timeout > 0 {
|
||||
var cancel func()
|
||||
dialCtx, cancel = context.WithTimeout(ctx, p.dial_timeout)
|
||||
defer cancel()
|
||||
}
|
||||
// if instance is specified use instance resolution service
|
||||
if p.instance != "" && p.port == 0 {
|
||||
p.instance = strings.ToUpper(p.instance)
|
||||
d := c.getDialer(&p)
|
||||
instances, err := getInstances(dialCtx, d, p.host)
|
||||
if err != nil {
|
||||
f := "Unable to get instances from Sql Server Browser on host %v: %v"
|
||||
return nil, fmt.Errorf(f, p.host, err.Error())
|
||||
}
|
||||
strport, ok := instances[p.instance]["tcp"]
|
||||
if !ok {
|
||||
f := "No instance matching '%v' returned from host '%v'"
|
||||
return nil, fmt.Errorf(f, p.instance, p.host)
|
||||
}
|
||||
port, err := strconv.ParseUint(strport, 0, 16)
|
||||
if err != nil {
|
||||
f := "Invalid tcp port returned from Sql Server Browser '%v': %v"
|
||||
return nil, fmt.Errorf(f, strport, err.Error())
|
||||
}
|
||||
p.port = port
|
||||
}
|
||||
|
||||
initiate_connection:
|
||||
conn, err := dialConnection(dialCtx, c, p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toconn := newTimeoutConn(conn, p.conn_timeout)
|
||||
|
||||
outbuf := newTdsBuffer(p.packetSize, toconn)
|
||||
sess := tdsSession{
|
||||
buf: outbuf,
|
||||
log: log,
|
||||
logFlags: p.logFlags,
|
||||
}
|
||||
|
||||
instance_buf := []byte(p.instance)
|
||||
func preparePreloginFields(p msdsn.Config, fe *featureExtFedAuth) map[uint8][]byte {
|
||||
instance_buf := []byte(p.Instance)
|
||||
instance_buf = append(instance_buf, 0) // zero terminate instance name
|
||||
|
||||
var encrypt byte
|
||||
if p.disableEncryption {
|
||||
switch p.Encryption {
|
||||
default:
|
||||
panic(fmt.Errorf("Unsupported Encryption Config %v", p.Encryption))
|
||||
case msdsn.EncryptionDisabled:
|
||||
encrypt = encryptNotSup
|
||||
} else if p.encrypt {
|
||||
case msdsn.EncryptionRequired:
|
||||
encrypt = encryptOn
|
||||
} else {
|
||||
case msdsn.EncryptionOff:
|
||||
encrypt = encryptOff
|
||||
}
|
||||
|
||||
fields := map[uint8][]byte{
|
||||
preloginVERSION: {0, 0, 0, 0, 0, 0},
|
||||
preloginENCRYPTION: {encrypt},
|
||||
|
@ -892,7 +940,182 @@ initiate_connection:
|
|||
preloginMARS: {0}, // MARS disabled
|
||||
}
|
||||
|
||||
err = writePrelogin(outbuf, fields)
|
||||
if fe.FedAuthLibrary != fedAuthLibraryReserved {
|
||||
fields[preloginFEDAUTHREQUIRED] = []byte{1}
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
func interpretPreloginResponse(p msdsn.Config, fe *featureExtFedAuth, fields map[uint8][]byte) (encrypt byte, err error) {
|
||||
// If the server returns the preloginFEDAUTHREQUIRED field, then federated authentication
|
||||
// is supported. The actual value may be 0 or 1, where 0 means either SSPI or federated
|
||||
// authentication is allowed, while 1 means only federated authentication is allowed.
|
||||
if fedAuthSupport, ok := fields[preloginFEDAUTHREQUIRED]; ok {
|
||||
if len(fedAuthSupport) != 1 {
|
||||
return 0, fmt.Errorf("federated authentication flag length should be 1: is %d", len(fedAuthSupport))
|
||||
}
|
||||
|
||||
// We need to be able to echo the value back to the server
|
||||
fe.FedAuthEcho = fedAuthSupport[0] != 0
|
||||
} else if fe.FedAuthLibrary != fedAuthLibraryReserved {
|
||||
return 0, fmt.Errorf("federated authentication is not supported by the server")
|
||||
}
|
||||
|
||||
encryptBytes, ok := fields[preloginENCRYPTION]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("encrypt negotiation failed")
|
||||
}
|
||||
encrypt = encryptBytes[0]
|
||||
if p.Encryption == msdsn.EncryptionRequired && (encrypt == encryptNotSup || encrypt == encryptOff) {
|
||||
return 0, fmt.Errorf("server does not support encryption")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, log optionalLogger, auth auth, fe *featureExtFedAuth, packetSize uint32) (l *login, err error) {
|
||||
var typeFlags uint8
|
||||
if p.ReadOnlyIntent {
|
||||
typeFlags |= fReadOnlyIntent
|
||||
}
|
||||
l = &login{
|
||||
TDSVersion: verTDS74,
|
||||
PacketSize: packetSize,
|
||||
Database: p.Database,
|
||||
OptionFlags2: fODBC, // to get unlimited TEXTSIZE
|
||||
HostName: p.Workstation,
|
||||
ServerName: p.Host,
|
||||
AppName: p.AppName,
|
||||
TypeFlags: typeFlags,
|
||||
}
|
||||
switch {
|
||||
case fe.FedAuthLibrary == fedAuthLibrarySecurityToken:
|
||||
if uint64(p.LogFlags)&logDebug != 0 {
|
||||
log.Println("Starting federated authentication using security token")
|
||||
}
|
||||
|
||||
fe.FedAuthToken, err = c.securityTokenProvider(ctx)
|
||||
if err != nil {
|
||||
if uint64(p.LogFlags)&logDebug != 0 {
|
||||
log.Printf("Failed to retrieve service principal token for federated authentication security token library: %v", err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l.FeatureExt.Add(fe)
|
||||
|
||||
case fe.FedAuthLibrary == fedAuthLibraryADAL:
|
||||
if uint64(p.LogFlags)&logDebug != 0 {
|
||||
log.Println("Starting federated authentication using ADAL")
|
||||
}
|
||||
|
||||
l.FeatureExt.Add(fe)
|
||||
|
||||
case auth != nil:
|
||||
if uint64(p.LogFlags)&logDebug != 0 {
|
||||
log.Println("Starting SSPI login")
|
||||
}
|
||||
|
||||
l.SSPI, err = auth.InitialBytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l.OptionFlags2 |= fIntSecurity
|
||||
return l, nil
|
||||
|
||||
default:
|
||||
// Default to SQL server authentication with user and password
|
||||
l.UserName = p.User
|
||||
l.Password = p.Password
|
||||
}
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func connect(ctx context.Context, c *Connector, log optionalLogger, p msdsn.Config) (res *tdsSession, err error) {
|
||||
dialCtx := ctx
|
||||
if p.DialTimeout >= 0 {
|
||||
dt := p.DialTimeout
|
||||
if dt == 0 {
|
||||
dt = 15 * time.Second
|
||||
}
|
||||
var cancel func()
|
||||
dialCtx, cancel = context.WithTimeout(ctx, dt)
|
||||
defer cancel()
|
||||
}
|
||||
// if instance is specified use instance resolution service
|
||||
if len(p.Instance) > 0 && p.Port != 0 {
|
||||
// both instance name and port specified
|
||||
// when port is specified instance name is not used
|
||||
// you should not provide instance name when you provide port
|
||||
log.Println("WARN: You specified both instance name and port in the connection string, port will be used and instance name will be ignored")
|
||||
}
|
||||
if len(p.Instance) > 0 {
|
||||
p.Instance = strings.ToUpper(p.Instance)
|
||||
d := c.getDialer(&p)
|
||||
instances, err := getInstances(dialCtx, d, p.Host)
|
||||
if err != nil {
|
||||
f := "unable to get instances from Sql Server Browser on host %v: %v"
|
||||
return nil, fmt.Errorf(f, p.Host, err.Error())
|
||||
}
|
||||
strport, ok := instances[p.Instance]["tcp"]
|
||||
if !ok {
|
||||
f := "no instance matching '%v' returned from host '%v'"
|
||||
return nil, fmt.Errorf(f, p.Instance, p.Host)
|
||||
}
|
||||
port, err := strconv.ParseUint(strport, 0, 16)
|
||||
if err != nil {
|
||||
f := "invalid tcp port returned from Sql Server Browser '%v': %v"
|
||||
return nil, fmt.Errorf(f, strport, err.Error())
|
||||
}
|
||||
p.Port = port
|
||||
}
|
||||
if p.Port == 0 {
|
||||
p.Port = defaultServerPort
|
||||
}
|
||||
|
||||
packetSize := p.PacketSize
|
||||
if packetSize == 0 {
|
||||
packetSize = defaultPacketSize
|
||||
}
|
||||
// Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes
|
||||
// NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request
|
||||
// a higher packet size, the server will respond with an ENVCHANGE request to
|
||||
// alter the packet size to 16383 bytes.
|
||||
if packetSize < 512 {
|
||||
packetSize = 512
|
||||
} else if packetSize > 32767 {
|
||||
packetSize = 32767
|
||||
}
|
||||
|
||||
initiate_connection:
|
||||
conn, err := dialConnection(dialCtx, c, p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toconn := newTimeoutConn(conn, p.ConnTimeout)
|
||||
|
||||
outbuf := newTdsBuffer(packetSize, toconn)
|
||||
sess := tdsSession{
|
||||
buf: outbuf,
|
||||
log: log,
|
||||
logFlags: uint64(p.LogFlags),
|
||||
}
|
||||
|
||||
fedAuth := &featureExtFedAuth{
|
||||
FedAuthLibrary: fedAuthLibraryReserved,
|
||||
}
|
||||
if c.fedAuthRequired {
|
||||
fedAuth.FedAuthLibrary = c.fedAuthLibrary
|
||||
fedAuth.ADALWorkflow = c.fedAuthADALWorkflow
|
||||
}
|
||||
|
||||
fields := preparePreloginFields(p, fedAuth)
|
||||
|
||||
err = writePrelogin(packPrelogin, outbuf, fields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -902,39 +1125,36 @@ initiate_connection:
|
|||
return nil, err
|
||||
}
|
||||
|
||||
encryptBytes, ok := fields[preloginENCRYPTION]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Encrypt negotiation failed")
|
||||
}
|
||||
encrypt = encryptBytes[0]
|
||||
if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) {
|
||||
return nil, fmt.Errorf("Server does not support encryption")
|
||||
encrypt, err := interpretPreloginResponse(p, fedAuth, fields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if encrypt != encryptNotSup {
|
||||
var config tls.Config
|
||||
if p.certificate != "" {
|
||||
pem, err := ioutil.ReadFile(p.certificate)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Cannot read certificate %q: %v", p.certificate, err)
|
||||
var config *tls.Config
|
||||
if pc := p.TLSConfig; pc != nil {
|
||||
config = pc
|
||||
if config.DynamicRecordSizingDisabled == false {
|
||||
config = config.Clone()
|
||||
|
||||
// fix for https://github.com/denisenkom/go-mssqldb/issues/166
|
||||
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
|
||||
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
|
||||
// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
|
||||
config.DynamicRecordSizingDisabled = true
|
||||
}
|
||||
certs := x509.NewCertPool()
|
||||
certs.AppendCertsFromPEM(pem)
|
||||
config.RootCAs = certs
|
||||
}
|
||||
if p.trustServerCertificate {
|
||||
config.InsecureSkipVerify = true
|
||||
if config == nil {
|
||||
config, err = msdsn.SetupTLS("", false, p.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
config.ServerName = p.hostInCertificate
|
||||
// fix for https://github.com/denisenkom/go-mssqldb/issues/166
|
||||
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
|
||||
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
|
||||
// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
|
||||
config.DynamicRecordSizingDisabled = true
|
||||
|
||||
// setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream
|
||||
handshakeConn := tlsHandshakeConn{buf: outbuf}
|
||||
passthrough := passthroughConn{c: &handshakeConn}
|
||||
tlsConn := tls.Client(&passthrough, &config)
|
||||
tlsConn := tls.Client(&passthrough, config)
|
||||
err = tlsConn.Handshake()
|
||||
passthrough.c = toconn
|
||||
outbuf.transport = tlsConn
|
||||
|
@ -948,54 +1168,48 @@ initiate_connection:
|
|||
}
|
||||
}
|
||||
|
||||
login := login{
|
||||
TDSVersion: verTDS74,
|
||||
PacketSize: uint32(outbuf.PackageSize()),
|
||||
Database: p.database,
|
||||
OptionFlags2: fODBC, // to get unlimited TEXTSIZE
|
||||
HostName: p.workstation,
|
||||
ServerName: p.host,
|
||||
AppName: p.appname,
|
||||
TypeFlags: p.typeFlags,
|
||||
}
|
||||
auth, authOk := getAuth(p.user, p.password, p.serverSPN, p.workstation)
|
||||
switch {
|
||||
case p.fedAuthAccessToken != "": // accesstoken ignores user/password
|
||||
featurext := &featureExtFedAuthSTS{
|
||||
FedAuthEcho: len(fields[preloginFEDAUTHREQUIRED]) > 0 && fields[preloginFEDAUTHREQUIRED][0] == 1,
|
||||
FedAuthToken: p.fedAuthAccessToken,
|
||||
Nonce: fields[preloginNONCEOPT],
|
||||
}
|
||||
login.FeatureExt.Add(featurext)
|
||||
case authOk:
|
||||
login.SSPI, err = auth.InitialBytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
login.OptionFlags2 |= fIntSecurity
|
||||
auth, authOk := getAuth(p.User, p.Password, p.ServerSPN, p.Workstation)
|
||||
if authOk {
|
||||
defer auth.Free()
|
||||
default:
|
||||
login.UserName = p.user
|
||||
login.Password = p.password
|
||||
} else {
|
||||
auth = nil
|
||||
}
|
||||
|
||||
login, err := prepareLogin(ctx, c, p, log, auth, fedAuth, uint32(outbuf.PackageSize()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = sendLogin(outbuf, login)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// processing login response
|
||||
success := false
|
||||
for {
|
||||
tokchan := make(chan tokenStruct, 5)
|
||||
go processResponse(context.Background(), &sess, tokchan, nil)
|
||||
for tok := range tokchan {
|
||||
// Loop until a packet containing a login acknowledgement is received.
|
||||
// SSPI and federated authentication scenarios may require multiple
|
||||
// packet exchanges to complete the login sequence.
|
||||
for loginAck := false; !loginAck; {
|
||||
reader := startReading(&sess, ctx, outputs{})
|
||||
// don't send attention or wait for cancel confirmation during login
|
||||
reader.noAttn = true
|
||||
|
||||
for {
|
||||
tok, err := reader.nextToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tok == nil {
|
||||
break
|
||||
}
|
||||
|
||||
switch token := tok.(type) {
|
||||
case sspiMsg:
|
||||
sspi_msg, err := auth.NextBytes(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sspi_msg != nil && len(sspi_msg) > 0 {
|
||||
if len(sspi_msg) > 0 {
|
||||
outbuf.BeginPacket(packSSPIMessage, false)
|
||||
_, err = outbuf.Write(sspi_msg)
|
||||
if err != nil {
|
||||
|
@ -1007,31 +1221,59 @@ initiate_connection:
|
|||
}
|
||||
sspi_msg = nil
|
||||
}
|
||||
// TODO: for Live ID authentication it may be necessary to
|
||||
// compare fedAuth.Nonce == token.Nonce and keep track of signature
|
||||
//case fedAuthAckStruct:
|
||||
//fedAuth.Signature = token.Signature
|
||||
case fedAuthInfoStruct:
|
||||
// For ADAL workflows this contains the STS URL and server SPN.
|
||||
// If received outside of an ADAL workflow, ignore.
|
||||
if c == nil || c.adalTokenProvider == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Request the AD token given the server SPN and STS URL
|
||||
fedAuth.FedAuthToken, err = c.adalTokenProvider(ctx, token.ServerSPN, token.STSURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Now need to send the token as a FEDINFO packet
|
||||
err = sendFedAuthInfo(outbuf, fedAuth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case loginAckStruct:
|
||||
success = true
|
||||
sess.loginAck = token
|
||||
case error:
|
||||
return nil, fmt.Errorf("Login error: %s", token.Error())
|
||||
loginAck = true
|
||||
case doneStruct:
|
||||
if token.isError() {
|
||||
return nil, fmt.Errorf("Login error: %s", token.getError())
|
||||
tokenErr := token.getError()
|
||||
tokenErr.Message = "login error: " + tokenErr.Message
|
||||
return nil, tokenErr
|
||||
}
|
||||
goto loginEnd
|
||||
case error:
|
||||
return nil, fmt.Errorf("login error: %s", token.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
loginEnd:
|
||||
if !success {
|
||||
return nil, fmt.Errorf("Login failed")
|
||||
}
|
||||
|
||||
if sess.routedServer != "" {
|
||||
toconn.Close()
|
||||
p.host = sess.routedServer
|
||||
p.port = uint64(sess.routedPort)
|
||||
if !p.hostInCertificateProvided {
|
||||
p.hostInCertificate = sess.routedServer
|
||||
p.Host = sess.routedServer
|
||||
p.Port = uint64(sess.routedPort)
|
||||
if !p.HostInCertificateProvided && p.TLSConfig != nil {
|
||||
p.TLSConfig.ServerName = sess.routedServer
|
||||
}
|
||||
goto initiate_connection
|
||||
}
|
||||
return &sess, nil
|
||||
}
|
||||
|
||||
func resolveServerPort(port uint64) uint64 {
|
||||
if port == 0 {
|
||||
return defaultServerPort
|
||||
}
|
||||
|
||||
return port
|
||||
}
|
||||
|
|
447
vendor/github.com/denisenkom/go-mssqldb/token.go
generated
vendored
447
vendor/github.com/denisenkom/go-mssqldb/token.go
generated
vendored
|
@ -6,12 +6,11 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"io/ioutil"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//go:generate stringer -type token
|
||||
//go:generate go run golang.org/x/tools/cmd/stringer -type token
|
||||
|
||||
type token byte
|
||||
|
||||
|
@ -29,6 +28,7 @@ const (
|
|||
tokenNbcRow token = 210 // 0xd2
|
||||
tokenEnvChange token = 227 // 0xE3
|
||||
tokenSSPI token = 237 // 0xED
|
||||
tokenFedAuthInfo token = 238 // 0xEE
|
||||
tokenDone token = 253 // 0xFD
|
||||
tokenDoneProc token = 254
|
||||
tokenDoneInProc token = 255
|
||||
|
@ -70,6 +70,11 @@ const (
|
|||
envRouting = 20
|
||||
)
|
||||
|
||||
const (
|
||||
fedAuthInfoSTSURL = 0x01
|
||||
fedAuthInfoSPN = 0x02
|
||||
)
|
||||
|
||||
// COLMETADATA flags
|
||||
// https://msdn.microsoft.com/en-us/library/dd357363.aspx
|
||||
const (
|
||||
|
@ -96,35 +101,18 @@ func (d doneStruct) isError() bool {
|
|||
}
|
||||
|
||||
func (d doneStruct) getError() Error {
|
||||
if len(d.errors) > 0 {
|
||||
return d.errors[len(d.errors)-1]
|
||||
} else {
|
||||
n := len(d.errors)
|
||||
if n == 0 {
|
||||
return Error{Message: "Request failed but didn't provide reason"}
|
||||
}
|
||||
err := d.errors[n-1]
|
||||
err.All = make([]Error, n)
|
||||
copy(err.All, d.errors)
|
||||
return err
|
||||
}
|
||||
|
||||
type doneInProcStruct doneStruct
|
||||
|
||||
var doneFlags2str = map[uint16]string{
|
||||
doneFinal: "final",
|
||||
doneMore: "more",
|
||||
doneError: "error",
|
||||
doneInxact: "inxact",
|
||||
doneCount: "count",
|
||||
doneAttn: "attn",
|
||||
doneSrvError: "srverror",
|
||||
}
|
||||
|
||||
func doneFlags2Str(flags uint16) string {
|
||||
strs := make([]string, 0, len(doneFlags2str))
|
||||
for flag, tag := range doneFlags2str {
|
||||
if flags&flag != 0 {
|
||||
strs = append(strs, tag)
|
||||
}
|
||||
}
|
||||
return strings.Join(strs, "|")
|
||||
}
|
||||
|
||||
// ENVCHANGE stream
|
||||
// http://msdn.microsoft.com/en-us/library/dd303449.aspx
|
||||
func processEnvChg(sess *tdsSession) {
|
||||
|
@ -380,9 +368,8 @@ func processEnvChg(sess *tdsSession) {
|
|||
default:
|
||||
// ignore rest of records because we don't know how to skip those
|
||||
sess.log.Printf("WARN: Unknown ENVCHANGE record detected with type id = %d\n", envtype)
|
||||
break
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -425,6 +412,78 @@ func parseSSPIMsg(r *tdsBuffer) sspiMsg {
|
|||
return sspiMsg(buf)
|
||||
}
|
||||
|
||||
type fedAuthInfoStruct struct {
|
||||
STSURL string
|
||||
ServerSPN string
|
||||
}
|
||||
|
||||
type fedAuthInfoOpt struct {
|
||||
fedAuthInfoID byte
|
||||
dataLength, dataOffset uint32
|
||||
}
|
||||
|
||||
func parseFedAuthInfo(r *tdsBuffer) fedAuthInfoStruct {
|
||||
size := r.uint32()
|
||||
|
||||
var STSURL, SPN string
|
||||
var err error
|
||||
|
||||
// Each fedAuthInfoOpt is one byte to indicate the info ID,
|
||||
// then a four byte offset and a four byte length.
|
||||
count := r.uint32()
|
||||
offset := uint32(4)
|
||||
opts := make([]fedAuthInfoOpt, count)
|
||||
|
||||
for i := uint32(0); i < count; i++ {
|
||||
fedAuthInfoID := r.byte()
|
||||
dataLength := r.uint32()
|
||||
dataOffset := r.uint32()
|
||||
offset += 1 + 4 + 4
|
||||
|
||||
opts[i] = fedAuthInfoOpt{
|
||||
fedAuthInfoID: fedAuthInfoID,
|
||||
dataLength: dataLength,
|
||||
dataOffset: dataOffset,
|
||||
}
|
||||
}
|
||||
|
||||
data := make([]byte, size-offset)
|
||||
r.ReadFull(data)
|
||||
|
||||
for i := uint32(0); i < count; i++ {
|
||||
if opts[i].dataOffset < offset {
|
||||
badStreamPanicf("Fed auth info opt stated data offset %d is before data begins in packet at %d",
|
||||
opts[i].dataOffset, offset)
|
||||
// returns via panic
|
||||
}
|
||||
|
||||
if opts[i].dataOffset+opts[i].dataLength > size {
|
||||
badStreamPanicf("Fed auth info opt stated data length %d added to stated offset exceeds size of packet %d",
|
||||
opts[i].dataOffset+opts[i].dataLength, size)
|
||||
// returns via panic
|
||||
}
|
||||
|
||||
optData := data[opts[i].dataOffset-offset : opts[i].dataOffset-offset+opts[i].dataLength]
|
||||
switch opts[i].fedAuthInfoID {
|
||||
case fedAuthInfoSTSURL:
|
||||
STSURL, err = ucs22str(optData)
|
||||
case fedAuthInfoSPN:
|
||||
SPN, err = ucs22str(optData)
|
||||
default:
|
||||
err = fmt.Errorf("unexpected fed auth info opt ID %d", int(opts[i].fedAuthInfoID))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
badStreamPanic(err)
|
||||
}
|
||||
}
|
||||
|
||||
return fedAuthInfoStruct{
|
||||
STSURL: STSURL,
|
||||
ServerSPN: SPN,
|
||||
}
|
||||
}
|
||||
|
||||
type loginAckStruct struct {
|
||||
Interface uint8
|
||||
TDSVersion uint32
|
||||
|
@ -449,19 +508,43 @@ func parseLoginAck(r *tdsBuffer) loginAckStruct {
|
|||
}
|
||||
|
||||
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/2eb82f8e-11f0-46dc-b42d-27302fa4701a
|
||||
func parseFeatureExtAck(r *tdsBuffer) {
|
||||
// at most 1 featureAck per feature in featureExt
|
||||
// go-mssqldb will add at most 1 feature, the spec defines 7 different features
|
||||
for i := 0; i < 8; i++ {
|
||||
featureID := r.byte() // FeatureID
|
||||
if featureID == 0xff {
|
||||
return
|
||||
type fedAuthAckStruct struct {
|
||||
Nonce []byte
|
||||
Signature []byte
|
||||
}
|
||||
|
||||
func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} {
|
||||
ack := map[byte]interface{}{}
|
||||
|
||||
for feature := r.byte(); feature != featExtTERMINATOR; feature = r.byte() {
|
||||
length := r.uint32()
|
||||
|
||||
switch feature {
|
||||
case featExtFEDAUTH:
|
||||
// In theory we need to know the federated authentication library to
|
||||
// know how to parse, but the alternatives provide compatible structures.
|
||||
fedAuthAck := fedAuthAckStruct{}
|
||||
if length >= 32 {
|
||||
fedAuthAck.Nonce = make([]byte, 32)
|
||||
r.ReadFull(fedAuthAck.Nonce)
|
||||
length -= 32
|
||||
}
|
||||
if length >= 32 {
|
||||
fedAuthAck.Signature = make([]byte, 32)
|
||||
r.ReadFull(fedAuthAck.Signature)
|
||||
length -= 32
|
||||
}
|
||||
ack[feature] = fedAuthAck
|
||||
|
||||
}
|
||||
|
||||
// Skip unprocessed bytes
|
||||
if length > 0 {
|
||||
io.CopyN(ioutil.Discard, r, int64(length))
|
||||
}
|
||||
size := r.uint32() // FeatureAckDataLen
|
||||
d := make([]byte, size)
|
||||
r.ReadFull(d)
|
||||
}
|
||||
panic("parsed more than 7 featureAck's, protocol implementation error?")
|
||||
|
||||
return ack
|
||||
}
|
||||
|
||||
// http://msdn.microsoft.com/en-us/library/dd357363.aspx
|
||||
|
@ -555,7 +638,7 @@ func parseReturnValue(r *tdsBuffer) (nv namedValue) {
|
|||
return
|
||||
}
|
||||
|
||||
func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) {
|
||||
func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs outputs) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
if sess.logFlags&logErrors != 0 {
|
||||
|
@ -579,7 +662,7 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin
|
|||
}
|
||||
var columns []columnStruct
|
||||
errs := make([]Error, 0, 5)
|
||||
for {
|
||||
for tokens := 0; ; tokens += 1 {
|
||||
token := token(sess.buf.byte())
|
||||
if sess.logFlags&logDebug != 0 {
|
||||
sess.log.Printf("got token %v", token)
|
||||
|
@ -588,6 +671,9 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin
|
|||
case tokenSSPI:
|
||||
ch <- parseSSPIMsg(sess.buf)
|
||||
return
|
||||
case tokenFedAuthInfo:
|
||||
ch <- parseFedAuthInfo(sess.buf)
|
||||
return
|
||||
case tokenReturnStatus:
|
||||
returnStatus := parseReturnStatus(sess.buf)
|
||||
ch <- returnStatus
|
||||
|
@ -595,7 +681,8 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin
|
|||
loginAck := parseLoginAck(sess.buf)
|
||||
ch <- loginAck
|
||||
case tokenFeatureExtAck:
|
||||
parseFeatureExtAck(sess.buf)
|
||||
featureExtAck := parseFeatureExtAck(sess.buf)
|
||||
ch <- featureExtAck
|
||||
case tokenOrder:
|
||||
order := parseOrder(sess.buf)
|
||||
ch <- order
|
||||
|
@ -612,7 +699,7 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin
|
|||
sess.log.Printf("got DONE or DONEPROC status=%d", done.Status)
|
||||
}
|
||||
if done.Status&doneSrvError != 0 {
|
||||
ch <- errors.New("SQL Server had internal error")
|
||||
ch <- ServerError{done.getError()}
|
||||
return
|
||||
}
|
||||
if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 {
|
||||
|
@ -656,7 +743,7 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin
|
|||
nv := parseReturnValue(sess.buf)
|
||||
if len(nv.Name) > 0 {
|
||||
name := nv.Name[1:] // Remove the leading "@".
|
||||
if ov, has := outs[name]; has {
|
||||
if ov, has := outs.params[name]; has {
|
||||
err = scanIntoOut(name, nv.Value, ov)
|
||||
if err != nil {
|
||||
fmt.Println("scan error", err)
|
||||
|
@ -670,154 +757,144 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin
|
|||
}
|
||||
}
|
||||
|
||||
type parseRespIter byte
|
||||
|
||||
const (
|
||||
parseRespIterContinue parseRespIter = iota // Continue parsing current token.
|
||||
parseRespIterNext // Fetch the next token.
|
||||
parseRespIterDone // Done with parsing the response.
|
||||
)
|
||||
|
||||
type parseRespState byte
|
||||
|
||||
const (
|
||||
parseRespStateNormal parseRespState = iota // Normal response state.
|
||||
parseRespStateCancel // Query is canceled, wait for server to confirm.
|
||||
parseRespStateClosing // Waiting for tokens to come through.
|
||||
)
|
||||
|
||||
type parseResp struct {
|
||||
sess *tdsSession
|
||||
ctxDone <-chan struct{}
|
||||
state parseRespState
|
||||
cancelError error
|
||||
type tokenProcessor struct {
|
||||
tokChan chan tokenStruct
|
||||
ctx context.Context
|
||||
sess *tdsSession
|
||||
outs outputs
|
||||
lastRow []interface{}
|
||||
rowCount int64
|
||||
firstError error
|
||||
// whether to skip sending attention when ctx is done
|
||||
noAttn bool
|
||||
}
|
||||
|
||||
func (ts *parseResp) sendAttention(ch chan tokenStruct) parseRespIter {
|
||||
if err := sendAttention(ts.sess.buf); err != nil {
|
||||
ts.dlogf("failed to send attention signal %v", err)
|
||||
ch <- err
|
||||
return parseRespIterDone
|
||||
}
|
||||
ts.state = parseRespStateCancel
|
||||
return parseRespIterContinue
|
||||
}
|
||||
|
||||
func (ts *parseResp) dlog(msg string) {
|
||||
if ts.sess.logFlags&logDebug != 0 {
|
||||
ts.sess.log.Println(msg)
|
||||
}
|
||||
}
|
||||
func (ts *parseResp) dlogf(f string, v ...interface{}) {
|
||||
if ts.sess.logFlags&logDebug != 0 {
|
||||
ts.sess.log.Printf(f, v...)
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *parseResp) iter(ctx context.Context, ch chan tokenStruct, tokChan chan tokenStruct) parseRespIter {
|
||||
switch ts.state {
|
||||
default:
|
||||
panic("unknown state")
|
||||
case parseRespStateNormal:
|
||||
select {
|
||||
case tok, ok := <-tokChan:
|
||||
if !ok {
|
||||
ts.dlog("response finished")
|
||||
return parseRespIterDone
|
||||
}
|
||||
if err, ok := tok.(net.Error); ok && err.Timeout() {
|
||||
ts.cancelError = err
|
||||
ts.dlog("got timeout error, sending attention signal to server")
|
||||
return ts.sendAttention(ch)
|
||||
}
|
||||
// Pass the token along.
|
||||
ch <- tok
|
||||
return parseRespIterContinue
|
||||
|
||||
case <-ts.ctxDone:
|
||||
ts.ctxDone = nil
|
||||
ts.dlog("got cancel message, sending attention signal to server")
|
||||
return ts.sendAttention(ch)
|
||||
}
|
||||
case parseRespStateCancel: // Read all responses until a DONE or error is received.Auth
|
||||
select {
|
||||
case tok, ok := <-tokChan:
|
||||
if !ok {
|
||||
ts.dlog("response finished but waiting for attention ack")
|
||||
return parseRespIterNext
|
||||
}
|
||||
switch tok := tok.(type) {
|
||||
default:
|
||||
// Ignore all other tokens while waiting.
|
||||
// The TDS spec says other tokens may arrive after an attention
|
||||
// signal is sent. Ignore these tokens and continue looking for
|
||||
// a DONE with attention confirm mark.
|
||||
case doneStruct:
|
||||
if tok.Status&doneAttn != 0 {
|
||||
ts.dlog("got cancellation confirmation from server")
|
||||
if ts.cancelError != nil {
|
||||
ch <- ts.cancelError
|
||||
ts.cancelError = nil
|
||||
} else {
|
||||
ch <- ctx.Err()
|
||||
}
|
||||
return parseRespIterDone
|
||||
}
|
||||
|
||||
// If an error happens during cancel, pass it along and just stop.
|
||||
// We are uncertain to receive more tokens.
|
||||
case error:
|
||||
ch <- tok
|
||||
ts.state = parseRespStateClosing
|
||||
}
|
||||
return parseRespIterContinue
|
||||
case <-ts.ctxDone:
|
||||
ts.ctxDone = nil
|
||||
ts.state = parseRespStateClosing
|
||||
return parseRespIterContinue
|
||||
}
|
||||
case parseRespStateClosing: // Wait for current token chan to close.
|
||||
if _, ok := <-tokChan; !ok {
|
||||
ts.dlog("response finished")
|
||||
return parseRespIterDone
|
||||
}
|
||||
return parseRespIterContinue
|
||||
}
|
||||
}
|
||||
|
||||
func processResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) {
|
||||
ts := &parseResp{
|
||||
func startReading(sess *tdsSession, ctx context.Context, outs outputs) *tokenProcessor {
|
||||
tokChan := make(chan tokenStruct, 5)
|
||||
go processSingleResponse(sess, tokChan, outs)
|
||||
return &tokenProcessor{
|
||||
tokChan: tokChan,
|
||||
ctx: ctx,
|
||||
sess: sess,
|
||||
ctxDone: ctx.Done(),
|
||||
outs: outs,
|
||||
}
|
||||
defer func() {
|
||||
// Ensure any remaining error is piped through
|
||||
// or the query may look like it executed when it actually failed.
|
||||
if ts.cancelError != nil {
|
||||
ch <- ts.cancelError
|
||||
ts.cancelError = nil
|
||||
}
|
||||
close(ch)
|
||||
}()
|
||||
}
|
||||
|
||||
// Loop over multiple responses.
|
||||
func (t *tokenProcessor) iterateResponse() error {
|
||||
for {
|
||||
ts.dlog("initiating response reading")
|
||||
|
||||
tokChan := make(chan tokenStruct)
|
||||
go processSingleResponse(sess, tokChan, outs)
|
||||
|
||||
// Loop over multiple tokens in response.
|
||||
tokensLoop:
|
||||
for {
|
||||
switch ts.iter(ctx, ch, tokChan) {
|
||||
case parseRespIterContinue:
|
||||
// Nothing, continue to next token.
|
||||
case parseRespIterNext:
|
||||
break tokensLoop
|
||||
case parseRespIterDone:
|
||||
return
|
||||
tok, err := t.nextToken()
|
||||
if err == nil {
|
||||
if tok == nil {
|
||||
return t.firstError
|
||||
} else {
|
||||
switch token := tok.(type) {
|
||||
case []columnStruct:
|
||||
t.sess.columns = token
|
||||
case []interface{}:
|
||||
t.lastRow = token
|
||||
case doneInProcStruct:
|
||||
if token.Status&doneCount != 0 {
|
||||
t.rowCount += int64(token.RowCount)
|
||||
}
|
||||
case doneStruct:
|
||||
if token.Status&doneCount != 0 {
|
||||
t.rowCount += int64(token.RowCount)
|
||||
}
|
||||
if token.isError() && t.firstError == nil {
|
||||
t.firstError = token.getError()
|
||||
}
|
||||
case ReturnStatus:
|
||||
if t.outs.returnStatus != nil {
|
||||
*t.outs.returnStatus = token
|
||||
}
|
||||
/*case error:
|
||||
if resultError == nil {
|
||||
resultError = token
|
||||
}*/
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t tokenProcessor) nextToken() (tokenStruct, error) {
|
||||
// we do this separate non-blocking check on token channel to
|
||||
// prioritize it over cancellation channel
|
||||
select {
|
||||
case tok, more := <-t.tokChan:
|
||||
err, more := tok.(error)
|
||||
if more {
|
||||
// this is an error and not a token
|
||||
return nil, err
|
||||
} else {
|
||||
return tok, nil
|
||||
}
|
||||
default:
|
||||
// there are no tokens on the channel, will need to wait
|
||||
}
|
||||
|
||||
select {
|
||||
case tok, more := <-t.tokChan:
|
||||
if more {
|
||||
err, ok := tok.(error)
|
||||
if ok {
|
||||
// this is an error and not a token
|
||||
return nil, err
|
||||
} else {
|
||||
return tok, nil
|
||||
}
|
||||
} else {
|
||||
// completed reading response
|
||||
return nil, nil
|
||||
}
|
||||
case <-t.ctx.Done():
|
||||
if t.noAttn {
|
||||
return nil, t.ctx.Err()
|
||||
}
|
||||
if err := sendAttention(t.sess.buf); err != nil {
|
||||
// unable to send attention, current connection is bad
|
||||
// notify caller and close channel
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// now the server should send cancellation confirmation
|
||||
// it is possible that we already received full response
|
||||
// just before we sent cancellation request
|
||||
// in this case current response would not contain confirmation
|
||||
// and we would need to read one more response
|
||||
|
||||
// first lets finish reading current response and look
|
||||
// for confirmation in it
|
||||
if readCancelConfirmation(t.tokChan) {
|
||||
// we got confirmation in current response
|
||||
return nil, t.ctx.Err()
|
||||
}
|
||||
// we did not get cancellation confirmation in the current response
|
||||
// read one more response, it must be there
|
||||
t.tokChan = make(chan tokenStruct, 5)
|
||||
go processSingleResponse(t.sess, t.tokChan, t.outs)
|
||||
if readCancelConfirmation(t.tokChan) {
|
||||
return nil, t.ctx.Err()
|
||||
}
|
||||
// we did not get cancellation confirmation, something is not
|
||||
// right, this connection is not usable anymore
|
||||
return nil, errors.New("did not get cancellation confirmation from the server")
|
||||
}
|
||||
}
|
||||
|
||||
func readCancelConfirmation(tokChan chan tokenStruct) bool {
|
||||
for tok := range tokChan {
|
||||
switch tok := tok.(type) {
|
||||
default:
|
||||
// just skip token
|
||||
case doneStruct:
|
||||
if tok.Status&doneAttn != 0 {
|
||||
// got cancellation confirmation, exit
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
44
vendor/github.com/denisenkom/go-mssqldb/token_string.go
generated
vendored
44
vendor/github.com/denisenkom/go-mssqldb/token_string.go
generated
vendored
|
@ -1,29 +1,24 @@
|
|||
// Code generated by "stringer -type token"; DO NOT EDIT
|
||||
// Code generated by "stringer -type token"; DO NOT EDIT.
|
||||
|
||||
package mssql
|
||||
|
||||
import "fmt"
|
||||
import "strconv"
|
||||
|
||||
const (
|
||||
_token_name_0 = "tokenReturnStatus"
|
||||
_token_name_1 = "tokenColMetadata"
|
||||
_token_name_2 = "tokenOrdertokenErrortokenInfo"
|
||||
_token_name_3 = "tokenLoginAck"
|
||||
_token_name_4 = "tokenRowtokenNbcRow"
|
||||
_token_name_5 = "tokenEnvChange"
|
||||
_token_name_6 = "tokenSSPI"
|
||||
_token_name_7 = "tokenDonetokenDoneProctokenDoneInProc"
|
||||
_token_name_2 = "tokenOrdertokenErrortokenInfotokenReturnValuetokenLoginAcktokenFeatureExtAck"
|
||||
_token_name_3 = "tokenRowtokenNbcRow"
|
||||
_token_name_4 = "tokenEnvChange"
|
||||
_token_name_5 = "tokenSSPItokenFedAuthInfo"
|
||||
_token_name_6 = "tokenDonetokenDoneProctokenDoneInProc"
|
||||
)
|
||||
|
||||
var (
|
||||
_token_index_0 = [...]uint8{0, 17}
|
||||
_token_index_1 = [...]uint8{0, 16}
|
||||
_token_index_2 = [...]uint8{0, 10, 20, 29}
|
||||
_token_index_3 = [...]uint8{0, 13}
|
||||
_token_index_4 = [...]uint8{0, 8, 19}
|
||||
_token_index_5 = [...]uint8{0, 14}
|
||||
_token_index_6 = [...]uint8{0, 9}
|
||||
_token_index_7 = [...]uint8{0, 9, 22, 37}
|
||||
_token_index_2 = [...]uint8{0, 10, 20, 29, 45, 58, 76}
|
||||
_token_index_3 = [...]uint8{0, 8, 19}
|
||||
_token_index_5 = [...]uint8{0, 9, 25}
|
||||
_token_index_6 = [...]uint8{0, 9, 22, 37}
|
||||
)
|
||||
|
||||
func (i token) String() string {
|
||||
|
@ -32,22 +27,21 @@ func (i token) String() string {
|
|||
return _token_name_0
|
||||
case i == 129:
|
||||
return _token_name_1
|
||||
case 169 <= i && i <= 171:
|
||||
case 169 <= i && i <= 174:
|
||||
i -= 169
|
||||
return _token_name_2[_token_index_2[i]:_token_index_2[i+1]]
|
||||
case i == 173:
|
||||
return _token_name_3
|
||||
case 209 <= i && i <= 210:
|
||||
i -= 209
|
||||
return _token_name_4[_token_index_4[i]:_token_index_4[i+1]]
|
||||
return _token_name_3[_token_index_3[i]:_token_index_3[i+1]]
|
||||
case i == 227:
|
||||
return _token_name_5
|
||||
case i == 237:
|
||||
return _token_name_6
|
||||
return _token_name_4
|
||||
case 237 <= i && i <= 238:
|
||||
i -= 237
|
||||
return _token_name_5[_token_index_5[i]:_token_index_5[i+1]]
|
||||
case 253 <= i && i <= 255:
|
||||
i -= 253
|
||||
return _token_name_7[_token_index_7[i]:_token_index_7[i+1]]
|
||||
return _token_name_6[_token_index_6[i]:_token_index_6[i+1]]
|
||||
default:
|
||||
return fmt.Sprintf("token(%d)", i)
|
||||
return "token(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
}
|
||||
|
|
10
vendor/github.com/denisenkom/go-mssqldb/tran.go
generated
vendored
10
vendor/github.com/denisenkom/go-mssqldb/tran.go
generated
vendored
|
@ -21,11 +21,11 @@ type isoLevel uint8
|
|||
|
||||
const (
|
||||
isolationUseCurrent isoLevel = 0
|
||||
isolationReadUncommited = 1
|
||||
isolationReadCommited = 2
|
||||
isolationRepeatableRead = 3
|
||||
isolationSerializable = 4
|
||||
isolationSnapshot = 5
|
||||
isolationReadUncommited isoLevel = 1
|
||||
isolationReadCommited isoLevel = 2
|
||||
isolationRepeatableRead isoLevel = 3
|
||||
isolationSerializable isoLevel = 4
|
||||
isolationSnapshot isoLevel = 5
|
||||
)
|
||||
|
||||
func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation isoLevel, name string, resetSession bool) (err error) {
|
||||
|
|
73
vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go
generated
vendored
73
vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go
generated
vendored
|
@ -4,6 +4,7 @@ package mssql
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -97,6 +98,9 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
|
|||
for columnStrIdx, fieldIdx := range tvpFieldIndexes {
|
||||
field := refStr.Field(fieldIdx)
|
||||
tvpVal := field.Interface()
|
||||
if tvp.verifyStandardTypeOnNull(buf, tvpVal) {
|
||||
continue
|
||||
}
|
||||
valOf := reflect.ValueOf(tvpVal)
|
||||
elemKind := field.Kind()
|
||||
if elemKind == reflect.Ptr && valOf.IsNil() {
|
||||
|
@ -155,7 +159,7 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {
|
|||
defaultValues = append(defaultValues, v.Interface())
|
||||
continue
|
||||
}
|
||||
defaultValues = append(defaultValues, reflect.Zero(field.Type).Interface())
|
||||
defaultValues = append(defaultValues, tvp.createZeroType(reflect.Zero(field.Type).Interface()))
|
||||
}
|
||||
|
||||
if columnCount-len(tvpFieldIndexes) == columnCount {
|
||||
|
@ -209,19 +213,23 @@ func getSchemeAndName(tvpName string) (string, string, error) {
|
|||
}
|
||||
splitVal := strings.Split(tvpName, ".")
|
||||
if len(splitVal) > 2 {
|
||||
return "", "", errors.New("wrong tvp name")
|
||||
return "", "", ErrorObjectName
|
||||
}
|
||||
const (
|
||||
openSquareBrackets = "["
|
||||
closeSquareBrackets = "]"
|
||||
)
|
||||
if len(splitVal) == 2 {
|
||||
res := make([]string, 2)
|
||||
for key, value := range splitVal {
|
||||
tmp := strings.Replace(value, "[", "", -1)
|
||||
tmp = strings.Replace(tmp, "]", "", -1)
|
||||
tmp := strings.Replace(value, openSquareBrackets, "", -1)
|
||||
tmp = strings.Replace(tmp, closeSquareBrackets, "", -1)
|
||||
res[key] = tmp
|
||||
}
|
||||
return res[0], res[1], nil
|
||||
}
|
||||
tmp := strings.Replace(splitVal[0], "[", "", -1)
|
||||
tmp = strings.Replace(tmp, "]", "", -1)
|
||||
tmp := strings.Replace(splitVal[0], openSquareBrackets, "", -1)
|
||||
tmp = strings.Replace(tmp, closeSquareBrackets, "", -1)
|
||||
|
||||
return "", tmp, nil
|
||||
}
|
||||
|
@ -229,3 +237,56 @@ func getSchemeAndName(tvpName string) (string, string, error) {
|
|||
func getCountSQLSeparators(str string) int {
|
||||
return strings.Count(str, sqlSeparator)
|
||||
}
|
||||
|
||||
// verify types https://golang.org/pkg/database/sql/
|
||||
func (tvp TVP) createZeroType(fieldVal interface{}) interface{} {
|
||||
const (
|
||||
defaultBool = false
|
||||
defaultFloat64 = float64(0)
|
||||
defaultInt64 = int64(0)
|
||||
defaultString = ""
|
||||
)
|
||||
|
||||
switch fieldVal.(type) {
|
||||
case sql.NullBool:
|
||||
return defaultBool
|
||||
case sql.NullFloat64:
|
||||
return defaultFloat64
|
||||
case sql.NullInt64:
|
||||
return defaultInt64
|
||||
case sql.NullString:
|
||||
return defaultString
|
||||
}
|
||||
return fieldVal
|
||||
}
|
||||
|
||||
// verify types https://golang.org/pkg/database/sql/
|
||||
func (tvp TVP) verifyStandardTypeOnNull(buf *bytes.Buffer, tvpVal interface{}) bool {
|
||||
const (
|
||||
defaultNull = uint8(0)
|
||||
)
|
||||
|
||||
switch val := tvpVal.(type) {
|
||||
case sql.NullBool:
|
||||
if !val.Valid {
|
||||
binary.Write(buf, binary.LittleEndian, defaultNull)
|
||||
return true
|
||||
}
|
||||
case sql.NullFloat64:
|
||||
if !val.Valid {
|
||||
binary.Write(buf, binary.LittleEndian, defaultNull)
|
||||
return true
|
||||
}
|
||||
case sql.NullInt64:
|
||||
if !val.Valid {
|
||||
binary.Write(buf, binary.LittleEndian, defaultNull)
|
||||
return true
|
||||
}
|
||||
case sql.NullString:
|
||||
if !val.Valid {
|
||||
binary.Write(buf, binary.LittleEndian, uint64(_PLP_NULL))
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
12
vendor/github.com/denisenkom/go-mssqldb/types.go
generated
vendored
12
vendor/github.com/denisenkom/go-mssqldb/types.go
generated
vendored
|
@ -665,7 +665,7 @@ func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} {
|
|||
default:
|
||||
buf = bytes.NewBuffer(make([]byte, 0, size))
|
||||
}
|
||||
for true {
|
||||
for {
|
||||
chunksize := r.uint32()
|
||||
if chunksize == 0 {
|
||||
break
|
||||
|
@ -690,6 +690,10 @@ func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} {
|
|||
}
|
||||
|
||||
func writePLPType(w io.Writer, ti typeInfo, buf []byte) (err error) {
|
||||
if buf == nil {
|
||||
err = binary.Write(w, binary.LittleEndian, uint64(_PLP_NULL))
|
||||
return
|
||||
}
|
||||
if err = binary.Write(w, binary.LittleEndian, uint64(_UNKNOWN_PLP_LEN)); err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -807,7 +811,6 @@ func readVarLen(ti *typeInfo, r *tdsBuffer) {
|
|||
default:
|
||||
badStreamPanicf("Invalid type %d", ti.TypeId)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func decodeMoney(buf []byte) []byte {
|
||||
|
@ -834,8 +837,7 @@ func decodeGuid(buf []byte) []byte {
|
|||
}
|
||||
|
||||
func decodeDecimal(prec uint8, scale uint8, buf []byte) []byte {
|
||||
var sign uint8
|
||||
sign = buf[0]
|
||||
sign := buf[0]
|
||||
var dec decimal.Decimal
|
||||
dec.SetPositive(sign != 0)
|
||||
dec.SetPrec(prec)
|
||||
|
@ -1187,7 +1189,7 @@ func makeDecl(ti typeInfo) string {
|
|||
return fmt.Sprintf("char(%d)", ti.Size)
|
||||
case typeBigVarChar, typeVarChar:
|
||||
if ti.Size > 8000 || ti.Size == 0 {
|
||||
return fmt.Sprintf("varchar(max)")
|
||||
return "varchar(max)"
|
||||
} else {
|
||||
return fmt.Sprintf("varchar(%d)", ti.Size)
|
||||
}
|
||||
|
|
43
vendor/github.com/go-asn1-ber/asn1-ber/.travis.yml
generated
vendored
Normal file
43
vendor/github.com/go-asn1-ber/asn1-ber/.travis.yml
generated
vendored
Normal file
|
@ -0,0 +1,43 @@
|
|||
language: go
|
||||
|
||||
go:
|
||||
- 1.2.x
|
||||
- 1.6.x
|
||||
- 1.9.x
|
||||
- 1.10.x
|
||||
- 1.11.x
|
||||
- 1.12.x
|
||||
- 1.14.x
|
||||
- tip
|
||||
|
||||
os:
|
||||
- linux
|
||||
|
||||
arch:
|
||||
- amd64
|
||||
- ppc64le
|
||||
|
||||
dist: xenial
|
||||
|
||||
env:
|
||||
- GOARCH=amd64
|
||||
|
||||
jobs:
|
||||
include:
|
||||
- os: windows
|
||||
go: 1.14.x
|
||||
- os: osx
|
||||
go: 1.14.x
|
||||
- os: linux
|
||||
go: 1.14.x
|
||||
arch: arm64
|
||||
- os: linux
|
||||
go: 1.14.x
|
||||
env:
|
||||
- GOARCH=386
|
||||
|
||||
script:
|
||||
- go test -v -cover ./... || go test -v ./...
|
||||
matrix:
|
||||
allowfailures:
|
||||
go: 1.2.x
|
0
vendor/gopkg.in/asn1-ber.v1/LICENSE → vendor/github.com/go-asn1-ber/asn1-ber/LICENSE
generated
vendored
0
vendor/gopkg.in/asn1-ber.v1/LICENSE → vendor/github.com/go-asn1-ber/asn1-ber/LICENSE
generated
vendored
208
vendor/gopkg.in/asn1-ber.v1/ber.go → vendor/github.com/go-asn1-ber/asn1-ber/ber.go
generated
vendored
208
vendor/gopkg.in/asn1-ber.v1/ber.go → vendor/github.com/go-asn1-ber/asn1-ber/ber.go
generated
vendored
|
@ -8,6 +8,8 @@ import (
|
|||
"math"
|
||||
"os"
|
||||
"reflect"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// MaxPacketLengthBytes specifies the maximum allowed packet size when calling ReadPacket or DecodePacket. Set to 0 for
|
||||
|
@ -143,42 +145,46 @@ var TypeMap = map[Type]string{
|
|||
TypeConstructed: "Constructed",
|
||||
}
|
||||
|
||||
var Debug bool = false
|
||||
var Debug = false
|
||||
|
||||
func PrintBytes(out io.Writer, buf []byte, indent string) {
|
||||
data_lines := make([]string, (len(buf)/30)+1)
|
||||
num_lines := make([]string, (len(buf)/30)+1)
|
||||
dataLines := make([]string, (len(buf)/30)+1)
|
||||
numLines := make([]string, (len(buf)/30)+1)
|
||||
|
||||
for i, b := range buf {
|
||||
data_lines[i/30] += fmt.Sprintf("%02x ", b)
|
||||
num_lines[i/30] += fmt.Sprintf("%02d ", (i+1)%100)
|
||||
dataLines[i/30] += fmt.Sprintf("%02x ", b)
|
||||
numLines[i/30] += fmt.Sprintf("%02d ", (i+1)%100)
|
||||
}
|
||||
|
||||
for i := 0; i < len(data_lines); i++ {
|
||||
out.Write([]byte(indent + data_lines[i] + "\n"))
|
||||
out.Write([]byte(indent + num_lines[i] + "\n\n"))
|
||||
for i := 0; i < len(dataLines); i++ {
|
||||
_, _ = out.Write([]byte(indent + dataLines[i] + "\n"))
|
||||
_, _ = out.Write([]byte(indent + numLines[i] + "\n\n"))
|
||||
}
|
||||
}
|
||||
|
||||
func WritePacket(out io.Writer, p *Packet) {
|
||||
printPacket(out, p, 0, false)
|
||||
}
|
||||
|
||||
func PrintPacket(p *Packet) {
|
||||
printPacket(os.Stdout, p, 0, false)
|
||||
}
|
||||
|
||||
func printPacket(out io.Writer, p *Packet, indent int, printBytes bool) {
|
||||
indent_str := ""
|
||||
indentStr := ""
|
||||
|
||||
for len(indent_str) != indent {
|
||||
indent_str += " "
|
||||
for len(indentStr) != indent {
|
||||
indentStr += " "
|
||||
}
|
||||
|
||||
class_str := ClassMap[p.ClassType]
|
||||
classStr := ClassMap[p.ClassType]
|
||||
|
||||
tagtype_str := TypeMap[p.TagType]
|
||||
tagTypeStr := TypeMap[p.TagType]
|
||||
|
||||
tag_str := fmt.Sprintf("0x%02X", p.Tag)
|
||||
tagStr := fmt.Sprintf("0x%02X", p.Tag)
|
||||
|
||||
if p.ClassType == ClassUniversal {
|
||||
tag_str = tagMap[p.Tag]
|
||||
tagStr = tagMap[p.Tag]
|
||||
}
|
||||
|
||||
value := fmt.Sprint(p.Value)
|
||||
|
@ -188,10 +194,10 @@ func printPacket(out io.Writer, p *Packet, indent int, printBytes bool) {
|
|||
description = p.Description + ": "
|
||||
}
|
||||
|
||||
fmt.Fprintf(out, "%s%s(%s, %s, %s) Len=%d %q\n", indent_str, description, class_str, tagtype_str, tag_str, p.Data.Len(), value)
|
||||
_, _ = fmt.Fprintf(out, "%s%s(%s, %s, %s) Len=%d %q\n", indentStr, description, classStr, tagTypeStr, tagStr, p.Data.Len(), value)
|
||||
|
||||
if printBytes {
|
||||
PrintBytes(out, p.Bytes(), indent_str)
|
||||
PrintBytes(out, p.Bytes(), indentStr)
|
||||
}
|
||||
|
||||
for _, child := range p.Children {
|
||||
|
@ -199,7 +205,7 @@ func printPacket(out io.Writer, p *Packet, indent int, printBytes bool) {
|
|||
}
|
||||
}
|
||||
|
||||
// ReadPacket reads a single Packet from the reader
|
||||
// ReadPacket reads a single Packet from the reader.
|
||||
func ReadPacket(reader io.Reader) (*Packet, error) {
|
||||
p, _, err := readPacket(reader)
|
||||
if err != nil {
|
||||
|
@ -235,7 +241,7 @@ func encodeInteger(i int64) []byte {
|
|||
|
||||
var j int
|
||||
for ; n > 0; n-- {
|
||||
out[j] = (byte(i >> uint((n-1)*8)))
|
||||
out[j] = byte(i >> uint((n-1)*8))
|
||||
j++
|
||||
}
|
||||
|
||||
|
@ -267,7 +273,7 @@ func DecodePacket(data []byte) *Packet {
|
|||
}
|
||||
|
||||
// DecodePacketErr decodes the given bytes into a single Packet
|
||||
// If a decode error is encountered, nil is returned
|
||||
// If a decode error is encountered, nil is returned.
|
||||
func DecodePacketErr(data []byte) (*Packet, error) {
|
||||
p, _, err := readPacket(bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
|
@ -276,7 +282,7 @@ func DecodePacketErr(data []byte) (*Packet, error) {
|
|||
return p, nil
|
||||
}
|
||||
|
||||
// readPacket reads a single Packet from the reader, returning the number of bytes read
|
||||
// readPacket reads a single Packet from the reader, returning the number of bytes read.
|
||||
func readPacket(reader io.Reader) (*Packet, int, error) {
|
||||
identifier, length, read, err := readHeader(reader)
|
||||
if err != nil {
|
||||
|
@ -338,7 +344,7 @@ func readPacket(reader io.Reader) (*Packet, int, error) {
|
|||
if MaxPacketLengthBytes > 0 && int64(length) > MaxPacketLengthBytes {
|
||||
return nil, read, fmt.Errorf("length %d greater than maximum %d", length, MaxPacketLengthBytes)
|
||||
}
|
||||
content := make([]byte, length, length)
|
||||
content := make([]byte, length)
|
||||
if length > 0 {
|
||||
_, err := io.ReadFull(reader, content)
|
||||
if err != nil {
|
||||
|
@ -373,22 +379,42 @@ func readPacket(reader io.Reader) (*Packet, int, error) {
|
|||
case TagObjectDescriptor:
|
||||
case TagExternal:
|
||||
case TagRealFloat:
|
||||
p.Value, err = ParseReal(content)
|
||||
case TagEnumerated:
|
||||
p.Value, _ = ParseInt64(content)
|
||||
case TagEmbeddedPDV:
|
||||
case TagUTF8String:
|
||||
p.Value = DecodeString(content)
|
||||
val := DecodeString(content)
|
||||
if !utf8.Valid([]byte(val)) {
|
||||
err = errors.New("invalid UTF-8 string")
|
||||
} else {
|
||||
p.Value = val
|
||||
}
|
||||
case TagRelativeOID:
|
||||
case TagSequence:
|
||||
case TagSet:
|
||||
case TagNumericString:
|
||||
case TagPrintableString:
|
||||
p.Value = DecodeString(content)
|
||||
val := DecodeString(content)
|
||||
if err = isPrintableString(val); err == nil {
|
||||
p.Value = val
|
||||
}
|
||||
case TagT61String:
|
||||
case TagVideotexString:
|
||||
case TagIA5String:
|
||||
val := DecodeString(content)
|
||||
for i, c := range val {
|
||||
if c >= 0x7F {
|
||||
err = fmt.Errorf("invalid character for IA5String at pos %d: %c", i, c)
|
||||
break
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
p.Value = val
|
||||
}
|
||||
case TagUTCTime:
|
||||
case TagGeneralizedTime:
|
||||
p.Value, err = ParseGeneralizedTime(content)
|
||||
case TagGraphicString:
|
||||
case TagVisibleString:
|
||||
case TagGeneralString:
|
||||
|
@ -400,7 +426,24 @@ func readPacket(reader io.Reader) (*Packet, int, error) {
|
|||
p.Data.Write(content)
|
||||
}
|
||||
|
||||
return p, read, nil
|
||||
return p, read, err
|
||||
}
|
||||
|
||||
func isPrintableString(val string) error {
|
||||
for i, c := range val {
|
||||
switch {
|
||||
case c >= 'a' && c <= 'z':
|
||||
case c >= 'A' && c <= 'Z':
|
||||
case c >= '0' && c <= '9':
|
||||
default:
|
||||
switch c {
|
||||
case '\'', '(', ')', '+', ',', '-', '.', '=', '/', ':', '?', ' ':
|
||||
default:
|
||||
return fmt.Errorf("invalid character in position %d", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Packet) Bytes() []byte {
|
||||
|
@ -418,61 +461,99 @@ func (p *Packet) AppendChild(child *Packet) {
|
|||
p.Children = append(p.Children, child)
|
||||
}
|
||||
|
||||
func Encode(ClassType Class, TagType Type, Tag Tag, Value interface{}, Description string) *Packet {
|
||||
func Encode(classType Class, tagType Type, tag Tag, value interface{}, description string) *Packet {
|
||||
p := new(Packet)
|
||||
|
||||
p.ClassType = ClassType
|
||||
p.TagType = TagType
|
||||
p.Tag = Tag
|
||||
p.ClassType = classType
|
||||
p.TagType = tagType
|
||||
p.Tag = tag
|
||||
p.Data = new(bytes.Buffer)
|
||||
|
||||
p.Children = make([]*Packet, 0, 2)
|
||||
|
||||
p.Value = Value
|
||||
p.Description = Description
|
||||
p.Value = value
|
||||
p.Description = description
|
||||
|
||||
if Value != nil {
|
||||
v := reflect.ValueOf(Value)
|
||||
if value != nil {
|
||||
v := reflect.ValueOf(value)
|
||||
|
||||
if ClassType == ClassUniversal {
|
||||
switch Tag {
|
||||
if classType == ClassUniversal {
|
||||
switch tag {
|
||||
case TagOctetString:
|
||||
sv, ok := v.Interface().(string)
|
||||
|
||||
if ok {
|
||||
p.Data.Write([]byte(sv))
|
||||
}
|
||||
case TagEnumerated:
|
||||
bv, ok := v.Interface().([]byte)
|
||||
if ok {
|
||||
p.Data.Write(bv)
|
||||
}
|
||||
case TagEmbeddedPDV:
|
||||
bv, ok := v.Interface().([]byte)
|
||||
if ok {
|
||||
p.Data.Write(bv)
|
||||
}
|
||||
}
|
||||
} else if classType == ClassContext {
|
||||
switch tag {
|
||||
case TagEnumerated:
|
||||
bv, ok := v.Interface().([]byte)
|
||||
if ok {
|
||||
p.Data.Write(bv)
|
||||
}
|
||||
case TagEmbeddedPDV:
|
||||
bv, ok := v.Interface().([]byte)
|
||||
if ok {
|
||||
p.Data.Write(bv)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func NewSequence(Description string) *Packet {
|
||||
return Encode(ClassUniversal, TypeConstructed, TagSequence, nil, Description)
|
||||
func NewSequence(description string) *Packet {
|
||||
return Encode(ClassUniversal, TypeConstructed, TagSequence, nil, description)
|
||||
}
|
||||
|
||||
func NewBoolean(ClassType Class, TagType Type, Tag Tag, Value bool, Description string) *Packet {
|
||||
func NewBoolean(classType Class, tagType Type, tag Tag, value bool, description string) *Packet {
|
||||
intValue := int64(0)
|
||||
|
||||
if Value {
|
||||
if value {
|
||||
intValue = 1
|
||||
}
|
||||
|
||||
p := Encode(ClassType, TagType, Tag, nil, Description)
|
||||
p := Encode(classType, tagType, tag, nil, description)
|
||||
|
||||
p.Value = Value
|
||||
p.Value = value
|
||||
p.Data.Write(encodeInteger(intValue))
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func NewInteger(ClassType Class, TagType Type, Tag Tag, Value interface{}, Description string) *Packet {
|
||||
p := Encode(ClassType, TagType, Tag, nil, Description)
|
||||
// NewLDAPBoolean returns a RFC 4511-compliant Boolean packet.
|
||||
func NewLDAPBoolean(classType Class, tagType Type, tag Tag, value bool, description string) *Packet {
|
||||
intValue := int64(0)
|
||||
|
||||
p.Value = Value
|
||||
switch v := Value.(type) {
|
||||
if value {
|
||||
intValue = 255
|
||||
}
|
||||
|
||||
p := Encode(classType, tagType, tag, nil, description)
|
||||
|
||||
p.Value = value
|
||||
p.Data.Write(encodeInteger(intValue))
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func NewInteger(classType Class, tagType Type, tag Tag, value interface{}, description string) *Packet {
|
||||
p := Encode(classType, tagType, tag, nil, description)
|
||||
|
||||
p.Value = value
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
p.Data.Write(encodeInteger(int64(v)))
|
||||
case uint:
|
||||
|
@ -502,11 +583,38 @@ func NewInteger(ClassType Class, TagType Type, Tag Tag, Value interface{}, Descr
|
|||
return p
|
||||
}
|
||||
|
||||
func NewString(ClassType Class, TagType Type, Tag Tag, Value, Description string) *Packet {
|
||||
p := Encode(ClassType, TagType, Tag, nil, Description)
|
||||
func NewString(classType Class, tagType Type, tag Tag, value, description string) *Packet {
|
||||
p := Encode(classType, tagType, tag, nil, description)
|
||||
|
||||
p.Value = Value
|
||||
p.Data.Write([]byte(Value))
|
||||
p.Value = value
|
||||
p.Data.Write([]byte(value))
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func NewGeneralizedTime(classType Class, tagType Type, tag Tag, value time.Time, description string) *Packet {
|
||||
p := Encode(classType, tagType, tag, nil, description)
|
||||
var s string
|
||||
if value.Nanosecond() != 0 {
|
||||
s = value.Format(`20060102150405.000000000Z`)
|
||||
} else {
|
||||
s = value.Format(`20060102150405Z`)
|
||||
}
|
||||
p.Value = s
|
||||
p.Data.Write([]byte(s))
|
||||
return p
|
||||
}
|
||||
|
||||
func NewReal(classType Class, tagType Type, tag Tag, value interface{}, description string) *Packet {
|
||||
p := Encode(classType, tagType, tag, nil, description)
|
||||
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
p.Data.Write(encodeFloat(v))
|
||||
case float32:
|
||||
p.Data.Write(encodeFloat(float64(v)))
|
||||
default:
|
||||
panic(fmt.Sprintf("Invalid type %T, expected float{64|32}", v))
|
||||
}
|
||||
return p
|
||||
}
|
|
@ -6,7 +6,7 @@ func encodeUnsignedInteger(i uint64) []byte {
|
|||
|
||||
var j int
|
||||
for ; n > 0; n-- {
|
||||
out[j] = (byte(i >> uint((n-1)*8)))
|
||||
out[j] = byte(i >> uint((n-1)*8))
|
||||
j++
|
||||
}
|
||||
|
105
vendor/github.com/go-asn1-ber/asn1-ber/generalizedTime.go
generated
vendored
Normal file
105
vendor/github.com/go-asn1-ber/asn1-ber/generalizedTime.go
generated
vendored
Normal file
|
@ -0,0 +1,105 @@
|
|||
package ber
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrInvalidTimeFormat is returned when the generalizedTime string was not correct.
|
||||
var ErrInvalidTimeFormat = errors.New("invalid time format")
|
||||
|
||||
var zeroTime = time.Time{}
|
||||
|
||||
// ParseGeneralizedTime parses a string value and if it conforms to
|
||||
// GeneralizedTime[^0] format, will return a time.Time for that value.
|
||||
//
|
||||
// [^0]: https://www.itu.int/rec/T-REC-X.690-201508-I/en Section 11.7
|
||||
func ParseGeneralizedTime(v []byte) (time.Time, error) {
|
||||
var format string
|
||||
var fract time.Duration
|
||||
|
||||
str := []byte(DecodeString(v))
|
||||
tzIndex := bytes.IndexAny(str, "Z+-")
|
||||
if tzIndex < 0 {
|
||||
return zeroTime, ErrInvalidTimeFormat
|
||||
}
|
||||
|
||||
dot := bytes.IndexAny(str, ".,")
|
||||
switch dot {
|
||||
case -1:
|
||||
switch tzIndex {
|
||||
case 10:
|
||||
format = `2006010215Z`
|
||||
case 12:
|
||||
format = `200601021504Z`
|
||||
case 14:
|
||||
format = `20060102150405Z`
|
||||
default:
|
||||
return zeroTime, ErrInvalidTimeFormat
|
||||
}
|
||||
|
||||
case 10, 12:
|
||||
if tzIndex < dot {
|
||||
return zeroTime, ErrInvalidTimeFormat
|
||||
}
|
||||
// a "," is also allowed, but would not be parsed by time.Parse():
|
||||
str[dot] = '.'
|
||||
|
||||
// If <minute> is omitted, then <fraction> represents a fraction of an
|
||||
// hour; otherwise, if <second> and <leap-second> are omitted, then
|
||||
// <fraction> represents a fraction of a minute; otherwise, <fraction>
|
||||
// represents a fraction of a second.
|
||||
|
||||
// parse as float from dot to timezone
|
||||
f, err := strconv.ParseFloat(string(str[dot:tzIndex]), 64)
|
||||
if err != nil {
|
||||
return zeroTime, fmt.Errorf("failed to parse float: %s", err)
|
||||
}
|
||||
// ...and strip that part
|
||||
str = append(str[:dot], str[tzIndex:]...)
|
||||
tzIndex = dot
|
||||
|
||||
if dot == 10 {
|
||||
fract = time.Duration(int64(f * float64(time.Hour)))
|
||||
format = `2006010215Z`
|
||||
} else {
|
||||
fract = time.Duration(int64(f * float64(time.Minute)))
|
||||
format = `200601021504Z`
|
||||
}
|
||||
|
||||
case 14:
|
||||
if tzIndex < dot {
|
||||
return zeroTime, ErrInvalidTimeFormat
|
||||
}
|
||||
str[dot] = '.'
|
||||
// no need for fractional seconds, time.Parse() handles that
|
||||
format = `20060102150405Z`
|
||||
|
||||
default:
|
||||
return zeroTime, ErrInvalidTimeFormat
|
||||
}
|
||||
|
||||
l := len(str)
|
||||
switch l - tzIndex {
|
||||
case 1:
|
||||
if str[l-1] != 'Z' {
|
||||
return zeroTime, ErrInvalidTimeFormat
|
||||
}
|
||||
case 3:
|
||||
format += `0700`
|
||||
str = append(str, []byte("00")...)
|
||||
case 5:
|
||||
format += `0700`
|
||||
default:
|
||||
return zeroTime, ErrInvalidTimeFormat
|
||||
}
|
||||
|
||||
t, err := time.Parse(format, string(str))
|
||||
if err != nil {
|
||||
return zeroTime, fmt.Errorf("%s: %s", ErrInvalidTimeFormat, err)
|
||||
}
|
||||
return t.Add(fract), nil
|
||||
}
|
23
vendor/gopkg.in/asn1-ber.v1/header.go → vendor/github.com/go-asn1-ber/asn1-ber/header.go
generated
vendored
23
vendor/gopkg.in/asn1-ber.v1/header.go → vendor/github.com/go-asn1-ber/asn1-ber/header.go
generated
vendored
|
@ -7,19 +7,22 @@ import (
|
|||
)
|
||||
|
||||
func readHeader(reader io.Reader) (identifier Identifier, length int, read int, err error) {
|
||||
if i, c, err := readIdentifier(reader); err != nil {
|
||||
return Identifier{}, 0, read, err
|
||||
} else {
|
||||
identifier = i
|
||||
read += c
|
||||
}
|
||||
var (
|
||||
c, l int
|
||||
i Identifier
|
||||
)
|
||||
|
||||
if l, c, err := readLength(reader); err != nil {
|
||||
if i, c, err = readIdentifier(reader); err != nil {
|
||||
return Identifier{}, 0, read, err
|
||||
} else {
|
||||
length = l
|
||||
read += c
|
||||
}
|
||||
identifier = i
|
||||
read += c
|
||||
|
||||
if l, c, err = readLength(reader); err != nil {
|
||||
return Identifier{}, 0, read, err
|
||||
}
|
||||
length = l
|
||||
read += c
|
||||
|
||||
// Validate length type with identifier (x.600, 8.1.3.2.a)
|
||||
if length == LengthIndefinite && identifier.TagType == TypePrimitive {
|
12
vendor/gopkg.in/asn1-ber.v1/length.go → vendor/github.com/go-asn1-ber/asn1-ber/length.go
generated
vendored
12
vendor/gopkg.in/asn1-ber.v1/length.go → vendor/github.com/go-asn1-ber/asn1-ber/length.go
generated
vendored
|
@ -71,11 +71,11 @@ func readLength(reader io.Reader) (length int, read int, err error) {
|
|||
}
|
||||
|
||||
func encodeLength(length int) []byte {
|
||||
length_bytes := encodeUnsignedInteger(uint64(length))
|
||||
if length > 127 || len(length_bytes) > 1 {
|
||||
longFormBytes := []byte{(LengthLongFormBitmask | byte(len(length_bytes)))}
|
||||
longFormBytes = append(longFormBytes, length_bytes...)
|
||||
length_bytes = longFormBytes
|
||||
lengthBytes := encodeUnsignedInteger(uint64(length))
|
||||
if length > 127 || len(lengthBytes) > 1 {
|
||||
longFormBytes := []byte{LengthLongFormBitmask | byte(len(lengthBytes))}
|
||||
longFormBytes = append(longFormBytes, lengthBytes...)
|
||||
lengthBytes = longFormBytes
|
||||
}
|
||||
return length_bytes
|
||||
return lengthBytes
|
||||
}
|
157
vendor/github.com/go-asn1-ber/asn1-ber/real.go
generated
vendored
Normal file
157
vendor/github.com/go-asn1-ber/asn1-ber/real.go
generated
vendored
Normal file
|
@ -0,0 +1,157 @@
|
|||
package ber
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func encodeFloat(v float64) []byte {
|
||||
switch {
|
||||
case math.IsInf(v, 1):
|
||||
return []byte{0x40}
|
||||
case math.IsInf(v, -1):
|
||||
return []byte{0x41}
|
||||
case math.IsNaN(v):
|
||||
return []byte{0x42}
|
||||
case v == 0.0:
|
||||
if math.Signbit(v) {
|
||||
return []byte{0x43}
|
||||
}
|
||||
return []byte{}
|
||||
default:
|
||||
// we take the easy part ;-)
|
||||
value := []byte(strconv.FormatFloat(v, 'G', -1, 64))
|
||||
var ret []byte
|
||||
if bytes.Contains(value, []byte{'E'}) {
|
||||
ret = []byte{0x03}
|
||||
} else {
|
||||
ret = []byte{0x02}
|
||||
}
|
||||
ret = append(ret, value...)
|
||||
return ret
|
||||
}
|
||||
}
|
||||
|
||||
func ParseReal(v []byte) (val float64, err error) {
|
||||
if len(v) == 0 {
|
||||
return 0.0, nil
|
||||
}
|
||||
switch {
|
||||
case v[0]&0x80 == 0x80:
|
||||
val, err = parseBinaryFloat(v)
|
||||
case v[0]&0xC0 == 0x40:
|
||||
val, err = parseSpecialFloat(v)
|
||||
case v[0]&0xC0 == 0x0:
|
||||
val, err = parseDecimalFloat(v)
|
||||
default:
|
||||
return 0.0, fmt.Errorf("invalid info block")
|
||||
}
|
||||
if err != nil {
|
||||
return 0.0, err
|
||||
}
|
||||
|
||||
if val == 0.0 && !math.Signbit(val) {
|
||||
return 0.0, errors.New("REAL value +0 must be encoded with zero-length value block")
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func parseBinaryFloat(v []byte) (float64, error) {
|
||||
var info byte
|
||||
var buf []byte
|
||||
|
||||
info, v = v[0], v[1:]
|
||||
|
||||
var base int
|
||||
switch info & 0x30 {
|
||||
case 0x00:
|
||||
base = 2
|
||||
case 0x10:
|
||||
base = 8
|
||||
case 0x20:
|
||||
base = 16
|
||||
case 0x30:
|
||||
return 0.0, errors.New("bits 6 and 5 of information octet for REAL are equal to 11")
|
||||
}
|
||||
|
||||
scale := uint((info & 0x0c) >> 2)
|
||||
|
||||
var expLen int
|
||||
switch info & 0x03 {
|
||||
case 0x00:
|
||||
expLen = 1
|
||||
case 0x01:
|
||||
expLen = 2
|
||||
case 0x02:
|
||||
expLen = 3
|
||||
case 0x03:
|
||||
expLen = int(v[0])
|
||||
if expLen > 8 {
|
||||
return 0.0, errors.New("too big value of exponent")
|
||||
}
|
||||
v = v[1:]
|
||||
}
|
||||
buf, v = v[:expLen], v[expLen:]
|
||||
exponent, err := ParseInt64(buf)
|
||||
if err != nil {
|
||||
return 0.0, err
|
||||
}
|
||||
|
||||
if len(v) > 8 {
|
||||
return 0.0, errors.New("too big value of mantissa")
|
||||
}
|
||||
|
||||
mant, err := ParseInt64(v)
|
||||
if err != nil {
|
||||
return 0.0, err
|
||||
}
|
||||
mantissa := mant << scale
|
||||
|
||||
if info&0x40 == 0x40 {
|
||||
mantissa = -mantissa
|
||||
}
|
||||
|
||||
return float64(mantissa) * math.Pow(float64(base), float64(exponent)), nil
|
||||
}
|
||||
|
||||
func parseDecimalFloat(v []byte) (val float64, err error) {
|
||||
switch v[0] & 0x3F {
|
||||
case 0x01: // NR form 1
|
||||
var iVal int64
|
||||
iVal, err = strconv.ParseInt(strings.TrimLeft(string(v[1:]), " "), 10, 64)
|
||||
val = float64(iVal)
|
||||
case 0x02, 0x03: // NR form 2, 3
|
||||
val, err = strconv.ParseFloat(strings.Replace(strings.TrimLeft(string(v[1:]), " "), ",", ".", -1), 64)
|
||||
default:
|
||||
err = errors.New("incorrect NR form")
|
||||
}
|
||||
if err != nil {
|
||||
return 0.0, err
|
||||
}
|
||||
|
||||
if val == 0.0 && math.Signbit(val) {
|
||||
return 0.0, errors.New("REAL value -0 must be encoded as a special value")
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func parseSpecialFloat(v []byte) (float64, error) {
|
||||
if len(v) != 1 {
|
||||
return 0.0, errors.New(`encoding of "special value" must not contain exponent and mantissa`)
|
||||
}
|
||||
switch v[0] {
|
||||
case 0x40:
|
||||
return math.Inf(1), nil
|
||||
case 0x41:
|
||||
return math.Inf(-1), nil
|
||||
case 0x42:
|
||||
return math.NaN(), nil
|
||||
case 0x43:
|
||||
return math.Copysign(0, -1), nil
|
||||
}
|
||||
return 0.0, errors.New(`encoding of "special value" not from ASN.1 standard`)
|
||||
}
|
2
vendor/gopkg.in/asn1-ber.v1/util.go → vendor/github.com/go-asn1-ber/asn1-ber/util.go
generated
vendored
2
vendor/gopkg.in/asn1-ber.v1/util.go → vendor/github.com/go-asn1-ber/asn1-ber/util.go
generated
vendored
|
@ -3,7 +3,7 @@ package ber
|
|||
import "io"
|
||||
|
||||
func readByte(reader io.Reader) (byte, error) {
|
||||
bytes := make([]byte, 1, 1)
|
||||
bytes := make([]byte, 1)
|
||||
_, err := io.ReadFull(reader, bytes)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
0
vendor/gopkg.in/ldap.v3/LICENSE → vendor/github.com/go-ldap/ldap/v3/LICENSE
generated
vendored
0
vendor/gopkg.in/ldap.v3/LICENSE → vendor/github.com/go-ldap/ldap/v3/LICENSE
generated
vendored
11
vendor/gopkg.in/ldap.v3/add.go → vendor/github.com/go-ldap/ldap/v3/add.go
generated
vendored
11
vendor/gopkg.in/ldap.v3/add.go → vendor/github.com/go-ldap/ldap/v3/add.go
generated
vendored
|
@ -1,18 +1,9 @@
|
|||
//
|
||||
// https://tools.ietf.org/html/rfc4511
|
||||
//
|
||||
// AddRequest ::= [APPLICATION 8] SEQUENCE {
|
||||
// entry LDAPDN,
|
||||
// attributes AttributeList }
|
||||
//
|
||||
// AttributeList ::= SEQUENCE OF attribute Attribute
|
||||
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
ber "gopkg.in/asn1-ber.v1"
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
// Attribute represents an LDAP attribute
|
540
vendor/github.com/go-ldap/ldap/v3/bind.go
generated
vendored
Normal file
540
vendor/github.com/go-ldap/ldap/v3/bind.go
generated
vendored
Normal file
|
@ -0,0 +1,540 @@
|
|||
package ldap
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/md5"
|
||||
enchex "encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"strings"
|
||||
|
||||
"github.com/Azure/go-ntlmssp"
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
// SimpleBindRequest represents a username/password bind operation
|
||||
type SimpleBindRequest struct {
|
||||
// Username is the name of the Directory object that the client wishes to bind as
|
||||
Username string
|
||||
// Password is the credentials to bind with
|
||||
Password string
|
||||
// Controls are optional controls to send with the bind request
|
||||
Controls []Control
|
||||
// AllowEmptyPassword sets whether the client allows binding with an empty password
|
||||
// (normally used for unauthenticated bind).
|
||||
AllowEmptyPassword bool
|
||||
}
|
||||
|
||||
// SimpleBindResult contains the response from the server
|
||||
type SimpleBindResult struct {
|
||||
Controls []Control
|
||||
}
|
||||
|
||||
// NewSimpleBindRequest returns a bind request
|
||||
func NewSimpleBindRequest(username string, password string, controls []Control) *SimpleBindRequest {
|
||||
return &SimpleBindRequest{
|
||||
Username: username,
|
||||
Password: password,
|
||||
Controls: controls,
|
||||
AllowEmptyPassword: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (req *SimpleBindRequest) appendTo(envelope *ber.Packet) error {
|
||||
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
|
||||
pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
|
||||
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Username, "User Name"))
|
||||
pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.Password, "Password"))
|
||||
|
||||
envelope.AppendChild(pkt)
|
||||
if len(req.Controls) > 0 {
|
||||
envelope.AppendChild(encodeControls(req.Controls))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SimpleBind performs the simple bind operation defined in the given request
|
||||
func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResult, error) {
|
||||
if simpleBindRequest.Password == "" && !simpleBindRequest.AllowEmptyPassword {
|
||||
return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client"))
|
||||
}
|
||||
|
||||
msgCtx, err := l.doRequest(simpleBindRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer l.finishMessage(msgCtx)
|
||||
|
||||
packet, err := l.readPacket(msgCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &SimpleBindResult{
|
||||
Controls: make([]Control, 0),
|
||||
}
|
||||
|
||||
if len(packet.Children) == 3 {
|
||||
for _, child := range packet.Children[2].Children {
|
||||
decodedChild, decodeErr := DecodeControl(child)
|
||||
if decodeErr != nil {
|
||||
return nil, fmt.Errorf("failed to decode child control: %s", decodeErr)
|
||||
}
|
||||
result.Controls = append(result.Controls, decodedChild)
|
||||
}
|
||||
}
|
||||
|
||||
err = GetLDAPError(packet)
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Bind performs a bind with the given username and password.
|
||||
//
|
||||
// It does not allow unauthenticated bind (i.e. empty password). Use the UnauthenticatedBind method
|
||||
// for that.
|
||||
func (l *Conn) Bind(username, password string) error {
|
||||
req := &SimpleBindRequest{
|
||||
Username: username,
|
||||
Password: password,
|
||||
AllowEmptyPassword: false,
|
||||
}
|
||||
_, err := l.SimpleBind(req)
|
||||
return err
|
||||
}
|
||||
|
||||
// UnauthenticatedBind performs an unauthenticated bind.
|
||||
//
|
||||
// A username may be provided for trace (e.g. logging) purpose only, but it is normally not
|
||||
// authenticated or otherwise validated by the LDAP server.
|
||||
//
|
||||
// See https://tools.ietf.org/html/rfc4513#section-5.1.2 .
|
||||
// See https://tools.ietf.org/html/rfc4513#section-6.3.1 .
|
||||
func (l *Conn) UnauthenticatedBind(username string) error {
|
||||
req := &SimpleBindRequest{
|
||||
Username: username,
|
||||
Password: "",
|
||||
AllowEmptyPassword: true,
|
||||
}
|
||||
_, err := l.SimpleBind(req)
|
||||
return err
|
||||
}
|
||||
|
||||
// DigestMD5BindRequest represents a digest-md5 bind operation
|
||||
type DigestMD5BindRequest struct {
|
||||
Host string
|
||||
// Username is the name of the Directory object that the client wishes to bind as
|
||||
Username string
|
||||
// Password is the credentials to bind with
|
||||
Password string
|
||||
// Controls are optional controls to send with the bind request
|
||||
Controls []Control
|
||||
}
|
||||
|
||||
func (req *DigestMD5BindRequest) appendTo(envelope *ber.Packet) error {
|
||||
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
|
||||
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
|
||||
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name"))
|
||||
|
||||
auth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication")
|
||||
auth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "DIGEST-MD5", "SASL Mech"))
|
||||
request.AppendChild(auth)
|
||||
envelope.AppendChild(request)
|
||||
if len(req.Controls) > 0 {
|
||||
envelope.AppendChild(encodeControls(req.Controls))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DigestMD5BindResult contains the response from the server
|
||||
type DigestMD5BindResult struct {
|
||||
Controls []Control
|
||||
}
|
||||
|
||||
// MD5Bind performs a digest-md5 bind with the given host, username and password.
|
||||
func (l *Conn) MD5Bind(host, username, password string) error {
|
||||
req := &DigestMD5BindRequest{
|
||||
Host: host,
|
||||
Username: username,
|
||||
Password: password,
|
||||
}
|
||||
_, err := l.DigestMD5Bind(req)
|
||||
return err
|
||||
}
|
||||
|
||||
// DigestMD5Bind performs the digest-md5 bind operation defined in the given request
|
||||
func (l *Conn) DigestMD5Bind(digestMD5BindRequest *DigestMD5BindRequest) (*DigestMD5BindResult, error) {
|
||||
if digestMD5BindRequest.Password == "" {
|
||||
return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client"))
|
||||
}
|
||||
|
||||
msgCtx, err := l.doRequest(digestMD5BindRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer l.finishMessage(msgCtx)
|
||||
|
||||
packet, err := l.readPacket(msgCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
|
||||
if l.Debug {
|
||||
if err = addLDAPDescriptions(packet); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ber.PrintPacket(packet)
|
||||
}
|
||||
|
||||
result := &DigestMD5BindResult{
|
||||
Controls: make([]Control, 0),
|
||||
}
|
||||
var params map[string]string
|
||||
if len(packet.Children) == 2 {
|
||||
if len(packet.Children[1].Children) == 4 {
|
||||
child := packet.Children[1].Children[0]
|
||||
if child.Tag != ber.TagEnumerated {
|
||||
return result, GetLDAPError(packet)
|
||||
}
|
||||
if child.Value.(int64) != 14 {
|
||||
return result, GetLDAPError(packet)
|
||||
}
|
||||
child = packet.Children[1].Children[3]
|
||||
if child.Tag != ber.TagObjectDescriptor {
|
||||
return result, GetLDAPError(packet)
|
||||
}
|
||||
if child.Data == nil {
|
||||
return result, GetLDAPError(packet)
|
||||
}
|
||||
data, _ := ioutil.ReadAll(child.Data)
|
||||
params, err = parseParams(string(data))
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("parsing digest-challenge: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if params != nil {
|
||||
resp := computeResponse(
|
||||
params,
|
||||
"ldap/"+strings.ToLower(digestMD5BindRequest.Host),
|
||||
digestMD5BindRequest.Username,
|
||||
digestMD5BindRequest.Password,
|
||||
)
|
||||
packet = ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
|
||||
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
|
||||
|
||||
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
|
||||
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
|
||||
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name"))
|
||||
|
||||
auth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication")
|
||||
auth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "DIGEST-MD5", "SASL Mech"))
|
||||
auth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, resp, "Credentials"))
|
||||
request.AppendChild(auth)
|
||||
packet.AppendChild(request)
|
||||
msgCtx, err = l.sendMessage(packet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send message: %s", err)
|
||||
}
|
||||
defer l.finishMessage(msgCtx)
|
||||
packetResponse, ok := <-msgCtx.responses
|
||||
if !ok {
|
||||
return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
|
||||
}
|
||||
packet, err = packetResponse.ReadPacket()
|
||||
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read packet: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = GetLDAPError(packet)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func parseParams(str string) (map[string]string, error) {
|
||||
m := make(map[string]string)
|
||||
var key, value string
|
||||
var state int
|
||||
for i := 0; i <= len(str); i++ {
|
||||
switch state {
|
||||
case 0: //reading key
|
||||
if i == len(str) {
|
||||
return nil, fmt.Errorf("syntax error on %d", i)
|
||||
}
|
||||
if str[i] != '=' {
|
||||
key += string(str[i])
|
||||
continue
|
||||
}
|
||||
state = 1
|
||||
case 1: //reading value
|
||||
if i == len(str) {
|
||||
m[key] = value
|
||||
break
|
||||
}
|
||||
switch str[i] {
|
||||
case ',':
|
||||
m[key] = value
|
||||
state = 0
|
||||
key = ""
|
||||
value = ""
|
||||
case '"':
|
||||
if value != "" {
|
||||
return nil, fmt.Errorf("syntax error on %d", i)
|
||||
}
|
||||
state = 2
|
||||
default:
|
||||
value += string(str[i])
|
||||
}
|
||||
case 2: //inside quotes
|
||||
if i == len(str) {
|
||||
return nil, fmt.Errorf("syntax error on %d", i)
|
||||
}
|
||||
if str[i] != '"' {
|
||||
value += string(str[i])
|
||||
} else {
|
||||
state = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func computeResponse(params map[string]string, uri, username, password string) string {
|
||||
nc := "00000001"
|
||||
qop := "auth"
|
||||
cnonce := enchex.EncodeToString(randomBytes(16))
|
||||
x := username + ":" + params["realm"] + ":" + password
|
||||
y := md5Hash([]byte(x))
|
||||
|
||||
a1 := bytes.NewBuffer(y)
|
||||
a1.WriteString(":" + params["nonce"] + ":" + cnonce)
|
||||
if len(params["authzid"]) > 0 {
|
||||
a1.WriteString(":" + params["authzid"])
|
||||
}
|
||||
a2 := bytes.NewBuffer([]byte("AUTHENTICATE"))
|
||||
a2.WriteString(":" + uri)
|
||||
ha1 := enchex.EncodeToString(md5Hash(a1.Bytes()))
|
||||
ha2 := enchex.EncodeToString(md5Hash(a2.Bytes()))
|
||||
|
||||
kd := ha1
|
||||
kd += ":" + params["nonce"]
|
||||
kd += ":" + nc
|
||||
kd += ":" + cnonce
|
||||
kd += ":" + qop
|
||||
kd += ":" + ha2
|
||||
resp := enchex.EncodeToString(md5Hash([]byte(kd)))
|
||||
return fmt.Sprintf(
|
||||
`username="%s",realm="%s",nonce="%s",cnonce="%s",nc=00000001,qop=%s,digest-uri="%s",response=%s`,
|
||||
username,
|
||||
params["realm"],
|
||||
params["nonce"],
|
||||
cnonce,
|
||||
qop,
|
||||
uri,
|
||||
resp,
|
||||
)
|
||||
}
|
||||
|
||||
func md5Hash(b []byte) []byte {
|
||||
hasher := md5.New()
|
||||
hasher.Write(b)
|
||||
return hasher.Sum(nil)
|
||||
}
|
||||
|
||||
func randomBytes(len int) []byte {
|
||||
b := make([]byte, len)
|
||||
for i := 0; i < len; i++ {
|
||||
b[i] = byte(rand.Intn(256))
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
var externalBindRequest = requestFunc(func(envelope *ber.Packet) error {
|
||||
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
|
||||
pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
|
||||
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name"))
|
||||
|
||||
saslAuth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication")
|
||||
saslAuth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "EXTERNAL", "SASL Mech"))
|
||||
saslAuth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "SASL Cred"))
|
||||
|
||||
pkt.AppendChild(saslAuth)
|
||||
|
||||
envelope.AppendChild(pkt)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// ExternalBind performs SASL/EXTERNAL authentication.
|
||||
//
|
||||
// Use ldap.DialURL("ldapi://") to connect to the Unix socket before ExternalBind.
|
||||
//
|
||||
// See https://tools.ietf.org/html/rfc4422#appendix-A
|
||||
func (l *Conn) ExternalBind() error {
|
||||
msgCtx, err := l.doRequest(externalBindRequest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer l.finishMessage(msgCtx)
|
||||
|
||||
packet, err := l.readPacket(msgCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return GetLDAPError(packet)
|
||||
}
|
||||
|
||||
// NTLMBind performs an NTLMSSP bind leveraging https://github.com/Azure/go-ntlmssp
|
||||
|
||||
// NTLMBindRequest represents an NTLMSSP bind operation
|
||||
type NTLMBindRequest struct {
|
||||
// Domain is the AD Domain to authenticate too. If not specified, it will be grabbed from the NTLMSSP Challenge
|
||||
Domain string
|
||||
// Username is the name of the Directory object that the client wishes to bind as
|
||||
Username string
|
||||
// Password is the credentials to bind with
|
||||
Password string
|
||||
// Hash is the hex NTLM hash to bind with. Password or hash must be provided
|
||||
Hash string
|
||||
// Controls are optional controls to send with the bind request
|
||||
Controls []Control
|
||||
}
|
||||
|
||||
func (req *NTLMBindRequest) appendTo(envelope *ber.Packet) error {
|
||||
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
|
||||
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
|
||||
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name"))
|
||||
|
||||
// generate an NTLMSSP Negotiation message for the specified domain (it can be blank)
|
||||
negMessage, err := ntlmssp.NewNegotiateMessage(req.Domain, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("err creating negmessage: %s", err)
|
||||
}
|
||||
|
||||
// append the generated NTLMSSP message as a TagEnumerated BER value
|
||||
auth := ber.Encode(ber.ClassContext, ber.TypePrimitive, ber.TagEnumerated, negMessage, "authentication")
|
||||
request.AppendChild(auth)
|
||||
envelope.AppendChild(request)
|
||||
if len(req.Controls) > 0 {
|
||||
envelope.AppendChild(encodeControls(req.Controls))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NTLMBindResult contains the response from the server
|
||||
type NTLMBindResult struct {
|
||||
Controls []Control
|
||||
}
|
||||
|
||||
// NTLMBind performs an NTLMSSP Bind with the given domain, username and password
|
||||
func (l *Conn) NTLMBind(domain, username, password string) error {
|
||||
req := &NTLMBindRequest{
|
||||
Domain: domain,
|
||||
Username: username,
|
||||
Password: password,
|
||||
}
|
||||
_, err := l.NTLMChallengeBind(req)
|
||||
return err
|
||||
}
|
||||
|
||||
// NTLMBindWithHash performs an NTLM Bind with an NTLM hash instead of plaintext password (pass-the-hash)
|
||||
func (l *Conn) NTLMBindWithHash(domain, username, hash string) error {
|
||||
req := &NTLMBindRequest{
|
||||
Domain: domain,
|
||||
Username: username,
|
||||
Hash: hash,
|
||||
}
|
||||
_, err := l.NTLMChallengeBind(req)
|
||||
return err
|
||||
}
|
||||
|
||||
// NTLMChallengeBind performs the NTLMSSP bind operation defined in the given request
|
||||
func (l *Conn) NTLMChallengeBind(ntlmBindRequest *NTLMBindRequest) (*NTLMBindResult, error) {
|
||||
if ntlmBindRequest.Password == "" && ntlmBindRequest.Hash == "" {
|
||||
return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client"))
|
||||
}
|
||||
|
||||
msgCtx, err := l.doRequest(ntlmBindRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer l.finishMessage(msgCtx)
|
||||
packet, err := l.readPacket(msgCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
|
||||
if l.Debug {
|
||||
if err = addLDAPDescriptions(packet); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ber.PrintPacket(packet)
|
||||
}
|
||||
result := &NTLMBindResult{
|
||||
Controls: make([]Control, 0),
|
||||
}
|
||||
var ntlmsspChallenge []byte
|
||||
|
||||
// now find the NTLM Response Message
|
||||
if len(packet.Children) == 2 {
|
||||
if len(packet.Children[1].Children) == 3 {
|
||||
child := packet.Children[1].Children[1]
|
||||
ntlmsspChallenge = child.ByteValue
|
||||
// Check to make sure we got the right message. It will always start with NTLMSSP
|
||||
if len(ntlmsspChallenge) < 7 || !bytes.Equal(ntlmsspChallenge[:7], []byte("NTLMSSP")) {
|
||||
return result, GetLDAPError(packet)
|
||||
}
|
||||
l.Debug.Printf("%d: found ntlmssp challenge", msgCtx.id)
|
||||
}
|
||||
}
|
||||
if ntlmsspChallenge != nil {
|
||||
var err error
|
||||
var responseMessage []byte
|
||||
// generate a response message to the challenge with the given Username/Password if password is provided
|
||||
if ntlmBindRequest.Password != "" {
|
||||
responseMessage, err = ntlmssp.ProcessChallenge(ntlmsspChallenge, ntlmBindRequest.Username, ntlmBindRequest.Password)
|
||||
} else if ntlmBindRequest.Hash != "" {
|
||||
responseMessage, err = ntlmssp.ProcessChallengeWithHash(ntlmsspChallenge, ntlmBindRequest.Username, ntlmBindRequest.Hash)
|
||||
} else {
|
||||
err = fmt.Errorf("need a password or hash to generate reply")
|
||||
}
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("parsing ntlm-challenge: %s", err)
|
||||
}
|
||||
packet = ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
|
||||
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
|
||||
|
||||
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
|
||||
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
|
||||
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name"))
|
||||
|
||||
// append the challenge response message as a TagEmbeddedPDV BER value
|
||||
auth := ber.Encode(ber.ClassContext, ber.TypePrimitive, ber.TagEmbeddedPDV, responseMessage, "authentication")
|
||||
|
||||
request.AppendChild(auth)
|
||||
packet.AppendChild(request)
|
||||
msgCtx, err = l.sendMessage(packet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send message: %s", err)
|
||||
}
|
||||
defer l.finishMessage(msgCtx)
|
||||
packetResponse, ok := <-msgCtx.responses
|
||||
if !ok {
|
||||
return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
|
||||
}
|
||||
packet, err = packetResponse.ReadPacket()
|
||||
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read packet: %s", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
err = GetLDAPError(packet)
|
||||
return result, err
|
||||
}
|
2
vendor/gopkg.in/ldap.v3/client.go → vendor/github.com/go-ldap/ldap/v3/client.go
generated
vendored
2
vendor/gopkg.in/ldap.v3/client.go → vendor/github.com/go-ldap/ldap/v3/client.go
generated
vendored
|
@ -10,6 +10,7 @@ type Client interface {
|
|||
Start()
|
||||
StartTLS(*tls.Config) error
|
||||
Close()
|
||||
IsClosing() bool
|
||||
SetTimeout(time.Duration)
|
||||
|
||||
Bind(username, password string) error
|
||||
|
@ -21,6 +22,7 @@ type Client interface {
|
|||
Del(*DelRequest) error
|
||||
Modify(*ModifyRequest) error
|
||||
ModifyDN(*ModifyDNRequest) error
|
||||
ModifyWithResult(*ModifyRequest) (*ModifyResult, error)
|
||||
|
||||
Compare(dn, attribute, value string) (bool, error)
|
||||
PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error)
|
21
vendor/gopkg.in/ldap.v3/compare.go → vendor/github.com/go-ldap/ldap/v3/compare.go
generated
vendored
21
vendor/gopkg.in/ldap.v3/compare.go → vendor/github.com/go-ldap/ldap/v3/compare.go
generated
vendored
|
@ -1,28 +1,9 @@
|
|||
// File contains Compare functionality
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc4511
|
||||
//
|
||||
// CompareRequest ::= [APPLICATION 14] SEQUENCE {
|
||||
// entry LDAPDN,
|
||||
// ava AttributeValueAssertion }
|
||||
//
|
||||
// AttributeValueAssertion ::= SEQUENCE {
|
||||
// attributeDesc AttributeDescription,
|
||||
// assertionValue AssertionValue }
|
||||
//
|
||||
// AttributeDescription ::= LDAPString
|
||||
// -- Constrained to <attributedescription>
|
||||
// -- [RFC4512]
|
||||
//
|
||||
// AttributeValue ::= OCTET STRING
|
||||
//
|
||||
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
ber "gopkg.in/asn1-ber.v1"
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
// CompareRequest represents an LDAP CompareRequest operation.
|
135
vendor/gopkg.in/ldap.v3/conn.go → vendor/github.com/go-ldap/ldap/v3/conn.go
generated
vendored
135
vendor/gopkg.in/ldap.v3/conn.go → vendor/github.com/go-ldap/ldap/v3/conn.go
generated
vendored
|
@ -1,6 +1,7 @@
|
|||
package ldap
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -11,7 +12,7 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
ber "gopkg.in/asn1-ber.v1"
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -112,8 +113,72 @@ var _ Client = &Conn{}
|
|||
// multiple places will probably result in undesired behaviour.
|
||||
var DefaultTimeout = 60 * time.Second
|
||||
|
||||
// DialOpt configures DialContext.
|
||||
type DialOpt func(*DialContext)
|
||||
|
||||
// DialWithDialer updates net.Dialer in DialContext.
|
||||
func DialWithDialer(d *net.Dialer) DialOpt {
|
||||
return func(dc *DialContext) {
|
||||
dc.d = d
|
||||
}
|
||||
}
|
||||
|
||||
// DialWithTLSConfig updates tls.Config in DialContext.
|
||||
func DialWithTLSConfig(tc *tls.Config) DialOpt {
|
||||
return func(dc *DialContext) {
|
||||
dc.tc = tc
|
||||
}
|
||||
}
|
||||
|
||||
// DialWithTLSDialer is a wrapper for DialWithTLSConfig with the option to
|
||||
// specify a net.Dialer to for example define a timeout or a custom resolver.
|
||||
func DialWithTLSDialer(tlsConfig *tls.Config, dialer *net.Dialer) DialOpt {
|
||||
return func(dc *DialContext) {
|
||||
dc.tc = tlsConfig
|
||||
dc.d = dialer
|
||||
}
|
||||
}
|
||||
|
||||
// DialContext contains necessary parameters to dial the given ldap URL.
|
||||
type DialContext struct {
|
||||
d *net.Dialer
|
||||
tc *tls.Config
|
||||
}
|
||||
|
||||
func (dc *DialContext) dial(u *url.URL) (net.Conn, error) {
|
||||
if u.Scheme == "ldapi" {
|
||||
if u.Path == "" || u.Path == "/" {
|
||||
u.Path = "/var/run/slapd/ldapi"
|
||||
}
|
||||
return dc.d.Dial("unix", u.Path)
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(u.Host)
|
||||
if err != nil {
|
||||
// we assume that error is due to missing port
|
||||
host = u.Host
|
||||
port = ""
|
||||
}
|
||||
|
||||
switch u.Scheme {
|
||||
case "ldap":
|
||||
if port == "" {
|
||||
port = DefaultLdapPort
|
||||
}
|
||||
return dc.d.Dial("tcp", net.JoinHostPort(host, port))
|
||||
case "ldaps":
|
||||
if port == "" {
|
||||
port = DefaultLdapsPort
|
||||
}
|
||||
return tls.DialWithDialer(dc.d, "tcp", net.JoinHostPort(host, port), dc.tc)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme)
|
||||
}
|
||||
|
||||
// Dial connects to the given address on the given network using net.Dial
|
||||
// and then returns a new Conn for the connection.
|
||||
// @deprecated Use DialURL instead.
|
||||
func Dial(network, addr string) (*Conn, error) {
|
||||
c, err := net.DialTimeout(network, addr, DefaultTimeout)
|
||||
if err != nil {
|
||||
|
@ -126,6 +191,7 @@ func Dial(network, addr string) (*Conn, error) {
|
|||
|
||||
// DialTLS connects to the given address on the given network using tls.Dial
|
||||
// and then returns a new Conn for the connection.
|
||||
// @deprecated Use DialURL instead.
|
||||
func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
|
||||
c, err := tls.DialWithDialer(&net.Dialer{Timeout: DefaultTimeout}, network, addr, config)
|
||||
if err != nil {
|
||||
|
@ -136,44 +202,31 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
|
|||
return conn, nil
|
||||
}
|
||||
|
||||
// DialURL connects to the given ldap URL vie TCP using tls.Dial or net.Dial if ldaps://
|
||||
// or ldap:// specified as protocol. On success a new Conn for the connection
|
||||
// is returned.
|
||||
func DialURL(addr string) (*Conn, error) {
|
||||
lurl, err := url.Parse(addr)
|
||||
// DialURL connects to the given ldap URL.
|
||||
// The following schemas are supported: ldap://, ldaps://, ldapi://.
|
||||
// On success a new Conn for the connection is returned.
|
||||
func DialURL(addr string, opts ...DialOpt) (*Conn, error) {
|
||||
u, err := url.Parse(addr)
|
||||
if err != nil {
|
||||
return nil, NewError(ErrorNetwork, err)
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(lurl.Host)
|
||||
var dc DialContext
|
||||
for _, opt := range opts {
|
||||
opt(&dc)
|
||||
}
|
||||
if dc.d == nil {
|
||||
dc.d = &net.Dialer{Timeout: DefaultTimeout}
|
||||
}
|
||||
|
||||
c, err := dc.dial(u)
|
||||
if err != nil {
|
||||
// we asume that error is due to missing port
|
||||
host = lurl.Host
|
||||
port = ""
|
||||
return nil, NewError(ErrorNetwork, err)
|
||||
}
|
||||
|
||||
switch lurl.Scheme {
|
||||
case "ldapi":
|
||||
if lurl.Path == "" || lurl.Path == "/" {
|
||||
lurl.Path = "/var/run/slapd/ldapi"
|
||||
}
|
||||
return Dial("unix", lurl.Path)
|
||||
case "ldap":
|
||||
if port == "" {
|
||||
port = DefaultLdapPort
|
||||
}
|
||||
return Dial("tcp", net.JoinHostPort(host, port))
|
||||
case "ldaps":
|
||||
if port == "" {
|
||||
port = DefaultLdapsPort
|
||||
}
|
||||
tlsConf := &tls.Config{
|
||||
ServerName: host,
|
||||
}
|
||||
return DialTLS("tcp", net.JoinHostPort(host, port), tlsConf)
|
||||
}
|
||||
|
||||
return nil, NewError(ErrorNetwork, fmt.Errorf("Unknown scheme '%s'", lurl.Scheme))
|
||||
conn := NewConn(c, u.Scheme == "ldaps")
|
||||
conn.Start()
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// NewConn returns a new Conn using conn for network I/O.
|
||||
|
@ -191,9 +244,9 @@ func NewConn(conn net.Conn, isTLS bool) *Conn {
|
|||
|
||||
// Start initializes goroutines to read responses and process messages
|
||||
func (l *Conn) Start() {
|
||||
l.wgClose.Add(1)
|
||||
go l.reader()
|
||||
go l.processMessages()
|
||||
l.wgClose.Add(1)
|
||||
}
|
||||
|
||||
// IsClosing returns whether or not we're currently closing.
|
||||
|
@ -278,7 +331,7 @@ func (l *Conn) StartTLS(config *tls.Config) error {
|
|||
l.Close()
|
||||
return err
|
||||
}
|
||||
ber.PrintPacket(packet)
|
||||
l.Debug.PrintPacket(packet)
|
||||
}
|
||||
|
||||
if err := GetLDAPError(packet); err == nil {
|
||||
|
@ -347,7 +400,12 @@ func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags)
|
|||
responses: responses,
|
||||
},
|
||||
}
|
||||
l.sendProcessMessage(message)
|
||||
if !l.sendProcessMessage(message) {
|
||||
if l.IsClosing() {
|
||||
return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
|
||||
}
|
||||
return nil, NewError(ErrorNetwork, errors.New("ldap: could not send message for unknown reason"))
|
||||
}
|
||||
return message.Context, nil
|
||||
}
|
||||
|
||||
|
@ -451,14 +509,14 @@ func (l *Conn) processMessages() {
|
|||
msgCtx.sendResponse(&PacketResponse{message.Packet, nil})
|
||||
} else {
|
||||
log.Printf("Received unexpected message %d, %v", message.MessageID, l.IsClosing())
|
||||
ber.PrintPacket(message.Packet)
|
||||
l.Debug.PrintPacket(message.Packet)
|
||||
}
|
||||
case MessageTimeout:
|
||||
// Handle the timeout by closing the channel
|
||||
// All reads will return immediately
|
||||
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
|
||||
l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
|
||||
msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")})
|
||||
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))})
|
||||
delete(l.messageContexts, message.MessageID)
|
||||
close(msgCtx.responses)
|
||||
}
|
||||
|
@ -484,12 +542,13 @@ func (l *Conn) reader() {
|
|||
}
|
||||
}()
|
||||
|
||||
bufConn := bufio.NewReader(l.conn)
|
||||
for {
|
||||
if cleanstop {
|
||||
l.Debug.Printf("reader clean stopping (without closing the connection)")
|
||||
return
|
||||
}
|
||||
packet, err := ber.ReadPacket(l.conn)
|
||||
packet, err := ber.ReadPacket(bufConn)
|
||||
if err != nil {
|
||||
// A read error is expected here if we are closing the connection...
|
||||
if !l.IsClosing() {
|
81
vendor/gopkg.in/ldap.v3/control.go → vendor/github.com/go-ldap/ldap/v3/control.go
generated
vendored
81
vendor/gopkg.in/ldap.v3/control.go → vendor/github.com/go-ldap/ldap/v3/control.go
generated
vendored
|
@ -4,7 +4,7 @@ import (
|
|||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"gopkg.in/asn1-ber.v1"
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -18,20 +18,25 @@ const (
|
|||
ControlTypeVChuPasswordWarning = "2.16.840.1.113730.3.4.5"
|
||||
// ControlTypeManageDsaIT - https://tools.ietf.org/html/rfc3296
|
||||
ControlTypeManageDsaIT = "2.16.840.1.113730.3.4.2"
|
||||
// ControlTypeWhoAmI - https://tools.ietf.org/html/rfc4532
|
||||
ControlTypeWhoAmI = "1.3.6.1.4.1.4203.1.11.3"
|
||||
|
||||
// ControlTypeMicrosoftNotification - https://msdn.microsoft.com/en-us/library/aa366983(v=vs.85).aspx
|
||||
ControlTypeMicrosoftNotification = "1.2.840.113556.1.4.528"
|
||||
// ControlTypeMicrosoftShowDeleted - https://msdn.microsoft.com/en-us/library/aa366989(v=vs.85).aspx
|
||||
ControlTypeMicrosoftShowDeleted = "1.2.840.113556.1.4.417"
|
||||
// ControlTypeMicrosoftServerLinkTTL - https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-adts/f4f523a8-abc0-4b3a-a471-6b2fef135481?redirectedfrom=MSDN
|
||||
ControlTypeMicrosoftServerLinkTTL = "1.2.840.113556.1.4.2309"
|
||||
)
|
||||
|
||||
// ControlTypeMap maps controls to text descriptions
|
||||
var ControlTypeMap = map[string]string{
|
||||
ControlTypePaging: "Paging",
|
||||
ControlTypeBeheraPasswordPolicy: "Password Policy - Behera Draft",
|
||||
ControlTypeManageDsaIT: "Manage DSA IT",
|
||||
ControlTypeMicrosoftNotification: "Change Notification - Microsoft",
|
||||
ControlTypeMicrosoftShowDeleted: "Show Deleted Objects - Microsoft",
|
||||
ControlTypePaging: "Paging",
|
||||
ControlTypeBeheraPasswordPolicy: "Password Policy - Behera Draft",
|
||||
ControlTypeManageDsaIT: "Manage DSA IT",
|
||||
ControlTypeMicrosoftNotification: "Change Notification - Microsoft",
|
||||
ControlTypeMicrosoftShowDeleted: "Show Deleted Objects - Microsoft",
|
||||
ControlTypeMicrosoftServerLinkTTL: "Return TTL-DNs for link values with associated expiry times - Microsoft",
|
||||
}
|
||||
|
||||
// Control defines an interface controls provide to encode and describe themselves
|
||||
|
@ -305,6 +310,35 @@ func NewControlMicrosoftShowDeleted() *ControlMicrosoftShowDeleted {
|
|||
return &ControlMicrosoftShowDeleted{}
|
||||
}
|
||||
|
||||
// ControlMicrosoftServerLinkTTL implements the control described in https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-adts/f4f523a8-abc0-4b3a-a471-6b2fef135481?redirectedfrom=MSDN
|
||||
type ControlMicrosoftServerLinkTTL struct{}
|
||||
|
||||
// GetControlType returns the OID
|
||||
func (c *ControlMicrosoftServerLinkTTL) GetControlType() string {
|
||||
return ControlTypeMicrosoftServerLinkTTL
|
||||
}
|
||||
|
||||
// Encode returns the ber packet representation
|
||||
func (c *ControlMicrosoftServerLinkTTL) Encode() *ber.Packet {
|
||||
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
|
||||
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeMicrosoftServerLinkTTL, "Control Type ("+ControlTypeMap[ControlTypeMicrosoftServerLinkTTL]+")"))
|
||||
|
||||
return packet
|
||||
}
|
||||
|
||||
// String returns a human-readable description
|
||||
func (c *ControlMicrosoftServerLinkTTL) String() string {
|
||||
return fmt.Sprintf(
|
||||
"Control Type: %s (%q)",
|
||||
ControlTypeMap[ControlTypeMicrosoftServerLinkTTL],
|
||||
ControlTypeMicrosoftServerLinkTTL)
|
||||
}
|
||||
|
||||
// NewControlMicrosoftServerLinkTTL returns a ControlMicrosoftServerLinkTTL control
|
||||
func NewControlMicrosoftServerLinkTTL() *ControlMicrosoftServerLinkTTL {
|
||||
return &ControlMicrosoftServerLinkTTL{}
|
||||
}
|
||||
|
||||
// FindControl returns the first control of the given type in the list, or nil
|
||||
func FindControl(controls []Control, controlType string) Control {
|
||||
for _, c := range controls {
|
||||
|
@ -404,33 +438,26 @@ func DecodeControl(packet *ber.Packet) (Control, error) {
|
|||
if child.Tag == 0 {
|
||||
//Warning
|
||||
warningPacket := child.Children[0]
|
||||
packet, err := ber.DecodePacketErr(warningPacket.Data.Bytes())
|
||||
val, err := ber.ParseInt64(warningPacket.Data.Bytes())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode data bytes: %s", err)
|
||||
}
|
||||
val, ok := packet.Value.(int64)
|
||||
if ok {
|
||||
if warningPacket.Tag == 0 {
|
||||
//timeBeforeExpiration
|
||||
c.Expire = val
|
||||
warningPacket.Value = c.Expire
|
||||
} else if warningPacket.Tag == 1 {
|
||||
//graceAuthNsRemaining
|
||||
c.Grace = val
|
||||
warningPacket.Value = c.Grace
|
||||
}
|
||||
if warningPacket.Tag == 0 {
|
||||
//timeBeforeExpiration
|
||||
c.Expire = val
|
||||
warningPacket.Value = c.Expire
|
||||
} else if warningPacket.Tag == 1 {
|
||||
//graceAuthNsRemaining
|
||||
c.Grace = val
|
||||
warningPacket.Value = c.Grace
|
||||
}
|
||||
} else if child.Tag == 1 {
|
||||
// Error
|
||||
packet, err := ber.DecodePacketErr(child.Data.Bytes())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode data bytes: %s", err)
|
||||
}
|
||||
val, ok := packet.Value.(int8)
|
||||
if !ok {
|
||||
// what to do?
|
||||
val = -1
|
||||
bs := child.Data.Bytes()
|
||||
if len(bs) != 1 || bs[0] > 8 {
|
||||
return nil, fmt.Errorf("failed to decode data bytes: %s", "invalid PasswordPolicyResponse enum value")
|
||||
}
|
||||
val := int8(bs[0])
|
||||
c.Error = val
|
||||
child.Value = c.Error
|
||||
c.ErrorString = BeheraPasswordPolicyErrorMap[c.Error]
|
||||
|
@ -456,6 +483,8 @@ func DecodeControl(packet *ber.Packet) (Control, error) {
|
|||
return NewControlMicrosoftNotification(), nil
|
||||
case ControlTypeMicrosoftShowDeleted:
|
||||
return NewControlMicrosoftShowDeleted(), nil
|
||||
case ControlTypeMicrosoftServerLinkTTL:
|
||||
return NewControlMicrosoftServerLinkTTL(), nil
|
||||
default:
|
||||
c := new(ControlString)
|
||||
c.ControlType = ControlType
|
4
vendor/gopkg.in/ldap.v3/debug.go → vendor/github.com/go-ldap/ldap/v3/debug.go
generated
vendored
4
vendor/gopkg.in/ldap.v3/debug.go → vendor/github.com/go-ldap/ldap/v3/debug.go
generated
vendored
|
@ -3,7 +3,7 @@ package ldap
|
|||
import (
|
||||
"log"
|
||||
|
||||
ber "gopkg.in/asn1-ber.v1"
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
// debugging type
|
||||
|
@ -25,6 +25,6 @@ func (debug debugging) Printf(format string, args ...interface{}) {
|
|||
// PrintPacket dumps a packet.
|
||||
func (debug debugging) PrintPacket(packet *ber.Packet) {
|
||||
if debug {
|
||||
ber.PrintPacket(packet)
|
||||
ber.WritePacket(log.Writer(), packet)
|
||||
}
|
||||
}
|
7
vendor/gopkg.in/ldap.v3/del.go → vendor/github.com/go-ldap/ldap/v3/del.go
generated
vendored
7
vendor/gopkg.in/ldap.v3/del.go → vendor/github.com/go-ldap/ldap/v3/del.go
generated
vendored
|
@ -1,14 +1,9 @@
|
|||
//
|
||||
// https://tools.ietf.org/html/rfc4511
|
||||
//
|
||||
// DelRequest ::= [APPLICATION 10] LDAPDN
|
||||
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
ber "gopkg.in/asn1-ber.v1"
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
// DelRequest implements an LDAP deletion request
|
109
vendor/gopkg.in/ldap.v3/dn.go → vendor/github.com/go-ldap/ldap/v3/dn.go
generated
vendored
109
vendor/gopkg.in/ldap.v3/dn.go → vendor/github.com/go-ldap/ldap/v3/dn.go
generated
vendored
|
@ -1,44 +1,3 @@
|
|||
// File contains DN parsing functionality
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc4514
|
||||
//
|
||||
// distinguishedName = [ relativeDistinguishedName
|
||||
// *( COMMA relativeDistinguishedName ) ]
|
||||
// relativeDistinguishedName = attributeTypeAndValue
|
||||
// *( PLUS attributeTypeAndValue )
|
||||
// attributeTypeAndValue = attributeType EQUALS attributeValue
|
||||
// attributeType = descr / numericoid
|
||||
// attributeValue = string / hexstring
|
||||
//
|
||||
// ; The following characters are to be escaped when they appear
|
||||
// ; in the value to be encoded: ESC, one of <escaped>, leading
|
||||
// ; SHARP or SPACE, trailing SPACE, and NULL.
|
||||
// string = [ ( leadchar / pair ) [ *( stringchar / pair )
|
||||
// ( trailchar / pair ) ] ]
|
||||
//
|
||||
// leadchar = LUTF1 / UTFMB
|
||||
// LUTF1 = %x01-1F / %x21 / %x24-2A / %x2D-3A /
|
||||
// %x3D / %x3F-5B / %x5D-7F
|
||||
//
|
||||
// trailchar = TUTF1 / UTFMB
|
||||
// TUTF1 = %x01-1F / %x21 / %x23-2A / %x2D-3A /
|
||||
// %x3D / %x3F-5B / %x5D-7F
|
||||
//
|
||||
// stringchar = SUTF1 / UTFMB
|
||||
// SUTF1 = %x01-21 / %x23-2A / %x2D-3A /
|
||||
// %x3D / %x3F-5B / %x5D-7F
|
||||
//
|
||||
// pair = ESC ( ESC / special / hexpair )
|
||||
// special = escaped / SPACE / SHARP / EQUALS
|
||||
// escaped = DQUOTE / PLUS / COMMA / SEMI / LANGLE / RANGLE
|
||||
// hexstring = SHARP 1*hexpair
|
||||
// hexpair = HEX HEX
|
||||
//
|
||||
// where the productions <descr>, <numericoid>, <COMMA>, <DQUOTE>,
|
||||
// <EQUALS>, <ESC>, <HEX>, <LANGLE>, <NULL>, <PLUS>, <RANGLE>, <SEMI>,
|
||||
// <SPACE>, <SHARP>, and <UTFMB> are defined in [RFC4512].
|
||||
//
|
||||
|
||||
package ldap
|
||||
|
||||
import (
|
||||
|
@ -48,7 +7,7 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/asn1-ber.v1"
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
// AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514
|
||||
|
@ -69,7 +28,8 @@ type DN struct {
|
|||
RDNs []*RelativeDN
|
||||
}
|
||||
|
||||
// ParseDN returns a distinguishedName or an error
|
||||
// ParseDN returns a distinguishedName or an error.
|
||||
// The function respects https://tools.ietf.org/html/rfc4514
|
||||
func ParseDN(str string) (*DN, error) {
|
||||
dn := new(DN)
|
||||
dn.RDNs = make([]*RelativeDN, 0)
|
||||
|
@ -245,3 +205,66 @@ func (r *RelativeDN) hasAllAttributes(attrs []*AttributeTypeAndValue) bool {
|
|||
func (a *AttributeTypeAndValue) Equal(other *AttributeTypeAndValue) bool {
|
||||
return strings.EqualFold(a.Type, other.Type) && a.Value == other.Value
|
||||
}
|
||||
|
||||
// Equal returns true if the DNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch).
|
||||
// Returns true if they have the same number of relative distinguished names
|
||||
// and corresponding relative distinguished names (by position) are the same.
|
||||
// Case of the attribute type and value is not significant
|
||||
func (d *DN) EqualFold(other *DN) bool {
|
||||
if len(d.RDNs) != len(other.RDNs) {
|
||||
return false
|
||||
}
|
||||
for i := range d.RDNs {
|
||||
if !d.RDNs[i].EqualFold(other.RDNs[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// AncestorOfFold returns true if the other DN consists of at least one RDN followed by all the RDNs of the current DN.
|
||||
// Case of the attribute type and value is not significant
|
||||
func (d *DN) AncestorOfFold(other *DN) bool {
|
||||
if len(d.RDNs) >= len(other.RDNs) {
|
||||
return false
|
||||
}
|
||||
// Take the last `len(d.RDNs)` RDNs from the other DN to compare against
|
||||
otherRDNs := other.RDNs[len(other.RDNs)-len(d.RDNs):]
|
||||
for i := range d.RDNs {
|
||||
if !d.RDNs[i].EqualFold(otherRDNs[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Equal returns true if the RelativeDNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch).
|
||||
// Case of the attribute type is not significant
|
||||
func (r *RelativeDN) EqualFold(other *RelativeDN) bool {
|
||||
if len(r.Attributes) != len(other.Attributes) {
|
||||
return false
|
||||
}
|
||||
return r.hasAllAttributesFold(other.Attributes) && other.hasAllAttributesFold(r.Attributes)
|
||||
}
|
||||
|
||||
func (r *RelativeDN) hasAllAttributesFold(attrs []*AttributeTypeAndValue) bool {
|
||||
for _, attr := range attrs {
|
||||
found := false
|
||||
for _, myattr := range r.Attributes {
|
||||
if myattr.EqualFold(attr) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// EqualFold returns true if the AttributeTypeAndValue is equivalent to the specified AttributeTypeAndValue
|
||||
// Case of the attribute type and value is not significant
|
||||
func (a *AttributeTypeAndValue) EqualFold(other *AttributeTypeAndValue) bool {
|
||||
return strings.EqualFold(a.Type, other.Type) && strings.EqualFold(a.Value, other.Value)
|
||||
}
|
0
vendor/gopkg.in/ldap.v3/doc.go → vendor/github.com/go-ldap/ldap/v3/doc.go
generated
vendored
0
vendor/gopkg.in/ldap.v3/doc.go → vendor/github.com/go-ldap/ldap/v3/doc.go
generated
vendored
33
vendor/gopkg.in/ldap.v3/error.go → vendor/github.com/go-ldap/ldap/v3/error.go
generated
vendored
33
vendor/gopkg.in/ldap.v3/error.go → vendor/github.com/go-ldap/ldap/v3/error.go
generated
vendored
|
@ -3,7 +3,7 @@ package ldap
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
ber "gopkg.in/asn1-ber.v1"
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
// LDAP Result Codes
|
||||
|
@ -184,6 +184,8 @@ type Error struct {
|
|||
ResultCode uint16
|
||||
// MatchedDN is the matchedDN returned if any
|
||||
MatchedDN string
|
||||
// Packet is the returned packet if any
|
||||
Packet *ber.Packet
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
|
@ -201,19 +203,23 @@ func GetLDAPError(packet *ber.Packet) error {
|
|||
if len(packet.Children) >= 2 {
|
||||
response := packet.Children[1]
|
||||
if response == nil {
|
||||
return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty response in packet")}
|
||||
return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty response in packet"), Packet: packet}
|
||||
}
|
||||
if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) >= 3 {
|
||||
resultCode := uint16(response.Children[0].Value.(int64))
|
||||
if resultCode == 0 { // No error
|
||||
return nil
|
||||
}
|
||||
return &Error{ResultCode: resultCode, MatchedDN: response.Children[1].Value.(string),
|
||||
Err: fmt.Errorf("%s", response.Children[2].Value.(string))}
|
||||
return &Error{
|
||||
ResultCode: resultCode,
|
||||
MatchedDN: response.Children[1].Value.(string),
|
||||
Err: fmt.Errorf("%s", response.Children[2].Value.(string)),
|
||||
Packet: packet,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Error{ResultCode: ErrorNetwork, Err: fmt.Errorf("Invalid packet format")}
|
||||
return &Error{ResultCode: ErrorNetwork, Err: fmt.Errorf("Invalid packet format"), Packet: packet}
|
||||
}
|
||||
|
||||
// NewError creates an LDAP error with the given code and underlying error
|
||||
|
@ -221,8 +227,8 @@ func NewError(resultCode uint16, err error) error {
|
|||
return &Error{ResultCode: resultCode, Err: err}
|
||||
}
|
||||
|
||||
// IsErrorWithCode returns true if the given error is an LDAP error with the given result code
|
||||
func IsErrorWithCode(err error, desiredResultCode uint16) bool {
|
||||
// IsErrorAnyOf returns true if the given error is an LDAP error with any one of the given result codes
|
||||
func IsErrorAnyOf(err error, codes ...uint16) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
@ -232,5 +238,16 @@ func IsErrorWithCode(err error, desiredResultCode uint16) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
return serverError.ResultCode == desiredResultCode
|
||||
for _, code := range codes {
|
||||
if serverError.ResultCode == code {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsErrorWithCode returns true if the given error is an LDAP error with the given result code
|
||||
func IsErrorWithCode(err error, desiredResultCode uint16) bool {
|
||||
return IsErrorAnyOf(err, desiredResultCode)
|
||||
}
|
176
vendor/gopkg.in/ldap.v3/filter.go → vendor/github.com/go-ldap/ldap/v3/filter.go
generated
vendored
176
vendor/gopkg.in/ldap.v3/filter.go → vendor/github.com/go-ldap/ldap/v3/filter.go
generated
vendored
|
@ -5,10 +5,12 @@ import (
|
|||
hexpac "encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"gopkg.in/asn1-ber.v1"
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
// Filter choices
|
||||
|
@ -69,6 +71,8 @@ var MatchingRuleAssertionMap = map[uint64]string{
|
|||
MatchingRuleAssertionDNAttributes: "Matching Rule Assertion DN Attributes",
|
||||
}
|
||||
|
||||
var _SymbolAny = []byte{'*'}
|
||||
|
||||
// CompileFilter converts a string representation of a filter into a BER-encoded packet
|
||||
func CompileFilter(filter string) (*ber.Packet, error) {
|
||||
if len(filter) == 0 || filter[0] != '(' {
|
||||
|
@ -88,74 +92,75 @@ func CompileFilter(filter string) (*ber.Packet, error) {
|
|||
}
|
||||
|
||||
// DecompileFilter converts a packet representation of a filter into a string representation
|
||||
func DecompileFilter(packet *ber.Packet) (ret string, err error) {
|
||||
func DecompileFilter(packet *ber.Packet) (_ string, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = NewError(ErrorFilterDecompile, errors.New("ldap: error decompiling filter"))
|
||||
}
|
||||
}()
|
||||
ret = "("
|
||||
err = nil
|
||||
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.WriteByte('(')
|
||||
childStr := ""
|
||||
|
||||
switch packet.Tag {
|
||||
case FilterAnd:
|
||||
ret += "&"
|
||||
buf.WriteByte('&')
|
||||
for _, child := range packet.Children {
|
||||
childStr, err = DecompileFilter(child)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ret += childStr
|
||||
buf.WriteString(childStr)
|
||||
}
|
||||
case FilterOr:
|
||||
ret += "|"
|
||||
buf.WriteByte('|')
|
||||
for _, child := range packet.Children {
|
||||
childStr, err = DecompileFilter(child)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ret += childStr
|
||||
buf.WriteString(childStr)
|
||||
}
|
||||
case FilterNot:
|
||||
ret += "!"
|
||||
buf.WriteByte('!')
|
||||
childStr, err = DecompileFilter(packet.Children[0])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ret += childStr
|
||||
buf.WriteString(childStr)
|
||||
|
||||
case FilterSubstrings:
|
||||
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
|
||||
ret += "="
|
||||
buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes()))
|
||||
buf.WriteByte('=')
|
||||
for i, child := range packet.Children[1].Children {
|
||||
if i == 0 && child.Tag != FilterSubstringsInitial {
|
||||
ret += "*"
|
||||
buf.Write(_SymbolAny)
|
||||
}
|
||||
ret += EscapeFilter(ber.DecodeString(child.Data.Bytes()))
|
||||
buf.WriteString(EscapeFilter(ber.DecodeString(child.Data.Bytes())))
|
||||
if child.Tag != FilterSubstringsFinal {
|
||||
ret += "*"
|
||||
buf.Write(_SymbolAny)
|
||||
}
|
||||
}
|
||||
case FilterEqualityMatch:
|
||||
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
|
||||
ret += "="
|
||||
ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
|
||||
buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes()))
|
||||
buf.WriteByte('=')
|
||||
buf.WriteString(EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())))
|
||||
case FilterGreaterOrEqual:
|
||||
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
|
||||
ret += ">="
|
||||
ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
|
||||
buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes()))
|
||||
buf.WriteString(">=")
|
||||
buf.WriteString(EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())))
|
||||
case FilterLessOrEqual:
|
||||
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
|
||||
ret += "<="
|
||||
ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
|
||||
buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes()))
|
||||
buf.WriteString("<=")
|
||||
buf.WriteString(EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())))
|
||||
case FilterPresent:
|
||||
ret += ber.DecodeString(packet.Data.Bytes())
|
||||
ret += "=*"
|
||||
buf.WriteString(ber.DecodeString(packet.Data.Bytes()))
|
||||
buf.WriteString("=*")
|
||||
case FilterApproxMatch:
|
||||
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
|
||||
ret += "~="
|
||||
ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
|
||||
buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes()))
|
||||
buf.WriteString("~=")
|
||||
buf.WriteString(EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())))
|
||||
case FilterExtensibleMatch:
|
||||
attr := ""
|
||||
dnAttributes := false
|
||||
|
@ -176,21 +181,22 @@ func DecompileFilter(packet *ber.Packet) (ret string, err error) {
|
|||
}
|
||||
|
||||
if len(attr) > 0 {
|
||||
ret += attr
|
||||
buf.WriteString(attr)
|
||||
}
|
||||
if dnAttributes {
|
||||
ret += ":dn"
|
||||
buf.WriteString(":dn")
|
||||
}
|
||||
if len(matchingRule) > 0 {
|
||||
ret += ":"
|
||||
ret += matchingRule
|
||||
buf.WriteString(":")
|
||||
buf.WriteString(matchingRule)
|
||||
}
|
||||
ret += ":="
|
||||
ret += EscapeFilter(value)
|
||||
buf.WriteString(":=")
|
||||
buf.WriteString(EscapeFilter(value))
|
||||
}
|
||||
|
||||
ret += ")"
|
||||
return
|
||||
buf.WriteByte(')')
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func compileFilterSet(filter string, pos int, parent *ber.Packet) (int, error) {
|
||||
|
@ -253,11 +259,10 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
|
|||
)
|
||||
|
||||
state := stateReadingAttr
|
||||
|
||||
attribute := ""
|
||||
attribute := bytes.NewBuffer(nil)
|
||||
extensibleDNAttributes := false
|
||||
extensibleMatchingRule := ""
|
||||
condition := ""
|
||||
extensibleMatchingRule := bytes.NewBuffer(nil)
|
||||
condition := bytes.NewBuffer(nil)
|
||||
|
||||
for newPos < len(filter) {
|
||||
remainingFilter := filter[newPos:]
|
||||
|
@ -324,7 +329,7 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
|
|||
|
||||
// Still reading the attribute name
|
||||
default:
|
||||
attribute += fmt.Sprintf("%c", currentRune)
|
||||
attribute.WriteRune(currentRune)
|
||||
newPos += currentWidth
|
||||
}
|
||||
|
||||
|
@ -338,13 +343,13 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
|
|||
|
||||
// Still reading the matching rule oid
|
||||
default:
|
||||
extensibleMatchingRule += fmt.Sprintf("%c", currentRune)
|
||||
extensibleMatchingRule.WriteRune(currentRune)
|
||||
newPos += currentWidth
|
||||
}
|
||||
|
||||
case stateReadingCondition:
|
||||
// append to the condition
|
||||
condition += fmt.Sprintf("%c", currentRune)
|
||||
condition.WriteRune(currentRune)
|
||||
newPos += currentWidth
|
||||
}
|
||||
}
|
||||
|
@ -368,17 +373,17 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
|
|||
// }
|
||||
|
||||
// Include the matching rule oid, if specified
|
||||
if len(extensibleMatchingRule) > 0 {
|
||||
packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionMatchingRule, extensibleMatchingRule, MatchingRuleAssertionMap[MatchingRuleAssertionMatchingRule]))
|
||||
if extensibleMatchingRule.Len() > 0 {
|
||||
packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionMatchingRule, extensibleMatchingRule.String(), MatchingRuleAssertionMap[MatchingRuleAssertionMatchingRule]))
|
||||
}
|
||||
|
||||
// Include the attribute, if specified
|
||||
if len(attribute) > 0 {
|
||||
packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionType, attribute, MatchingRuleAssertionMap[MatchingRuleAssertionType]))
|
||||
if attribute.Len() > 0 {
|
||||
packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionType, attribute.String(), MatchingRuleAssertionMap[MatchingRuleAssertionType]))
|
||||
}
|
||||
|
||||
// Add the value (only required child)
|
||||
encodedString, encodeErr := escapedStringToEncodedBytes(condition)
|
||||
encodedString, encodeErr := decodeEscapedSymbols(condition.Bytes())
|
||||
if encodeErr != nil {
|
||||
return packet, newPos, encodeErr
|
||||
}
|
||||
|
@ -389,16 +394,16 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
|
|||
packet.AppendChild(ber.NewBoolean(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionDNAttributes, extensibleDNAttributes, MatchingRuleAssertionMap[MatchingRuleAssertionDNAttributes]))
|
||||
}
|
||||
|
||||
case packet.Tag == FilterEqualityMatch && condition == "*":
|
||||
packet = ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterPresent, attribute, FilterMap[FilterPresent])
|
||||
case packet.Tag == FilterEqualityMatch && strings.Contains(condition, "*"):
|
||||
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
|
||||
case packet.Tag == FilterEqualityMatch && bytes.Equal(condition.Bytes(), _SymbolAny):
|
||||
packet = ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterPresent, attribute.String(), FilterMap[FilterPresent])
|
||||
case packet.Tag == FilterEqualityMatch && bytes.Index(condition.Bytes(), _SymbolAny) > -1:
|
||||
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute.String(), "Attribute"))
|
||||
packet.Tag = FilterSubstrings
|
||||
packet.Description = FilterMap[uint64(packet.Tag)]
|
||||
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
|
||||
parts := strings.Split(condition, "*")
|
||||
parts := bytes.Split(condition.Bytes(), _SymbolAny)
|
||||
for i, part := range parts {
|
||||
if part == "" {
|
||||
if len(part) == 0 {
|
||||
continue
|
||||
}
|
||||
var tag ber.Tag
|
||||
|
@ -410,7 +415,7 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
|
|||
default:
|
||||
tag = FilterSubstringsAny
|
||||
}
|
||||
encodedString, encodeErr := escapedStringToEncodedBytes(part)
|
||||
encodedString, encodeErr := decodeEscapedSymbols(part)
|
||||
if encodeErr != nil {
|
||||
return packet, newPos, encodeErr
|
||||
}
|
||||
|
@ -418,11 +423,11 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
|
|||
}
|
||||
packet.AppendChild(seq)
|
||||
default:
|
||||
encodedString, encodeErr := escapedStringToEncodedBytes(condition)
|
||||
encodedString, encodeErr := decodeEscapedSymbols(condition.Bytes())
|
||||
if encodeErr != nil {
|
||||
return packet, newPos, encodeErr
|
||||
}
|
||||
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
|
||||
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute.String(), "Attribute"))
|
||||
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, encodedString, "Condition"))
|
||||
}
|
||||
|
||||
|
@ -432,34 +437,51 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
|
|||
}
|
||||
|
||||
// Convert from "ABC\xx\xx\xx" form to literal bytes for transport
|
||||
func escapedStringToEncodedBytes(escapedString string) (string, error) {
|
||||
var buffer bytes.Buffer
|
||||
i := 0
|
||||
for i < len(escapedString) {
|
||||
currentRune, currentWidth := utf8.DecodeRuneInString(escapedString[i:])
|
||||
if currentRune == utf8.RuneError {
|
||||
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", i))
|
||||
func decodeEscapedSymbols(src []byte) (string, error) {
|
||||
|
||||
var (
|
||||
buffer bytes.Buffer
|
||||
offset int
|
||||
reader = bytes.NewReader(src)
|
||||
byteHex []byte
|
||||
byteVal []byte
|
||||
)
|
||||
|
||||
for {
|
||||
runeVal, runeSize, err := reader.ReadRune()
|
||||
if err == io.EOF {
|
||||
return buffer.String(), nil
|
||||
} else if err != nil {
|
||||
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: failed to read filter: %v", err))
|
||||
} else if runeVal == unicode.ReplacementChar {
|
||||
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", offset))
|
||||
}
|
||||
|
||||
// Check for escaped hex characters and convert them to their literal value for transport.
|
||||
if currentRune == '\\' {
|
||||
if runeVal == '\\' {
|
||||
// http://tools.ietf.org/search/rfc4515
|
||||
// \ (%x5C) is not a valid character unless it is followed by two HEX characters due to not
|
||||
// being a member of UTF1SUBSET.
|
||||
if i+2 > len(escapedString) {
|
||||
return "", NewError(ErrorFilterCompile, errors.New("ldap: missing characters for escape in filter"))
|
||||
if byteHex == nil {
|
||||
byteHex = make([]byte, 2)
|
||||
byteVal = make([]byte, 1)
|
||||
}
|
||||
escByte, decodeErr := hexpac.DecodeString(escapedString[i+1 : i+3])
|
||||
if decodeErr != nil {
|
||||
return "", NewError(ErrorFilterCompile, errors.New("ldap: invalid characters for escape in filter"))
|
||||
|
||||
if _, err := io.ReadFull(reader, byteHex); err != nil {
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
return "", NewError(ErrorFilterCompile, errors.New("ldap: missing characters for escape in filter"))
|
||||
}
|
||||
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: invalid characters for escape in filter: %v", err))
|
||||
}
|
||||
buffer.WriteByte(escByte[0])
|
||||
i += 2 // +1 from end of loop, so 3 total for \xx.
|
||||
|
||||
if _, err := hexpac.Decode(byteVal, byteHex); err != nil {
|
||||
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: invalid characters for escape in filter: %v", err))
|
||||
}
|
||||
|
||||
buffer.Write(byteVal)
|
||||
} else {
|
||||
buffer.WriteRune(currentRune)
|
||||
buffer.WriteRune(runeVal)
|
||||
}
|
||||
|
||||
i += currentWidth
|
||||
offset += runeSize
|
||||
}
|
||||
return buffer.String(), nil
|
||||
}
|
51
vendor/gopkg.in/ldap.v3/ldap.go → vendor/github.com/go-ldap/ldap/v3/ldap.go
generated
vendored
51
vendor/gopkg.in/ldap.v3/ldap.go → vendor/github.com/go-ldap/ldap/v3/ldap.go
generated
vendored
|
@ -5,7 +5,7 @@ import (
|
|||
"io/ioutil"
|
||||
"os"
|
||||
|
||||
ber "gopkg.in/asn1-ber.v1"
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
// LDAP Application Codes
|
||||
|
@ -223,32 +223,26 @@ func addControlDescriptions(packet *ber.Packet) error {
|
|||
if child.Tag == 0 {
|
||||
//Warning
|
||||
warningPacket := child.Children[0]
|
||||
packet, err := ber.DecodePacketErr(warningPacket.Data.Bytes())
|
||||
val, err := ber.ParseInt64(warningPacket.Data.Bytes())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode data bytes: %s", err)
|
||||
}
|
||||
val, ok := packet.Value.(int64)
|
||||
if ok {
|
||||
if warningPacket.Tag == 0 {
|
||||
//timeBeforeExpiration
|
||||
value.Description += " (TimeBeforeExpiration)"
|
||||
warningPacket.Value = val
|
||||
} else if warningPacket.Tag == 1 {
|
||||
//graceAuthNsRemaining
|
||||
value.Description += " (GraceAuthNsRemaining)"
|
||||
warningPacket.Value = val
|
||||
}
|
||||
if warningPacket.Tag == 0 {
|
||||
//timeBeforeExpiration
|
||||
value.Description += " (TimeBeforeExpiration)"
|
||||
warningPacket.Value = val
|
||||
} else if warningPacket.Tag == 1 {
|
||||
//graceAuthNsRemaining
|
||||
value.Description += " (GraceAuthNsRemaining)"
|
||||
warningPacket.Value = val
|
||||
}
|
||||
} else if child.Tag == 1 {
|
||||
// Error
|
||||
packet, err := ber.DecodePacketErr(child.Data.Bytes())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode data bytes: %s", err)
|
||||
}
|
||||
val, ok := packet.Value.(int8)
|
||||
if !ok {
|
||||
val = -1
|
||||
bs := child.Data.Bytes()
|
||||
if len(bs) != 1 || bs[0] > 8 {
|
||||
return fmt.Errorf("failed to decode data bytes: %s", "invalid PasswordPolicyResponse enum value")
|
||||
}
|
||||
val := int8(bs[0])
|
||||
child.Description = "Error"
|
||||
child.Value = val
|
||||
}
|
||||
|
@ -269,13 +263,18 @@ func addRequestDescriptions(packet *ber.Packet) error {
|
|||
}
|
||||
|
||||
func addDefaultLDAPResponseDescriptions(packet *ber.Packet) error {
|
||||
err := GetLDAPError(packet)
|
||||
if err == nil {
|
||||
return nil
|
||||
resultCode := uint16(LDAPResultSuccess)
|
||||
matchedDN := ""
|
||||
description := "Success"
|
||||
if err := GetLDAPError(packet); err != nil {
|
||||
resultCode = err.(*Error).ResultCode
|
||||
matchedDN = err.(*Error).MatchedDN
|
||||
description = "Error Message"
|
||||
}
|
||||
packet.Children[1].Children[0].Description = "Result Code (" + LDAPResultCodeMap[err.(*Error).ResultCode] + ")"
|
||||
packet.Children[1].Children[1].Description = "Matched DN (" + err.(*Error).MatchedDN + ")"
|
||||
packet.Children[1].Children[2].Description = "Error Message"
|
||||
|
||||
packet.Children[1].Children[0].Description = "Result Code (" + LDAPResultCodeMap[resultCode] + ")"
|
||||
packet.Children[1].Children[1].Description = "Matched DN (" + matchedDN + ")"
|
||||
packet.Children[1].Children[2].Description = description
|
||||
if len(packet.Children[1].Children) > 3 {
|
||||
packet.Children[1].Children[3].Description = "Referral"
|
||||
}
|
39
vendor/gopkg.in/ldap.v3/moddn.go → vendor/github.com/go-ldap/ldap/v3/moddn.go
generated
vendored
39
vendor/gopkg.in/ldap.v3/moddn.go → vendor/github.com/go-ldap/ldap/v3/moddn.go
generated
vendored
|
@ -1,19 +1,9 @@
|
|||
// Package ldap - moddn.go contains ModifyDN functionality
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc4511
|
||||
// ModifyDNRequest ::= [APPLICATION 12] SEQUENCE {
|
||||
// entry LDAPDN,
|
||||
// newrdn RelativeLDAPDN,
|
||||
// deleteoldrdn BOOLEAN,
|
||||
// newSuperior [0] LDAPDN OPTIONAL }
|
||||
//
|
||||
//
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
ber "gopkg.in/asn1-ber.v1"
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
// ModifyDNRequest holds the request to modify a DN
|
||||
|
@ -22,6 +12,8 @@ type ModifyDNRequest struct {
|
|||
NewRDN string
|
||||
DeleteOldRDN bool
|
||||
NewSuperior string
|
||||
// Controls hold optional controls to send with the request
|
||||
Controls []Control
|
||||
}
|
||||
|
||||
// NewModifyDNRequest creates a new request which can be passed to ModifyDN().
|
||||
|
@ -45,16 +37,39 @@ func NewModifyDNRequest(dn string, rdn string, delOld bool, newSup string) *Modi
|
|||
}
|
||||
}
|
||||
|
||||
// NewModifyDNWithControlsRequest creates a new request which can be passed to ModifyDN()
|
||||
// and also allows setting LDAP request controls.
|
||||
//
|
||||
// Refer NewModifyDNRequest for other parameters
|
||||
func NewModifyDNWithControlsRequest(dn string, rdn string, delOld bool,
|
||||
newSup string, controls []Control) *ModifyDNRequest {
|
||||
return &ModifyDNRequest{
|
||||
DN: dn,
|
||||
NewRDN: rdn,
|
||||
DeleteOldRDN: delOld,
|
||||
NewSuperior: newSup,
|
||||
Controls: controls,
|
||||
}
|
||||
}
|
||||
|
||||
func (req *ModifyDNRequest) appendTo(envelope *ber.Packet) error {
|
||||
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyDNRequest, nil, "Modify DN Request")
|
||||
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN"))
|
||||
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.NewRDN, "New RDN"))
|
||||
pkt.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, req.DeleteOldRDN, "Delete old RDN"))
|
||||
if req.DeleteOldRDN {
|
||||
buf := []byte{0xff}
|
||||
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, string(buf), "Delete old RDN"))
|
||||
} else {
|
||||
pkt.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, req.DeleteOldRDN, "Delete old RDN"))
|
||||
}
|
||||
if req.NewSuperior != "" {
|
||||
pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.NewSuperior, "New Superior"))
|
||||
}
|
||||
|
||||
envelope.AppendChild(pkt)
|
||||
if len(req.Controls) > 0 {
|
||||
envelope.AppendChild(encodeControls(req.Controls))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
84
vendor/gopkg.in/ldap.v3/modify.go → vendor/github.com/go-ldap/ldap/v3/modify.go
generated
vendored
84
vendor/gopkg.in/ldap.v3/modify.go → vendor/github.com/go-ldap/ldap/v3/modify.go
generated
vendored
|
@ -1,41 +1,18 @@
|
|||
// File contains Modify functionality
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc4511
|
||||
//
|
||||
// ModifyRequest ::= [APPLICATION 6] SEQUENCE {
|
||||
// object LDAPDN,
|
||||
// changes SEQUENCE OF change SEQUENCE {
|
||||
// operation ENUMERATED {
|
||||
// add (0),
|
||||
// delete (1),
|
||||
// replace (2),
|
||||
// ... },
|
||||
// modification PartialAttribute } }
|
||||
//
|
||||
// PartialAttribute ::= SEQUENCE {
|
||||
// type AttributeDescription,
|
||||
// vals SET OF value AttributeValue }
|
||||
//
|
||||
// AttributeDescription ::= LDAPString
|
||||
// -- Constrained to <attributedescription>
|
||||
// -- [RFC4512]
|
||||
//
|
||||
// AttributeValue ::= OCTET STRING
|
||||
//
|
||||
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
ber "gopkg.in/asn1-ber.v1"
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
// Change operation choices
|
||||
const (
|
||||
AddAttribute = 0
|
||||
DeleteAttribute = 1
|
||||
ReplaceAttribute = 2
|
||||
AddAttribute = 0
|
||||
DeleteAttribute = 1
|
||||
ReplaceAttribute = 2
|
||||
IncrementAttribute = 3 // (https://tools.ietf.org/html/rfc4525)
|
||||
)
|
||||
|
||||
// PartialAttribute for a ModifyRequest as defined in https://tools.ietf.org/html/rfc4511
|
||||
|
@ -97,6 +74,11 @@ func (req *ModifyRequest) Replace(attrType string, attrVals []string) {
|
|||
req.appendChange(ReplaceAttribute, attrType, attrVals)
|
||||
}
|
||||
|
||||
// Increment appends the given attribute to the list of changes to be made
|
||||
func (req *ModifyRequest) Increment(attrType string, attrVal string) {
|
||||
req.appendChange(IncrementAttribute, attrType, []string{attrVal})
|
||||
}
|
||||
|
||||
func (req *ModifyRequest) appendChange(operation uint, attrType string, attrVals []string) {
|
||||
req.Changes = append(req.Changes, Change{operation, PartialAttribute{Type: attrType, Vals: attrVals}})
|
||||
}
|
||||
|
@ -149,3 +131,47 @@ func (l *Conn) Modify(modifyRequest *ModifyRequest) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ModifyResult holds the server's response to a modify request
|
||||
type ModifyResult struct {
|
||||
// Controls are the returned controls
|
||||
Controls []Control
|
||||
}
|
||||
|
||||
// ModifyWithResult performs the ModifyRequest and returns the result
|
||||
func (l *Conn) ModifyWithResult(modifyRequest *ModifyRequest) (*ModifyResult, error) {
|
||||
msgCtx, err := l.doRequest(modifyRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer l.finishMessage(msgCtx)
|
||||
|
||||
result := &ModifyResult{
|
||||
Controls: make([]Control, 0),
|
||||
}
|
||||
|
||||
l.Debug.Printf("%d: waiting for response", msgCtx.id)
|
||||
packet, err := l.readPacket(msgCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch packet.Children[1].Tag {
|
||||
case ApplicationModifyResponse:
|
||||
err := GetLDAPError(packet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(packet.Children) == 3 {
|
||||
for _, child := range packet.Children[2].Children {
|
||||
decodedChild, err := DecodeControl(child)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to decode child control: " + err.Error())
|
||||
}
|
||||
result.Controls = append(result.Controls, decodedChild)
|
||||
}
|
||||
}
|
||||
}
|
||||
l.Debug.Printf("%d: returning", msgCtx.id)
|
||||
return result, nil
|
||||
}
|
|
@ -1,14 +1,9 @@
|
|||
// This file contains the password modify extended operation as specified in rfc 3062
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc3062
|
||||
//
|
||||
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
ber "gopkg.in/asn1-ber.v1"
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -61,7 +56,7 @@ func (req *PasswordModifyRequest) appendTo(envelope *ber.Packet) error {
|
|||
|
||||
// NewPasswordModifyRequest creates a new PasswordModifyRequest
|
||||
//
|
||||
// According to the RFC 3602:
|
||||
// According to the RFC 3602 (https://tools.ietf.org/html/rfc3062):
|
||||
// userIdentity is a string representing the user associated with the request.
|
||||
// This string may or may not be an LDAPDN (RFC 2253).
|
||||
// If userIdentity is empty then the operation will act on the user associated
|
11
vendor/gopkg.in/ldap.v3/request.go → vendor/github.com/go-ldap/ldap/v3/request.go
generated
vendored
11
vendor/gopkg.in/ldap.v3/request.go → vendor/github.com/go-ldap/ldap/v3/request.go
generated
vendored
|
@ -3,12 +3,13 @@ package ldap
|
|||
import (
|
||||
"errors"
|
||||
|
||||
ber "gopkg.in/asn1-ber.v1"
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
var (
|
||||
errRespChanClosed = errors.New("ldap: response channel closed")
|
||||
errCouldNotRetMsg = errors.New("ldap: could not retrieve message")
|
||||
ErrNilConnection = errors.New("ldap: conn is nil, expected net.Conn")
|
||||
)
|
||||
|
||||
type request interface {
|
||||
|
@ -22,6 +23,10 @@ func (f requestFunc) appendTo(p *ber.Packet) error {
|
|||
}
|
||||
|
||||
func (l *Conn) doRequest(req request) (*messageContext, error) {
|
||||
if l == nil || l.conn == nil {
|
||||
return nil, ErrNilConnection
|
||||
}
|
||||
|
||||
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
|
||||
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
|
||||
if err := req.appendTo(packet); err != nil {
|
||||
|
@ -29,7 +34,7 @@ func (l *Conn) doRequest(req request) (*messageContext, error) {
|
|||
}
|
||||
|
||||
if l.Debug {
|
||||
ber.PrintPacket(packet)
|
||||
l.Debug.PrintPacket(packet)
|
||||
}
|
||||
|
||||
msgCtx, err := l.sendMessage(packet)
|
||||
|
@ -60,7 +65,7 @@ func (l *Conn) readPacket(msgCtx *messageContext) (*ber.Packet, error) {
|
|||
if err = addLDAPDescriptions(packet); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ber.PrintPacket(packet)
|
||||
l.Debug.PrintPacket(packet)
|
||||
}
|
||||
return packet, nil
|
||||
}
|
140
vendor/gopkg.in/ldap.v3/search.go → vendor/github.com/go-ldap/ldap/v3/search.go
generated
vendored
140
vendor/gopkg.in/ldap.v3/search.go → vendor/github.com/go-ldap/ldap/v3/search.go
generated
vendored
|
@ -1,58 +1,3 @@
|
|||
// File contains Search functionality
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc4511
|
||||
//
|
||||
// SearchRequest ::= [APPLICATION 3] SEQUENCE {
|
||||
// baseObject LDAPDN,
|
||||
// scope ENUMERATED {
|
||||
// baseObject (0),
|
||||
// singleLevel (1),
|
||||
// wholeSubtree (2),
|
||||
// ... },
|
||||
// derefAliases ENUMERATED {
|
||||
// neverDerefAliases (0),
|
||||
// derefInSearching (1),
|
||||
// derefFindingBaseObj (2),
|
||||
// derefAlways (3) },
|
||||
// sizeLimit INTEGER (0 .. maxInt),
|
||||
// timeLimit INTEGER (0 .. maxInt),
|
||||
// typesOnly BOOLEAN,
|
||||
// filter Filter,
|
||||
// attributes AttributeSelection }
|
||||
//
|
||||
// AttributeSelection ::= SEQUENCE OF selector LDAPString
|
||||
// -- The LDAPString is constrained to
|
||||
// -- <attributeSelector> in Section 4.5.1.8
|
||||
//
|
||||
// Filter ::= CHOICE {
|
||||
// and [0] SET SIZE (1..MAX) OF filter Filter,
|
||||
// or [1] SET SIZE (1..MAX) OF filter Filter,
|
||||
// not [2] Filter,
|
||||
// equalityMatch [3] AttributeValueAssertion,
|
||||
// substrings [4] SubstringFilter,
|
||||
// greaterOrEqual [5] AttributeValueAssertion,
|
||||
// lessOrEqual [6] AttributeValueAssertion,
|
||||
// present [7] AttributeDescription,
|
||||
// approxMatch [8] AttributeValueAssertion,
|
||||
// extensibleMatch [9] MatchingRuleAssertion,
|
||||
// ... }
|
||||
//
|
||||
// SubstringFilter ::= SEQUENCE {
|
||||
// type AttributeDescription,
|
||||
// substrings SEQUENCE SIZE (1..MAX) OF substring CHOICE {
|
||||
// initial [0] AssertionValue, -- can occur at most once
|
||||
// any [1] AssertionValue,
|
||||
// final [2] AssertionValue } -- can occur at most once
|
||||
// }
|
||||
//
|
||||
// MatchingRuleAssertion ::= SEQUENCE {
|
||||
// matchingRule [1] MatchingRuleId OPTIONAL,
|
||||
// type [2] AttributeDescription OPTIONAL,
|
||||
// matchValue [3] AssertionValue,
|
||||
// dnAttributes [4] BOOLEAN DEFAULT FALSE }
|
||||
//
|
||||
//
|
||||
|
||||
package ldap
|
||||
|
||||
import (
|
||||
|
@ -61,7 +6,7 @@ import (
|
|||
"sort"
|
||||
"strings"
|
||||
|
||||
ber "gopkg.in/asn1-ber.v1"
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
// scope choices
|
||||
|
@ -132,6 +77,17 @@ func (e *Entry) GetAttributeValues(attribute string) []string {
|
|||
return []string{}
|
||||
}
|
||||
|
||||
// GetEqualFoldAttributeValues returns the values for the named attribute, or an
|
||||
// empty list. Attribute matching is done with strings.EqualFold.
|
||||
func (e *Entry) GetEqualFoldAttributeValues(attribute string) []string {
|
||||
for _, attr := range e.Attributes {
|
||||
if strings.EqualFold(attribute, attr.Name) {
|
||||
return attr.Values
|
||||
}
|
||||
}
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// GetRawAttributeValues returns the byte values for the named attribute, or an empty list
|
||||
func (e *Entry) GetRawAttributeValues(attribute string) [][]byte {
|
||||
for _, attr := range e.Attributes {
|
||||
|
@ -142,6 +98,16 @@ func (e *Entry) GetRawAttributeValues(attribute string) [][]byte {
|
|||
return [][]byte{}
|
||||
}
|
||||
|
||||
// GetEqualFoldRawAttributeValues returns the byte values for the named attribute, or an empty list
|
||||
func (e *Entry) GetEqualFoldRawAttributeValues(attribute string) [][]byte {
|
||||
for _, attr := range e.Attributes {
|
||||
if strings.EqualFold(attr.Name, attribute) {
|
||||
return attr.ByteValues
|
||||
}
|
||||
}
|
||||
return [][]byte{}
|
||||
}
|
||||
|
||||
// GetAttributeValue returns the first value for the named attribute, or ""
|
||||
func (e *Entry) GetAttributeValue(attribute string) string {
|
||||
values := e.GetAttributeValues(attribute)
|
||||
|
@ -151,6 +117,16 @@ func (e *Entry) GetAttributeValue(attribute string) string {
|
|||
return values[0]
|
||||
}
|
||||
|
||||
// GetEqualFoldAttributeValue returns the first value for the named attribute, or "".
|
||||
// Attribute comparison is done with strings.EqualFold.
|
||||
func (e *Entry) GetEqualFoldAttributeValue(attribute string) string {
|
||||
values := e.GetEqualFoldAttributeValues(attribute)
|
||||
if len(values) == 0 {
|
||||
return ""
|
||||
}
|
||||
return values[0]
|
||||
}
|
||||
|
||||
// GetRawAttributeValue returns the first value for the named attribute, or an empty slice
|
||||
func (e *Entry) GetRawAttributeValue(attribute string) []byte {
|
||||
values := e.GetRawAttributeValues(attribute)
|
||||
|
@ -160,6 +136,15 @@ func (e *Entry) GetRawAttributeValue(attribute string) []byte {
|
|||
return values[0]
|
||||
}
|
||||
|
||||
// GetEqualFoldRawAttributeValue returns the first value for the named attribute, or an empty slice
|
||||
func (e *Entry) GetEqualFoldRawAttributeValue(attribute string) []byte {
|
||||
values := e.GetEqualFoldRawAttributeValues(attribute)
|
||||
if len(values) == 0 {
|
||||
return []byte{}
|
||||
}
|
||||
return values[0]
|
||||
}
|
||||
|
||||
// Print outputs a human-readable description
|
||||
func (e *Entry) Print() {
|
||||
fmt.Printf("DN: %s\n", e.DN)
|
||||
|
@ -386,33 +371,26 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
|
|||
for {
|
||||
packet, err := l.readPacket(msgCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return result, err
|
||||
}
|
||||
|
||||
switch packet.Children[1].Tag {
|
||||
case 4:
|
||||
entry := new(Entry)
|
||||
entry.DN = packet.Children[1].Children[0].Value.(string)
|
||||
for _, child := range packet.Children[1].Children[1].Children {
|
||||
attr := new(EntryAttribute)
|
||||
attr.Name = child.Children[0].Value.(string)
|
||||
for _, value := range child.Children[1].Children {
|
||||
attr.Values = append(attr.Values, value.Value.(string))
|
||||
attr.ByteValues = append(attr.ByteValues, value.ByteValue)
|
||||
}
|
||||
entry.Attributes = append(entry.Attributes, attr)
|
||||
entry := &Entry{
|
||||
DN: packet.Children[1].Children[0].Value.(string),
|
||||
Attributes: unpackAttributes(packet.Children[1].Children[1].Children),
|
||||
}
|
||||
result.Entries = append(result.Entries, entry)
|
||||
case 5:
|
||||
err := GetLDAPError(packet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return result, err
|
||||
}
|
||||
if len(packet.Children) == 3 {
|
||||
for _, child := range packet.Children[2].Children {
|
||||
decodedChild, err := DecodeControl(child)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode child control: %s", err)
|
||||
return result, fmt.Errorf("failed to decode child control: %s", err)
|
||||
}
|
||||
result.Controls = append(result.Controls, decodedChild)
|
||||
}
|
||||
|
@ -423,3 +401,27 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// unpackAttributes will extract all given LDAP attributes and it's values
|
||||
// from the ber.Packet
|
||||
func unpackAttributes(children []*ber.Packet) []*EntryAttribute {
|
||||
entries := make([]*EntryAttribute, len(children))
|
||||
for i, child := range children {
|
||||
length := len(child.Children[1].Children)
|
||||
entry := &EntryAttribute{
|
||||
Name: child.Children[0].Value.(string),
|
||||
// pre-allocate the slice since we can determine
|
||||
// the number of attributes at this point
|
||||
Values: make([]string, length),
|
||||
ByteValues: make([][]byte, length),
|
||||
}
|
||||
|
||||
for i, value := range child.Children[1].Children {
|
||||
entry.ByteValues[i] = value.ByteValue
|
||||
entry.Values[i] = value.Value.(string)
|
||||
}
|
||||
entries[i] = entry
|
||||
}
|
||||
|
||||
return entries
|
||||
}
|
37
vendor/github.com/go-ldap/ldap/v3/unbind.go
generated
vendored
Normal file
37
vendor/github.com/go-ldap/ldap/v3/unbind.go
generated
vendored
Normal file
|
@ -0,0 +1,37 @@
|
|||
package ldap
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
var ErrConnUnbound = NewError(ErrorNetwork, errors.New("ldap: connection is closed"))
|
||||
|
||||
type unbindRequest struct{}
|
||||
|
||||
func (unbindRequest) appendTo(envelope *ber.Packet) error {
|
||||
envelope.AppendChild(ber.Encode(ber.ClassApplication, ber.TypePrimitive, ApplicationUnbindRequest, nil, ApplicationMap[ApplicationUnbindRequest]))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unbind will perform an unbind request. The Unbind operation
|
||||
// should be thought of as the "quit" operation.
|
||||
// See https://datatracker.ietf.org/doc/html/rfc4511#section-4.3
|
||||
func (l *Conn) Unbind() error {
|
||||
if l.IsClosing() {
|
||||
return ErrConnUnbound
|
||||
}
|
||||
|
||||
_, err := l.doRequest(unbindRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Sending an unbindRequest will make the connection unusable.
|
||||
// Pending requests will fail with:
|
||||
// LDAP Result Code 200 "Network Error": ldap: response channel closed
|
||||
l.Close()
|
||||
|
||||
return nil
|
||||
}
|
91
vendor/github.com/go-ldap/ldap/v3/whoami.go
generated
vendored
Normal file
91
vendor/github.com/go-ldap/ldap/v3/whoami.go
generated
vendored
Normal file
|
@ -0,0 +1,91 @@
|
|||
package ldap
|
||||
|
||||
// This file contains the "Who Am I?" extended operation as specified in rfc 4532
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc4532
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
ber "github.com/go-asn1-ber/asn1-ber"
|
||||
)
|
||||
|
||||
type whoAmIRequest bool
|
||||
|
||||
// WhoAmIResult is returned by the WhoAmI() call
|
||||
type WhoAmIResult struct {
|
||||
AuthzID string
|
||||
}
|
||||
|
||||
func (r whoAmIRequest) encode() (*ber.Packet, error) {
|
||||
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Who Am I? Extended Operation")
|
||||
request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, ControlTypeWhoAmI, "Extended Request Name: Who Am I? OID"))
|
||||
return request, nil
|
||||
}
|
||||
|
||||
// WhoAmI returns the authzId the server thinks we are, you may pass controls
|
||||
// like a Proxied Authorization control
|
||||
func (l *Conn) WhoAmI(controls []Control) (*WhoAmIResult, error) {
|
||||
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
|
||||
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
|
||||
req := whoAmIRequest(true)
|
||||
encodedWhoAmIRequest, err := req.encode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
packet.AppendChild(encodedWhoAmIRequest)
|
||||
|
||||
if len(controls) != 0 {
|
||||
packet.AppendChild(encodeControls(controls))
|
||||
}
|
||||
|
||||
l.Debug.PrintPacket(packet)
|
||||
|
||||
msgCtx, err := l.sendMessage(packet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer l.finishMessage(msgCtx)
|
||||
|
||||
result := &WhoAmIResult{}
|
||||
|
||||
l.Debug.Printf("%d: waiting for response", msgCtx.id)
|
||||
packetResponse, ok := <-msgCtx.responses
|
||||
if !ok {
|
||||
return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
|
||||
}
|
||||
packet, err = packetResponse.ReadPacket()
|
||||
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if packet == nil {
|
||||
return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve message"))
|
||||
}
|
||||
|
||||
if l.Debug {
|
||||
if err := addLDAPDescriptions(packet); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ber.PrintPacket(packet)
|
||||
}
|
||||
|
||||
if packet.Children[1].Tag == ApplicationExtendedResponse {
|
||||
if err := GetLDAPError(packet); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, NewError(ErrorUnexpectedResponse, fmt.Errorf("Unexpected Response: %d", packet.Children[1].Tag))
|
||||
}
|
||||
|
||||
extendedResponse := packet.Children[1]
|
||||
for _, child := range extendedResponse.Children {
|
||||
if child.Tag == 11 {
|
||||
result.AuthzID = ber.DecodeString(child.Data.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
129
vendor/github.com/go-sql-driver/mysql/.travis.yml
generated
vendored
129
vendor/github.com/go-sql-driver/mysql/.travis.yml
generated
vendored
|
@ -1,129 +0,0 @@
|
|||
sudo: false
|
||||
language: go
|
||||
go:
|
||||
- 1.10.x
|
||||
- 1.11.x
|
||||
- 1.12.x
|
||||
- 1.13.x
|
||||
- master
|
||||
|
||||
before_install:
|
||||
- go get golang.org/x/tools/cmd/cover
|
||||
- go get github.com/mattn/goveralls
|
||||
|
||||
before_script:
|
||||
- echo -e "[server]\ninnodb_log_file_size=256MB\ninnodb_buffer_pool_size=512MB\nmax_allowed_packet=16MB" | sudo tee -a /etc/mysql/my.cnf
|
||||
- sudo service mysql restart
|
||||
- .travis/wait_mysql.sh
|
||||
- mysql -e 'create database gotest;'
|
||||
|
||||
matrix:
|
||||
include:
|
||||
- env: DB=MYSQL8
|
||||
sudo: required
|
||||
dist: trusty
|
||||
go: 1.10.x
|
||||
services:
|
||||
- docker
|
||||
before_install:
|
||||
- go get golang.org/x/tools/cmd/cover
|
||||
- go get github.com/mattn/goveralls
|
||||
- docker pull mysql:8.0
|
||||
- docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret
|
||||
mysql:8.0 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1
|
||||
- cp .travis/docker.cnf ~/.my.cnf
|
||||
- .travis/wait_mysql.sh
|
||||
before_script:
|
||||
- export MYSQL_TEST_USER=gotest
|
||||
- export MYSQL_TEST_PASS=secret
|
||||
- export MYSQL_TEST_ADDR=127.0.0.1:3307
|
||||
- export MYSQL_TEST_CONCURRENT=1
|
||||
|
||||
- env: DB=MYSQL57
|
||||
sudo: required
|
||||
dist: trusty
|
||||
go: 1.10.x
|
||||
services:
|
||||
- docker
|
||||
before_install:
|
||||
- go get golang.org/x/tools/cmd/cover
|
||||
- go get github.com/mattn/goveralls
|
||||
- docker pull mysql:5.7
|
||||
- docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret
|
||||
mysql:5.7 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1
|
||||
- cp .travis/docker.cnf ~/.my.cnf
|
||||
- .travis/wait_mysql.sh
|
||||
before_script:
|
||||
- export MYSQL_TEST_USER=gotest
|
||||
- export MYSQL_TEST_PASS=secret
|
||||
- export MYSQL_TEST_ADDR=127.0.0.1:3307
|
||||
- export MYSQL_TEST_CONCURRENT=1
|
||||
|
||||
- env: DB=MARIA55
|
||||
sudo: required
|
||||
dist: trusty
|
||||
go: 1.10.x
|
||||
services:
|
||||
- docker
|
||||
before_install:
|
||||
- go get golang.org/x/tools/cmd/cover
|
||||
- go get github.com/mattn/goveralls
|
||||
- docker pull mariadb:5.5
|
||||
- docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret
|
||||
mariadb:5.5 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1
|
||||
- cp .travis/docker.cnf ~/.my.cnf
|
||||
- .travis/wait_mysql.sh
|
||||
before_script:
|
||||
- export MYSQL_TEST_USER=gotest
|
||||
- export MYSQL_TEST_PASS=secret
|
||||
- export MYSQL_TEST_ADDR=127.0.0.1:3307
|
||||
- export MYSQL_TEST_CONCURRENT=1
|
||||
|
||||
- env: DB=MARIA10_1
|
||||
sudo: required
|
||||
dist: trusty
|
||||
go: 1.10.x
|
||||
services:
|
||||
- docker
|
||||
before_install:
|
||||
- go get golang.org/x/tools/cmd/cover
|
||||
- go get github.com/mattn/goveralls
|
||||
- docker pull mariadb:10.1
|
||||
- docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret
|
||||
mariadb:10.1 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1
|
||||
- cp .travis/docker.cnf ~/.my.cnf
|
||||
- .travis/wait_mysql.sh
|
||||
before_script:
|
||||
- export MYSQL_TEST_USER=gotest
|
||||
- export MYSQL_TEST_PASS=secret
|
||||
- export MYSQL_TEST_ADDR=127.0.0.1:3307
|
||||
- export MYSQL_TEST_CONCURRENT=1
|
||||
|
||||
- os: osx
|
||||
osx_image: xcode10.1
|
||||
addons:
|
||||
homebrew:
|
||||
packages:
|
||||
- mysql
|
||||
update: true
|
||||
go: 1.12.x
|
||||
before_install:
|
||||
- go get golang.org/x/tools/cmd/cover
|
||||
- go get github.com/mattn/goveralls
|
||||
before_script:
|
||||
- echo -e "[server]\ninnodb_log_file_size=256MB\ninnodb_buffer_pool_size=512MB\nmax_allowed_packet=16MB\nlocal_infile=1" >> /usr/local/etc/my.cnf
|
||||
- mysql.server start
|
||||
- mysql -uroot -e 'CREATE USER gotest IDENTIFIED BY "secret"'
|
||||
- mysql -uroot -e 'GRANT ALL ON *.* TO gotest'
|
||||
- mysql -uroot -e 'create database gotest;'
|
||||
- export MYSQL_TEST_USER=gotest
|
||||
- export MYSQL_TEST_PASS=secret
|
||||
- export MYSQL_TEST_ADDR=127.0.0.1:3306
|
||||
- export MYSQL_TEST_CONCURRENT=1
|
||||
|
||||
script:
|
||||
- go test -v -covermode=count -coverprofile=coverage.out
|
||||
- go vet ./...
|
||||
- .travis/gofmt.sh
|
||||
after_script:
|
||||
- $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci
|
12
vendor/github.com/go-sql-driver/mysql/AUTHORS
generated
vendored
12
vendor/github.com/go-sql-driver/mysql/AUTHORS
generated
vendored
|
@ -13,11 +13,15 @@
|
|||
|
||||
Aaron Hopkins <go-sql-driver at die.net>
|
||||
Achille Roussel <achille.roussel at gmail.com>
|
||||
Alex Snast <alexsn at fb.com>
|
||||
Alexey Palazhchenko <alexey.palazhchenko at gmail.com>
|
||||
Andrew Reid <andrew.reid at tixtrack.com>
|
||||
Animesh Ray <mail.rayanimesh at gmail.com>
|
||||
Arne Hormann <arnehormann at gmail.com>
|
||||
Ariel Mashraki <ariel at mashraki.co.il>
|
||||
Asta Xie <xiemengjun at gmail.com>
|
||||
Bulat Gaifullin <gaifullinbf at gmail.com>
|
||||
Caine Jette <jette at alum.mit.edu>
|
||||
Carlos Nieto <jose.carlos at menteslibres.net>
|
||||
Chris Moos <chris at tech9computers.com>
|
||||
Craig Wilson <craiggwilson at gmail.com>
|
||||
|
@ -52,6 +56,7 @@ Julien Schmidt <go-sql-driver at julienschmidt.com>
|
|||
Justin Li <jli at j-li.net>
|
||||
Justin Nuß <nuss.justin at gmail.com>
|
||||
Kamil Dziedzic <kamil at klecza.pl>
|
||||
Kei Kamikawa <x00.x7f.x86 at gmail.com>
|
||||
Kevin Malachowski <kevin at chowski.com>
|
||||
Kieron Woodhouse <kieron.woodhouse at infosum.com>
|
||||
Lennart Rudolph <lrudolph at hmc.edu>
|
||||
|
@ -74,20 +79,26 @@ Reed Allman <rdallman10 at gmail.com>
|
|||
Richard Wilkes <wilkes at me.com>
|
||||
Robert Russell <robert at rrbrussell.com>
|
||||
Runrioter Wung <runrioter at gmail.com>
|
||||
Sho Iizuka <sho.i518 at gmail.com>
|
||||
Sho Ikeda <suicaicoca at gmail.com>
|
||||
Shuode Li <elemount at qq.com>
|
||||
Simon J Mudd <sjmudd at pobox.com>
|
||||
Soroush Pour <me at soroushjp.com>
|
||||
Stan Putrya <root.vagner at gmail.com>
|
||||
Stanley Gunawan <gunawan.stanley at gmail.com>
|
||||
Steven Hartland <steven.hartland at multiplay.co.uk>
|
||||
Tan Jinhua <312841925 at qq.com>
|
||||
Thomas Wodarek <wodarekwebpage at gmail.com>
|
||||
Tim Ruffles <timruffles at gmail.com>
|
||||
Tom Jenkinson <tom at tjenkinson.me>
|
||||
Vladimir Kovpak <cn007b at gmail.com>
|
||||
Vladyslav Zhelezniak <zhvladi at gmail.com>
|
||||
Xiangyu Hu <xiangyu.hu at outlook.com>
|
||||
Xiaobing Jiang <s7v7nislands at gmail.com>
|
||||
Xiuming Chen <cc at cxm.cc>
|
||||
Xuehong Chan <chanxuehong at gmail.com>
|
||||
Zhenye Xie <xiezhenye at gmail.com>
|
||||
Zhixin Wen <john.wenzhixin at gmail.com>
|
||||
|
||||
# Organizations
|
||||
|
||||
|
@ -103,3 +114,4 @@ Multiplay Ltd.
|
|||
Percona LLC
|
||||
Pivotal Inc.
|
||||
Stripe Inc.
|
||||
Zendesk Inc.
|
||||
|
|
26
vendor/github.com/go-sql-driver/mysql/CHANGELOG.md
generated
vendored
26
vendor/github.com/go-sql-driver/mysql/CHANGELOG.md
generated
vendored
|
@ -1,3 +1,29 @@
|
|||
## Version 1.6 (2021-04-01)
|
||||
|
||||
Changes:
|
||||
|
||||
- Migrate the CI service from travis-ci to GitHub Actions (#1176, #1183, #1190)
|
||||
- `NullTime` is deprecated (#960, #1144)
|
||||
- Reduce allocations when building SET command (#1111)
|
||||
- Performance improvement for time formatting (#1118)
|
||||
- Performance improvement for time parsing (#1098, #1113)
|
||||
|
||||
New Features:
|
||||
|
||||
- Implement `driver.Validator` interface (#1106, #1174)
|
||||
- Support returning `uint64` from `Valuer` in `ConvertValue` (#1143)
|
||||
- Add `json.RawMessage` for converter and prepared statement (#1059)
|
||||
- Interpolate `json.RawMessage` as `string` (#1058)
|
||||
- Implements `CheckNamedValue` (#1090)
|
||||
|
||||
Bugfixes:
|
||||
|
||||
- Stop rounding times (#1121, #1172)
|
||||
- Put zero filler into the SSL handshake packet (#1066)
|
||||
- Fix checking cancelled connections back into the connection pool (#1095)
|
||||
- Fix remove last 0 byte for mysql_old_password when password is empty (#1133)
|
||||
|
||||
|
||||
## Version 1.5 (2020-01-07)
|
||||
|
||||
Changes:
|
||||
|
|
43
vendor/github.com/go-sql-driver/mysql/README.md
generated
vendored
43
vendor/github.com/go-sql-driver/mysql/README.md
generated
vendored
|
@ -35,7 +35,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac
|
|||
* Supports queries larger than 16MB
|
||||
* Full [`sql.RawBytes`](https://golang.org/pkg/database/sql/#RawBytes) support.
|
||||
* Intelligent `LONG DATA` handling in prepared statements
|
||||
* Secure `LOAD DATA LOCAL INFILE` support with file Whitelisting and `io.Reader` support
|
||||
* Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support
|
||||
* Optional `time.Time` parsing
|
||||
* Optional placeholder interpolation
|
||||
|
||||
|
@ -56,15 +56,37 @@ Make sure [Git is installed](https://git-scm.com/downloads) on your machine and
|
|||
_Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface. You only need to import the driver and can use the full [`database/sql`](https://golang.org/pkg/database/sql/) API then.
|
||||
|
||||
Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`:
|
||||
|
||||
```go
|
||||
import "database/sql"
|
||||
import _ "github.com/go-sql-driver/mysql"
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
// ...
|
||||
|
||||
db, err := sql.Open("mysql", "user:password@/dbname")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// See "Important settings" section.
|
||||
db.SetConnMaxLifetime(time.Minute * 3)
|
||||
db.SetMaxOpenConns(10)
|
||||
db.SetMaxIdleConns(10)
|
||||
```
|
||||
|
||||
[Examples are available in our Wiki](https://github.com/go-sql-driver/mysql/wiki/Examples "Go-MySQL-Driver Examples").
|
||||
|
||||
### Important settings
|
||||
|
||||
`db.SetConnMaxLifetime()` is required to ensure connections are closed by the driver safely before connection is closed by MySQL server, OS, or other middlewares. Since some middlewares close idle connections by 5 minutes, we recommend timeout shorter than 5 minutes. This setting helps load balancing and changing system variables too.
|
||||
|
||||
`db.SetMaxOpenConns()` is highly recommended to limit the number of connection used by the application. There is no recommended limit number because it depends on application and MySQL server.
|
||||
|
||||
`db.SetMaxIdleConns()` is recommended to be set same to (or greater than) `db.SetMaxOpenConns()`. When it is smaller than `SetMaxOpenConns()`, connections can be opened and closed very frequently than you expect. Idle connections can be closed by the `db.SetConnMaxLifetime()`. If you want to close idle connections more rapidly, you can use `db.SetConnMaxIdleTime()` since Go 1.15.
|
||||
|
||||
|
||||
### DSN (Data Source Name)
|
||||
|
||||
|
@ -122,7 +144,7 @@ Valid Values: true, false
|
|||
Default: false
|
||||
```
|
||||
|
||||
`allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files.
|
||||
`allowAllFiles=true` disables the file allowlist for `LOAD DATA LOCAL INFILE` and allows *all* files.
|
||||
[*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)
|
||||
|
||||
##### `allowCleartextPasswords`
|
||||
|
@ -133,7 +155,7 @@ Valid Values: true, false
|
|||
Default: false
|
||||
```
|
||||
|
||||
`allowCleartextPasswords=true` allows using the [cleartext client side plugin](http://dev.mysql.com/doc/en/cleartext-authentication-plugin.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network.
|
||||
`allowCleartextPasswords=true` allows using the [cleartext client side plugin](https://dev.mysql.com/doc/en/cleartext-pluggable-authentication.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network.
|
||||
|
||||
##### `allowNativePasswords`
|
||||
|
||||
|
@ -230,7 +252,7 @@ Default: false
|
|||
|
||||
If `interpolateParams` is true, placeholders (`?`) in calls to `db.Query()` and `db.Exec()` are interpolated into a single query string with given parameters. This reduces the number of roundtrips, since the driver has to prepare a statement, execute it with given parameters and close the statement again with `interpolateParams=false`.
|
||||
|
||||
*This can not be used together with the multibyte encodings BIG5, CP932, GB2312, GBK or SJIS. These are blacklisted as they may [introduce a SQL injection vulnerability](http://stackoverflow.com/a/12118602/3430118)!*
|
||||
*This can not be used together with the multibyte encodings BIG5, CP932, GB2312, GBK or SJIS. These are rejected as they may [introduce a SQL injection vulnerability](http://stackoverflow.com/a/12118602/3430118)!*
|
||||
|
||||
##### `loc`
|
||||
|
||||
|
@ -376,7 +398,7 @@ Rules:
|
|||
Examples:
|
||||
* `autocommit=1`: `SET autocommit=1`
|
||||
* [`time_zone=%27Europe%2FParis%27`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `SET time_zone='Europe/Paris'`
|
||||
* [`tx_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation): `SET tx_isolation='REPEATABLE-READ'`
|
||||
* [`transaction_isolation=%27REPEATABLE-READ%27`](https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_transaction_isolation): `SET transaction_isolation='REPEATABLE-READ'`
|
||||
|
||||
|
||||
#### Examples
|
||||
|
@ -445,7 +467,7 @@ For this feature you need direct access to the package. Therefore you must chang
|
|||
import "github.com/go-sql-driver/mysql"
|
||||
```
|
||||
|
||||
Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the Whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)).
|
||||
Files must be explicitly allowed by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the allowlist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)).
|
||||
|
||||
To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::<name>` then. Choose different names for different handlers and `DeregisterReaderHandler` when you don't need it anymore.
|
||||
|
||||
|
@ -459,8 +481,6 @@ However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` v
|
|||
|
||||
**Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes).
|
||||
|
||||
Alternatively you can use the [`NullTime`](https://godoc.org/github.com/go-sql-driver/mysql#NullTime) type as the scan destination, which works with both `time.Time` and `string` / `[]byte`.
|
||||
|
||||
|
||||
### Unicode support
|
||||
Since version 1.5 Go-MySQL-Driver automatically uses the collation ` utf8mb4_general_ci` by default.
|
||||
|
@ -477,7 +497,7 @@ To run the driver tests you may need to adjust the configuration. See the [Testi
|
|||
Go-MySQL-Driver is not feature-complete yet. Your help is very appreciated.
|
||||
If you want to contribute, you can work on an [open issue](https://github.com/go-sql-driver/mysql/issues?state=open) or review a [pull request](https://github.com/go-sql-driver/mysql/pulls).
|
||||
|
||||
See the [Contribution Guidelines](https://github.com/go-sql-driver/mysql/blob/master/CONTRIBUTING.md) for details.
|
||||
See the [Contribution Guidelines](https://github.com/go-sql-driver/mysql/blob/master/.github/CONTRIBUTING.md) for details.
|
||||
|
||||
---------------------------------------
|
||||
|
||||
|
@ -498,4 +518,3 @@ Please read the [MPL 2.0 FAQ](https://www.mozilla.org/en-US/MPL/2.0/FAQ/) if you
|
|||
You can read the full terms here: [LICENSE](https://raw.github.com/go-sql-driver/mysql/master/LICENSE).
|
||||
|
||||

|
||||
|
||||
|
|
13
vendor/github.com/go-sql-driver/mysql/auth.go
generated
vendored
13
vendor/github.com/go-sql-driver/mysql/auth.go
generated
vendored
|
@ -15,6 +15,7 @@ import (
|
|||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
|
@ -136,10 +137,6 @@ func pwHash(password []byte) (result [2]uint32) {
|
|||
|
||||
// Hash password using insecure pre 4.1 method
|
||||
func scrambleOldPassword(scramble []byte, password string) []byte {
|
||||
if len(password) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
scramble = scramble[:8]
|
||||
|
||||
hashPw := pwHash([]byte(password))
|
||||
|
@ -247,6 +244,9 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
|
|||
if !mc.cfg.AllowOldPasswords {
|
||||
return nil, ErrOldPassword
|
||||
}
|
||||
if len(mc.cfg.Passwd) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
// Note: there are edge cases where this should work but doesn't;
|
||||
// this is currently "wontfix":
|
||||
// https://github.com/go-sql-driver/mysql/issues/184
|
||||
|
@ -372,7 +372,10 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(data[1:])
|
||||
block, rest := pem.Decode(data[1:])
|
||||
if block == nil {
|
||||
return fmt.Errorf("No Pem data found, data: %s", rest)
|
||||
}
|
||||
pkix, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
2
vendor/github.com/go-sql-driver/mysql/collations.go
generated
vendored
2
vendor/github.com/go-sql-driver/mysql/collations.go
generated
vendored
|
@ -247,7 +247,7 @@ var collations = map[string]byte{
|
|||
"utf8mb4_0900_ai_ci": 255,
|
||||
}
|
||||
|
||||
// A blacklist of collations which is unsafe to interpolate parameters.
|
||||
// A denylist of collations which is unsafe to interpolate parameters.
|
||||
// These multibyte encodings may contains 0x5c (`\`) in their trailing bytes.
|
||||
var unsafeCollations = map[string]bool{
|
||||
"big5_chinese_ci": true,
|
||||
|
|
85
vendor/github.com/go-sql-driver/mysql/connection.go
generated
vendored
85
vendor/github.com/go-sql-driver/mysql/connection.go
generated
vendored
|
@ -12,6 +12,7 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
|
@ -46,9 +47,10 @@ type mysqlConn struct {
|
|||
|
||||
// Handles parameters set in DSN after the connection is established
|
||||
func (mc *mysqlConn) handleParams() (err error) {
|
||||
var cmdSet strings.Builder
|
||||
for param, val := range mc.cfg.Params {
|
||||
switch param {
|
||||
// Charset
|
||||
// Charset: character_set_connection, character_set_client, character_set_results
|
||||
case "charset":
|
||||
charsets := strings.Split(val, ",")
|
||||
for i := range charsets {
|
||||
|
@ -62,12 +64,25 @@ func (mc *mysqlConn) handleParams() (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
// System Vars
|
||||
// Other system vars accumulated in a single SET command
|
||||
default:
|
||||
err = mc.exec("SET " + param + "=" + val + "")
|
||||
if err != nil {
|
||||
return
|
||||
if cmdSet.Len() == 0 {
|
||||
// Heuristic: 29 chars for each other key=value to reduce reallocations
|
||||
cmdSet.Grow(4 + len(param) + 1 + len(val) + 30*(len(mc.cfg.Params)-1))
|
||||
cmdSet.WriteString("SET ")
|
||||
} else {
|
||||
cmdSet.WriteByte(',')
|
||||
}
|
||||
cmdSet.WriteString(param)
|
||||
cmdSet.WriteByte('=')
|
||||
cmdSet.WriteString(val)
|
||||
}
|
||||
}
|
||||
|
||||
if cmdSet.Len() > 0 {
|
||||
err = mc.exec(cmdSet.String())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -230,47 +245,21 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
|
|||
if v.IsZero() {
|
||||
buf = append(buf, "'0000-00-00'"...)
|
||||
} else {
|
||||
v := v.In(mc.cfg.Loc)
|
||||
v = v.Add(time.Nanosecond * 500) // To round under microsecond
|
||||
year := v.Year()
|
||||
year100 := year / 100
|
||||
year1 := year % 100
|
||||
month := v.Month()
|
||||
day := v.Day()
|
||||
hour := v.Hour()
|
||||
minute := v.Minute()
|
||||
second := v.Second()
|
||||
micro := v.Nanosecond() / 1000
|
||||
|
||||
buf = append(buf, []byte{
|
||||
'\'',
|
||||
digits10[year100], digits01[year100],
|
||||
digits10[year1], digits01[year1],
|
||||
'-',
|
||||
digits10[month], digits01[month],
|
||||
'-',
|
||||
digits10[day], digits01[day],
|
||||
' ',
|
||||
digits10[hour], digits01[hour],
|
||||
':',
|
||||
digits10[minute], digits01[minute],
|
||||
':',
|
||||
digits10[second], digits01[second],
|
||||
}...)
|
||||
|
||||
if micro != 0 {
|
||||
micro10000 := micro / 10000
|
||||
micro100 := micro / 100 % 100
|
||||
micro1 := micro % 100
|
||||
buf = append(buf, []byte{
|
||||
'.',
|
||||
digits10[micro10000], digits01[micro10000],
|
||||
digits10[micro100], digits01[micro100],
|
||||
digits10[micro1], digits01[micro1],
|
||||
}...)
|
||||
buf = append(buf, '\'')
|
||||
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
buf = append(buf, '\'')
|
||||
}
|
||||
case json.RawMessage:
|
||||
buf = append(buf, '\'')
|
||||
if mc.status&statusNoBackslashEscapes == 0 {
|
||||
buf = escapeBytesBackslash(buf, v)
|
||||
} else {
|
||||
buf = escapeBytesQuotes(buf, v)
|
||||
}
|
||||
buf = append(buf, '\'')
|
||||
case []byte:
|
||||
if v == nil {
|
||||
buf = append(buf, "NULL"...)
|
||||
|
@ -480,6 +469,10 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
|
|||
|
||||
// BeginTx implements driver.ConnBeginTx interface
|
||||
func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
if mc.closed.IsSet() {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
if err := mc.watchCancel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -649,3 +642,9 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error {
|
|||
mc.reset = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsValid implements driver.Validator interface
|
||||
// (From Go 1.15)
|
||||
func (mc *mysqlConn) IsValid() bool {
|
||||
return !mc.closed.IsSet()
|
||||
}
|
||||
|
|
2
vendor/github.com/go-sql-driver/mysql/dsn.go
generated
vendored
2
vendor/github.com/go-sql-driver/mysql/dsn.go
generated
vendored
|
@ -375,7 +375,7 @@ func parseDSNParams(cfg *Config, params string) (err error) {
|
|||
|
||||
// cfg params
|
||||
switch value := param[1]; param[0] {
|
||||
// Disable INFILE whitelist / enable all files
|
||||
// Disable INFILE allowlist / enable all files
|
||||
case "allowAllFiles":
|
||||
var isBool bool
|
||||
cfg.AllowAllFiles, isBool = readBool(value)
|
||||
|
|
2
vendor/github.com/go-sql-driver/mysql/fields.go
generated
vendored
2
vendor/github.com/go-sql-driver/mysql/fields.go
generated
vendored
|
@ -106,7 +106,7 @@ var (
|
|||
scanTypeInt64 = reflect.TypeOf(int64(0))
|
||||
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
|
||||
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
|
||||
scanTypeNullTime = reflect.TypeOf(NullTime{})
|
||||
scanTypeNullTime = reflect.TypeOf(nullTime{})
|
||||
scanTypeUint8 = reflect.TypeOf(uint8(0))
|
||||
scanTypeUint16 = reflect.TypeOf(uint16(0))
|
||||
scanTypeUint32 = reflect.TypeOf(uint32(0))
|
||||
|
|
24
vendor/github.com/go-sql-driver/mysql/fuzz.go
generated
vendored
Normal file
24
vendor/github.com/go-sql-driver/mysql/fuzz.go
generated
vendored
Normal file
|
@ -0,0 +1,24 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package.
|
||||
//
|
||||
// Copyright 2020 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// +build gofuzz
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func Fuzz(data []byte) int {
|
||||
db, err := sql.Open("mysql", string(data))
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
db.Close()
|
||||
return 1
|
||||
}
|
4
vendor/github.com/go-sql-driver/mysql/infile.go
generated
vendored
4
vendor/github.com/go-sql-driver/mysql/infile.go
generated
vendored
|
@ -23,7 +23,7 @@ var (
|
|||
readerRegisterLock sync.RWMutex
|
||||
)
|
||||
|
||||
// RegisterLocalFile adds the given file to the file whitelist,
|
||||
// RegisterLocalFile adds the given file to the file allowlist,
|
||||
// so that it can be used by "LOAD DATA LOCAL INFILE <filepath>".
|
||||
// Alternatively you can allow the use of all local files with
|
||||
// the DSN parameter 'allowAllFiles=true'
|
||||
|
@ -45,7 +45,7 @@ func RegisterLocalFile(filePath string) {
|
|||
fileRegisterLock.Unlock()
|
||||
}
|
||||
|
||||
// DeregisterLocalFile removes the given filepath from the whitelist.
|
||||
// DeregisterLocalFile removes the given filepath from the allowlist.
|
||||
func DeregisterLocalFile(filePath string) {
|
||||
fileRegisterLock.Lock()
|
||||
delete(fileRegister, strings.Trim(filePath, `"`))
|
||||
|
|
4
vendor/github.com/go-sql-driver/mysql/nulltime.go
generated
vendored
4
vendor/github.com/go-sql-driver/mysql/nulltime.go
generated
vendored
|
@ -28,11 +28,11 @@ func (nt *NullTime) Scan(value interface{}) (err error) {
|
|||
nt.Time, nt.Valid = v, true
|
||||
return
|
||||
case []byte:
|
||||
nt.Time, err = parseDateTime(string(v), time.UTC)
|
||||
nt.Time, err = parseDateTime(v, time.UTC)
|
||||
nt.Valid = (err == nil)
|
||||
return
|
||||
case string:
|
||||
nt.Time, err = parseDateTime(v, time.UTC)
|
||||
nt.Time, err = parseDateTime([]byte(v), time.UTC)
|
||||
nt.Valid = (err == nil)
|
||||
return
|
||||
}
|
||||
|
|
9
vendor/github.com/go-sql-driver/mysql/nulltime_go113.go
generated
vendored
9
vendor/github.com/go-sql-driver/mysql/nulltime_go113.go
generated
vendored
|
@ -28,4 +28,13 @@ import (
|
|||
// }
|
||||
//
|
||||
// This NullTime implementation is not driver-specific
|
||||
//
|
||||
// Deprecated: NullTime doesn't honor the loc DSN parameter.
|
||||
// NullTime.Scan interprets a time as UTC, not the loc DSN parameter.
|
||||
// Use sql.NullTime instead.
|
||||
type NullTime sql.NullTime
|
||||
|
||||
// for internal use.
|
||||
// the mysql package uses sql.NullTime if it is available.
|
||||
// if not, the package uses mysql.NullTime.
|
||||
type nullTime = sql.NullTime // sql.NullTime is available
|
||||
|
|
5
vendor/github.com/go-sql-driver/mysql/nulltime_legacy.go
generated
vendored
5
vendor/github.com/go-sql-driver/mysql/nulltime_legacy.go
generated
vendored
|
@ -32,3 +32,8 @@ type NullTime struct {
|
|||
Time time.Time
|
||||
Valid bool // Valid is true if Time is not NULL
|
||||
}
|
||||
|
||||
// for internal use.
|
||||
// the mysql package uses sql.NullTime if it is available.
|
||||
// if not, the package uses mysql.NullTime.
|
||||
type nullTime = NullTime // sql.NullTime is not available
|
||||
|
|
23
vendor/github.com/go-sql-driver/mysql/packets.go
generated
vendored
23
vendor/github.com/go-sql-driver/mysql/packets.go
generated
vendored
|
@ -13,6 +13,7 @@ import (
|
|||
"crypto/tls"
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -348,6 +349,12 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
|
|||
return errors.New("unknown collation")
|
||||
}
|
||||
|
||||
// Filler [23 bytes] (all 0x00)
|
||||
pos := 13
|
||||
for ; pos < 13+23; pos++ {
|
||||
data[pos] = 0
|
||||
}
|
||||
|
||||
// SSL Connection Request Packet
|
||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
|
||||
if mc.cfg.tls != nil {
|
||||
|
@ -366,12 +373,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
|
|||
mc.buf.nc = tlsConn
|
||||
}
|
||||
|
||||
// Filler [23 bytes] (all 0x00)
|
||||
pos := 13
|
||||
for ; pos < 13+23; pos++ {
|
||||
data[pos] = 0
|
||||
}
|
||||
|
||||
// User [null terminated string]
|
||||
if len(mc.cfg.User) > 0 {
|
||||
pos += copy(data[pos:], mc.cfg.User)
|
||||
|
@ -777,7 +778,7 @@ func (rows *textRows) readRow(dest []driver.Value) error {
|
|||
case fieldTypeTimestamp, fieldTypeDateTime,
|
||||
fieldTypeDate, fieldTypeNewDate:
|
||||
dest[i], err = parseDateTime(
|
||||
string(dest[i].([]byte)),
|
||||
dest[i].([]byte),
|
||||
mc.cfg.Loc,
|
||||
)
|
||||
if err == nil {
|
||||
|
@ -1003,6 +1004,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|||
continue
|
||||
}
|
||||
|
||||
if v, ok := arg.(json.RawMessage); ok {
|
||||
arg = []byte(v)
|
||||
}
|
||||
// cache types and values
|
||||
switch v := arg.(type) {
|
||||
case int64:
|
||||
|
@ -1112,7 +1116,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|||
if v.IsZero() {
|
||||
b = append(b, "0000-00-00"...)
|
||||
} else {
|
||||
b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat)
|
||||
b, err = appendDateTime(b, v.In(mc.cfg.Loc))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
paramValues = appendLengthEncodedInteger(paramValues,
|
||||
|
|
30
vendor/github.com/go-sql-driver/mysql/statement.go
generated
vendored
30
vendor/github.com/go-sql-driver/mysql/statement.go
generated
vendored
|
@ -10,6 +10,7 @@ package mysql
|
|||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
|
@ -43,6 +44,11 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
|
|||
return converter{}
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) {
|
||||
nv.Value, err = converter{}.ConvertValue(nv.Value)
|
||||
return
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
if stmt.mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
|
@ -129,6 +135,8 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
|
|||
return rows, err
|
||||
}
|
||||
|
||||
var jsonType = reflect.TypeOf(json.RawMessage{})
|
||||
|
||||
type converter struct{}
|
||||
|
||||
// ConvertValue mirrors the reference/default converter in database/sql/driver
|
||||
|
@ -146,12 +154,17 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !driver.IsValue(sv) {
|
||||
return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
|
||||
if driver.IsValue(sv) {
|
||||
return sv, nil
|
||||
}
|
||||
return sv, nil
|
||||
// A value returend from the Valuer interface can be "a type handled by
|
||||
// a database driver's NamedValueChecker interface" so we should accept
|
||||
// uint64 here as well.
|
||||
if u, ok := sv.(uint64); ok {
|
||||
return u, nil
|
||||
}
|
||||
return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
switch rv.Kind() {
|
||||
case reflect.Ptr:
|
||||
|
@ -170,11 +183,14 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
|
|||
case reflect.Bool:
|
||||
return rv.Bool(), nil
|
||||
case reflect.Slice:
|
||||
ek := rv.Type().Elem().Kind()
|
||||
if ek == reflect.Uint8 {
|
||||
switch t := rv.Type(); {
|
||||
case t == jsonType:
|
||||
return v, nil
|
||||
case t.Elem().Kind() == reflect.Uint8:
|
||||
return rv.Bytes(), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, t.Elem().Kind())
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
|
||||
case reflect.String:
|
||||
return rv.String(), nil
|
||||
}
|
||||
|
|
195
vendor/github.com/go-sql-driver/mysql/utils.go
generated
vendored
195
vendor/github.com/go-sql-driver/mysql/utils.go
generated
vendored
|
@ -106,27 +106,136 @@ func readBool(input string) (value bool, valid bool) {
|
|||
* Time related utils *
|
||||
******************************************************************************/
|
||||
|
||||
func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
|
||||
base := "0000-00-00 00:00:00.0000000"
|
||||
switch len(str) {
|
||||
func parseDateTime(b []byte, loc *time.Location) (time.Time, error) {
|
||||
const base = "0000-00-00 00:00:00.000000"
|
||||
switch len(b) {
|
||||
case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM"
|
||||
if str == base[:len(str)] {
|
||||
return
|
||||
if string(b) == base[:len(b)] {
|
||||
return time.Time{}, nil
|
||||
}
|
||||
t, err = time.Parse(timeFormat[:len(str)], str)
|
||||
|
||||
year, err := parseByteYear(b)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if year <= 0 {
|
||||
year = 1
|
||||
}
|
||||
|
||||
if b[4] != '-' {
|
||||
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[4])
|
||||
}
|
||||
|
||||
m, err := parseByte2Digits(b[5], b[6])
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if m <= 0 {
|
||||
m = 1
|
||||
}
|
||||
month := time.Month(m)
|
||||
|
||||
if b[7] != '-' {
|
||||
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[7])
|
||||
}
|
||||
|
||||
day, err := parseByte2Digits(b[8], b[9])
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if day <= 0 {
|
||||
day = 1
|
||||
}
|
||||
if len(b) == 10 {
|
||||
return time.Date(year, month, day, 0, 0, 0, 0, loc), nil
|
||||
}
|
||||
|
||||
if b[10] != ' ' {
|
||||
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[10])
|
||||
}
|
||||
|
||||
hour, err := parseByte2Digits(b[11], b[12])
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if b[13] != ':' {
|
||||
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[13])
|
||||
}
|
||||
|
||||
min, err := parseByte2Digits(b[14], b[15])
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if b[16] != ':' {
|
||||
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[16])
|
||||
}
|
||||
|
||||
sec, err := parseByte2Digits(b[17], b[18])
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if len(b) == 19 {
|
||||
return time.Date(year, month, day, hour, min, sec, 0, loc), nil
|
||||
}
|
||||
|
||||
if b[19] != '.' {
|
||||
return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[19])
|
||||
}
|
||||
nsec, err := parseByteNanoSec(b[20:])
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
return time.Date(year, month, day, hour, min, sec, nsec, loc), nil
|
||||
default:
|
||||
err = fmt.Errorf("invalid time string: %s", str)
|
||||
return
|
||||
return time.Time{}, fmt.Errorf("invalid time bytes: %s", b)
|
||||
}
|
||||
}
|
||||
|
||||
// Adjust location
|
||||
if err == nil && loc != time.UTC {
|
||||
y, mo, d := t.Date()
|
||||
h, mi, s := t.Clock()
|
||||
t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil
|
||||
func parseByteYear(b []byte) (int, error) {
|
||||
year, n := 0, 1000
|
||||
for i := 0; i < 4; i++ {
|
||||
v, err := bToi(b[i])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
year += v * n
|
||||
n = n / 10
|
||||
}
|
||||
return year, nil
|
||||
}
|
||||
|
||||
return
|
||||
func parseByte2Digits(b1, b2 byte) (int, error) {
|
||||
d1, err := bToi(b1)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
d2, err := bToi(b2)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return d1*10 + d2, nil
|
||||
}
|
||||
|
||||
func parseByteNanoSec(b []byte) (int, error) {
|
||||
ns, digit := 0, 100000 // max is 6-digits
|
||||
for i := 0; i < len(b); i++ {
|
||||
v, err := bToi(b[i])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
ns += v * digit
|
||||
digit /= 10
|
||||
}
|
||||
// nanoseconds has 10-digits. (needs to scale digits)
|
||||
// 10 - 6 = 4, so we have to multiple 1000.
|
||||
return ns * 1000, nil
|
||||
}
|
||||
|
||||
func bToi(b byte) (int, error) {
|
||||
if b < '0' || b > '9' {
|
||||
return 0, errors.New("not [0-9]")
|
||||
}
|
||||
return int(b - '0'), nil
|
||||
}
|
||||
|
||||
func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) {
|
||||
|
@ -167,6 +276,64 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va
|
|||
return nil, fmt.Errorf("invalid DATETIME packet length %d", num)
|
||||
}
|
||||
|
||||
func appendDateTime(buf []byte, t time.Time) ([]byte, error) {
|
||||
year, month, day := t.Date()
|
||||
hour, min, sec := t.Clock()
|
||||
nsec := t.Nanosecond()
|
||||
|
||||
if year < 1 || year > 9999 {
|
||||
return buf, errors.New("year is not in the range [1, 9999]: " + strconv.Itoa(year)) // use errors.New instead of fmt.Errorf to avoid year escape to heap
|
||||
}
|
||||
year100 := year / 100
|
||||
year1 := year % 100
|
||||
|
||||
var localBuf [len("2006-01-02T15:04:05.999999999")]byte // does not escape
|
||||
localBuf[0], localBuf[1], localBuf[2], localBuf[3] = digits10[year100], digits01[year100], digits10[year1], digits01[year1]
|
||||
localBuf[4] = '-'
|
||||
localBuf[5], localBuf[6] = digits10[month], digits01[month]
|
||||
localBuf[7] = '-'
|
||||
localBuf[8], localBuf[9] = digits10[day], digits01[day]
|
||||
|
||||
if hour == 0 && min == 0 && sec == 0 && nsec == 0 {
|
||||
return append(buf, localBuf[:10]...), nil
|
||||
}
|
||||
|
||||
localBuf[10] = ' '
|
||||
localBuf[11], localBuf[12] = digits10[hour], digits01[hour]
|
||||
localBuf[13] = ':'
|
||||
localBuf[14], localBuf[15] = digits10[min], digits01[min]
|
||||
localBuf[16] = ':'
|
||||
localBuf[17], localBuf[18] = digits10[sec], digits01[sec]
|
||||
|
||||
if nsec == 0 {
|
||||
return append(buf, localBuf[:19]...), nil
|
||||
}
|
||||
nsec100000000 := nsec / 100000000
|
||||
nsec1000000 := (nsec / 1000000) % 100
|
||||
nsec10000 := (nsec / 10000) % 100
|
||||
nsec100 := (nsec / 100) % 100
|
||||
nsec1 := nsec % 100
|
||||
localBuf[19] = '.'
|
||||
|
||||
// milli second
|
||||
localBuf[20], localBuf[21], localBuf[22] =
|
||||
digits01[nsec100000000], digits10[nsec1000000], digits01[nsec1000000]
|
||||
// micro second
|
||||
localBuf[23], localBuf[24], localBuf[25] =
|
||||
digits10[nsec10000], digits01[nsec10000], digits10[nsec100]
|
||||
// nano second
|
||||
localBuf[26], localBuf[27], localBuf[28] =
|
||||
digits01[nsec100], digits10[nsec1], digits01[nsec1]
|
||||
|
||||
// trim trailing zeros
|
||||
n := len(localBuf)
|
||||
for n > 0 && localBuf[n-1] == '0' {
|
||||
n--
|
||||
}
|
||||
|
||||
return append(buf, localBuf[:n]...), nil
|
||||
}
|
||||
|
||||
// zeroDateTime is used in formatBinaryDateTime to avoid an allocation
|
||||
// if the DATE or DATETIME has the zero value.
|
||||
// It must never be changed.
|
||||
|
|
2
vendor/github.com/lib/pq/.gitignore
generated
vendored
2
vendor/github.com/lib/pq/.gitignore
generated
vendored
|
@ -2,3 +2,5 @@
|
|||
*.test
|
||||
*~
|
||||
*.swp
|
||||
.idea
|
||||
.vscode
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue