1
0
Fork 0
mirror of https://github.com/documize/community.git synced 2025-07-19 05:09:42 +02:00

Bump Go deps

This commit is contained in:
Harvey Kandola 2024-02-19 11:54:27 -05:00
parent f2ba294be8
commit acb59e1b43
91 changed files with 9004 additions and 513 deletions

41
vendor/github.com/Azure/go-ntlmssp/SECURITY.md generated vendored Normal file
View file

@ -0,0 +1,41 @@
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.8 BLOCK -->
## Security
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
## Reporting Security Issues
**Please do not report security vulnerabilities through public GitHub issues.**
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
* Full paths of source file(s) related to the manifestation of the issue
* The location of the affected source code (tag/branch/commit or direct URL)
* Any special configuration required to reproduce the issue
* Step-by-step instructions to reproduce the issue
* Proof-of-concept or exploit code (if possible)
* Impact of the issue, including how an attacker might exploit the issue
This information will help us triage your report more quickly.
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
## Preferred Languages
We prefer all communications to be in English.
## Policy
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
<!-- END MICROSOFT SECURITY.MD BLOCK -->

View file

@ -42,7 +42,7 @@ func (m authenicateMessage) MarshalBinary() ([]byte, error) {
}
target, user := toUnicode(m.TargetName), toUnicode(m.UserName)
workstation := toUnicode("go-ntlmssp")
workstation := toUnicode("")
ptr := binary.Size(&authenticateMessageFields{})
f := authenticateMessageFields{
@ -82,7 +82,7 @@ func (m authenicateMessage) MarshalBinary() ([]byte, error) {
//ProcessChallenge crafts an AUTHENTICATE message in response to the CHALLENGE message
//that was received from the server
func ProcessChallenge(challengeMessageData []byte, user, password string) ([]byte, error) {
func ProcessChallenge(challengeMessageData []byte, user, password string, domainNeeded bool) ([]byte, error) {
if user == "" && password == "" {
return nil, errors.New("Anonymous authentication not supported")
}
@ -98,6 +98,10 @@ func ProcessChallenge(challengeMessageData []byte, user, password string) ([]byt
if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEKEYEXCH) {
return nil, errors.New("Key exchange requested but not supported (NTLMSSP_NEGOTIATE_KEY_EXCH)")
}
if !domainNeeded {
cm.TargetName = ""
}
am := authenicateMessage{
UserName: user,

View file

@ -5,26 +5,55 @@ import (
"strings"
)
type authheader string
type authheader []string
func (h authheader) IsBasic() bool {
return strings.HasPrefix(string(h), "Basic ")
for _, s := range h {
if strings.HasPrefix(string(s), "Basic ") {
return true
}
}
return false
}
func (h authheader) Basic() string {
for _, s := range h {
if strings.HasPrefix(string(s), "Basic ") {
return s
}
}
return ""
}
func (h authheader) IsNegotiate() bool {
return strings.HasPrefix(string(h), "Negotiate")
for _, s := range h {
if strings.HasPrefix(string(s), "Negotiate") {
return true
}
}
return false
}
func (h authheader) IsNTLM() bool {
return strings.HasPrefix(string(h), "NTLM")
for _, s := range h {
if strings.HasPrefix(string(s), "NTLM") {
return true
}
}
return false
}
func (h authheader) GetData() ([]byte, error) {
p := strings.Split(string(h), " ")
if len(p) < 2 {
return nil, nil
for _, s := range h {
if strings.HasPrefix(string(s), "NTLM") || strings.HasPrefix(string(s), "Negotiate") || strings.HasPrefix(string(s), "Basic ") {
p := strings.Split(string(s), " ")
if len(p) < 2 {
return nil, nil
}
return base64.StdEncoding.DecodeString(string(p[1]))
}
}
return base64.StdEncoding.DecodeString(string(p[1]))
return nil, nil
}
func (h authheader) GetBasicCreds() (username, password string, err error) {

View file

@ -10,15 +10,22 @@ import (
)
// GetDomain : parse domain name from based on slashes in the input
func GetDomain(user string) (string, string) {
// Need to check for upn as well
func GetDomain(user string) (string, string, bool) {
domain := ""
domainNeeded := false
if strings.Contains(user, "\\") {
ucomponents := strings.SplitN(user, "\\", 2)
domain = ucomponents[0]
user = ucomponents[1]
domainNeeded = true
} else if strings.Contains(user, "@") {
domainNeeded = false
} else {
domainNeeded = true
}
return user, domain
return user, domain, domainNeeded
}
//Negotiator is a http.Roundtripper decorator that automatically
@ -34,10 +41,11 @@ func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error)
rt = http.DefaultTransport
}
// If it is not basic auth, just round trip the request as usual
reqauth := authheader(req.Header.Get("Authorization"))
reqauth := authheader(req.Header.Values("Authorization"))
if !reqauth.IsBasic() {
return rt.RoundTrip(req)
}
reqauthBasic := reqauth.Basic()
// Save request body
body := bytes.Buffer{}
if req.Body != nil {
@ -59,11 +67,10 @@ func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error)
if res.StatusCode != http.StatusUnauthorized {
return res, err
}
resauth := authheader(res.Header.Get("Www-Authenticate"))
resauth := authheader(res.Header.Values("Www-Authenticate"))
if !resauth.IsNegotiate() && !resauth.IsNTLM() {
// Unauthorized, Negotiate not requested, let's try with basic auth
req.Header.Set("Authorization", string(reqauth))
req.Header.Set("Authorization", string(reqauthBasic))
io.Copy(ioutil.Discard, res.Body)
res.Body.Close()
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
@ -75,7 +82,7 @@ func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error)
if res.StatusCode != http.StatusUnauthorized {
return res, err
}
resauth = authheader(res.Header.Get("Www-Authenticate"))
resauth = authheader(res.Header.Values("Www-Authenticate"))
}
if resauth.IsNegotiate() || resauth.IsNTLM() {
@ -91,7 +98,7 @@ func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error)
// get domain from username
domain := ""
u, domain = GetDomain(u)
u, domain, domainNeeded := GetDomain(u)
// send negotiate
negotiateMessage, err := NewNegotiateMessage(domain, "")
@ -112,7 +119,7 @@ func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error)
}
// receive challenge?
resauth = authheader(res.Header.Get("Www-Authenticate"))
resauth = authheader(res.Header.Values("Www-Authenticate"))
challengeMessage, err := resauth.GetData()
if err != nil {
return nil, err
@ -125,7 +132,7 @@ func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error)
res.Body.Close()
// send authenticate
authenticateMessage, err := ProcessChallenge(challengeMessage, u, p)
authenticateMessage, err := ProcessChallenge(challengeMessage, u, p, domainNeeded)
if err != nil {
return nil, err
}

View file

@ -1,6 +0,0 @@
language: go
go:
- 1.6
- 1.7
- 1.8

View file

@ -1,7 +1,7 @@
.PHONY: ci generate clean
ci: clean generate
go test -v ./...
go test -race -v ./...
generate:
go generate .

View file

@ -7,8 +7,8 @@ http.Handlers.
Doing this requires non-trivial wrapping of the http.ResponseWriter interface,
which is also exposed for users interested in a more low-level API.
[![GoDoc](https://godoc.org/github.com/felixge/httpsnoop?status.svg)](https://godoc.org/github.com/felixge/httpsnoop)
[![Build Status](https://travis-ci.org/felixge/httpsnoop.svg?branch=master)](https://travis-ci.org/felixge/httpsnoop)
[![Go Reference](https://pkg.go.dev/badge/github.com/felixge/httpsnoop.svg)](https://pkg.go.dev/github.com/felixge/httpsnoop)
[![Build Status](https://github.com/felixge/httpsnoop/actions/workflows/main.yaml/badge.svg)](https://github.com/felixge/httpsnoop/actions/workflows/main.yaml)
## Usage Example

View file

@ -52,7 +52,7 @@ func (m *Metrics) CaptureMetrics(w http.ResponseWriter, fn func(http.ResponseWri
return func(code int) {
next(code)
if !headerWritten {
if !(code >= 100 && code <= 199) && !headerWritten {
m.Code = code
headerWritten = true
}

View file

@ -1,5 +1,5 @@
// +build go1.8
// Code generated by "httpsnoop/codegen"; DO NOT EDIT
// Code generated by "httpsnoop/codegen"; DO NOT EDIT.
package httpsnoop

View file

@ -1,5 +1,5 @@
// +build !go1.8
// Code generated by "httpsnoop/codegen"; DO NOT EDIT
// Code generated by "httpsnoop/codegen"; DO NOT EDIT.
package httpsnoop

View file

@ -1,8 +1,7 @@
package ldap
import (
"log"
"fmt"
ber "github.com/go-asn1-ber/asn1-ber"
)
@ -63,7 +62,6 @@ func NewAddRequest(dn string, controls []Control) *AddRequest {
DN: dn,
Controls: controls,
}
}
// Add performs the given AddRequest
@ -85,7 +83,7 @@ func (l *Conn) Add(addRequest *AddRequest) error {
return err
}
} else {
log.Printf("Unexpected Response: %d", packet.Children[1].Tag)
return fmt.Errorf("ldap: unexpected response: %d", packet.Children[1].Tag)
}
return nil
}

View file

@ -261,7 +261,7 @@ func parseParams(str string) (map[string]string, error) {
var state int
for i := 0; i <= len(str); i++ {
switch state {
case 0: //reading key
case 0: // reading key
if i == len(str) {
return nil, fmt.Errorf("syntax error on %d", i)
}
@ -270,7 +270,7 @@ func parseParams(str string) (map[string]string, error) {
continue
}
state = 1
case 1: //reading value
case 1: // reading value
if i == len(str) {
m[key] = value
break
@ -289,7 +289,7 @@ func parseParams(str string) (map[string]string, error) {
default:
value += string(str[i])
}
case 2: //inside quotes
case 2: // inside quotes
if i == len(str) {
return nil, fmt.Errorf("syntax error on %d", i)
}
@ -399,6 +399,9 @@ type NTLMBindRequest struct {
Username string
// Password is the credentials to bind with
Password string
// AllowEmptyPassword sets whether the client allows binding with an empty password
// (normally used for unauthenticated bind).
AllowEmptyPassword bool
// Hash is the hex NTLM hash to bind with. Password or hash must be provided
Hash string
// Controls are optional controls to send with the bind request
@ -442,6 +445,22 @@ func (l *Conn) NTLMBind(domain, username, password string) error {
return err
}
// NTLMUnauthenticatedBind performs an bind with an empty password.
//
// A username is required. The anonymous bind is not (yet) supported by the go-ntlmssp library (https://github.com/Azure/go-ntlmssp/blob/819c794454d067543bc61d29f61fef4b3c3df62c/authenticate_message.go#L87)
//
// See https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-nlmp/b38c36ed-2804-4868-a9ff-8dd3182128e4 part 3.2.5.1.2
func (l *Conn) NTLMUnauthenticatedBind(domain, username string) error {
req := &NTLMBindRequest{
Domain: domain,
Username: username,
Password: "",
AllowEmptyPassword: true,
}
_, err := l.NTLMChallengeBind(req)
return err
}
// NTLMBindWithHash performs an NTLM Bind with an NTLM hash instead of plaintext password (pass-the-hash)
func (l *Conn) NTLMBindWithHash(domain, username, hash string) error {
req := &NTLMBindRequest{
@ -455,7 +474,7 @@ func (l *Conn) NTLMBindWithHash(domain, username, hash string) error {
// NTLMChallengeBind performs the NTLMSSP bind operation defined in the given request
func (l *Conn) NTLMChallengeBind(ntlmBindRequest *NTLMBindRequest) (*NTLMBindResult, error) {
if ntlmBindRequest.Password == "" && ntlmBindRequest.Hash == "" {
if !ntlmBindRequest.AllowEmptyPassword && ntlmBindRequest.Password == "" && ntlmBindRequest.Hash == "" {
return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client"))
}
@ -496,10 +515,11 @@ func (l *Conn) NTLMChallengeBind(ntlmBindRequest *NTLMBindRequest) (*NTLMBindRes
var err error
var responseMessage []byte
// generate a response message to the challenge with the given Username/Password if password is provided
if ntlmBindRequest.Password != "" {
responseMessage, err = ntlmssp.ProcessChallenge(ntlmsspChallenge, ntlmBindRequest.Username, ntlmBindRequest.Password)
} else if ntlmBindRequest.Hash != "" {
if ntlmBindRequest.Hash != "" {
responseMessage, err = ntlmssp.ProcessChallengeWithHash(ntlmsspChallenge, ntlmBindRequest.Username, ntlmBindRequest.Hash)
} else if ntlmBindRequest.Password != "" || ntlmBindRequest.AllowEmptyPassword {
_, _, domainNeeded := ntlmssp.GetDomain(ntlmBindRequest.Username)
responseMessage, err = ntlmssp.ProcessChallenge(ntlmsspChallenge, ntlmBindRequest.Username, ntlmBindRequest.Password, domainNeeded)
} else {
err = fmt.Errorf("need a password or hash to generate reply")
}
@ -538,3 +558,178 @@ func (l *Conn) NTLMChallengeBind(ntlmBindRequest *NTLMBindRequest) (*NTLMBindRes
err = GetLDAPError(packet)
return result, err
}
// GSSAPIClient interface is used as the client-side implementation for the
// GSSAPI SASL mechanism.
// Interface inspired by GSSAPIClient from golang.org/x/crypto/ssh
type GSSAPIClient interface {
// InitSecContext initiates the establishment of a security context for
// GSS-API between the client and server.
// Initially the token parameter should be specified as nil.
// The routine may return a outputToken which should be transferred to
// the server, where the server will present it to AcceptSecContext.
// If no token need be sent, InitSecContext will indicate this by setting
// needContinue to false. To complete the context
// establishment, one or more reply tokens may be required from the server;
// if so, InitSecContext will return a needContinue which is true.
// In this case, InitSecContext should be called again when the
// reply token is received from the server, passing the reply token
// to InitSecContext via the token parameters.
// See RFC 4752 section 3.1.
InitSecContext(target string, token []byte) (outputToken []byte, needContinue bool, err error)
// NegotiateSaslAuth performs the last step of the Sasl handshake.
// It takes a token, which, when unwrapped, describes the servers supported
// security layers (first octet) and maximum receive buffer (remaining
// three octets).
// If the received token is unacceptable an error must be returned to abort
// the handshake.
// Outputs a signed token describing the client's selected security layer
// and receive buffer size and optionally an authorization identity.
// The returned token will be sent to the server and the handshake considered
// completed successfully and the server authenticated.
// See RFC 4752 section 3.1.
NegotiateSaslAuth(token []byte, authzid string) ([]byte, error)
// DeleteSecContext destroys any established secure context.
DeleteSecContext() error
}
// GSSAPIBindRequest represents a GSSAPI SASL mechanism bind request.
// See rfc4752 and rfc4513 section 5.2.1.2.
type GSSAPIBindRequest struct {
// Service Principal Name user for the service ticket. Eg. "ldap/<host>"
ServicePrincipalName string
// (Optional) Authorization entity
AuthZID string
// (Optional) Controls to send with the bind request
Controls []Control
}
// GSSAPIBind performs the GSSAPI SASL bind using the provided GSSAPI client.
func (l *Conn) GSSAPIBind(client GSSAPIClient, servicePrincipal, authzid string) error {
return l.GSSAPIBindRequest(client, &GSSAPIBindRequest{
ServicePrincipalName: servicePrincipal,
AuthZID: authzid,
})
}
// GSSAPIBindRequest performs the GSSAPI SASL bind using the provided GSSAPI client.
func (l *Conn) GSSAPIBindRequest(client GSSAPIClient, req *GSSAPIBindRequest) error {
//nolint:errcheck
defer client.DeleteSecContext()
var err error
var reqToken []byte
var recvToken []byte
needInit := true
for {
if needInit {
// Establish secure context between client and server.
reqToken, needInit, err = client.InitSecContext(req.ServicePrincipalName, recvToken)
if err != nil {
return err
}
} else {
// Secure context is set up, perform the last step of SASL handshake.
reqToken, err = client.NegotiateSaslAuth(recvToken, req.AuthZID)
if err != nil {
return err
}
}
// Send Bind request containing the current token and extract the
// token sent by server.
recvToken, err = l.saslBindTokenExchange(req.Controls, reqToken)
if err != nil {
return err
}
if !needInit && len(recvToken) == 0 {
break
}
}
return nil
}
func (l *Conn) saslBindTokenExchange(reqControls []Control, reqToken []byte) ([]byte, error) {
// Construct LDAP Bind request with GSSAPI SASL mechanism.
envelope := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
envelope.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name"))
auth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication")
auth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "GSSAPI", "SASL Mech"))
if len(reqToken) > 0 {
auth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, string(reqToken), "Credentials"))
}
request.AppendChild(auth)
envelope.AppendChild(request)
if len(reqControls) > 0 {
envelope.AppendChild(encodeControls(reqControls))
}
msgCtx, err := l.sendMessage(envelope)
if err != nil {
return nil, err
}
defer l.finishMessage(msgCtx)
packet, err := l.readPacket(msgCtx)
if err != nil {
return nil, err
}
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if l.Debug {
if err = addLDAPDescriptions(packet); err != nil {
return nil, err
}
ber.PrintPacket(packet)
}
// https://www.rfc-editor.org/rfc/rfc4511#section-4.1.1
// packet is an envelope
// child 0 is message id
// child 1 is protocolOp
if len(packet.Children) != 2 {
return nil, fmt.Errorf("bad bind response")
}
protocolOp := packet.Children[1]
RESP:
switch protocolOp.Description {
case "Bind Response": // Bind Response
// Bind Reponse is an LDAP Response (https://www.rfc-editor.org/rfc/rfc4511#section-4.1.9)
// with an additional optional serverSaslCreds string (https://www.rfc-editor.org/rfc/rfc4511#section-4.2.2)
// child 0 is resultCode
resultCode := protocolOp.Children[0]
if resultCode.Tag != ber.TagEnumerated {
break RESP
}
switch resultCode.Value.(int64) {
case 14: // Sasl bind in progress
if len(protocolOp.Children) < 3 {
break RESP
}
referral := protocolOp.Children[3]
switch referral.Description {
case "Referral":
if referral.ClassType != ber.ClassContext || referral.Tag != ber.TagObjectDescriptor {
break RESP
}
return ioutil.ReadAll(referral.Data)
}
// Optional:
//if len(protocolOp.Children) == 4 {
// serverSaslCreds := protocolOp.Children[4]
//}
case 0: // Success - Bind OK.
// SASL layer in effect (if any) (See https://www.rfc-editor.org/rfc/rfc4513#section-5.2.1.4)
// NOTE: SASL security layers are not supported currently.
return nil, nil
}
}
return nil, GetLDAPError(packet)
}

View file

@ -1,6 +1,7 @@
package ldap
import (
"context"
"crypto/tls"
"time"
)
@ -9,14 +10,18 @@ import (
type Client interface {
Start()
StartTLS(*tls.Config) error
Close()
Close() error
GetLastError() error
IsClosing() bool
SetTimeout(time.Duration)
TLSConnectionState() (tls.ConnectionState, bool)
Bind(username, password string) error
UnauthenticatedBind(username string) error
SimpleBind(*SimpleBindRequest) (*SimpleBindResult, error)
ExternalBind() error
NTLMUnauthenticatedBind(domain, username string) error
Unbind() error
Add(*AddRequest) error
Del(*DelRequest) error
@ -28,5 +33,9 @@ type Client interface {
PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error)
Search(*SearchRequest) (*SearchResult, error)
SearchAsync(ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response
SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error)
DirSync(searchRequest *SearchRequest, flags, maxAttrCount int64, cookie []byte) (*SearchResult, error)
DirSyncAsync(ctx context.Context, searchRequest *SearchRequest, bufferSize int, flags, maxAttrCount int64, cookie []byte) Response
Syncrepl(ctx context.Context, searchRequest *SearchRequest, bufferSize int, mode ControlSyncRequestMode, cookie []byte, reloadHint bool) Response
}

View file

@ -34,7 +34,8 @@ func (l *Conn) Compare(dn, attribute, value string) (bool, error) {
msgCtx, err := l.doRequest(&CompareRequest{
DN: dn,
Attribute: attribute,
Value: value})
Value: value,
})
if err != nil {
return false, err
}

View file

@ -2,10 +2,10 @@ package ldap
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"net/url"
"sync"
@ -61,13 +61,21 @@ type messageContext struct {
// sendResponse should only be called within the processMessages() loop which
// is also responsible for closing the responses channel.
func (msgCtx *messageContext) sendResponse(packet *PacketResponse) {
func (msgCtx *messageContext) sendResponse(packet *PacketResponse, timeout time.Duration) {
timeoutCtx := context.Background()
if timeout > 0 {
var cancelFunc context.CancelFunc
timeoutCtx, cancelFunc = context.WithTimeout(context.Background(), timeout)
defer cancelFunc()
}
select {
case msgCtx.responses <- packet:
// Successfully sent packet to message handler.
case <-msgCtx.done:
// The request handler is done and will not receive more
// packets.
case <-timeoutCtx.Done():
// The timeout was reached before the packet was sent.
}
}
@ -88,6 +96,7 @@ const (
type Conn struct {
// requestTimeout is loaded atomically
// so we need to ensure 64-bit alignment on 32-bit platforms.
// https://github.com/go-ldap/ldap/pull/199
requestTimeout int64
conn net.Conn
isTLS bool
@ -102,6 +111,8 @@ type Conn struct {
wgClose sync.WaitGroup
outstandingRequests uint
messageMutex sync.Mutex
err error
}
var _ Client = &Conn{}
@ -119,30 +130,31 @@ type DialOpt func(*DialContext)
// DialWithDialer updates net.Dialer in DialContext.
func DialWithDialer(d *net.Dialer) DialOpt {
return func(dc *DialContext) {
dc.d = d
dc.dialer = d
}
}
// DialWithTLSConfig updates tls.Config in DialContext.
func DialWithTLSConfig(tc *tls.Config) DialOpt {
return func(dc *DialContext) {
dc.tc = tc
dc.tlsConfig = tc
}
}
// DialWithTLSDialer is a wrapper for DialWithTLSConfig with the option to
// specify a net.Dialer to for example define a timeout or a custom resolver.
// @deprecated Use DialWithDialer and DialWithTLSConfig instead
func DialWithTLSDialer(tlsConfig *tls.Config, dialer *net.Dialer) DialOpt {
return func(dc *DialContext) {
dc.tc = tlsConfig
dc.d = dialer
dc.tlsConfig = tlsConfig
dc.dialer = dialer
}
}
// DialContext contains necessary parameters to dial the given ldap URL.
type DialContext struct {
d *net.Dialer
tc *tls.Config
dialer *net.Dialer
tlsConfig *tls.Config
}
func (dc *DialContext) dial(u *url.URL) (net.Conn, error) {
@ -150,7 +162,7 @@ func (dc *DialContext) dial(u *url.URL) (net.Conn, error) {
if u.Path == "" || u.Path == "/" {
u.Path = "/var/run/slapd/ldapi"
}
return dc.d.Dial("unix", u.Path)
return dc.dialer.Dial("unix", u.Path)
}
host, port, err := net.SplitHostPort(u.Host)
@ -161,16 +173,21 @@ func (dc *DialContext) dial(u *url.URL) (net.Conn, error) {
}
switch u.Scheme {
case "cldap":
if port == "" {
port = DefaultLdapPort
}
return dc.dialer.Dial("udp", net.JoinHostPort(host, port))
case "ldap":
if port == "" {
port = DefaultLdapPort
}
return dc.d.Dial("tcp", net.JoinHostPort(host, port))
return dc.dialer.Dial("tcp", net.JoinHostPort(host, port))
case "ldaps":
if port == "" {
port = DefaultLdapsPort
}
return tls.DialWithDialer(dc.d, "tcp", net.JoinHostPort(host, port), dc.tc)
return tls.DialWithDialer(dc.dialer, "tcp", net.JoinHostPort(host, port), dc.tlsConfig)
}
return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme)
@ -203,7 +220,8 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
}
// DialURL connects to the given ldap URL.
// The following schemas are supported: ldap://, ldaps://, ldapi://.
// The following schemas are supported: ldap://, ldaps://, ldapi://,
// and cldap:// (RFC1798, deprecated but used by Active Directory).
// On success a new Conn for the connection is returned.
func DialURL(addr string, opts ...DialOpt) (*Conn, error) {
u, err := url.Parse(addr)
@ -215,8 +233,8 @@ func DialURL(addr string, opts ...DialOpt) (*Conn, error) {
for _, opt := range opts {
opt(&dc)
}
if dc.d == nil {
dc.d = &net.Dialer{Timeout: DefaultTimeout}
if dc.dialer == nil {
dc.dialer = &net.Dialer{Timeout: DefaultTimeout}
}
c, err := dc.dial(u)
@ -231,7 +249,7 @@ func DialURL(addr string, opts ...DialOpt) (*Conn, error) {
// NewConn returns a new Conn using conn for network I/O.
func NewConn(conn net.Conn, isTLS bool) *Conn {
return &Conn{
l := &Conn{
conn: conn,
chanConfirm: make(chan struct{}),
chanMessageID: make(chan int64),
@ -240,11 +258,12 @@ func NewConn(conn net.Conn, isTLS bool) *Conn {
requestTimeout: 0,
isTLS: isTLS,
}
l.wgClose.Add(1)
return l
}
// Start initializes goroutines to read responses and process messages
func (l *Conn) Start() {
l.wgClose.Add(1)
go l.reader()
go l.processMessages()
}
@ -260,31 +279,45 @@ func (l *Conn) setClosing() bool {
}
// Close closes the connection.
func (l *Conn) Close() {
func (l *Conn) Close() (err error) {
l.messageMutex.Lock()
defer l.messageMutex.Unlock()
if l.setClosing() {
l.Debug.Printf("Sending quit message and waiting for confirmation")
l.chanMessage <- &messagePacket{Op: MessageQuit}
<-l.chanConfirm
timeoutCtx := context.Background()
if l.getTimeout() > 0 {
var cancelFunc context.CancelFunc
timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(l.getTimeout()))
defer cancelFunc()
}
select {
case <-l.chanConfirm:
// Confirmation was received.
case <-timeoutCtx.Done():
// The timeout was reached before confirmation was received.
}
close(l.chanMessage)
l.Debug.Printf("Closing network connection")
if err := l.conn.Close(); err != nil {
log.Println(err)
}
err = l.conn.Close()
l.wgClose.Done()
}
l.wgClose.Wait()
return err
}
// SetTimeout sets the time after a request is sent that a MessageTimeout triggers
func (l *Conn) SetTimeout(timeout time.Duration) {
if timeout > 0 {
atomic.StoreInt64(&l.requestTimeout, int64(timeout))
}
atomic.StoreInt64(&l.requestTimeout, int64(timeout))
}
func (l *Conn) getTimeout() int64 {
return atomic.LoadInt64(&l.requestTimeout)
}
// Returns the next available messageID
@ -295,6 +328,14 @@ func (l *Conn) nextMessageID() int64 {
return 0
}
// GetLastError returns the last recorded error from goroutines like processMessages and reader.
// Only the last recorded error will be returned.
func (l *Conn) GetLastError() error {
l.messageMutex.Lock()
defer l.messageMutex.Unlock()
return l.err
}
// StartTLS sends the command to start a TLS session and then creates a new TLS Client
func (l *Conn) StartTLS(config *tls.Config) error {
if l.isTLS {
@ -443,13 +484,13 @@ func (l *Conn) sendProcessMessage(message *messagePacket) bool {
func (l *Conn) processMessages() {
defer func() {
if err := recover(); err != nil {
log.Printf("ldap: recovered panic in processMessages: %v", err)
l.err = fmt.Errorf("ldap: recovered panic in processMessages: %v", err)
}
for messageID, msgCtx := range l.messageContexts {
// If we are closing due to an error, inform anyone who
// is waiting about the error.
if l.IsClosing() && l.closeErr.Load() != nil {
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)})
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.getTimeout()))
}
l.Debug.Printf("Closing channel for MessageID %d", messageID)
close(msgCtx.responses)
@ -477,7 +518,7 @@ func (l *Conn) processMessages() {
_, err := l.conn.Write(buf)
if err != nil {
l.Debug.Printf("Error Sending Message: %s", err.Error())
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)})
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.getTimeout()))
close(message.Context.responses)
break
}
@ -487,28 +528,35 @@ func (l *Conn) processMessages() {
l.messageContexts[message.MessageID] = message.Context
// Add timeout if defined
requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout))
requestTimeout := l.getTimeout()
if requestTimeout > 0 {
go func() {
timer := time.NewTimer(time.Duration(requestTimeout))
defer func() {
if err := recover(); err != nil {
log.Printf("ldap: recovered panic in RequestTimeout: %v", err)
l.err = fmt.Errorf("ldap: recovered panic in RequestTimeout: %v", err)
}
timer.Stop()
}()
time.Sleep(requestTimeout)
timeoutMessage := &messagePacket{
Op: MessageTimeout,
MessageID: message.MessageID,
select {
case <-timer.C:
timeoutMessage := &messagePacket{
Op: MessageTimeout,
MessageID: message.MessageID,
}
l.sendProcessMessage(timeoutMessage)
case <-message.Context.done:
}
l.sendProcessMessage(timeoutMessage)
}()
}
case MessageResponse:
l.Debug.Printf("Receiving message %d", message.MessageID)
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
msgCtx.sendResponse(&PacketResponse{message.Packet, nil})
msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.getTimeout()))
} else {
log.Printf("Received unexpected message %d, %v", message.MessageID, l.IsClosing())
l.err = fmt.Errorf("ldap: received unexpected message %d, %v", message.MessageID, l.IsClosing())
l.Debug.PrintPacket(message.Packet)
}
case MessageTimeout:
@ -516,7 +564,7 @@ func (l *Conn) processMessages() {
// All reads will return immediately
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))})
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.getTimeout()))
delete(l.messageContexts, message.MessageID)
close(msgCtx.responses)
}
@ -535,7 +583,7 @@ func (l *Conn) reader() {
cleanstop := false
defer func() {
if err := recover(); err != nil {
log.Printf("ldap: recovered panic in reader: %v", err)
l.err = fmt.Errorf("ldap: recovered panic in reader: %v", err)
}
if !cleanstop {
l.Close()

View file

@ -5,6 +5,7 @@ import (
"strconv"
ber "github.com/go-asn1-ber/asn1-ber"
"github.com/google/uuid"
)
const (
@ -20,6 +21,13 @@ const (
ControlTypeManageDsaIT = "2.16.840.1.113730.3.4.2"
// ControlTypeWhoAmI - https://tools.ietf.org/html/rfc4532
ControlTypeWhoAmI = "1.3.6.1.4.1.4203.1.11.3"
// ControlTypeSubtreeDelete - https://datatracker.ietf.org/doc/html/draft-armijo-ldap-treedelete-02
ControlTypeSubtreeDelete = "1.2.840.113556.1.4.805"
// ControlTypeServerSideSorting - https://www.ietf.org/rfc/rfc2891.txt
ControlTypeServerSideSorting = "1.2.840.113556.1.4.473"
// ControlTypeServerSideSorting - https://www.ietf.org/rfc/rfc2891.txt
ControlTypeServerSideSortingResult = "1.2.840.113556.1.4.474"
// ControlTypeMicrosoftNotification - https://msdn.microsoft.com/en-us/library/aa366983(v=vs.85).aspx
ControlTypeMicrosoftNotification = "1.2.840.113556.1.4.528"
@ -27,16 +35,43 @@ const (
ControlTypeMicrosoftShowDeleted = "1.2.840.113556.1.4.417"
// ControlTypeMicrosoftServerLinkTTL - https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-adts/f4f523a8-abc0-4b3a-a471-6b2fef135481?redirectedfrom=MSDN
ControlTypeMicrosoftServerLinkTTL = "1.2.840.113556.1.4.2309"
// ControlTypeDirSync - Active Directory DirSync - https://msdn.microsoft.com/en-us/library/aa366978(v=vs.85).aspx
ControlTypeDirSync = "1.2.840.113556.1.4.841"
// ControlTypeSyncRequest - https://www.ietf.org/rfc/rfc4533.txt
ControlTypeSyncRequest = "1.3.6.1.4.1.4203.1.9.1.1"
// ControlTypeSyncState - https://www.ietf.org/rfc/rfc4533.txt
ControlTypeSyncState = "1.3.6.1.4.1.4203.1.9.1.2"
// ControlTypeSyncDone - https://www.ietf.org/rfc/rfc4533.txt
ControlTypeSyncDone = "1.3.6.1.4.1.4203.1.9.1.3"
// ControlTypeSyncInfo - https://www.ietf.org/rfc/rfc4533.txt
ControlTypeSyncInfo = "1.3.6.1.4.1.4203.1.9.1.4"
)
// Flags for DirSync control
const (
DirSyncIncrementalValues int64 = 2147483648
DirSyncPublicDataOnly int64 = 8192
DirSyncAncestorsFirstOrder int64 = 2048
DirSyncObjectSecurity int64 = 1
)
// ControlTypeMap maps controls to text descriptions
var ControlTypeMap = map[string]string{
ControlTypePaging: "Paging",
ControlTypeBeheraPasswordPolicy: "Password Policy - Behera Draft",
ControlTypeManageDsaIT: "Manage DSA IT",
ControlTypeMicrosoftNotification: "Change Notification - Microsoft",
ControlTypeMicrosoftShowDeleted: "Show Deleted Objects - Microsoft",
ControlTypeMicrosoftServerLinkTTL: "Return TTL-DNs for link values with associated expiry times - Microsoft",
ControlTypePaging: "Paging",
ControlTypeBeheraPasswordPolicy: "Password Policy - Behera Draft",
ControlTypeManageDsaIT: "Manage DSA IT",
ControlTypeSubtreeDelete: "Subtree Delete Control",
ControlTypeMicrosoftNotification: "Change Notification - Microsoft",
ControlTypeMicrosoftShowDeleted: "Show Deleted Objects - Microsoft",
ControlTypeMicrosoftServerLinkTTL: "Return TTL-DNs for link values with associated expiry times - Microsoft",
ControlTypeServerSideSorting: "Server Side Sorting Request - LDAP Control Extension for Server Side Sorting of Search Results (RFC2891)",
ControlTypeServerSideSortingResult: "Server Side Sorting Results - LDAP Control Extension for Server Side Sorting of Search Results (RFC2891)",
ControlTypeDirSync: "DirSync",
ControlTypeSyncRequest: "Sync Request",
ControlTypeSyncState: "Sync State",
ControlTypeSyncDone: "Sync Done",
ControlTypeSyncInfo: "Sync Info",
}
// Control defines an interface controls provide to encode and describe themselves
@ -229,7 +264,7 @@ func (c *ControlManageDsaIT) GetControlType() string {
// Encode returns the ber packet representation
func (c *ControlManageDsaIT) Encode() *ber.Packet {
//FIXME
// FIXME
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeManageDsaIT, "Control Type ("+ControlTypeMap[ControlTypeManageDsaIT]+")"))
if c.Criticality {
@ -369,7 +404,13 @@ func DecodeControl(packet *ber.Packet) (Control, error) {
case 2:
packet.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")"
ControlType = packet.Children[0].Value.(string)
if packet.Children[0].Value != nil {
ControlType = packet.Children[0].Value.(string)
} else if packet.Children[0].Data != nil {
ControlType = packet.Children[0].Data.String()
} else {
return nil, fmt.Errorf("not found where to get the control type")
}
// Children[1] could be criticality or value (both are optional)
// duck-type on whether this is a boolean
@ -436,18 +477,18 @@ func DecodeControl(packet *ber.Packet) (Control, error) {
for _, child := range sequence.Children {
if child.Tag == 0 {
//Warning
// Warning
warningPacket := child.Children[0]
val, err := ber.ParseInt64(warningPacket.Data.Bytes())
if err != nil {
return nil, fmt.Errorf("failed to decode data bytes: %s", err)
}
if warningPacket.Tag == 0 {
//timeBeforeExpiration
// timeBeforeExpiration
c.Expire = val
warningPacket.Value = c.Expire
} else if warningPacket.Tag == 1 {
//graceAuthNsRemaining
// graceAuthNsRemaining
c.Grace = val
warningPacket.Value = c.Grace
}
@ -485,6 +526,36 @@ func DecodeControl(packet *ber.Packet) (Control, error) {
return NewControlMicrosoftShowDeleted(), nil
case ControlTypeMicrosoftServerLinkTTL:
return NewControlMicrosoftServerLinkTTL(), nil
case ControlTypeSubtreeDelete:
return NewControlSubtreeDelete(), nil
case ControlTypeServerSideSorting:
return NewControlServerSideSorting(value)
case ControlTypeServerSideSortingResult:
return NewControlServerSideSortingResult(value)
case ControlTypeDirSync:
value.Description += " (DirSync)"
return NewResponseControlDirSync(value)
case ControlTypeSyncState:
value.Description += " (Sync State)"
valueChildren, err := ber.DecodePacketErr(value.Data.Bytes())
if err != nil {
return nil, fmt.Errorf("failed to decode data bytes: %s", err)
}
return NewControlSyncState(valueChildren)
case ControlTypeSyncDone:
value.Description += " (Sync Done)"
valueChildren, err := ber.DecodePacketErr(value.Data.Bytes())
if err != nil {
return nil, fmt.Errorf("failed to decode data bytes: %s", err)
}
return NewControlSyncDone(valueChildren)
case ControlTypeSyncInfo:
value.Description += " (Sync Info)"
valueChildren, err := ber.DecodePacketErr(value.Data.Bytes())
if err != nil {
return nil, fmt.Errorf("failed to decode data bytes: %s", err)
}
return NewControlSyncInfo(valueChildren)
default:
c := new(ControlString)
c.ControlType = ControlType
@ -519,6 +590,35 @@ func NewControlBeheraPasswordPolicy() *ControlBeheraPasswordPolicy {
}
}
// ControlSubtreeDelete implements the subtree delete control described in
// https://datatracker.ietf.org/doc/html/draft-armijo-ldap-treedelete-02
type ControlSubtreeDelete struct{}
// GetControlType returns the OID
func (c *ControlSubtreeDelete) GetControlType() string {
return ControlTypeSubtreeDelete
}
// NewControlSubtreeDelete returns a ControlSubtreeDelete control.
func NewControlSubtreeDelete() *ControlSubtreeDelete {
return &ControlSubtreeDelete{}
}
// Encode returns the ber packet representation
func (c *ControlSubtreeDelete) Encode() *ber.Packet {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeSubtreeDelete, "Control Type ("+ControlTypeMap[ControlTypeSubtreeDelete]+")"))
return packet
}
func (c *ControlSubtreeDelete) String() string {
return fmt.Sprintf(
"Control Type: %s (%q)",
ControlTypeMap[ControlTypeSubtreeDelete],
ControlTypeSubtreeDelete)
}
func encodeControls(controls []Control) *ber.Packet {
packet := ber.Encode(ber.ClassContext, ber.TypeConstructed, 0, nil, "Controls")
for _, control := range controls {
@ -526,3 +626,669 @@ func encodeControls(controls []Control) *ber.Packet {
}
return packet
}
// ControlDirSync implements the control described in https://msdn.microsoft.com/en-us/library/aa366978(v=vs.85).aspx
type ControlDirSync struct {
Criticality bool
Flags int64
MaxAttrCount int64
Cookie []byte
}
// @deprecated Use NewRequestControlDirSync instead
func NewControlDirSync(flags int64, maxAttrCount int64, cookie []byte) *ControlDirSync {
return NewRequestControlDirSync(flags, maxAttrCount, cookie)
}
// NewRequestControlDirSync returns a dir sync control
func NewRequestControlDirSync(
flags int64, maxAttrCount int64, cookie []byte,
) *ControlDirSync {
return &ControlDirSync{
Criticality: true,
Flags: flags,
MaxAttrCount: maxAttrCount,
Cookie: cookie,
}
}
// NewResponseControlDirSync returns a dir sync control
func NewResponseControlDirSync(value *ber.Packet) (*ControlDirSync, error) {
if value.Value != nil {
valueChildren, err := ber.DecodePacketErr(value.Data.Bytes())
if err != nil {
return nil, fmt.Errorf("failed to decode data bytes: %s", err)
}
value.Data.Truncate(0)
value.Value = nil
value.AppendChild(valueChildren)
}
child := value.Children[0]
if len(child.Children) != 3 { // also on initial creation, Cookie is an empty string
return nil, fmt.Errorf("invalid number of children in dirSync control")
}
child.Description = "DirSync Control Value"
child.Children[0].Description = "Flags"
child.Children[1].Description = "MaxAttrCount"
child.Children[2].Description = "Cookie"
cookie := child.Children[2].Data.Bytes()
child.Children[2].Value = cookie
return &ControlDirSync{
Criticality: true,
Flags: child.Children[0].Value.(int64),
MaxAttrCount: child.Children[1].Value.(int64),
Cookie: cookie,
}, nil
}
// GetControlType returns the OID
func (c *ControlDirSync) GetControlType() string {
return ControlTypeDirSync
}
// String returns a human-readable description
func (c *ControlDirSync) String() string {
return fmt.Sprintf(
"ControlType: %s (%q) Criticality: %t ControlValue: Flags: %d MaxAttrCount: %d",
ControlTypeMap[ControlTypeDirSync],
ControlTypeDirSync,
c.Criticality,
c.Flags,
c.MaxAttrCount,
)
}
// Encode returns the ber packet representation
func (c *ControlDirSync) Encode() *ber.Packet {
cookie := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "Cookie")
if len(c.Cookie) != 0 {
cookie.Value = c.Cookie
cookie.Data.Write(c.Cookie)
}
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeDirSync, "Control Type ("+ControlTypeMap[ControlTypeDirSync]+")"))
packet.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, c.Criticality, "Criticality")) // must be true always
val := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Control Value (DirSync)")
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "DirSync Control Value")
seq.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(c.Flags), "Flags"))
seq.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(c.MaxAttrCount), "MaxAttrCount"))
seq.AppendChild(cookie)
val.AppendChild(seq)
packet.AppendChild(val)
return packet
}
// SetCookie stores the given cookie in the dirSync control
func (c *ControlDirSync) SetCookie(cookie []byte) {
c.Cookie = cookie
}
// ControlServerSideSorting
type SortKey struct {
Reverse bool
AttributeType string
MatchingRule string
}
type ControlServerSideSorting struct {
SortKeys []*SortKey
}
func (c *ControlServerSideSorting) GetControlType() string {
return ControlTypeServerSideSorting
}
func NewControlServerSideSorting(value *ber.Packet) (*ControlServerSideSorting, error) {
sortKeys := []*SortKey{}
val := value.Children[1].Children
if len(val) != 1 {
return nil, fmt.Errorf("no sequence value in packet")
}
sequences := val[0].Children
for i, sequence := range sequences {
sortKey := &SortKey{}
if len(sequence.Children) < 2 {
return nil, fmt.Errorf("attributeType or matchingRule is missing from sequence %d", i)
}
sortKey.AttributeType = sequence.Children[0].Value.(string)
sortKey.MatchingRule = sequence.Children[1].Value.(string)
if len(sequence.Children) == 3 {
sortKey.Reverse = sequence.Children[2].Value.(bool)
}
sortKeys = append(sortKeys, sortKey)
}
return &ControlServerSideSorting{SortKeys: sortKeys}, nil
}
func NewControlServerSideSortingWithSortKeys(sortKeys []*SortKey) *ControlServerSideSorting {
return &ControlServerSideSorting{SortKeys: sortKeys}
}
func (c *ControlServerSideSorting) Encode() *ber.Packet {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
control := ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, c.GetControlType(), "Control Type")
value := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Control Value")
seqs := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "SortKeyList")
for _, f := range c.SortKeys {
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "")
seq.AppendChild(
ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, f.AttributeType, "attributeType"),
)
seq.AppendChild(
ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, f.MatchingRule, "orderingRule"),
)
if f.Reverse {
seq.AppendChild(
ber.NewBoolean(ber.ClassContext, ber.TypePrimitive, 1, f.Reverse, "reverseOrder"),
)
}
seqs.AppendChild(seq)
}
value.AppendChild(seqs)
packet.AppendChild(control)
packet.AppendChild(value)
return packet
}
func (c *ControlServerSideSorting) String() string {
return fmt.Sprintf(
"Control Type: %s (%q) Criticality:%t %+v",
"Server Side Sorting",
c.GetControlType(),
false,
c.SortKeys,
)
}
// ControlServerSideSortingResponse
const (
ControlServerSideSortingCodeSuccess ControlServerSideSortingCode = 0
ControlServerSideSortingCodeOperationsError ControlServerSideSortingCode = 1
ControlServerSideSortingCodeTimeLimitExceeded ControlServerSideSortingCode = 2
ControlServerSideSortingCodeStrongAuthRequired ControlServerSideSortingCode = 8
ControlServerSideSortingCodeAdminLimitExceeded ControlServerSideSortingCode = 11
ControlServerSideSortingCodeNoSuchAttribute ControlServerSideSortingCode = 16
ControlServerSideSortingCodeInappropriateMatching ControlServerSideSortingCode = 18
ControlServerSideSortingCodeInsufficientAccessRights ControlServerSideSortingCode = 50
ControlServerSideSortingCodeBusy ControlServerSideSortingCode = 51
ControlServerSideSortingCodeUnwillingToPerform ControlServerSideSortingCode = 53
ControlServerSideSortingCodeOther ControlServerSideSortingCode = 80
)
var ControlServerSideSortingCodes = []ControlServerSideSortingCode{
ControlServerSideSortingCodeSuccess,
ControlServerSideSortingCodeOperationsError,
ControlServerSideSortingCodeTimeLimitExceeded,
ControlServerSideSortingCodeStrongAuthRequired,
ControlServerSideSortingCodeAdminLimitExceeded,
ControlServerSideSortingCodeNoSuchAttribute,
ControlServerSideSortingCodeInappropriateMatching,
ControlServerSideSortingCodeInsufficientAccessRights,
ControlServerSideSortingCodeBusy,
ControlServerSideSortingCodeUnwillingToPerform,
ControlServerSideSortingCodeOther,
}
type ControlServerSideSortingCode int64
// Valid test the code contained in the control against the ControlServerSideSortingCodes slice and return an error if the code is unknown.
func (c ControlServerSideSortingCode) Valid() error {
for _, validRet := range ControlServerSideSortingCodes {
if c == validRet {
return nil
}
}
return fmt.Errorf("unknown return code : %d", c)
}
func NewControlServerSideSortingResult(pkt *ber.Packet) (*ControlServerSideSortingResult, error) {
control := &ControlServerSideSortingResult{}
if pkt == nil || len(pkt.Children) == 0 {
return nil, fmt.Errorf("bad packet")
}
codeInt, err := ber.ParseInt64(pkt.Children[0].Data.Bytes())
if err != nil {
return nil, err
}
code := ControlServerSideSortingCode(codeInt)
if err := code.Valid(); err != nil {
return nil, err
}
return control, nil
}
type ControlServerSideSortingResult struct {
Criticality bool
Result ControlServerSideSortingCode
// Not populated for now. I can't get openldap to send me this value, so I think this is specific to other directory server
// AttributeType string
}
func (control *ControlServerSideSortingResult) GetControlType() string {
return ControlTypeServerSideSortingResult
}
func (c *ControlServerSideSortingResult) Encode() *ber.Packet {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "SortResult sequence")
sortResult := ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, int64(c.Result), "SortResult")
packet.AppendChild(sortResult)
return packet
}
func (c *ControlServerSideSortingResult) String() string {
return fmt.Sprintf(
"Control Type: %s (%q) Criticality:%t ResultCode:%+v",
"Server Side Sorting Result",
c.GetControlType(),
c.Criticality,
c.Result,
)
}
// Mode for ControlTypeSyncRequest
type ControlSyncRequestMode int64
const (
SyncRequestModeRefreshOnly ControlSyncRequestMode = 1
SyncRequestModeRefreshAndPersist ControlSyncRequestMode = 3
)
// ControlSyncRequest implements the Sync Request Control described in https://www.ietf.org/rfc/rfc4533.txt
type ControlSyncRequest struct {
Criticality bool
Mode ControlSyncRequestMode
Cookie []byte
ReloadHint bool
}
func NewControlSyncRequest(
mode ControlSyncRequestMode, cookie []byte, reloadHint bool,
) *ControlSyncRequest {
return &ControlSyncRequest{
Criticality: true,
Mode: mode,
Cookie: cookie,
ReloadHint: reloadHint,
}
}
// GetControlType returns the OID
func (c *ControlSyncRequest) GetControlType() string {
return ControlTypeSyncRequest
}
// Encode encodes the control
func (c *ControlSyncRequest) Encode() *ber.Packet {
_mode := int64(c.Mode)
mode := ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, _mode, "Mode")
var cookie *ber.Packet
if len(c.Cookie) > 0 {
cookie = ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Cookie")
cookie.Value = c.Cookie
cookie.Data.Write(c.Cookie)
}
reloadHint := ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, c.ReloadHint, "Reload Hint")
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeSyncRequest, "Control Type ("+ControlTypeMap[ControlTypeSyncRequest]+")"))
packet.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, c.Criticality, "Criticality"))
val := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Control Value (Sync Request)")
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Sync Request Value")
seq.AppendChild(mode)
if cookie != nil {
seq.AppendChild(cookie)
}
seq.AppendChild(reloadHint)
val.AppendChild(seq)
packet.AppendChild(val)
return packet
}
// String returns a human-readable description
func (c *ControlSyncRequest) String() string {
return fmt.Sprintf(
"Control Type: %s (%q) Criticality: %t Mode: %d Cookie: %s ReloadHint: %t",
ControlTypeMap[ControlTypeSyncRequest],
ControlTypeSyncRequest,
c.Criticality,
c.Mode,
string(c.Cookie),
c.ReloadHint,
)
}
// State for ControlSyncState
type ControlSyncStateState int64
const (
SyncStatePresent ControlSyncStateState = 0
SyncStateAdd ControlSyncStateState = 1
SyncStateModify ControlSyncStateState = 2
SyncStateDelete ControlSyncStateState = 3
)
// ControlSyncState implements the Sync State Control described in https://www.ietf.org/rfc/rfc4533.txt
type ControlSyncState struct {
Criticality bool
State ControlSyncStateState
EntryUUID uuid.UUID
Cookie []byte
}
func NewControlSyncState(pkt *ber.Packet) (*ControlSyncState, error) {
var (
state ControlSyncStateState
entryUUID uuid.UUID
cookie []byte
err error
)
switch len(pkt.Children) {
case 0, 1:
return nil, fmt.Errorf("at least two children are required: %d", len(pkt.Children))
case 2:
state = ControlSyncStateState(pkt.Children[0].Value.(int64))
entryUUID, err = uuid.FromBytes(pkt.Children[1].ByteValue)
if err != nil {
return nil, fmt.Errorf("failed to decode uuid: %w", err)
}
case 3:
state = ControlSyncStateState(pkt.Children[0].Value.(int64))
entryUUID, err = uuid.FromBytes(pkt.Children[1].ByteValue)
if err != nil {
return nil, fmt.Errorf("failed to decode uuid: %w", err)
}
cookie = pkt.Children[2].ByteValue
}
return &ControlSyncState{
Criticality: false,
State: state,
EntryUUID: entryUUID,
Cookie: cookie,
}, nil
}
// GetControlType returns the OID
func (c *ControlSyncState) GetControlType() string {
return ControlTypeSyncState
}
// Encode encodes the control
func (c *ControlSyncState) Encode() *ber.Packet {
return nil
}
// String returns a human-readable description
func (c *ControlSyncState) String() string {
return fmt.Sprintf(
"Control Type: %s (%q) Criticality: %t State: %d EntryUUID: %s Cookie: %s",
ControlTypeMap[ControlTypeSyncState],
ControlTypeSyncState,
c.Criticality,
c.State,
c.EntryUUID.String(),
string(c.Cookie),
)
}
// ControlSyncDone implements the Sync Done Control described in https://www.ietf.org/rfc/rfc4533.txt
type ControlSyncDone struct {
Criticality bool
Cookie []byte
RefreshDeletes bool
}
func NewControlSyncDone(pkt *ber.Packet) (*ControlSyncDone, error) {
var (
cookie []byte
refreshDeletes bool
)
switch len(pkt.Children) {
case 0:
// have nothing to do
case 1:
cookie = pkt.Children[0].ByteValue
case 2:
cookie = pkt.Children[0].ByteValue
refreshDeletes = pkt.Children[1].Value.(bool)
}
return &ControlSyncDone{
Criticality: false,
Cookie: cookie,
RefreshDeletes: refreshDeletes,
}, nil
}
// GetControlType returns the OID
func (c *ControlSyncDone) GetControlType() string {
return ControlTypeSyncDone
}
// Encode encodes the control
func (c *ControlSyncDone) Encode() *ber.Packet {
return nil
}
// String returns a human-readable description
func (c *ControlSyncDone) String() string {
return fmt.Sprintf(
"Control Type: %s (%q) Criticality: %t Cookie: %s RefreshDeletes: %t",
ControlTypeMap[ControlTypeSyncDone],
ControlTypeSyncDone,
c.Criticality,
string(c.Cookie),
c.RefreshDeletes,
)
}
// Tag For ControlSyncInfo
type ControlSyncInfoValue uint64
const (
SyncInfoNewcookie ControlSyncInfoValue = 0
SyncInfoRefreshDelete ControlSyncInfoValue = 1
SyncInfoRefreshPresent ControlSyncInfoValue = 2
SyncInfoSyncIdSet ControlSyncInfoValue = 3
)
// ControlSyncInfoNewCookie implements a part of syncInfoValue described in https://www.ietf.org/rfc/rfc4533.txt
type ControlSyncInfoNewCookie struct {
Cookie []byte
}
// String returns a human-readable description
func (c *ControlSyncInfoNewCookie) String() string {
return fmt.Sprintf(
"NewCookie[Cookie: %s]",
string(c.Cookie),
)
}
// ControlSyncInfoRefreshDelete implements a part of syncInfoValue described in https://www.ietf.org/rfc/rfc4533.txt
type ControlSyncInfoRefreshDelete struct {
Cookie []byte
RefreshDone bool
}
// String returns a human-readable description
func (c *ControlSyncInfoRefreshDelete) String() string {
return fmt.Sprintf(
"RefreshDelete[Cookie: %s RefreshDone: %t]",
string(c.Cookie),
c.RefreshDone,
)
}
// ControlSyncInfoRefreshPresent implements a part of syncInfoValue described in https://www.ietf.org/rfc/rfc4533.txt
type ControlSyncInfoRefreshPresent struct {
Cookie []byte
RefreshDone bool
}
// String returns a human-readable description
func (c *ControlSyncInfoRefreshPresent) String() string {
return fmt.Sprintf(
"RefreshPresent[Cookie: %s RefreshDone: %t]",
string(c.Cookie),
c.RefreshDone,
)
}
// ControlSyncInfoSyncIdSet implements a part of syncInfoValue described in https://www.ietf.org/rfc/rfc4533.txt
type ControlSyncInfoSyncIdSet struct {
Cookie []byte
RefreshDeletes bool
SyncUUIDs []uuid.UUID
}
// String returns a human-readable description
func (c *ControlSyncInfoSyncIdSet) String() string {
return fmt.Sprintf(
"SyncIdSet[Cookie: %s RefreshDeletes: %t SyncUUIDs: %v]",
string(c.Cookie),
c.RefreshDeletes,
c.SyncUUIDs,
)
}
// ControlSyncInfo implements the Sync Info Control described in https://www.ietf.org/rfc/rfc4533.txt
type ControlSyncInfo struct {
Criticality bool
Value ControlSyncInfoValue
NewCookie *ControlSyncInfoNewCookie
RefreshDelete *ControlSyncInfoRefreshDelete
RefreshPresent *ControlSyncInfoRefreshPresent
SyncIdSet *ControlSyncInfoSyncIdSet
}
func NewControlSyncInfo(pkt *ber.Packet) (*ControlSyncInfo, error) {
var (
cookie []byte
refreshDone = true
refreshDeletes bool
syncUUIDs []uuid.UUID
)
c := &ControlSyncInfo{Criticality: false}
switch ControlSyncInfoValue(pkt.Identifier.Tag) {
case SyncInfoNewcookie:
c.Value = SyncInfoNewcookie
c.NewCookie = &ControlSyncInfoNewCookie{
Cookie: pkt.ByteValue,
}
case SyncInfoRefreshDelete:
c.Value = SyncInfoRefreshDelete
switch len(pkt.Children) {
case 0:
// have nothing to do
case 1:
cookie = pkt.Children[0].ByteValue
case 2:
cookie = pkt.Children[0].ByteValue
refreshDone = pkt.Children[1].Value.(bool)
}
c.RefreshDelete = &ControlSyncInfoRefreshDelete{
Cookie: cookie,
RefreshDone: refreshDone,
}
case SyncInfoRefreshPresent:
c.Value = SyncInfoRefreshPresent
switch len(pkt.Children) {
case 0:
// have nothing to do
case 1:
cookie = pkt.Children[0].ByteValue
case 2:
cookie = pkt.Children[0].ByteValue
refreshDone = pkt.Children[1].Value.(bool)
}
c.RefreshPresent = &ControlSyncInfoRefreshPresent{
Cookie: cookie,
RefreshDone: refreshDone,
}
case SyncInfoSyncIdSet:
c.Value = SyncInfoSyncIdSet
switch len(pkt.Children) {
case 0:
// have nothing to do
case 1:
cookie = pkt.Children[0].ByteValue
case 2:
cookie = pkt.Children[0].ByteValue
refreshDeletes = pkt.Children[1].Value.(bool)
case 3:
cookie = pkt.Children[0].ByteValue
refreshDeletes = pkt.Children[1].Value.(bool)
syncUUIDs = make([]uuid.UUID, 0, len(pkt.Children[2].Children))
for _, child := range pkt.Children[2].Children {
u, err := uuid.FromBytes(child.ByteValue)
if err != nil {
return nil, fmt.Errorf("failed to decode uuid: %w", err)
}
syncUUIDs = append(syncUUIDs, u)
}
}
c.SyncIdSet = &ControlSyncInfoSyncIdSet{
Cookie: cookie,
RefreshDeletes: refreshDeletes,
SyncUUIDs: syncUUIDs,
}
default:
return nil, fmt.Errorf("unknown sync info value: %d", pkt.Identifier.Tag)
}
return c, nil
}
// GetControlType returns the OID
func (c *ControlSyncInfo) GetControlType() string {
return ControlTypeSyncInfo
}
// Encode encodes the control
func (c *ControlSyncInfo) Encode() *ber.Packet {
return nil
}
// String returns a human-readable description
func (c *ControlSyncInfo) String() string {
return fmt.Sprintf(
"Control Type: %s (%q) Criticality: %t Value: %d %s %s %s %s",
ControlTypeMap[ControlTypeSyncInfo],
ControlTypeSyncInfo,
c.Criticality,
c.Value,
c.NewCookie,
c.RefreshDelete,
c.RefreshPresent,
c.SyncIdSet,
)
}

View file

@ -1,13 +1,11 @@
package ldap
import (
"log"
ber "github.com/go-asn1-ber/asn1-ber"
)
// debugging type
// - has a Printf method to write the debug output
// - has a Printf method to write the debug output
type debugging bool
// Enable controls debugging mode.
@ -18,13 +16,13 @@ func (debug *debugging) Enable(b bool) {
// Printf writes debug output.
func (debug debugging) Printf(format string, args ...interface{}) {
if debug {
log.Printf(format, args...)
logger.Printf(format, args...)
}
}
// PrintPacket dumps a packet.
func (debug debugging) PrintPacket(packet *ber.Packet) {
if debug {
ber.WritePacket(log.Writer(), packet)
ber.WritePacket(logger.Writer(), packet)
}
}

View file

@ -1,8 +1,7 @@
package ldap
import (
"log"
"fmt"
ber "github.com/go-asn1-ber/asn1-ber"
)
@ -53,7 +52,8 @@ func (l *Conn) Del(delRequest *DelRequest) error {
return err
}
} else {
log.Printf("Unexpected Response: %d", packet.Children[1].Tag)
return fmt.Errorf("ldap: unexpected response: %d", packet.Children[1].Tag)
}
return nil
}

View file

@ -5,6 +5,7 @@ import (
enchex "encoding/hex"
"errors"
"fmt"
"sort"
"strings"
ber "github.com/go-asn1-ber/asn1-ber"
@ -18,16 +19,95 @@ type AttributeTypeAndValue struct {
Value string
}
// String returns a normalized string representation of this attribute type and
// value pair which is the a lowercased join of the Type and Value with a "=".
func (a *AttributeTypeAndValue) String() string {
return strings.ToLower(a.Type) + "=" + a.encodeValue()
}
func (a *AttributeTypeAndValue) encodeValue() string {
// Normalize the value first.
// value := strings.ToLower(a.Value)
value := a.Value
encodedBuf := bytes.Buffer{}
escapeChar := func(c byte) {
encodedBuf.WriteByte('\\')
encodedBuf.WriteByte(c)
}
escapeHex := func(c byte) {
encodedBuf.WriteByte('\\')
encodedBuf.WriteString(enchex.EncodeToString([]byte{c}))
}
for i := 0; i < len(value); i++ {
char := value[i]
if i == 0 && char == ' ' || char == '#' {
// Special case leading space or number sign.
escapeChar(char)
continue
}
if i == len(value)-1 && char == ' ' {
// Special case trailing space.
escapeChar(char)
continue
}
switch char {
case '"', '+', ',', ';', '<', '>', '\\':
// Each of these special characters must be escaped.
escapeChar(char)
continue
}
if char < ' ' || char > '~' {
// All special character escapes are handled first
// above. All bytes less than ASCII SPACE and all bytes
// greater than ASCII TILDE must be hex-escaped.
escapeHex(char)
continue
}
// Any other character does not require escaping.
encodedBuf.WriteByte(char)
}
return encodedBuf.String()
}
// RelativeDN represents a relativeDistinguishedName from https://tools.ietf.org/html/rfc4514
type RelativeDN struct {
Attributes []*AttributeTypeAndValue
}
// String returns a normalized string representation of this relative DN which
// is the a join of all attributes (sorted in increasing order) with a "+".
func (r *RelativeDN) String() string {
attrs := make([]string, len(r.Attributes))
for i := range r.Attributes {
attrs[i] = r.Attributes[i].String()
}
sort.Strings(attrs)
return strings.Join(attrs, "+")
}
// DN represents a distinguishedName from https://tools.ietf.org/html/rfc4514
type DN struct {
RDNs []*RelativeDN
}
// String returns a normalized string representation of this DN which is the
// join of all relative DNs with a ",".
func (d *DN) String() string {
rdns := make([]string, len(d.RDNs))
for i := range d.RDNs {
rdns[i] = d.RDNs[i].String()
}
return strings.Join(rdns, ",")
}
// ParseDN returns a distinguishedName or an error.
// The function respects https://tools.ietf.org/html/rfc4514
func ParseDN(str string) (*DN, error) {
@ -76,7 +156,7 @@ func ParseDN(str string) (*DN, error) {
case char == '\\':
unescapedTrailingSpaces = 0
escaping = true
case char == '=':
case char == '=' && attribute.Type == "":
attribute.Type = stringFromBuffer()
// Special case: If the first character in the value is # the
// following data is BER encoded so we can just fast forward
@ -84,7 +164,7 @@ func ParseDN(str string) (*DN, error) {
if len(str) > i+1 && str[i+1] == '#' {
i += 2
index := strings.IndexAny(str[i:], ",+")
data := str
var data string
if index > 0 {
data = str[i : i+index]
} else {
@ -101,7 +181,7 @@ func ParseDN(str string) (*DN, error) {
buffer.WriteString(packet.Data.String())
i += len(data) - 1
}
case char == ',' || char == '+':
case char == ',' || char == '+' || char == ';':
// We're done with this RDN or value, push it
if len(attribute.Type) == 0 {
return nil, errors.New("incomplete type, value pair")
@ -109,7 +189,7 @@ func ParseDN(str string) (*DN, error) {
attribute.Value = stringFromBuffer()
rdn.Attributes = append(rdn.Attributes, attribute)
attribute = new(AttributeTypeAndValue)
if char == ',' {
if char == ',' || char == ';' {
dn.RDNs = append(dn.RDNs, rdn)
rdn = new(RelativeDN)
rdn.Attributes = make([]*AttributeTypeAndValue, 0)
@ -206,7 +286,7 @@ func (a *AttributeTypeAndValue) Equal(other *AttributeTypeAndValue) bool {
return strings.EqualFold(a.Type, other.Type) && a.Value == other.Value
}
// Equal returns true if the DNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch).
// EqualFold returns true if the DNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch).
// Returns true if they have the same number of relative distinguished names
// and corresponding relative distinguished names (by position) are the same.
// Case of the attribute type and value is not significant
@ -238,7 +318,7 @@ func (d *DN) AncestorOfFold(other *DN) bool {
return true
}
// Equal returns true if the RelativeDNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch).
// EqualFold returns true if the RelativeDNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch).
// Case of the attribute type is not significant
func (r *RelativeDN) EqualFold(other *RelativeDN) bool {
if len(r.Attributes) != len(other.Attributes) {

View file

@ -192,6 +192,8 @@ func (e *Error) Error() string {
return fmt.Sprintf("LDAP Result Code %d %q: %s", e.ResultCode, LDAPResultCodeMap[e.ResultCode], e.Err.Error())
}
func (e *Error) Unwrap() error { return e.Err }
// GetLDAPError creates an Error out of a BER packet representing a LDAPResult
// The return is an error object. It can be casted to a Error structure.
// This function returns nil if resultCode in the LDAPResult sequence is success(0).
@ -206,15 +208,21 @@ func GetLDAPError(packet *ber.Packet) error {
return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty response in packet"), Packet: packet}
}
if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) >= 3 {
resultCode := uint16(response.Children[0].Value.(int64))
if resultCode == 0 { // No error
return nil
}
return &Error{
ResultCode: resultCode,
MatchedDN: response.Children[1].Value.(string),
Err: fmt.Errorf("%s", response.Children[2].Value.(string)),
Packet: packet,
if ber.Type(response.Children[0].Tag) == ber.Type(ber.TagInteger) || ber.Type(response.Children[0].Tag) == ber.Type(ber.TagEnumerated) {
resultCode := uint16(response.Children[0].Value.(int64))
if resultCode == 0 { // No error
return nil
}
if ber.Type(response.Children[1].Tag) == ber.Type(ber.TagOctetString) &&
ber.Type(response.Children[2].Tag) == ber.Type(ber.TagOctetString) {
return &Error{
ResultCode: resultCode,
MatchedDN: response.Children[1].Value.(string),
Err: fmt.Errorf("%s", response.Children[2].Value.(string)),
Packet: packet,
}
}
}
}
}

View file

@ -396,7 +396,7 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
case packet.Tag == FilterEqualityMatch && bytes.Equal(condition.Bytes(), _SymbolAny):
packet = ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterPresent, attribute.String(), FilterMap[FilterPresent])
case packet.Tag == FilterEqualityMatch && bytes.Index(condition.Bytes(), _SymbolAny) > -1:
case packet.Tag == FilterEqualityMatch && bytes.Contains(condition.Bytes(), _SymbolAny):
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute.String(), "Attribute"))
packet.Tag = FilterSubstrings
packet.Description = FilterMap[uint64(packet.Tag)]
@ -438,7 +438,6 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
// Convert from "ABC\xx\xx\xx" form to literal bytes for transport
func decodeEscapedSymbols(src []byte) (string, error) {
var (
buffer bytes.Buffer
offset int

View file

@ -3,7 +3,9 @@ package ldap
import (
"fmt"
"io/ioutil"
"log"
"os"
"strings"
ber "github.com/go-asn1-ber/asn1-ber"
)
@ -30,6 +32,7 @@ const (
ApplicationSearchResultReference = 19
ApplicationExtendedRequest = 23
ApplicationExtendedResponse = 24
ApplicationIntermediateResponse = 25
)
// ApplicationMap contains human readable descriptions of LDAP Application Codes
@ -54,6 +57,7 @@ var ApplicationMap = map[uint8]string{
ApplicationSearchResultReference: "Search Result Reference",
ApplicationExtendedRequest: "Extended Request",
ApplicationExtendedResponse: "Extended Response",
ApplicationIntermediateResponse: "Intermediate Response",
}
// Ldap Behera Password Policy Draft 10 (https://tools.ietf.org/html/draft-behera-ldap-password-policy-10)
@ -82,6 +86,13 @@ var BeheraPasswordPolicyErrorMap = map[int8]string{
BeheraPasswordInHistory: "New password is in list of old passwords",
}
var logger = log.New(os.Stderr, "", log.LstdFlags)
// Logger allows clients to override the default logger
func Logger(l *log.Logger) {
logger = l
}
// Adds descriptions to an LDAP Response packet for debugging
func addLDAPDescriptions(packet *ber.Packet) (err error) {
defer func() {
@ -221,18 +232,18 @@ func addControlDescriptions(packet *ber.Packet) error {
sequence := value.Children[0]
for _, child := range sequence.Children {
if child.Tag == 0 {
//Warning
// Warning
warningPacket := child.Children[0]
val, err := ber.ParseInt64(warningPacket.Data.Bytes())
if err != nil {
return fmt.Errorf("failed to decode data bytes: %s", err)
}
if warningPacket.Tag == 0 {
//timeBeforeExpiration
// timeBeforeExpiration
value.Description += " (TimeBeforeExpiration)"
warningPacket.Value = val
} else if warningPacket.Tag == 1 {
//graceAuthNsRemaining
// graceAuthNsRemaining
value.Description += " (GraceAuthNsRemaining)"
warningPacket.Value = val
}
@ -337,3 +348,43 @@ func EscapeFilter(filter string) string {
}
return string(buf)
}
// EscapeDN escapes distinguished names as described in RFC4514. Characters in the
// set `"+,;<>\` are escaped by prepending a backslash, which is also done for trailing
// spaces or a leading `#`. Null bytes are replaced with `\00`.
func EscapeDN(dn string) string {
if dn == "" {
return ""
}
builder := strings.Builder{}
for i, r := range dn {
// Escape leading and trailing spaces
if (i == 0 || i == len(dn)-1) && r == ' ' {
builder.WriteRune('\\')
builder.WriteRune(r)
continue
}
// Escape leading '#'
if i == 0 && r == '#' {
builder.WriteRune('\\')
builder.WriteRune(r)
continue
}
// Escape characters as defined in RFC4514
switch r {
case '"', '+', ',', ';', '<', '>', '\\':
builder.WriteRune('\\')
builder.WriteRune(r)
case '\x00': // Null byte may not be escaped by a leading backslash
builder.WriteString("\\00")
default:
builder.WriteRune(r)
}
}
return builder.String()
}

View file

@ -1,8 +1,7 @@
package ldap
import (
"log"
"fmt"
ber "github.com/go-asn1-ber/asn1-ber"
)
@ -25,7 +24,9 @@ type ModifyDNRequest struct {
// RDN of the given DN.
//
// A call like
// mdnReq := NewModifyDNRequest("uid=someone,dc=example,dc=org", "uid=newname", true, "")
//
// mdnReq := NewModifyDNRequest("uid=someone,dc=example,dc=org", "uid=newname", true, "")
//
// will setup the request to just rename uid=someone,dc=example,dc=org to
// uid=newname,dc=example,dc=org.
func NewModifyDNRequest(dn string, rdn string, delOld bool, newSup string) *ModifyDNRequest {
@ -94,7 +95,8 @@ func (l *Conn) ModifyDN(m *ModifyDNRequest) error {
return err
}
} else {
log.Printf("Unexpected Response: %d", packet.Children[1].Tag)
return fmt.Errorf("ldap: unexpected response: %d", packet.Children[1].Tag)
}
return nil
}

View file

@ -2,7 +2,7 @@ package ldap
import (
"errors"
"log"
"fmt"
ber "github.com/go-asn1-ber/asn1-ber"
)
@ -127,8 +127,9 @@ func (l *Conn) Modify(modifyRequest *ModifyRequest) error {
return err
}
} else {
log.Printf("Unexpected Response: %d", packet.Children[1].Tag)
return fmt.Errorf("ldap: unexpected response: %d", packet.Children[1].Tag)
}
return nil
}
@ -136,6 +137,8 @@ func (l *Conn) Modify(modifyRequest *ModifyRequest) error {
type ModifyResult struct {
// Controls are the returned controls
Controls []Control
// Referral is the returned referral
Referral string
}
// ModifyWithResult performs the ModifyRequest and returns the result
@ -158,9 +161,10 @@ func (l *Conn) ModifyWithResult(modifyRequest *ModifyRequest) (*ModifyResult, er
switch packet.Children[1].Tag {
case ApplicationModifyResponse:
err := GetLDAPError(packet)
if err != nil {
return nil, err
if err = GetLDAPError(packet); err != nil {
result.Referral = getReferral(err, packet)
return result, err
}
if len(packet.Children) == 3 {
for _, child := range packet.Children[2].Children {

View file

@ -70,7 +70,6 @@ func (req *PasswordModifyRequest) appendTo(envelope *ber.Packet) error {
// newPassword is the desired user's password. If empty the server can return
// an error or generate a new password that will be available in the
// PasswordModifyResult.GeneratedPassword
//
func NewPasswordModifyRequest(userIdentity string, oldPassword string, newPassword string) *PasswordModifyRequest {
return &PasswordModifyRequest{
UserIdentity: userIdentity,
@ -95,15 +94,9 @@ func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*Pa
result := &PasswordModifyResult{}
if packet.Children[1].Tag == ApplicationExtendedResponse {
err := GetLDAPError(packet)
if err != nil {
if IsErrorWithCode(err, LDAPResultReferral) {
for _, child := range packet.Children[1].Children {
if child.Tag == 3 {
result.Referral = child.Children[0].Value.(string)
}
}
}
if err = GetLDAPError(packet); err != nil {
result.Referral = getReferral(err, packet)
return result, err
}
} else {
@ -112,10 +105,10 @@ func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*Pa
extendedResponse := packet.Children[1]
for _, child := range extendedResponse.Children {
if child.Tag == 11 {
if child.Tag == ber.TagEmbeddedPDV {
passwordModifyResponseValue := ber.DecodePacket(child.Data.Bytes())
if len(passwordModifyResponseValue.Children) == 1 {
if passwordModifyResponseValue.Children[0].Tag == 0 {
if passwordModifyResponseValue.Children[0].Tag == ber.TagEOC {
result.GeneratedPassword = ber.DecodeString(passwordModifyResponseValue.Children[0].Data.Bytes())
}
}

View file

@ -9,7 +9,8 @@ import (
var (
errRespChanClosed = errors.New("ldap: response channel closed")
errCouldNotRetMsg = errors.New("ldap: could not retrieve message")
ErrNilConnection = errors.New("ldap: conn is nil, expected net.Conn")
// ErrNilConnection is returned if doRequest is called with a nil connection.
ErrNilConnection = errors.New("ldap: conn is nil, expected net.Conn")
)
type request interface {
@ -69,3 +70,41 @@ func (l *Conn) readPacket(msgCtx *messageContext) (*ber.Packet, error) {
}
return packet, nil
}
func getReferral(err error, packet *ber.Packet) (referral string) {
if !IsErrorWithCode(err, LDAPResultReferral) {
return ""
}
if len(packet.Children) < 2 {
return ""
}
// The packet Tag itself (of child 2) is generally a ber.TagObjectDescriptor with referrals however OpenLDAP
// seemingly returns a ber.Tag.GeneralizedTime. Every currently tested LDAP server which returns referrals returns
// an ASN.1 BER packet with the Type of ber.TypeConstructed and Class of ber.ClassApplication however. Thus this
// check expressly checks these fields instead.
//
// Related Issues:
// - https://github.com/authelia/authelia/issues/4199 (downstream)
if len(packet.Children[1].Children) == 0 || (packet.Children[1].TagType != ber.TypeConstructed || packet.Children[1].ClassType != ber.ClassApplication) {
return ""
}
var ok bool
for _, child := range packet.Children[1].Children {
// The referral URI itself should be contained within a child which has a Tag of ber.BitString or
// ber.TagPrintableString, and the Type of ber.TypeConstructed and the Class of ClassContext. As soon as any of
// these conditions is not true we can skip this child.
if (child.Tag != ber.TagBitString && child.Tag != ber.TagPrintableString) || child.TagType != ber.TypeConstructed || child.ClassType != ber.ClassContext {
continue
}
if referral, ok = child.Children[0].Value.(string); ok {
return referral
}
}
return ""
}

207
vendor/github.com/go-ldap/ldap/v3/response.go generated vendored Normal file
View file

@ -0,0 +1,207 @@
package ldap
import (
"context"
"errors"
"fmt"
ber "github.com/go-asn1-ber/asn1-ber"
)
// Response defines an interface to get data from an LDAP server
type Response interface {
Entry() *Entry
Referral() string
Controls() []Control
Err() error
Next() bool
}
type searchResponse struct {
conn *Conn
ch chan *SearchSingleResult
entry *Entry
referral string
controls []Control
err error
}
// Entry returns an entry from the given search request
func (r *searchResponse) Entry() *Entry {
return r.entry
}
// Referral returns a referral from the given search request
func (r *searchResponse) Referral() string {
return r.referral
}
// Controls returns controls from the given search request
func (r *searchResponse) Controls() []Control {
return r.controls
}
// Err returns an error when the given search request was failed
func (r *searchResponse) Err() error {
return r.err
}
// Next returns whether next data exist or not
func (r *searchResponse) Next() bool {
res, ok := <-r.ch
if !ok {
return false
}
if res == nil {
return false
}
r.err = res.Error
if r.err != nil {
return false
}
r.entry = res.Entry
r.referral = res.Referral
r.controls = res.Controls
return true
}
func (r *searchResponse) start(ctx context.Context, searchRequest *SearchRequest) {
go func() {
defer func() {
close(r.ch)
if err := recover(); err != nil {
r.conn.err = fmt.Errorf("ldap: recovered panic in searchResponse: %v", err)
}
}()
if r.conn.IsClosing() {
return
}
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, r.conn.nextMessageID(), "MessageID"))
// encode search request
err := searchRequest.appendTo(packet)
if err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
r.conn.Debug.PrintPacket(packet)
msgCtx, err := r.conn.sendMessage(packet)
if err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
defer r.conn.finishMessage(msgCtx)
foundSearchSingleResultDone := false
for !foundSearchSingleResultDone {
select {
case <-ctx.Done():
r.conn.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error())
return
default:
r.conn.Debug.Printf("%d: waiting for response", msgCtx.id)
packetResponse, ok := <-msgCtx.responses
if !ok {
err := NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
r.ch <- &SearchSingleResult{Error: err}
return
}
packet, err = packetResponse.ReadPacket()
r.conn.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
if r.conn.Debug {
if err := addLDAPDescriptions(packet); err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
ber.PrintPacket(packet)
}
switch packet.Children[1].Tag {
case ApplicationSearchResultEntry:
result := &SearchSingleResult{
Entry: &Entry{
DN: packet.Children[1].Children[0].Value.(string),
Attributes: unpackAttributes(packet.Children[1].Children[1].Children),
},
}
if len(packet.Children) != 3 {
r.ch <- result
continue
}
decoded, err := DecodeControl(packet.Children[2].Children[0])
if err != nil {
werr := fmt.Errorf("failed to decode search result entry: %w", err)
result.Error = werr
r.ch <- result
return
}
result.Controls = append(result.Controls, decoded)
r.ch <- result
case ApplicationSearchResultDone:
if err := GetLDAPError(packet); err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
if len(packet.Children) == 3 {
result := &SearchSingleResult{}
for _, child := range packet.Children[2].Children {
decodedChild, err := DecodeControl(child)
if err != nil {
werr := fmt.Errorf("failed to decode child control: %w", err)
r.ch <- &SearchSingleResult{Error: werr}
return
}
result.Controls = append(result.Controls, decodedChild)
}
r.ch <- result
}
foundSearchSingleResultDone = true
case ApplicationSearchResultReference:
ref := packet.Children[1].Children[0].Value.(string)
r.ch <- &SearchSingleResult{Referral: ref}
case ApplicationIntermediateResponse:
decoded, err := DecodeControl(packet.Children[1])
if err != nil {
werr := fmt.Errorf("failed to decode intermediate response: %w", err)
r.ch <- &SearchSingleResult{Error: werr}
return
}
result := &SearchSingleResult{}
result.Controls = append(result.Controls, decoded)
r.ch <- result
default:
err := fmt.Errorf("unknown tag: %d", packet.Children[1].Tag)
r.ch <- &SearchSingleResult{Error: err}
return
}
}
}
r.conn.Debug.Printf("%d: returning", msgCtx.id)
}()
}
func newSearchResponse(conn *Conn, bufferSize int) *searchResponse {
var ch chan *SearchSingleResult
if bufferSize > 0 {
ch = make(chan *SearchSingleResult, bufferSize)
} else {
ch = make(chan *SearchSingleResult)
}
return &searchResponse{
conn: conn,
ch: ch,
}
}

View file

@ -1,10 +1,14 @@
package ldap
import (
"context"
"errors"
"fmt"
"reflect"
"sort"
"strconv"
"strings"
"time"
ber "github.com/go-asn1-ber/asn1-ber"
)
@ -161,6 +165,155 @@ func (e *Entry) PrettyPrint(indent int) {
}
}
// Describe the tag to use for struct field tags
const decoderTagName = "ldap"
// readTag will read the reflect.StructField value for
// the key defined in decoderTagName. If omitempty is
// specified, the field may not be filled.
func readTag(f reflect.StructField) (string, bool) {
val, ok := f.Tag.Lookup(decoderTagName)
if !ok {
return f.Name, false
}
opts := strings.Split(val, ",")
omit := false
if len(opts) == 2 {
omit = opts[1] == "omitempty"
}
return opts[0], omit
}
// Unmarshal parses the Entry in the value pointed to by i
//
// Currently, this methods only supports struct fields of type
// string, []string, int, int64, []byte, *DN, []*DN or time.Time. Other field types
// will not be regarded. If the field type is a string or int but multiple
// attribute values are returned, the first value will be used to fill the field.
//
// Example:
//
// type UserEntry struct {
// // Fields with the tag key `dn` are automatically filled with the
// // objects distinguishedName. This can be used multiple times.
// DN string `ldap:"dn"`
//
// // This field will be filled with the attribute value for
// // userPrincipalName. An attribute can be read into a struct field
// // multiple times. Missing attributes will not result in an error.
// UserPrincipalName string `ldap:"userPrincipalName"`
//
// // memberOf may have multiple values. If you don't
// // know the amount of attribute values at runtime, use a string array.
// MemberOf []string `ldap:"memberOf"`
//
// // ID is an integer value, it will fail unmarshaling when the given
// // attribute value cannot be parsed into an integer.
// ID int `ldap:"id"`
//
// // LongID is similar to ID but uses an int64 instead.
// LongID int64 `ldap:"longId"`
//
// // Data is similar to MemberOf a slice containing all attribute
// // values.
// Data []byte `ldap:"data"`
//
// // Time is parsed with the generalizedTime spec into a time.Time
// Created time.Time `ldap:"createdTimestamp"`
//
// // *DN is parsed with the ParseDN
// Owner *ldap.DN `ldap:"owner"`
//
// // []*DN is parsed with the ParseDN
// Children []*ldap.DN `ldap:"children"`
//
// // This won't work, as the field is not of type string. For this
// // to work, you'll have to temporarily store the result in string
// // (or string array) and convert it to the desired type afterwards.
// UserAccountControl uint32 `ldap:"userPrincipalName"`
// }
// user := UserEntry{}
//
// if err := result.Unmarshal(&user); err != nil {
// // ...
// }
func (e *Entry) Unmarshal(i interface{}) (err error) {
// Make sure it's a ptr
if vo := reflect.ValueOf(i).Kind(); vo != reflect.Ptr {
return fmt.Errorf("ldap: cannot use %s, expected pointer to a struct", vo)
}
sv, st := reflect.ValueOf(i).Elem(), reflect.TypeOf(i).Elem()
// Make sure it's pointing to a struct
if sv.Kind() != reflect.Struct {
return fmt.Errorf("ldap: expected pointer to a struct, got %s", sv.Kind())
}
for n := 0; n < st.NumField(); n++ {
// Holds struct field value and type
fv, ft := sv.Field(n), st.Field(n)
// skip unexported fields
if ft.PkgPath != "" {
continue
}
// omitempty can be safely discarded, as it's not needed when unmarshalling
fieldTag, _ := readTag(ft)
// Fill the field with the distinguishedName if the tag key is `dn`
if fieldTag == "dn" {
fv.SetString(e.DN)
continue
}
values := e.GetAttributeValues(fieldTag)
if len(values) == 0 {
continue
}
switch fv.Interface().(type) {
case []string:
for _, item := range values {
fv.Set(reflect.Append(fv, reflect.ValueOf(item)))
}
case string:
fv.SetString(values[0])
case []byte:
fv.SetBytes([]byte(values[0]))
case int, int64:
intVal, err := strconv.ParseInt(values[0], 10, 64)
if err != nil {
return fmt.Errorf("ldap: could not parse value '%s' into int field", values[0])
}
fv.SetInt(intVal)
case time.Time:
t, err := ber.ParseGeneralizedTime([]byte(values[0]))
if err != nil {
return fmt.Errorf("ldap: could not parse value '%s' into time.Time field", values[0])
}
fv.Set(reflect.ValueOf(t))
case *DN:
dn, err := ParseDN(values[0])
if err != nil {
return fmt.Errorf("ldap: could not parse value '%s' into *ldap.DN field", values[0])
}
fv.Set(reflect.ValueOf(dn))
case []*DN:
for _, item := range values {
dn, err := ParseDN(item)
if err != nil {
return fmt.Errorf("ldap: could not parse value '%s' into *ldap.DN field", item)
}
fv.Set(reflect.Append(fv, reflect.ValueOf(dn)))
}
default:
return fmt.Errorf("ldap: expected field to be of type string, []string, int, int64, []byte, *DN, []*DN or time.Time, got %v", ft.Type)
}
}
return
}
// NewEntryAttribute returns a new EntryAttribute with the desired key-value pair
func NewEntryAttribute(name string, values []string) *EntryAttribute {
var bytes [][]byte
@ -218,6 +371,35 @@ func (s *SearchResult) PrettyPrint(indent int) {
}
}
// appendTo appends all entries of `s` to `r`
func (s *SearchResult) appendTo(r *SearchResult) {
r.Entries = append(r.Entries, s.Entries...)
r.Referrals = append(r.Referrals, s.Referrals...)
r.Controls = append(r.Controls, s.Controls...)
}
// SearchSingleResult holds the server's single entry response to a search request
type SearchSingleResult struct {
// Entry is the returned entry
Entry *Entry
// Referral is the returned referral
Referral string
// Controls are the returned controls
Controls []Control
// Error is set when the search request was failed
Error error
}
// Print outputs a human-readable description
func (s *SearchSingleResult) Print() {
s.Entry.Print()
}
// PrettyPrint outputs a human-readable description with indenting
func (s *SearchSingleResult) PrettyPrint(indent int) {
s.Entry.PrettyPrint(indent)
}
// SearchRequest represents a search request to send to the server
type SearchRequest struct {
BaseDN string
@ -285,10 +467,11 @@ func NewSearchRequest(
// SearchWithPaging accepts a search request and desired page size in order to execute LDAP queries to fulfill the
// search request. All paged LDAP query responses will be buffered and the final result will be returned atomically.
// The following four cases are possible given the arguments:
// - given SearchRequest missing a control of type ControlTypePaging: we will add one with the desired paging size
// - given SearchRequest contains a control of type ControlTypePaging that isn't actually a ControlPaging: fail without issuing any queries
// - given SearchRequest contains a control of type ControlTypePaging with pagingSize equal to the size requested: no change to the search request
// - given SearchRequest contains a control of type ControlTypePaging with pagingSize not equal to the size requested: fail without issuing any queries
// - given SearchRequest missing a control of type ControlTypePaging: we will add one with the desired paging size
// - given SearchRequest contains a control of type ControlTypePaging that isn't actually a ControlPaging: fail without issuing any queries
// - given SearchRequest contains a control of type ControlTypePaging with pagingSize equal to the size requested: no change to the search request
// - given SearchRequest contains a control of type ControlTypePaging with pagingSize not equal to the size requested: fail without issuing any queries
//
// A requested pagingSize of 0 is interpreted as no limit by LDAP servers.
func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) {
var pagingControl *ControlPaging
@ -311,23 +494,19 @@ func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32)
searchResult := new(SearchResult)
for {
result, err := l.Search(searchRequest)
l.Debug.Printf("Looking for Paging Control...")
if result != nil {
result.appendTo(searchResult)
} else {
if err == nil {
// We have to do this beautifulness in case something absolutely strange happens, which
// should only occur in case there is no packet, but also no error.
return searchResult, NewError(ErrorNetwork, errors.New("ldap: packet not received"))
}
}
if err != nil {
// If an error occurred, all results that have been received so far will be returned
return searchResult, err
}
if result == nil {
return searchResult, NewError(ErrorNetwork, errors.New("ldap: packet not received"))
}
for _, entry := range result.Entries {
searchResult.Entries = append(searchResult.Entries, entry)
}
for _, referral := range result.Referrals {
searchResult.Referrals = append(searchResult.Referrals, referral)
}
for _, control := range result.Controls {
searchResult.Controls = append(searchResult.Controls, control)
}
l.Debug.Printf("Looking for Paging Control...")
pagingResult := FindControl(result.Controls, ControlTypePaging)
@ -349,7 +528,9 @@ func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32)
if pagingControl != nil {
l.Debug.Printf("Abandoning Paging...")
pagingControl.PagingSize = 0
l.Search(searchRequest)
if _, err := l.Search(searchRequest); err != nil {
return searchResult, err
}
}
return searchResult, nil
@ -366,7 +547,8 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
result := &SearchResult{
Entries: make([]*Entry, 0),
Referrals: make([]string, 0),
Controls: make([]Control, 0)}
Controls: make([]Control, 0),
}
for {
packet, err := l.readPacket(msgCtx)
@ -402,6 +584,32 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
}
}
// SearchAsync performs a search request and returns all search results asynchronously.
// This means you get all results until an error happens (or the search successfully finished),
// e.g. for size / time limited requests all are recieved until the limit is reached.
// To stop the search, call cancel function of the context.
func (l *Conn) SearchAsync(
ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response {
r := newSearchResponse(l, bufferSize)
r.start(ctx, searchRequest)
return r
}
// Syncrepl is a short name for LDAP Sync Replication engine that works on the
// consumer-side. This can perform a persistent search and returns an entry
// when the entry is updated on the server side.
// To stop the search, call cancel function of the context.
func (l *Conn) Syncrepl(
ctx context.Context, searchRequest *SearchRequest, bufferSize int,
mode ControlSyncRequestMode, cookie []byte, reloadHint bool,
) Response {
control := NewControlSyncRequest(mode, cookie, reloadHint)
searchRequest.Controls = append(searchRequest.Controls, control)
r := newSearchResponse(l, bufferSize)
r.start(ctx, searchRequest)
return r
}
// unpackAttributes will extract all given LDAP attributes and it's values
// from the ber.Packet
func unpackAttributes(children []*ber.Packet) []*EntryAttribute {
@ -425,3 +633,58 @@ func unpackAttributes(children []*ber.Packet) []*EntryAttribute {
return entries
}
// DirSync does a Search with dirSync Control.
func (l *Conn) DirSync(
searchRequest *SearchRequest, flags int64, maxAttrCount int64, cookie []byte,
) (*SearchResult, error) {
control := FindControl(searchRequest.Controls, ControlTypeDirSync)
if control == nil {
c := NewRequestControlDirSync(flags, maxAttrCount, cookie)
searchRequest.Controls = append(searchRequest.Controls, c)
} else {
c := control.(*ControlDirSync)
if c.Flags != flags {
return nil, fmt.Errorf("flags given in search request (%d) conflicts with flags given in search call (%d)", c.Flags, flags)
}
if c.MaxAttrCount != maxAttrCount {
return nil, fmt.Errorf("MaxAttrCnt given in search request (%d) conflicts with maxAttrCount given in search call (%d)", c.MaxAttrCount, maxAttrCount)
}
}
searchResult, err := l.Search(searchRequest)
l.Debug.Printf("Looking for result...")
if err != nil {
return nil, err
}
if searchResult == nil {
return nil, NewError(ErrorNetwork, errors.New("ldap: packet not received"))
}
l.Debug.Printf("Looking for DirSync Control...")
resultControl := FindControl(searchResult.Controls, ControlTypeDirSync)
if resultControl == nil {
l.Debug.Printf("Could not find dirSyncControl control. Breaking...")
return searchResult, nil
}
cookie = resultControl.(*ControlDirSync).Cookie
if len(cookie) == 0 {
l.Debug.Printf("Could not find cookie. Breaking...")
return searchResult, nil
}
return searchResult, nil
}
// DirSyncDirSyncAsync performs a search request and returns all search results
// asynchronously. This is efficient when the server returns lots of entries.
func (l *Conn) DirSyncAsync(
ctx context.Context, searchRequest *SearchRequest, bufferSize int,
flags, maxAttrCount int64, cookie []byte,
) Response {
control := NewRequestControlDirSync(flags, maxAttrCount, cookie)
searchRequest.Controls = append(searchRequest.Controls, control)
r := newSearchResponse(l, bufferSize)
r.start(ctx, searchRequest)
return r
}

View file

@ -6,6 +6,7 @@ import (
ber "github.com/go-asn1-ber/asn1-ber"
)
// ErrConnUnbound is returned when Unbind is called on an already closing connection.
var ErrConnUnbound = NewError(ErrorNetwork, errors.New("ldap: connection is closed"))
type unbindRequest struct{}

10
vendor/github.com/google/uuid/CHANGELOG.md generated vendored Normal file
View file

@ -0,0 +1,10 @@
# Changelog
## [1.3.1](https://github.com/google/uuid/compare/v1.3.0...v1.3.1) (2023-08-18)
### Bug Fixes
* Use .EqualFold() to parse urn prefixed UUIDs ([#118](https://github.com/google/uuid/issues/118)) ([574e687](https://github.com/google/uuid/commit/574e6874943741fb99d41764c705173ada5293f0))
## Changelog

26
vendor/github.com/google/uuid/CONTRIBUTING.md generated vendored Normal file
View file

@ -0,0 +1,26 @@
# How to contribute
We definitely welcome patches and contribution to this project!
### Tips
Commits must be formatted according to the [Conventional Commits Specification](https://www.conventionalcommits.org).
Always try to include a test case! If it is not possible or not necessary,
please explain why in the pull request description.
### Releasing
Commits that would precipitate a SemVer change, as desrcibed in the Conventional
Commits Specification, will trigger [`release-please`](https://github.com/google-github-actions/release-please-action)
to create a release candidate pull request. Once submitted, `release-please`
will create a release.
For tips on how to work with `release-please`, see its documentation.
### Legal requirements
In order to protect both you and ourselves, you will need to sign the
[Contributor License Agreement](https://cla.developers.google.com/clas).
You may have already signed it for other Google projects.

9
vendor/github.com/google/uuid/CONTRIBUTORS generated vendored Normal file
View file

@ -0,0 +1,9 @@
Paul Borman <borman@google.com>
bmatsuo
shawnps
theory
jboverfelt
dsymonds
cd1
wallclockbuilder
dansouza

27
vendor/github.com/google/uuid/LICENSE generated vendored Normal file
View file

@ -0,0 +1,27 @@
Copyright (c) 2009,2014 Google Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

21
vendor/github.com/google/uuid/README.md generated vendored Normal file
View file

@ -0,0 +1,21 @@
# uuid
The uuid package generates and inspects UUIDs based on
[RFC 4122](https://datatracker.ietf.org/doc/html/rfc4122)
and DCE 1.1: Authentication and Security Services.
This package is based on the github.com/pborman/uuid package (previously named
code.google.com/p/go-uuid). It differs from these earlier packages in that
a UUID is a 16 byte array rather than a byte slice. One loss due to this
change is the ability to represent an invalid UUID (vs a NIL UUID).
###### Install
```sh
go get github.com/google/uuid
```
###### Documentation
[![Go Reference](https://pkg.go.dev/badge/github.com/google/uuid.svg)](https://pkg.go.dev/github.com/google/uuid)
Full `go doc` style documentation for the package can be viewed online without
installing this package by using the GoDoc site here:
http://pkg.go.dev/github.com/google/uuid

80
vendor/github.com/google/uuid/dce.go generated vendored Normal file
View file

@ -0,0 +1,80 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"encoding/binary"
"fmt"
"os"
)
// A Domain represents a Version 2 domain
type Domain byte
// Domain constants for DCE Security (Version 2) UUIDs.
const (
Person = Domain(0)
Group = Domain(1)
Org = Domain(2)
)
// NewDCESecurity returns a DCE Security (Version 2) UUID.
//
// The domain should be one of Person, Group or Org.
// On a POSIX system the id should be the users UID for the Person
// domain and the users GID for the Group. The meaning of id for
// the domain Org or on non-POSIX systems is site defined.
//
// For a given domain/id pair the same token may be returned for up to
// 7 minutes and 10 seconds.
func NewDCESecurity(domain Domain, id uint32) (UUID, error) {
uuid, err := NewUUID()
if err == nil {
uuid[6] = (uuid[6] & 0x0f) | 0x20 // Version 2
uuid[9] = byte(domain)
binary.BigEndian.PutUint32(uuid[0:], id)
}
return uuid, err
}
// NewDCEPerson returns a DCE Security (Version 2) UUID in the person
// domain with the id returned by os.Getuid.
//
// NewDCESecurity(Person, uint32(os.Getuid()))
func NewDCEPerson() (UUID, error) {
return NewDCESecurity(Person, uint32(os.Getuid()))
}
// NewDCEGroup returns a DCE Security (Version 2) UUID in the group
// domain with the id returned by os.Getgid.
//
// NewDCESecurity(Group, uint32(os.Getgid()))
func NewDCEGroup() (UUID, error) {
return NewDCESecurity(Group, uint32(os.Getgid()))
}
// Domain returns the domain for a Version 2 UUID. Domains are only defined
// for Version 2 UUIDs.
func (uuid UUID) Domain() Domain {
return Domain(uuid[9])
}
// ID returns the id for a Version 2 UUID. IDs are only defined for Version 2
// UUIDs.
func (uuid UUID) ID() uint32 {
return binary.BigEndian.Uint32(uuid[0:4])
}
func (d Domain) String() string {
switch d {
case Person:
return "Person"
case Group:
return "Group"
case Org:
return "Org"
}
return fmt.Sprintf("Domain%d", int(d))
}

12
vendor/github.com/google/uuid/doc.go generated vendored Normal file
View file

@ -0,0 +1,12 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package uuid generates and inspects UUIDs.
//
// UUIDs are based on RFC 4122 and DCE 1.1: Authentication and Security
// Services.
//
// A UUID is a 16 byte (128 bit) array. UUIDs may be used as keys to
// maps or compared directly.
package uuid

53
vendor/github.com/google/uuid/hash.go generated vendored Normal file
View file

@ -0,0 +1,53 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"crypto/md5"
"crypto/sha1"
"hash"
)
// Well known namespace IDs and UUIDs
var (
NameSpaceDNS = Must(Parse("6ba7b810-9dad-11d1-80b4-00c04fd430c8"))
NameSpaceURL = Must(Parse("6ba7b811-9dad-11d1-80b4-00c04fd430c8"))
NameSpaceOID = Must(Parse("6ba7b812-9dad-11d1-80b4-00c04fd430c8"))
NameSpaceX500 = Must(Parse("6ba7b814-9dad-11d1-80b4-00c04fd430c8"))
Nil UUID // empty UUID, all zeros
)
// NewHash returns a new UUID derived from the hash of space concatenated with
// data generated by h. The hash should be at least 16 byte in length. The
// first 16 bytes of the hash are used to form the UUID. The version of the
// UUID will be the lower 4 bits of version. NewHash is used to implement
// NewMD5 and NewSHA1.
func NewHash(h hash.Hash, space UUID, data []byte, version int) UUID {
h.Reset()
h.Write(space[:]) //nolint:errcheck
h.Write(data) //nolint:errcheck
s := h.Sum(nil)
var uuid UUID
copy(uuid[:], s)
uuid[6] = (uuid[6] & 0x0f) | uint8((version&0xf)<<4)
uuid[8] = (uuid[8] & 0x3f) | 0x80 // RFC 4122 variant
return uuid
}
// NewMD5 returns a new MD5 (Version 3) UUID based on the
// supplied name space and data. It is the same as calling:
//
// NewHash(md5.New(), space, data, 3)
func NewMD5(space UUID, data []byte) UUID {
return NewHash(md5.New(), space, data, 3)
}
// NewSHA1 returns a new SHA1 (Version 5) UUID based on the
// supplied name space and data. It is the same as calling:
//
// NewHash(sha1.New(), space, data, 5)
func NewSHA1(space UUID, data []byte) UUID {
return NewHash(sha1.New(), space, data, 5)
}

38
vendor/github.com/google/uuid/marshal.go generated vendored Normal file
View file

@ -0,0 +1,38 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import "fmt"
// MarshalText implements encoding.TextMarshaler.
func (uuid UUID) MarshalText() ([]byte, error) {
var js [36]byte
encodeHex(js[:], uuid)
return js[:], nil
}
// UnmarshalText implements encoding.TextUnmarshaler.
func (uuid *UUID) UnmarshalText(data []byte) error {
id, err := ParseBytes(data)
if err != nil {
return err
}
*uuid = id
return nil
}
// MarshalBinary implements encoding.BinaryMarshaler.
func (uuid UUID) MarshalBinary() ([]byte, error) {
return uuid[:], nil
}
// UnmarshalBinary implements encoding.BinaryUnmarshaler.
func (uuid *UUID) UnmarshalBinary(data []byte) error {
if len(data) != 16 {
return fmt.Errorf("invalid UUID (got %d bytes)", len(data))
}
copy(uuid[:], data)
return nil
}

90
vendor/github.com/google/uuid/node.go generated vendored Normal file
View file

@ -0,0 +1,90 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"sync"
)
var (
nodeMu sync.Mutex
ifname string // name of interface being used
nodeID [6]byte // hardware for version 1 UUIDs
zeroID [6]byte // nodeID with only 0's
)
// NodeInterface returns the name of the interface from which the NodeID was
// derived. The interface "user" is returned if the NodeID was set by
// SetNodeID.
func NodeInterface() string {
defer nodeMu.Unlock()
nodeMu.Lock()
return ifname
}
// SetNodeInterface selects the hardware address to be used for Version 1 UUIDs.
// If name is "" then the first usable interface found will be used or a random
// Node ID will be generated. If a named interface cannot be found then false
// is returned.
//
// SetNodeInterface never fails when name is "".
func SetNodeInterface(name string) bool {
defer nodeMu.Unlock()
nodeMu.Lock()
return setNodeInterface(name)
}
func setNodeInterface(name string) bool {
iname, addr := getHardwareInterface(name) // null implementation for js
if iname != "" && addr != nil {
ifname = iname
copy(nodeID[:], addr)
return true
}
// We found no interfaces with a valid hardware address. If name
// does not specify a specific interface generate a random Node ID
// (section 4.1.6)
if name == "" {
ifname = "random"
randomBits(nodeID[:])
return true
}
return false
}
// NodeID returns a slice of a copy of the current Node ID, setting the Node ID
// if not already set.
func NodeID() []byte {
defer nodeMu.Unlock()
nodeMu.Lock()
if nodeID == zeroID {
setNodeInterface("")
}
nid := nodeID
return nid[:]
}
// SetNodeID sets the Node ID to be used for Version 1 UUIDs. The first 6 bytes
// of id are used. If id is less than 6 bytes then false is returned and the
// Node ID is not set.
func SetNodeID(id []byte) bool {
if len(id) < 6 {
return false
}
defer nodeMu.Unlock()
nodeMu.Lock()
copy(nodeID[:], id)
ifname = "user"
return true
}
// NodeID returns the 6 byte node id encoded in uuid. It returns nil if uuid is
// not valid. The NodeID is only well defined for version 1 and 2 UUIDs.
func (uuid UUID) NodeID() []byte {
var node [6]byte
copy(node[:], uuid[10:])
return node[:]
}

12
vendor/github.com/google/uuid/node_js.go generated vendored Normal file
View file

@ -0,0 +1,12 @@
// Copyright 2017 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build js
package uuid
// getHardwareInterface returns nil values for the JS version of the code.
// This removes the "net" dependency, because it is not used in the browser.
// Using the "net" library inflates the size of the transpiled JS code by 673k bytes.
func getHardwareInterface(name string) (string, []byte) { return "", nil }

33
vendor/github.com/google/uuid/node_net.go generated vendored Normal file
View file

@ -0,0 +1,33 @@
// Copyright 2017 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !js
package uuid
import "net"
var interfaces []net.Interface // cached list of interfaces
// getHardwareInterface returns the name and hardware address of interface name.
// If name is "" then the name and hardware address of one of the system's
// interfaces is returned. If no interfaces are found (name does not exist or
// there are no interfaces) then "", nil is returned.
//
// Only addresses of at least 6 bytes are returned.
func getHardwareInterface(name string) (string, []byte) {
if interfaces == nil {
var err error
interfaces, err = net.Interfaces()
if err != nil {
return "", nil
}
}
for _, ifs := range interfaces {
if len(ifs.HardwareAddr) >= 6 && (name == "" || name == ifs.Name) {
return ifs.Name, ifs.HardwareAddr
}
}
return "", nil
}

118
vendor/github.com/google/uuid/null.go generated vendored Normal file
View file

@ -0,0 +1,118 @@
// Copyright 2021 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"bytes"
"database/sql/driver"
"encoding/json"
"fmt"
)
var jsonNull = []byte("null")
// NullUUID represents a UUID that may be null.
// NullUUID implements the SQL driver.Scanner interface so
// it can be used as a scan destination:
//
// var u uuid.NullUUID
// err := db.QueryRow("SELECT name FROM foo WHERE id=?", id).Scan(&u)
// ...
// if u.Valid {
// // use u.UUID
// } else {
// // NULL value
// }
//
type NullUUID struct {
UUID UUID
Valid bool // Valid is true if UUID is not NULL
}
// Scan implements the SQL driver.Scanner interface.
func (nu *NullUUID) Scan(value interface{}) error {
if value == nil {
nu.UUID, nu.Valid = Nil, false
return nil
}
err := nu.UUID.Scan(value)
if err != nil {
nu.Valid = false
return err
}
nu.Valid = true
return nil
}
// Value implements the driver Valuer interface.
func (nu NullUUID) Value() (driver.Value, error) {
if !nu.Valid {
return nil, nil
}
// Delegate to UUID Value function
return nu.UUID.Value()
}
// MarshalBinary implements encoding.BinaryMarshaler.
func (nu NullUUID) MarshalBinary() ([]byte, error) {
if nu.Valid {
return nu.UUID[:], nil
}
return []byte(nil), nil
}
// UnmarshalBinary implements encoding.BinaryUnmarshaler.
func (nu *NullUUID) UnmarshalBinary(data []byte) error {
if len(data) != 16 {
return fmt.Errorf("invalid UUID (got %d bytes)", len(data))
}
copy(nu.UUID[:], data)
nu.Valid = true
return nil
}
// MarshalText implements encoding.TextMarshaler.
func (nu NullUUID) MarshalText() ([]byte, error) {
if nu.Valid {
return nu.UUID.MarshalText()
}
return jsonNull, nil
}
// UnmarshalText implements encoding.TextUnmarshaler.
func (nu *NullUUID) UnmarshalText(data []byte) error {
id, err := ParseBytes(data)
if err != nil {
nu.Valid = false
return err
}
nu.UUID = id
nu.Valid = true
return nil
}
// MarshalJSON implements json.Marshaler.
func (nu NullUUID) MarshalJSON() ([]byte, error) {
if nu.Valid {
return json.Marshal(nu.UUID)
}
return jsonNull, nil
}
// UnmarshalJSON implements json.Unmarshaler.
func (nu *NullUUID) UnmarshalJSON(data []byte) error {
if bytes.Equal(data, jsonNull) {
*nu = NullUUID{}
return nil // valid null UUID
}
err := json.Unmarshal(data, &nu.UUID)
nu.Valid = err == nil
return err
}

59
vendor/github.com/google/uuid/sql.go generated vendored Normal file
View file

@ -0,0 +1,59 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"database/sql/driver"
"fmt"
)
// Scan implements sql.Scanner so UUIDs can be read from databases transparently.
// Currently, database types that map to string and []byte are supported. Please
// consult database-specific driver documentation for matching types.
func (uuid *UUID) Scan(src interface{}) error {
switch src := src.(type) {
case nil:
return nil
case string:
// if an empty UUID comes from a table, we return a null UUID
if src == "" {
return nil
}
// see Parse for required string format
u, err := Parse(src)
if err != nil {
return fmt.Errorf("Scan: %v", err)
}
*uuid = u
case []byte:
// if an empty UUID comes from a table, we return a null UUID
if len(src) == 0 {
return nil
}
// assumes a simple slice of bytes if 16 bytes
// otherwise attempts to parse
if len(src) != 16 {
return uuid.Scan(string(src))
}
copy((*uuid)[:], src)
default:
return fmt.Errorf("Scan: unable to scan type %T into UUID", src)
}
return nil
}
// Value implements sql.Valuer so that UUIDs can be written to databases
// transparently. Currently, UUIDs map to strings. Please consult
// database-specific driver documentation for matching types.
func (uuid UUID) Value() (driver.Value, error) {
return uuid.String(), nil
}

123
vendor/github.com/google/uuid/time.go generated vendored Normal file
View file

@ -0,0 +1,123 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"encoding/binary"
"sync"
"time"
)
// A Time represents a time as the number of 100's of nanoseconds since 15 Oct
// 1582.
type Time int64
const (
lillian = 2299160 // Julian day of 15 Oct 1582
unix = 2440587 // Julian day of 1 Jan 1970
epoch = unix - lillian // Days between epochs
g1582 = epoch * 86400 // seconds between epochs
g1582ns100 = g1582 * 10000000 // 100s of a nanoseconds between epochs
)
var (
timeMu sync.Mutex
lasttime uint64 // last time we returned
clockSeq uint16 // clock sequence for this run
timeNow = time.Now // for testing
)
// UnixTime converts t the number of seconds and nanoseconds using the Unix
// epoch of 1 Jan 1970.
func (t Time) UnixTime() (sec, nsec int64) {
sec = int64(t - g1582ns100)
nsec = (sec % 10000000) * 100
sec /= 10000000
return sec, nsec
}
// GetTime returns the current Time (100s of nanoseconds since 15 Oct 1582) and
// clock sequence as well as adjusting the clock sequence as needed. An error
// is returned if the current time cannot be determined.
func GetTime() (Time, uint16, error) {
defer timeMu.Unlock()
timeMu.Lock()
return getTime()
}
func getTime() (Time, uint16, error) {
t := timeNow()
// If we don't have a clock sequence already, set one.
if clockSeq == 0 {
setClockSequence(-1)
}
now := uint64(t.UnixNano()/100) + g1582ns100
// If time has gone backwards with this clock sequence then we
// increment the clock sequence
if now <= lasttime {
clockSeq = ((clockSeq + 1) & 0x3fff) | 0x8000
}
lasttime = now
return Time(now), clockSeq, nil
}
// ClockSequence returns the current clock sequence, generating one if not
// already set. The clock sequence is only used for Version 1 UUIDs.
//
// The uuid package does not use global static storage for the clock sequence or
// the last time a UUID was generated. Unless SetClockSequence is used, a new
// random clock sequence is generated the first time a clock sequence is
// requested by ClockSequence, GetTime, or NewUUID. (section 4.2.1.1)
func ClockSequence() int {
defer timeMu.Unlock()
timeMu.Lock()
return clockSequence()
}
func clockSequence() int {
if clockSeq == 0 {
setClockSequence(-1)
}
return int(clockSeq & 0x3fff)
}
// SetClockSequence sets the clock sequence to the lower 14 bits of seq. Setting to
// -1 causes a new sequence to be generated.
func SetClockSequence(seq int) {
defer timeMu.Unlock()
timeMu.Lock()
setClockSequence(seq)
}
func setClockSequence(seq int) {
if seq == -1 {
var b [2]byte
randomBits(b[:]) // clock sequence
seq = int(b[0])<<8 | int(b[1])
}
oldSeq := clockSeq
clockSeq = uint16(seq&0x3fff) | 0x8000 // Set our variant
if oldSeq != clockSeq {
lasttime = 0
}
}
// Time returns the time in 100s of nanoseconds since 15 Oct 1582 encoded in
// uuid. The time is only defined for version 1 and 2 UUIDs.
func (uuid UUID) Time() Time {
time := int64(binary.BigEndian.Uint32(uuid[0:4]))
time |= int64(binary.BigEndian.Uint16(uuid[4:6])) << 32
time |= int64(binary.BigEndian.Uint16(uuid[6:8])&0xfff) << 48
return Time(time)
}
// ClockSequence returns the clock sequence encoded in uuid.
// The clock sequence is only well defined for version 1 and 2 UUIDs.
func (uuid UUID) ClockSequence() int {
return int(binary.BigEndian.Uint16(uuid[8:10])) & 0x3fff
}

43
vendor/github.com/google/uuid/util.go generated vendored Normal file
View file

@ -0,0 +1,43 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"io"
)
// randomBits completely fills slice b with random data.
func randomBits(b []byte) {
if _, err := io.ReadFull(rander, b); err != nil {
panic(err.Error()) // rand should never fail
}
}
// xvalues returns the value of a byte as a hexadecimal digit or 255.
var xvalues = [256]byte{
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 255, 255, 255, 255, 255, 255,
255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
}
// xtob converts hex characters x1 and x2 into a byte.
func xtob(x1, x2 byte) (byte, bool) {
b1 := xvalues[x1]
b2 := xvalues[x2]
return (b1 << 4) | b2, b1 != 255 && b2 != 255
}

296
vendor/github.com/google/uuid/uuid.go generated vendored Normal file
View file

@ -0,0 +1,296 @@
// Copyright 2018 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"bytes"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"io"
"strings"
"sync"
)
// A UUID is a 128 bit (16 byte) Universal Unique IDentifier as defined in RFC
// 4122.
type UUID [16]byte
// A Version represents a UUID's version.
type Version byte
// A Variant represents a UUID's variant.
type Variant byte
// Constants returned by Variant.
const (
Invalid = Variant(iota) // Invalid UUID
RFC4122 // The variant specified in RFC4122
Reserved // Reserved, NCS backward compatibility.
Microsoft // Reserved, Microsoft Corporation backward compatibility.
Future // Reserved for future definition.
)
const randPoolSize = 16 * 16
var (
rander = rand.Reader // random function
poolEnabled = false
poolMu sync.Mutex
poolPos = randPoolSize // protected with poolMu
pool [randPoolSize]byte // protected with poolMu
)
type invalidLengthError struct{ len int }
func (err invalidLengthError) Error() string {
return fmt.Sprintf("invalid UUID length: %d", err.len)
}
// IsInvalidLengthError is matcher function for custom error invalidLengthError
func IsInvalidLengthError(err error) bool {
_, ok := err.(invalidLengthError)
return ok
}
// Parse decodes s into a UUID or returns an error. Both the standard UUID
// forms of xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx and
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx are decoded as well as the
// Microsoft encoding {xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx} and the raw hex
// encoding: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx.
func Parse(s string) (UUID, error) {
var uuid UUID
switch len(s) {
// xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
case 36:
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
case 36 + 9:
if !strings.EqualFold(s[:9], "urn:uuid:") {
return uuid, fmt.Errorf("invalid urn prefix: %q", s[:9])
}
s = s[9:]
// {xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx}
case 36 + 2:
s = s[1:]
// xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
case 32:
var ok bool
for i := range uuid {
uuid[i], ok = xtob(s[i*2], s[i*2+1])
if !ok {
return uuid, errors.New("invalid UUID format")
}
}
return uuid, nil
default:
return uuid, invalidLengthError{len(s)}
}
// s is now at least 36 bytes long
// it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
return uuid, errors.New("invalid UUID format")
}
for i, x := range [16]int{
0, 2, 4, 6,
9, 11,
14, 16,
19, 21,
24, 26, 28, 30, 32, 34,
} {
v, ok := xtob(s[x], s[x+1])
if !ok {
return uuid, errors.New("invalid UUID format")
}
uuid[i] = v
}
return uuid, nil
}
// ParseBytes is like Parse, except it parses a byte slice instead of a string.
func ParseBytes(b []byte) (UUID, error) {
var uuid UUID
switch len(b) {
case 36: // xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
case 36 + 9: // urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
if !bytes.EqualFold(b[:9], []byte("urn:uuid:")) {
return uuid, fmt.Errorf("invalid urn prefix: %q", b[:9])
}
b = b[9:]
case 36 + 2: // {xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx}
b = b[1:]
case 32: // xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
var ok bool
for i := 0; i < 32; i += 2 {
uuid[i/2], ok = xtob(b[i], b[i+1])
if !ok {
return uuid, errors.New("invalid UUID format")
}
}
return uuid, nil
default:
return uuid, invalidLengthError{len(b)}
}
// s is now at least 36 bytes long
// it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
if b[8] != '-' || b[13] != '-' || b[18] != '-' || b[23] != '-' {
return uuid, errors.New("invalid UUID format")
}
for i, x := range [16]int{
0, 2, 4, 6,
9, 11,
14, 16,
19, 21,
24, 26, 28, 30, 32, 34,
} {
v, ok := xtob(b[x], b[x+1])
if !ok {
return uuid, errors.New("invalid UUID format")
}
uuid[i] = v
}
return uuid, nil
}
// MustParse is like Parse but panics if the string cannot be parsed.
// It simplifies safe initialization of global variables holding compiled UUIDs.
func MustParse(s string) UUID {
uuid, err := Parse(s)
if err != nil {
panic(`uuid: Parse(` + s + `): ` + err.Error())
}
return uuid
}
// FromBytes creates a new UUID from a byte slice. Returns an error if the slice
// does not have a length of 16. The bytes are copied from the slice.
func FromBytes(b []byte) (uuid UUID, err error) {
err = uuid.UnmarshalBinary(b)
return uuid, err
}
// Must returns uuid if err is nil and panics otherwise.
func Must(uuid UUID, err error) UUID {
if err != nil {
panic(err)
}
return uuid
}
// String returns the string form of uuid, xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
// , or "" if uuid is invalid.
func (uuid UUID) String() string {
var buf [36]byte
encodeHex(buf[:], uuid)
return string(buf[:])
}
// URN returns the RFC 2141 URN form of uuid,
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx, or "" if uuid is invalid.
func (uuid UUID) URN() string {
var buf [36 + 9]byte
copy(buf[:], "urn:uuid:")
encodeHex(buf[9:], uuid)
return string(buf[:])
}
func encodeHex(dst []byte, uuid UUID) {
hex.Encode(dst, uuid[:4])
dst[8] = '-'
hex.Encode(dst[9:13], uuid[4:6])
dst[13] = '-'
hex.Encode(dst[14:18], uuid[6:8])
dst[18] = '-'
hex.Encode(dst[19:23], uuid[8:10])
dst[23] = '-'
hex.Encode(dst[24:], uuid[10:])
}
// Variant returns the variant encoded in uuid.
func (uuid UUID) Variant() Variant {
switch {
case (uuid[8] & 0xc0) == 0x80:
return RFC4122
case (uuid[8] & 0xe0) == 0xc0:
return Microsoft
case (uuid[8] & 0xe0) == 0xe0:
return Future
default:
return Reserved
}
}
// Version returns the version of uuid.
func (uuid UUID) Version() Version {
return Version(uuid[6] >> 4)
}
func (v Version) String() string {
if v > 15 {
return fmt.Sprintf("BAD_VERSION_%d", v)
}
return fmt.Sprintf("VERSION_%d", v)
}
func (v Variant) String() string {
switch v {
case RFC4122:
return "RFC4122"
case Reserved:
return "Reserved"
case Microsoft:
return "Microsoft"
case Future:
return "Future"
case Invalid:
return "Invalid"
}
return fmt.Sprintf("BadVariant%d", int(v))
}
// SetRand sets the random number generator to r, which implements io.Reader.
// If r.Read returns an error when the package requests random data then
// a panic will be issued.
//
// Calling SetRand with nil sets the random number generator to the default
// generator.
func SetRand(r io.Reader) {
if r == nil {
rander = rand.Reader
return
}
rander = r
}
// EnableRandPool enables internal randomness pool used for Random
// (Version 4) UUID generation. The pool contains random bytes read from
// the random number generator on demand in batches. Enabling the pool
// may improve the UUID generation throughput significantly.
//
// Since the pool is stored on the Go heap, this feature may be a bad fit
// for security sensitive applications.
//
// Both EnableRandPool and DisableRandPool are not thread-safe and should
// only be called when there is no possibility that New or any other
// UUID Version 4 generation function will be called concurrently.
func EnableRandPool() {
poolEnabled = true
}
// DisableRandPool disables the randomness pool if it was previously
// enabled with EnableRandPool.
//
// Both EnableRandPool and DisableRandPool are not thread-safe and should
// only be called when there is no possibility that New or any other
// UUID Version 4 generation function will be called concurrently.
func DisableRandPool() {
poolEnabled = false
defer poolMu.Unlock()
poolMu.Lock()
poolPos = randPoolSize
}

44
vendor/github.com/google/uuid/version1.go generated vendored Normal file
View file

@ -0,0 +1,44 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"encoding/binary"
)
// NewUUID returns a Version 1 UUID based on the current NodeID and clock
// sequence, and the current time. If the NodeID has not been set by SetNodeID
// or SetNodeInterface then it will be set automatically. If the NodeID cannot
// be set NewUUID returns nil. If clock sequence has not been set by
// SetClockSequence then it will be set automatically. If GetTime fails to
// return the current NewUUID returns nil and an error.
//
// In most cases, New should be used.
func NewUUID() (UUID, error) {
var uuid UUID
now, seq, err := GetTime()
if err != nil {
return uuid, err
}
timeLow := uint32(now & 0xffffffff)
timeMid := uint16((now >> 32) & 0xffff)
timeHi := uint16((now >> 48) & 0x0fff)
timeHi |= 0x1000 // Version 1
binary.BigEndian.PutUint32(uuid[0:], timeLow)
binary.BigEndian.PutUint16(uuid[4:], timeMid)
binary.BigEndian.PutUint16(uuid[6:], timeHi)
binary.BigEndian.PutUint16(uuid[8:], seq)
nodeMu.Lock()
if nodeID == zeroID {
setNodeInterface("")
}
copy(uuid[10:], nodeID[:])
nodeMu.Unlock()
return uuid, nil
}

76
vendor/github.com/google/uuid/version4.go generated vendored Normal file
View file

@ -0,0 +1,76 @@
// Copyright 2016 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import "io"
// New creates a new random UUID or panics. New is equivalent to
// the expression
//
// uuid.Must(uuid.NewRandom())
func New() UUID {
return Must(NewRandom())
}
// NewString creates a new random UUID and returns it as a string or panics.
// NewString is equivalent to the expression
//
// uuid.New().String()
func NewString() string {
return Must(NewRandom()).String()
}
// NewRandom returns a Random (Version 4) UUID.
//
// The strength of the UUIDs is based on the strength of the crypto/rand
// package.
//
// Uses the randomness pool if it was enabled with EnableRandPool.
//
// A note about uniqueness derived from the UUID Wikipedia entry:
//
// Randomly generated UUIDs have 122 random bits. One's annual risk of being
// hit by a meteorite is estimated to be one chance in 17 billion, that
// means the probability is about 0.00000000006 (6 × 1011),
// equivalent to the odds of creating a few tens of trillions of UUIDs in a
// year and having one duplicate.
func NewRandom() (UUID, error) {
if !poolEnabled {
return NewRandomFromReader(rander)
}
return newRandomFromPool()
}
// NewRandomFromReader returns a UUID based on bytes read from a given io.Reader.
func NewRandomFromReader(r io.Reader) (UUID, error) {
var uuid UUID
_, err := io.ReadFull(r, uuid[:])
if err != nil {
return Nil, err
}
uuid[6] = (uuid[6] & 0x0f) | 0x40 // Version 4
uuid[8] = (uuid[8] & 0x3f) | 0x80 // Variant is 10
return uuid, nil
}
func newRandomFromPool() (UUID, error) {
var uuid UUID
poolMu.Lock()
if poolPos == randPoolSize {
_, err := io.ReadFull(rander, pool[:])
if err != nil {
poolMu.Unlock()
return Nil, err
}
poolPos = 0
}
copy(uuid[:], pool[poolPos:(poolPos+16)])
poolPos += 16
poolMu.Unlock()
uuid[6] = (uuid[6] & 0x0f) | 0x40 // Version 4
uuid[8] = (uuid[8] & 0x3f) | 0x80 // Variant is 10
return uuid, nil
}

View file

@ -1,5 +1,19 @@
# Changelog
## 1.6.0
### Changed
* Go.mod updated to Go 1.17
* Azure SDK for Go dependencies updated
### Features
* Added `ActiveDirectoryAzCli` and `ActiveDirectoryDeviceCode` authentication types to `azuread` package
* Always Encrypted encryption and decryption with 2 hour key cache (#116)
* 'pfx', 'MSSQL_CERTIFICATE_STORE', and 'AZURE_KEY_VAULT' encryption key providers
* TDS8 can now be used for connections by setting encrypt="strict"
## 1.5.0
### Features

View file

@ -1,4 +1,4 @@
# A pure Go MSSQL driver for Go's database/sql package
# Microsoft's official Go MSSQL driver
[![Go Reference](https://pkg.go.dev/badge/github.com/microsoft/go-mssqldb.svg)](https://pkg.go.dev/github.com/microsoft/go-mssqldb)
[![Build status](https://ci.appveyor.com/api/projects/status/jrln8cs62wj9i0a2?svg=true)](https://ci.appveyor.com/project/microsoft/go-mssqldb)
@ -7,7 +7,7 @@
## Install
Requires Go 1.10 or above.
Requires Go 1.17 or above.
Install with `go install github.com/microsoft/go-mssqldb@latest`.
@ -25,9 +25,10 @@ Other supported formats are listed below.
* `connection timeout` - in seconds (default is 0 for no timeout), set to 0 for no timeout. Recommended to set to 0 and use context to manage query and connection timeouts.
* `dial timeout` - in seconds (default is 15 times the number of registered protocols), set to 0 for no timeout.
* `encrypt`
* `strict` - Data sent between client and server is encrypted E2E using [TDS8](https://learn.microsoft.com/en-us/sql/relational-databases/security/networking/tds-8?view=sql-server-ver16).
* `disable` - Data send between client and server is not encrypted.
* `false` - Data sent between client and server is not encrypted beyond the login packet. (Default)
* `true` - Data sent between client and server is encrypted.
* `false`/`optional`/`no`/`0`/`f` - Data sent between client and server is not encrypted beyond the login packet. (Default)
* `true`/`mandatory`/`yes`/`1`/`t` - Data sent between client and server is encrypted.
* `app name` - The application name (default is go-mssqldb)
* `authenticator` - Can be used to specify use of a registered authentication provider. (e.g. ntlm, winsspi (on windows) or krb5 (on linux))
@ -56,13 +57,14 @@ Other supported formats are listed below.
* `TrustServerCertificate`
* false - Server certificate is checked. Default is false if encrypt is specified.
* true - Server certificate is not checked. Default is true if encrypt is not specified. If trust server certificate is true, driver accepts any certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing.
* `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates.
* `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates. Currently, certificates of PEM type are supported.
* `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host.
* `tlsmin` - Specifies the minimum TLS version for negotiating encryption with the server. Recognized values are `1.0`, `1.1`, `1.2`, `1.3`. If not set to a recognized value the default value for the `tls` package will be used. The default is currently `1.2`.
* `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port.
* `Workstation ID` - The workstation name (default is the host name)
* `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`.
* `protocol` - forces use of a protocol. Make sure the corresponding package is imported.
* `columnencryption` or `column encryption setting` - a boolean value indicating whether Always Encrypted should be enabled on the connection.
### Connection parameters for namedpipe package
* `pipe` - If set, no Browser query is made and named pipe used will be `\\<host>\pipe\<pipe>`
@ -216,6 +218,8 @@ The credential type is determined by the new `fedauth` connection string paramet
* `resource id=<resource id>` - optional resource id of user-assigned managed identity. If empty, system-assigned managed identity or user id are used (if both user id and resource id are provided, resource id will be used)
* `fedauth=ActiveDirectoryInteractive` - authenticates using credentials acquired from an external web browser. Only suitable for use with human interaction.
* `applicationclientid=<application id>` - This guid identifies an Azure Active Directory enterprise application that the AAD admin has approved for accessing Azure SQL database resources in the tenant. This driver does not have an associated application id of its own.
* `fedauth=ActiveDirectoryDeviceCode` - prints a message to stdout giving the user a URL and code to authenticate. Connection continues after user completes the login separately.
* `fedauth=ActiveDirectoryAzCli` - reuses local authentication the user already performed using Azure CLI.
```go
@ -377,8 +381,63 @@ db.QueryContext(ctx, `select * from t2 where user_name = @p1;`, mssql.VarChar(na
// Note: Mismatched data types on table and parameter may cause long running queries
```
## Using Always Encrypted
The protocol and cryptography details for AE are [detailed elsewhere](https://learn.microsoft.com/sql/relational-databases/security/encryption/always-encrypted-database-engine?view=sql-server-ver16).
### Enablement
To enable AE on a connection, set the `ColumnEncryption` value to true on a config or pass `columnencryption=true` in the connection string.
Decryption and encryption won't succeed, however, without also including a decryption key provider. To avoid code size impacts on non-AE applications, key providers are not included by default.
Include the local certificate providers:
```go
import (
"github.com/microsoft/go-mssqldb/aecmk/localcert"
)
```
You can also instantiate a key provider directly in code and hand it to a `Connector` instance.
```go
c := mssql.NewConnectorConfig(myconfig)
c.RegisterCekProvider(providerName, MyProviderType{})
```
### Decryption
If the correct key provider is included in your application, decryption of encrypted cells happens automatically with no extra server round trips.
### Encryption
Encryption of parameters passed to `Exec` and `Query` variants requires an extra round trip per query to fetch the encryption metadata. If the error returned by a query attempt indicates a type mismatch between the parameter and the destination table, most likely your input type is not a strict match for the SQL Server data type of the destination. You may be using a Go `string` when you need to use one of the driver-specific aliases like `VarChar` or `NVarCharMax`.
*** NOTE *** - Currently `char` and `varchar` types do not include a collation parameter component so can't be used for inserting encrypted values. Also, using a nullable sql package type like `sql.NullableInt32` to pass a `NULL` value for an encrypted column will not work unless the encrypted column type is `nvarchar`.
https://github.com/microsoft/go-mssqldb/issues/129
https://github.com/microsoft/go-mssqldb/issues/130
### Local certificate AE key provider
Key provider configuration is managed separately without any properties in the connection string.
The `pfx` provider exposes its instance as the variable `PfxKeyProvider`. You can give it passwords for certificates using `SetCertificatePassword(pathToCertificate, path)`. Use an empty string or `"*"` as the path to use the same password for all certificates.
The `MSSQL_CERTIFICATE_STORE` provider exposes its instance as the variable `WindowsCertificateStoreKeyProvider`.
Both providers can be constrained to an allowed list of encryption key paths by appending paths to `provider.AllowedLocations`.
### Azure Key Vault (AZURE_KEY_VAULT) key provider
Import this provider using `github.com/microsoft/go-mssqldb/aecmk/akv`
Constrain the provider to an allowed list of key vaults by appending vault host strings like "mykeyvault.vault.azure.net" to `akv.KeyProvider.AllowedLocations`.
## Important Notes
* [LastInsertId](https://golang.org/pkg/database/sql/#Result.LastInsertId) should
not be used with this driver (or SQL Server) due to how the TDS protocol
works. Please use the [OUTPUT Clause](https://docs.microsoft.com/en-us/sql/t-sql/queries/output-clause-transact-sql)
@ -409,6 +468,9 @@ db.QueryContext(ctx, `select * from t2 where user_name = @p1;`, mssql.VarChar(na
* A `namedpipe` package to support connections using named pipes (np:) on Windows
* A `sharedmemory` package to support connections using shared memory (lpc:) on Windows
* Dedicated Administrator Connection (DAC) is supported using `admin` protocol
* Always Encrypted
- `MSSQL_CERTIFICATE_STORE` provider on Windows
- `pfx` provider on Linux and Windows
## Tests
@ -449,6 +511,7 @@ To fix SQL Server 2008 R2 issue, install SQL Server 2008 R2 Service Pack 2.
To fix SQL Server 2008 issue, install Microsoft SQL Server 2008 Service Pack 3 and Cumulative update package 3 for SQL Server 2008 SP3.
More information: <http://support.microsoft.com/kb/2653857>
* Bulk copy does not yet support encrypting column values using Always Encrypted. Tracked in [#127](https://github.com/microsoft/go-mssqldb/issues/127)
# Contributing
This project is a fork of [https://github.com/denisenkom/go-mssqldb](https://github.com/denisenkom/go-mssqldb) and welcomes new and previous contributors. For more informaton on contributing to this project, please see [Contributing](./CONTRIBUTING.md).

View file

@ -0,0 +1,112 @@
package aecmk
import (
"fmt"
"sync"
"time"
)
const (
CertificateStoreKeyProvider = "MSSQL_CERTIFICATE_STORE"
CspKeyProvider = "MSSQL_CSP_PROVIDER"
CngKeyProvider = "MSSQL_CNG_STORE"
AzureKeyVaultKeyProvider = "AZURE_KEY_VAULT"
JavaKeyProvider = "MSSQL_JAVA_KEYSTORE"
KeyEncryptionAlgorithm = "RSA_OAEP"
)
// ColumnEncryptionKeyLifetime is the default lifetime of decrypted Column Encryption Keys in the global cache.
// The default is 2 hours
var ColumnEncryptionKeyLifetime time.Duration = 2 * time.Hour
type cekCacheEntry struct {
Expiry time.Time
Key []byte
}
type cekCache map[string]cekCacheEntry
type CekProvider struct {
Provider ColumnEncryptionKeyProvider
decryptedKeys cekCache
mutex sync.Mutex
}
func NewCekProvider(provider ColumnEncryptionKeyProvider) *CekProvider {
return &CekProvider{Provider: provider, decryptedKeys: make(cekCache), mutex: sync.Mutex{}}
}
func (cp *CekProvider) GetDecryptedKey(keyPath string, encryptedBytes []byte) (decryptedKey []byte, err error) {
cp.mutex.Lock()
ev, cachedKey := cp.decryptedKeys[keyPath]
if cachedKey {
if ev.Expiry.Before(time.Now()) {
delete(cp.decryptedKeys, keyPath)
cachedKey = false
} else {
decryptedKey = ev.Key
}
}
// decrypting a key can take a while, so let multiple callers race
// Key providers can choose to optimize their own concurrency.
// For example - there's probably minimal value in serializing access to a local certificate,
// but there'd be high value in having a queue of waiters for decrypting a key stored in the cloud.
cp.mutex.Unlock()
if !cachedKey {
decryptedKey = cp.Provider.DecryptColumnEncryptionKey(keyPath, KeyEncryptionAlgorithm, encryptedBytes)
}
if !cachedKey {
duration := cp.Provider.KeyLifetime()
if duration == nil {
duration = &ColumnEncryptionKeyLifetime
}
expiry := time.Now().Add(*duration)
cp.mutex.Lock()
cp.decryptedKeys[keyPath] = cekCacheEntry{Expiry: expiry, Key: decryptedKey}
cp.mutex.Unlock()
}
return
}
// no synchronization on this map. Providers register during init.
type ColumnEncryptionKeyProviderMap map[string]*CekProvider
var globalCekProviderFactoryMap = ColumnEncryptionKeyProviderMap{}
// ColumnEncryptionKeyProvider is the interface for decrypting and encrypting column encryption keys.
// It is similar to .Net https://learn.microsoft.com/dotnet/api/microsoft.data.sqlclient.sqlcolumnencryptionkeystoreprovider.
type ColumnEncryptionKeyProvider interface {
// DecryptColumnEncryptionKey decrypts the specified encrypted value of a column encryption key.
// The encrypted value is expected to be encrypted using the column master key with the specified key path and using the specified algorithm.
DecryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) []byte
// EncryptColumnEncryptionKey encrypts a column encryption key using the column master key with the specified key path and using the specified algorithm.
EncryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, cek []byte) []byte
// SignColumnMasterKeyMetadata digitally signs the column master key metadata with the column master key
// referenced by the masterKeyPath parameter. The input values used to generate the signature should be the
// specified values of the masterKeyPath and allowEnclaveComputations parameters. May return an empty slice if not supported.
SignColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) []byte
// VerifyColumnMasterKeyMetadata verifies the specified signature is valid for the column master key
// with the specified key path and the specified enclave behavior. Return nil if not supported.
VerifyColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) *bool
// KeyLifetime is an optional Duration. Keys fetched by this provider will be discarded after their lifetime expires.
// If it returns nil, the keys will expire based on the value of ColumnEncryptionKeyLifetime.
// If it returns zero, the keys will not be cached.
KeyLifetime() *time.Duration
}
func RegisterCekProvider(name string, provider ColumnEncryptionKeyProvider) error {
_, ok := globalCekProviderFactoryMap[name]
if ok {
return fmt.Errorf("CEK provider %s is already registered", name)
}
globalCekProviderFactoryMap[name] = &CekProvider{Provider: provider, decryptedKeys: cekCache{}, mutex: sync.Mutex{}}
return nil
}
func GetGlobalCekProviders() (providers ColumnEncryptionKeyProviderMap) {
providers = make(ColumnEncryptionKeyProviderMap)
for i, p := range globalCekProviderFactoryMap {
providers[i] = p
}
return
}

View file

@ -11,52 +11,29 @@ environment:
SQLUSER: sa
SQLPASSWORD: Password12!
DATABASE: test
GOVERSION: 113
GOVERSION: 117
COLUMNENCRYPTION:
APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
RACE: -race -cpu 4
TAGS:
matrix:
- GOVERSION: 110
SQLINSTANCE: SQL2017
- GOVERSION: 111
SQLINSTANCE: SQL2017
- GOVERSION: 112
SQLINSTANCE: SQL2017
- SQLINSTANCE: SQL2017
- SQLINSTANCE: SQL2016
- SQLINSTANCE: SQL2014
- SQLINSTANCE: SQL2012SP1
- SQLINSTANCE: SQL2008R2SP2
# Go 1.14+ and SQL2019 are available on the Visual Studio 2019 image only
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 114
- GOVERSION: 118
SQLINSTANCE: SQL2017
- GOVERSION: 120
RACE:
SQLINSTANCE: SQL2019
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 115
SQLINSTANCE: SQL2019
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 115
SQLINSTANCE: SQL2017
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 116
SQLINSTANCE: SQL2017
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 117
SQLINSTANCE: SQL2017
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 118
SQLINSTANCE: SQL2017
COLUMNENCRYPTION: 1
# Cover 32bit and named pipes protocol
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 118-x86
- GOVERSION: 119-x86
SQLINSTANCE: SQL2017
GOARCH: 386
RACE:
PROTOCOL: np
TAGS: -tags np
# Cover SSPI and lpc protocol
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 118
- GOVERSION: 120
RACE:
SQLINSTANCE: SQL2019
PROTOCOL: lpc
TAGS: -tags sm
@ -67,9 +44,6 @@ install:
- set PATH=%GOPATH%\bin;%GOROOT%\bin;%PATH%
- go version
- go env
- go get -u github.com/golang-sql/civil
- go get -u github.com/golang-sql/sqlexp
- go get -u golang.org/x/crypto/md4
build_script:
- go build

View file

@ -250,6 +250,10 @@ func (b *Bulk) createColMetadata() []byte {
buf.WriteByte(byte(tokenColMetadata)) // token
binary.Write(buf, binary.LittleEndian, uint16(len(b.bulkColumns))) // column count
// TODO: Write a valid CEK table if any parameters have cekTableEntry values
if b.cn.sess.alwaysEncrypted {
binary.Write(buf, binary.LittleEndian, uint16(0))
}
for i, col := range b.bulkColumns {
if b.cn.sess.loginAck.TDSVersion >= verTDS72 {

View file

@ -0,0 +1,40 @@
package mssql
const (
CertificateStoreKeyProvider = "MSSQL_CERTIFICATE_STORE"
CspKeyProvider = "MSSQL_CSP_PROVIDER"
CngKeyProvider = "MSSQL_CNG_STORE"
AzureKeyVaultKeyProvider = "AZURE_KEY_VAULT"
JavaKeyProvider = "MSSQL_JAVA_KEYSTORE"
KeyEncryptionAlgorithm = "RSA_OAEP"
)
// cek ==> Column Encryption Key
// Every row of an encrypted table has an associated list of keys used to decrypt its columns
type cekTable struct {
entries []cekTableEntry
}
type encryptionKeyInfo struct {
encryptedKey []byte
databaseID int
cekID int
cekVersion int
cekMdVersion []byte
keyPath string
keyStoreName string
algorithmName string
}
type cekTableEntry struct {
databaseID int
keyId int
keyVersion int
mdVersion []byte
valueCount int
cekValues []encryptionKeyInfo
}
func newCekTable(size uint16) cekTable {
return cekTable{entries: make([]cekTableEntry, size)}
}

292
vendor/github.com/microsoft/go-mssqldb/encrypt.go generated vendored Normal file
View file

@ -0,0 +1,292 @@
package mssql
import (
"context"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
"strings"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys"
)
type ColumnEncryptionType int
var (
ColumnEncryptionPlainText ColumnEncryptionType = 0
ColumnEncryptionDeterministic ColumnEncryptionType = 1
ColumnEncryptionRandomized ColumnEncryptionType = 2
)
type cekData struct {
ordinal int
database_id int
id int
version int
metadataVersion []byte
encryptedValue []byte
cmkStoreName string
cmkPath string
algorithm string
//byEnclave bool
//cmkSignature string
decryptedValue []byte
}
type parameterEncData struct {
ordinal int
name string
algorithm int
encType ColumnEncryptionType
cekOrdinal int
ruleVersion int
}
type paramMapEntry struct {
cek *cekData
p *parameterEncData
}
// when Always Encrypted is turned on, we have to ask the server for metadata about how to encrypt input parameters.
// This function stores the relevant encryption parameters in a copy of the args so they can be
// encrypted just before being sent to the server
func (s *Stmt) encryptArgs(ctx context.Context, args []namedValue) (encryptedArgs []namedValue, err error) {
q := Stmt{c: s.c,
paramCount: s.paramCount,
query: "sp_describe_parameter_encryption",
skipEncryption: true,
}
oldouts := s.c.outs
s.c.clearOuts()
newArgs, err := s.prepareEncryptionQuery(isProc(s.query), s.query, args)
if err != nil {
return
}
// TODO: Consider not using recursion.
rows, err := q.queryContext(ctx, newArgs)
if err != nil {
s.c.outs = oldouts
return
}
cekInfo, paramsInfo, err := processDescribeParameterEncryption(rows)
rows.Close()
s.c.outs = oldouts
if err != nil {
return
}
if len(cekInfo) == 0 {
return args, nil
}
err = s.decryptCek(cekInfo)
if err != nil {
return
}
paramMap := make(map[string]paramMapEntry)
for _, p := range paramsInfo {
if p.encType == ColumnEncryptionPlainText {
paramMap[p.name] = paramMapEntry{nil, p}
} else {
paramMap[p.name] = paramMapEntry{cekInfo[p.cekOrdinal-1], p}
}
}
encryptedArgs = make([]namedValue, len(args))
for i, a := range args {
encryptedArgs[i] = a
name := ""
if len(a.Name) > 0 {
name = "@" + a.Name
} else {
name = fmt.Sprintf("@p%d", a.Ordinal)
}
info := paramMap[name]
if info.p.encType == ColumnEncryptionPlainText || a.Value == nil {
continue
}
encryptedArgs[i].encrypt = getEncryptor(info)
}
return encryptedArgs, nil
}
// returns the arguments to sp_describe_parameter_encryption
// sp_describe_parameter_encryption
// [ @tsql = ] N'Transact-SQL_batch' ,
// [ @params = ] N'parameters'
// [ ;]
func (s *Stmt) prepareEncryptionQuery(isProc bool, q string, args []namedValue) (newArgs []namedValue, err error) {
newArgs = make([]namedValue, 2)
if isProc {
newArgs[0] = namedValue{Name: "tsql", Ordinal: 0, Value: buildStoredProcedureStatementForColumnEncryption(q, args)}
} else {
newArgs[0] = namedValue{Name: "tsql", Ordinal: 0, Value: q}
}
params, err := s.buildParametersForColumnEncryption(args)
if err != nil {
return
}
newArgs[1] = namedValue{Name: "params", Ordinal: 1, Value: params}
return
}
func (s *Stmt) buildParametersForColumnEncryption(args []namedValue) (parameters string, err error) {
_, decls, err := s.makeRPCParams(args, false)
if err != nil {
return
}
parameters = strings.Join(decls, ", ")
return
}
func (s *Stmt) decryptCek(cekInfo []*cekData) error {
for _, info := range cekInfo {
kp, ok := s.c.sess.aeSettings.keyProviders[info.cmkStoreName]
if !ok {
return fmt.Errorf("No provider found for key store %s", info.cmkStoreName)
}
dk, err := kp.GetDecryptedKey(info.cmkPath, info.encryptedValue)
if err != nil {
return err
}
info.decryptedValue = dk
}
return nil
}
func getEncryptor(info paramMapEntry) valueEncryptor {
k := keys.NewAeadAes256CbcHmac256(info.cek.decryptedValue)
alg := algorithms.NewAeadAes256CbcHmac256Algorithm(k, encryption.From(byte(info.p.encType)), byte(info.cek.version))
// Metadata to append to an encrypted parameter. Doesn't include original typeinfo
// https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/619c43b6-9495-4a58-9e49-a4950db245b3
// ParamCipherInfo = TYPE_INFO
// EncryptionAlgo (byte)
// [AlgoName] (b_varchar) unused, no custom algorithm
// EncryptionType (byte)
// DatabaseId (ulong)
// CekId (ulong)
// CekVersion (ulong)
// CekMDVersion (ulonglong) - really a byte array
// NormVersion (byte)
// algo+ enctype+ dbid+ keyid+ keyver+ normversion
metadataLen := 1 + 1 + 4 + 4 + 4 + 1
metadataLen += len(info.cek.metadataVersion)
metadata := make([]byte, metadataLen)
offset := 0
// AEAD_AES_256_CBC_HMAC_SHA256
metadata[offset] = byte(info.p.algorithm)
offset++
metadata[offset] = byte(info.p.encType)
offset++
binary.LittleEndian.PutUint32(metadata[offset:], uint32(info.cek.database_id))
offset += 4
binary.LittleEndian.PutUint32(metadata[offset:], uint32(info.cek.id))
offset += 4
binary.LittleEndian.PutUint32(metadata[offset:], uint32(info.cek.version))
offset += 4
copy(metadata[offset:], info.cek.metadataVersion)
offset += len(info.cek.metadataVersion)
metadata[offset] = byte(info.p.ruleVersion)
return func(b []byte) ([]byte, []byte, error) {
encryptedData, err := alg.Encrypt(b)
if err != nil {
return nil, nil, err
}
return encryptedData, metadata, nil
}
}
// Based on the .Net implementation at https://github.com/dotnet/SqlClient/blob/2b31810ce69b88d707450e2059ee8fbde63f774f/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs#L6040
func buildStoredProcedureStatementForColumnEncryption(sproc string, args []namedValue) string {
b := new(strings.Builder)
_, _ = b.WriteString("EXEC ")
q := TSQLQuoter{}
sproc = q.ID(sproc)
b.WriteString(sproc)
// Unlike ADO.Net, go-mssqldb doesn't support ReturnValue named parameters
first := true
for _, a := range args {
if !first {
b.WriteRune(',')
}
first = false
b.WriteRune(' ')
name := a.Name
if len(name) == 0 {
name = fmt.Sprintf("@p%d", a.Ordinal)
}
appendPrefixedParameterName(b, name)
if len(a.Name) > 0 {
b.WriteRune('=')
appendPrefixedParameterName(b, a.Name)
}
if isOutputValue(a.Value) {
b.WriteString(" OUTPUT")
}
}
return b.String()
}
func appendPrefixedParameterName(b *strings.Builder, p string) {
if len(p) > 0 {
if p[0] != '@' {
b.WriteRune('@')
}
b.WriteString(p)
}
}
func processDescribeParameterEncryption(rows driver.Rows) (cekInfo []*cekData, paramInfo []*parameterEncData, err error) {
cekInfo = make([]*cekData, 0)
values := make([]driver.Value, 9)
qerr := rows.Next(values)
for qerr == nil {
cekInfo = append(cekInfo, &cekData{ordinal: int(values[0].(int64)),
database_id: int(values[1].(int64)),
id: int(values[2].(int64)),
version: int(values[3].(int64)),
metadataVersion: values[4].([]byte),
encryptedValue: values[5].([]byte),
cmkStoreName: values[6].(string),
cmkPath: values[7].(string),
algorithm: values[8].(string),
})
qerr = rows.Next(values)
}
if len(cekInfo) == 0 || qerr != io.EOF {
if qerr != io.EOF {
err = qerr
}
// No encryption needed
return
}
r := rows.(driver.RowsNextResultSet)
err = r.NextResultSet()
if err != nil {
return
}
paramInfo = make([]*parameterEncData, 0)
qerr = rows.Next(values[:6])
for qerr == nil {
paramInfo = append(paramInfo, &parameterEncData{ordinal: int(values[0].(int64)),
name: values[1].(string),
algorithm: int(values[2].(int64)),
encType: ColumnEncryptionType(values[3].(int64)),
cekOrdinal: int(values[4].(int64)),
ruleVersion: int(values[5].(int64)),
})
qerr = rows.Next(values[:6])
}
if len(paramInfo) == 0 || qerr != io.EOF {
if qerr != io.EOF {
err = qerr
} else {
err = fmt.Errorf("No parameter encryption rows were returned from sp_describe_parameter_encryption")
}
}
return
}

View file

@ -0,0 +1,20 @@
Copyright (c) 2021 Swisscom (Switzerland) Ltd
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -0,0 +1,120 @@
package algorithms
import (
"crypto/rand"
"crypto/subtle"
"fmt"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys"
)
// https://tools.ietf.org/html/draft-mcgrew-aead-aes-cbc-hmac-sha2-05
// https://winprotocoldoc.blob.core.windows.net/productionwindowsarchives/MS-TDS/%5bMS-TDS%5d.pdf
var _ Algorithm = &AeadAes256CbcHmac256Algorithm{}
type AeadAes256CbcHmac256Algorithm struct {
algorithmVersion byte
deterministic bool
blockSizeBytes int
keySizeBytes int
minimumCipherTextLengthBytesNoAuthTag int
minimumCipherTextLengthBytesWithAuthTag int
cek keys.AeadAes256CbcHmac256
version []byte
versionSize []byte
}
func NewAeadAes256CbcHmac256Algorithm(key keys.AeadAes256CbcHmac256, encType encryption.Type, algorithmVersion byte) AeadAes256CbcHmac256Algorithm {
const keySizeBytes = 256 / 8
const blockSizeBytes = 16
const minimumCipherTextLengthBytesNoAuthTag = 1 + 2*blockSizeBytes
const minimumCipherTextLengthBytesWithAuthTag = minimumCipherTextLengthBytesNoAuthTag + keySizeBytes
a := AeadAes256CbcHmac256Algorithm{
algorithmVersion: algorithmVersion,
deterministic: encType.Deterministic,
blockSizeBytes: blockSizeBytes,
keySizeBytes: keySizeBytes,
cek: key,
minimumCipherTextLengthBytesNoAuthTag: minimumCipherTextLengthBytesNoAuthTag,
minimumCipherTextLengthBytesWithAuthTag: minimumCipherTextLengthBytesWithAuthTag,
version: []byte{0x01},
versionSize: []byte{1},
}
a.version[0] = algorithmVersion
return a
}
func (a *AeadAes256CbcHmac256Algorithm) Encrypt(cleartext []byte) ([]byte, error) {
buf := make([]byte, 0)
var iv []byte
if a.deterministic {
iv = crypto.Sha256Hmac(cleartext, a.cek.IvKey())
if len(iv) > a.blockSizeBytes {
iv = iv[:a.blockSizeBytes]
}
} else {
iv = make([]byte, a.blockSizeBytes)
_, err := rand.Read(iv)
if err != nil {
panic(err)
}
}
buf = append(buf, a.algorithmVersion)
aescdbc := crypto.NewAESCbcPKCS5(a.cek.EncryptionKey(), iv)
ciphertext := aescdbc.Encrypt(cleartext)
authTag := a.prepareAuthTag(iv, ciphertext)
buf = append(buf, authTag...)
buf = append(buf, iv...)
buf = append(buf, ciphertext...)
return buf, nil
}
func (a *AeadAes256CbcHmac256Algorithm) Decrypt(ciphertext []byte) ([]byte, error) {
// This algorithm always has the auth tag!
minimumCiphertextLength := a.minimumCipherTextLengthBytesWithAuthTag
if len(ciphertext) < minimumCiphertextLength {
return nil, fmt.Errorf("invalid ciphertext length: at least %v bytes expected", minimumCiphertextLength)
}
idx := 0
if ciphertext[idx] != a.algorithmVersion {
return nil, fmt.Errorf("invalid algorithm version used: %v found but %v expected", ciphertext[idx],
a.algorithmVersion)
}
idx++
authTag := ciphertext[idx : idx+a.keySizeBytes]
idx += a.keySizeBytes
iv := ciphertext[idx : idx+a.blockSizeBytes]
idx += len(iv)
realCiphertext := ciphertext[idx:]
ourAuthTag := a.prepareAuthTag(iv, realCiphertext)
// bytes.Compare is subject to timing attacks
if subtle.ConstantTimeCompare(ourAuthTag, authTag) != 1 {
return nil, fmt.Errorf("invalid auth tag")
}
// decrypt
aescdbc := crypto.NewAESCbcPKCS5(a.cek.EncryptionKey(), iv)
cleartext := aescdbc.Decrypt(realCiphertext)
return cleartext, nil
}
func (a *AeadAes256CbcHmac256Algorithm) prepareAuthTag(iv []byte, ciphertext []byte) []byte {
var input = make([]byte, 0)
input = append(input, a.algorithmVersion)
input = append(input, iv...)
input = append(input, ciphertext...)
input = append(input, a.versionSize...)
return crypto.Sha256Hmac(input, a.cek.MacKey())
}

View file

@ -0,0 +1,6 @@
package algorithms
type Algorithm interface {
Encrypt([]byte) ([]byte, error)
Decrypt([]byte) ([]byte, error)
}

View file

@ -0,0 +1,69 @@
package crypto
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"fmt"
)
// Inspired by: https://gist.github.com/hothero/7d085573f5cb7cdb5801d7adcf66dcf3
type AESCbcPKCS5 struct {
key []byte
iv []byte
block cipher.Block
}
func NewAESCbcPKCS5(key []byte, iv []byte) AESCbcPKCS5 {
a := AESCbcPKCS5{
key: key,
iv: iv,
block: nil,
}
a.initCipher()
return a
}
func (a AESCbcPKCS5) Encrypt(cleartext []byte) (cipherText []byte) {
if a.block == nil {
a.initCipher()
}
blockMode := cipher.NewCBCEncrypter(a.block, a.iv)
paddedCleartext := PKCS5Padding(cleartext, blockMode.BlockSize())
cipherText = make([]byte, len(paddedCleartext))
blockMode.CryptBlocks(cipherText, paddedCleartext)
return
}
func (a AESCbcPKCS5) Decrypt(ciphertext []byte) []byte {
if a.block == nil {
a.initCipher()
}
blockMode := cipher.NewCBCDecrypter(a.block, a.iv)
var cleartext = make([]byte, len(ciphertext))
blockMode.CryptBlocks(cleartext, ciphertext)
return PKCS5Trim(cleartext)
}
func PKCS5Padding(inArr []byte, blockSize int) []byte {
padding := blockSize - len(inArr)%blockSize
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(inArr, padText...)
}
func PKCS5Trim(inArr []byte) []byte {
padding := inArr[len(inArr)-1]
return inArr[:len(inArr)-int(padding)]
}
func (a *AESCbcPKCS5) initCipher() {
block, err := aes.NewCipher(a.key)
if err != nil {
panic(fmt.Errorf("unable to create cipher: %v", err))
}
a.block = block
}

View file

@ -0,0 +1,12 @@
package crypto
import (
"crypto/hmac"
"crypto/sha256"
)
func Sha256Hmac(input []byte, key []byte) []byte {
sha256Hmac := hmac.New(sha256.New, key)
sha256Hmac.Write(input)
return sha256Hmac.Sum(nil)
}

View file

@ -0,0 +1,37 @@
package encryption
type Type struct {
Deterministic bool
Name string
Value byte
}
var Plaintext = Type{
Deterministic: false,
Name: "Plaintext",
Value: 0,
}
var Deterministic = Type{
Deterministic: true,
Name: "Deterministic",
Value: 1,
}
var Randomized = Type{
Deterministic: false,
Name: "Randomized",
Value: 2,
}
func From(encType byte) Type {
switch encType {
case 0:
return Plaintext
case 1:
return Deterministic
case 2:
return Randomized
}
return Plaintext
}

View file

@ -0,0 +1,51 @@
package keys
import (
"fmt"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/utils"
)
var _ Key = &AeadAes256CbcHmac256{}
type AeadAes256CbcHmac256 struct {
rootKey []byte
encryptionKey []byte
macKey []byte
ivKey []byte
}
func NewAeadAes256CbcHmac256(rootKey []byte) AeadAes256CbcHmac256 {
const keySize = 256
const encryptionKeySaltFormat = "Microsoft SQL Server cell encryption key with encryption algorithm:%v and key length:%v"
const macKeySaltFormat = "Microsoft SQL Server cell MAC key with encryption algorithm:%v and key length:%v"
const ivKeySaltFormat = "Microsoft SQL Server cell IV key with encryption algorithm:%v and key length:%v"
const algorithmName = "AEAD_AES_256_CBC_HMAC_SHA256"
encryptionKeySalt := utils.ProcessUTF16LE(fmt.Sprintf(encryptionKeySaltFormat, algorithmName, keySize))
macKeySalt := utils.ProcessUTF16LE(fmt.Sprintf(macKeySaltFormat, algorithmName, keySize))
ivKeySalt := utils.ProcessUTF16LE(fmt.Sprintf(ivKeySaltFormat, algorithmName, keySize))
return AeadAes256CbcHmac256{
rootKey: rootKey,
encryptionKey: crypto.Sha256Hmac(encryptionKeySalt, rootKey),
macKey: crypto.Sha256Hmac(macKeySalt, rootKey),
ivKey: crypto.Sha256Hmac(ivKeySalt, rootKey)}
}
func (a AeadAes256CbcHmac256) IvKey() []byte {
return a.ivKey
}
func (a AeadAes256CbcHmac256) MacKey() []byte {
return a.macKey
}
func (a AeadAes256CbcHmac256) EncryptionKey() []byte {
return a.encryptionKey
}
func (a AeadAes256CbcHmac256) RootKey() []byte {
return a.rootKey
}

View file

@ -0,0 +1,5 @@
package keys
type Key interface {
RootKey() []byte
}

View file

@ -0,0 +1,18 @@
package utils
import (
"encoding/binary"
"unicode/utf16"
)
func ConvertUTF16ToLittleEndianBytes(u []uint16) []byte {
b := make([]byte, 2*len(u))
for index, value := range u {
binary.LittleEndian.PutUint16(b[index*2:], value)
}
return b
}
func ProcessUTF16LE(inputString string) []byte {
return ConvertUTF16ToLittleEndianBytes(utf16.Encode([]rune(inputString)))
}

View file

@ -21,10 +21,17 @@ type (
BrowserMsg byte
)
const (
DsnTypeURL = 1
DsnTypeOdbc = 2
DsnTypeAdo = 3
)
const (
EncryptionOff = 0
EncryptionRequired = 1
EncryptionDisabled = 3
EncryptionStrict = 4
)
const (
@ -44,6 +51,34 @@ const (
BrowserDAC BrowserMsg = 0x0f
)
const (
Database = "database"
Encrypt = "encrypt"
Password = "password"
ChangePassword = "change password"
UserID = "user id"
Port = "port"
TrustServerCertificate = "trustservercertificate"
Certificate = "certificate"
TLSMin = "tlsmin"
PacketSize = "packet size"
LogParam = "log"
ConnectionTimeout = "connection timeout"
HostNameInCertificate = "hostnameincertificate"
KeepAlive = "keepalive"
ServerSpn = "serverspn"
WorkstationID = "workstation id"
AppName = "app name"
ApplicationIntent = "applicationintent"
FailoverPartner = "failoverpartner"
FailOverPort = "failoverport"
DisableRetry = "disableretry"
Server = "server"
Protocol = "protocol"
DialTimeout = "dial timeout"
Pipe = "pipe"
)
type Config struct {
Port uint64
Host string
@ -88,6 +123,10 @@ type Config struct {
ProtocolParameters map[string]interface{}
// BrowserMsg is the message identifier to fetch instance data from SQL browser
BrowserMessage BrowserMsg
// ChangePassword is used to set the login's password during login. Ignored for non-SQL authentication.
ChangePassword string
//ColumnEncryption is true if the application needs to decrypt or encrypt Always Encrypted values
ColumnEncryption bool
}
// Build a tls.Config object from the supplied certificate.
@ -128,24 +167,26 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e
trustServerCert := false
var encryption Encryption = EncryptionOff
encrypt, ok := params["encrypt"]
encrypt, ok := params[Encrypt]
if ok {
if strings.EqualFold(encrypt, "DISABLE") {
encrypt = strings.ToLower(encrypt)
switch encrypt {
case "mandatory", "yes", "1", "t", "true":
encryption = EncryptionRequired
case "disable":
encryption = EncryptionDisabled
} else {
e, err := strconv.ParseBool(encrypt)
if err != nil {
f := "invalid encrypt '%s': %s"
return encryption, nil, fmt.Errorf(f, encrypt, err.Error())
}
if e {
encryption = EncryptionRequired
}
case "strict":
encryption = EncryptionStrict
case "optional", "no", "0", "f", "false":
encryption = EncryptionOff
default:
f := "invalid encrypt '%s'"
return encryption, nil, fmt.Errorf(f, encrypt)
}
} else {
trustServerCert = true
}
trust, ok := params["trustservercertificate"]
trust, ok := params[TrustServerCertificate]
if ok {
var err error
trustServerCert, err = strconv.ParseBool(trust)
@ -154,9 +195,12 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e
return encryption, nil, fmt.Errorf(f, trust, err.Error())
}
}
certificate := params["certificate"]
certificate := params[Certificate]
if encryption != EncryptionDisabled {
tlsMin := params["tlsmin"]
tlsMin := params[TLSMin]
if encrypt == "strict" {
trustServerCert = false
}
tlsConfig, err := SetupTLS(certificate, trustServerCert, host, tlsMin)
if err != nil {
return encryption, nil, fmt.Errorf("failed to setup TLS: %w", err)
@ -168,6 +212,38 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e
var skipSetup = errors.New("skip setting up TLS")
func getDsnType(dsn string) int {
if strings.HasPrefix(dsn, "sqlserver://") {
return DsnTypeURL
}
if strings.HasPrefix(dsn, "odbc:") {
return DsnTypeOdbc
}
return DsnTypeAdo
}
func getDsnParams(dsn string) (map[string]string, error) {
var params map[string]string
var err error
switch getDsnType(dsn) {
case DsnTypeOdbc:
params, err = splitConnectionStringOdbc(dsn[len("odbc:"):])
if err != nil {
return params, err
}
case DsnTypeURL:
params, err = splitConnectionStringURL(dsn)
if err != nil {
return params, err
}
default:
params = splitConnectionString(dsn)
}
return params, nil
}
func Parse(dsn string) (Config, error) {
p := Config{
ProtocolParameters: map[string]interface{}{},
@ -176,23 +252,14 @@ func Parse(dsn string) (Config, error) {
var params map[string]string
var err error
if strings.HasPrefix(dsn, "odbc:") {
params, err = splitConnectionStringOdbc(dsn[len("odbc:"):])
if err != nil {
return p, err
}
} else if strings.HasPrefix(dsn, "sqlserver://") {
params, err = splitConnectionStringURL(dsn)
if err != nil {
return p, err
}
} else {
params = splitConnectionString(dsn)
}
params, err = getDsnParams(dsn)
if err != nil {
return p, err
}
p.Parameters = params
strlog, ok := params["log"]
strlog, ok := params[LogParam]
if ok {
flags, err := strconv.ParseUint(strlog, 10, 64)
if err != nil {
@ -201,12 +268,12 @@ func Parse(dsn string) (Config, error) {
p.LogFlags = Log(flags)
}
p.Database = params["database"]
p.User = params["user id"]
p.Password = params["password"]
p.Database = params[Database]
p.User = params[UserID]
p.Password = params[Password]
p.ChangePassword = params[ChangePassword]
p.Port = 0
strport, ok := params["port"]
strport, ok := params[Port]
if ok {
var err error
p.Port, err = strconv.ParseUint(strport, 10, 16)
@ -217,7 +284,7 @@ func Parse(dsn string) (Config, error) {
}
// https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option\
strpsize, ok := params["packet size"]
strpsize, ok := params[PacketSize]
if ok {
var err error
psize, err := strconv.ParseUint(strpsize, 0, 16)
@ -242,7 +309,7 @@ func Parse(dsn string) (Config, error) {
//
// Do not set a connection timeout. Use Context to manage such things.
// Default to zero, but still allow it to be set.
if strconntimeout, ok := params["connection timeout"]; ok {
if strconntimeout, ok := params[ConnectionTimeout]; ok {
timeout, err := strconv.ParseUint(strconntimeout, 10, 64)
if err != nil {
f := "invalid connection timeout '%v': %v"
@ -254,7 +321,7 @@ func Parse(dsn string) (Config, error) {
// default keep alive should be 30 seconds according to spec:
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
p.KeepAlive = 30 * time.Second
if keepAlive, ok := params["keepalive"]; ok {
if keepAlive, ok := params[KeepAlive]; ok {
timeout, err := strconv.ParseUint(keepAlive, 10, 64)
if err != nil {
f := "invalid keepAlive value '%s': %s"
@ -263,12 +330,12 @@ func Parse(dsn string) (Config, error) {
p.KeepAlive = time.Duration(timeout) * time.Second
}
serverSPN, ok := params["serverspn"]
serverSPN, ok := params[ServerSpn]
if ok {
p.ServerSPN = serverSPN
} // If not set by the app, ServerSPN will be set by the successful dialer.
workstation, ok := params["workstation id"]
workstation, ok := params[WorkstationID]
if ok {
p.Workstation = workstation
} else {
@ -278,13 +345,13 @@ func Parse(dsn string) (Config, error) {
}
}
appname, ok := params["app name"]
appname, ok := params[AppName]
if !ok {
appname = "go-mssqldb"
}
p.AppName = appname
appintent, ok := params["applicationintent"]
appintent, ok := params[ApplicationIntent]
if ok {
if appintent == "ReadOnly" {
if p.Database == "" {
@ -294,12 +361,12 @@ func Parse(dsn string) (Config, error) {
}
}
failOverPartner, ok := params["failoverpartner"]
failOverPartner, ok := params[FailoverPartner]
if ok {
p.FailOverPartner = failOverPartner
}
failOverPort, ok := params["failoverport"]
failOverPort, ok := params[FailOverPort]
if ok {
var err error
p.FailOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
@ -309,7 +376,7 @@ func Parse(dsn string) (Config, error) {
}
}
disableRetry, ok := params["disableretry"]
disableRetry, ok := params[DisableRetry]
if ok {
var err error
p.DisableRetry, err = strconv.ParseBool(disableRetry)
@ -321,8 +388,8 @@ func Parse(dsn string) (Config, error) {
p.DisableRetry = disableRetryDefault
}
server := params["server"]
protocol, ok := params["protocol"]
server := params[Server]
protocol, ok := params[Protocol]
for _, parser := range ProtocolParsers {
if (!ok && !parser.Hidden()) || parser.Protocol() == protocol {
@ -348,7 +415,7 @@ func Parse(dsn string) (Config, error) {
f = 1
}
p.DialTimeout = time.Duration(15*f) * time.Second
if strdialtimeout, ok := params["dial timeout"]; ok {
if strdialtimeout, ok := params[DialTimeout]; ok {
timeout, err := strconv.ParseUint(strdialtimeout, 10, 64)
if err != nil {
f := "invalid dial timeout '%v': %v"
@ -358,7 +425,7 @@ func Parse(dsn string) (Config, error) {
p.DialTimeout = time.Duration(timeout) * time.Second
}
hostInCertificate, ok := params["hostnameincertificate"]
hostInCertificate, ok := params[HostNameInCertificate]
if ok {
p.HostInCertificateProvided = true
} else {
@ -371,6 +438,19 @@ func Parse(dsn string) (Config, error) {
return p, err
}
if c, ok := params["columnencryption"]; ok {
columnEncryption, err := strconv.ParseBool(c)
if err != nil {
if strings.EqualFold(c, "Enabled") {
columnEncryption = true
} else if strings.EqualFold(c, "Disabled") {
columnEncryption = false
} else {
return p, fmt.Errorf("invalid columnencryption '%v' : %v", columnEncryption, err.Error())
}
}
p.ColumnEncryption = columnEncryption
}
return p, nil
}
@ -379,10 +459,10 @@ func Parse(dsn string) (Config, error) {
func (p Config) URL() *url.URL {
q := url.Values{}
if p.Database != "" {
q.Add("database", p.Database)
q.Add(Database, p.Database)
}
if p.LogFlags != 0 {
q.Add("log", strconv.FormatUint(uint64(p.LogFlags), 10))
q.Add(LogParam, strconv.FormatUint(uint64(p.LogFlags), 10))
}
host := p.Host
protocol := ""
@ -397,8 +477,8 @@ func (p Config) URL() *url.URL {
if p.Port > 0 {
host = fmt.Sprintf("%s:%d", host, p.Port)
}
q.Add("disableRetry", fmt.Sprintf("%t", p.DisableRetry))
protocolParam, ok := p.Parameters["protocol"]
q.Add(DisableRetry, fmt.Sprintf("%t", p.DisableRetry))
protocolParam, ok := p.Parameters[Protocol]
if ok {
if protocol != "" && protocolParam != protocol {
panic("Mismatched protocol parameters!")
@ -406,11 +486,11 @@ func (p Config) URL() *url.URL {
protocol = protocolParam
}
if protocol != "" {
q.Add("protocol", protocol)
q.Add(Protocol, protocol)
}
pipe, ok := p.Parameters["pipe"]
pipe, ok := p.Parameters[Pipe]
if ok {
q.Add("pipe", pipe)
q.Add(Pipe, pipe)
}
res := url.URL{
Scheme: "sqlserver",
@ -420,7 +500,17 @@ func (p Config) URL() *url.URL {
if p.Instance != "" {
res.Path = p.Instance
}
q.Add("dial timeout", strconv.FormatFloat(float64(p.DialTimeout.Seconds()), 'f', 0, 64))
q.Add(DialTimeout, strconv.FormatFloat(float64(p.DialTimeout.Seconds()), 'f', 0, 64))
switch p.Encryption {
case EncryptionDisabled:
q.Add(Encrypt, "DISABLE")
case EncryptionRequired:
q.Add(Encrypt, "true")
}
if p.ColumnEncryption {
q.Add("columnencryption", "true")
}
if len(q) > 0 {
res.RawQuery = q.Encode()
}
@ -428,15 +518,17 @@ func (p Config) URL() *url.URL {
return &res
}
// ADO connection string keywords at https://github.com/dotnet/SqlClient/blob/main/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/DbConnectionStringCommon.cs
var adoSynonyms = map[string]string{
"application name": "app name",
"data source": "server",
"address": "server",
"network address": "server",
"addr": "server",
"user": "user id",
"uid": "user id",
"initial catalog": "database",
"application name": AppName,
"data source": Server,
"address": Server,
"network address": Server,
"addr": Server,
"user": UserID,
"uid": UserID,
"initial catalog": Database,
"column encryption setting": "columnencryption",
}
func splitConnectionString(dsn string) (res map[string]string) {
@ -460,18 +552,18 @@ func splitConnectionString(dsn string) (res map[string]string) {
name = synonym
}
// "server" in ADO can include a protocol and a port.
if name == "server" {
if name == Server {
for _, parser := range ProtocolParsers {
prot := parser.Protocol() + ":"
if strings.HasPrefix(value, prot) {
res["protocol"] = parser.Protocol()
res[Protocol] = parser.Protocol()
}
value = strings.TrimPrefix(value, prot)
}
serverParts := strings.Split(value, ",")
if len(serverParts) == 2 && len(serverParts[1]) > 0 {
value = serverParts[0]
res["port"] = serverParts[1]
res[Port] = serverParts[1]
}
}
res[name] = value
@ -493,10 +585,10 @@ func splitConnectionStringURL(dsn string) (map[string]string, error) {
}
if u.User != nil {
res["user id"] = u.User.Username()
res[UserID] = u.User.Username()
p, exists := u.User.Password()
if exists {
res["password"] = p
res[Password] = p
}
}
@ -506,13 +598,13 @@ func splitConnectionStringURL(dsn string) (map[string]string, error) {
}
if len(u.Path) > 0 {
res["server"] = host + "\\" + u.Path[1:]
res[Server] = host + "\\" + u.Path[1:]
} else {
res["server"] = host
res[Server] = host
}
if len(port) > 0 {
res["port"] = port
res[Port] = port
}
query := u.Query()

View file

@ -17,6 +17,7 @@ import (
"unicode"
"github.com/golang-sql/sqlexp"
"github.com/microsoft/go-mssqldb/aecmk"
"github.com/microsoft/go-mssqldb/internal/querytext"
"github.com/microsoft/go-mssqldb/msdsn"
)
@ -69,10 +70,7 @@ func (d *Driver) OpenConnector(dsn string) (*Connector, error) {
return nil, err
}
return &Connector{
params: params,
driver: d,
}, nil
return newConnector(params, d), nil
}
func (d *Driver) Open(dsn string) (driver.Conn, error) {
@ -122,10 +120,8 @@ func NewConnector(dsn string) (*Connector, error) {
if err != nil {
return nil, err
}
c := &Connector{
params: params,
driver: driverInstanceNoProcess,
}
c := newConnector(params, driverInstanceNoProcess)
return c, nil
}
@ -146,9 +142,14 @@ func NewConnectorWithAccessTokenProvider(dsn string, tokenProvider func(ctx cont
// NewConnectorConfig creates a new Connector for a DSN Config struct.
// The returned connector may be used with sql.OpenDB.
func NewConnectorConfig(config msdsn.Config) *Connector {
return newConnector(config, driverInstanceNoProcess)
}
func newConnector(config msdsn.Config, driver *Driver) *Connector {
return &Connector{
params: config,
driver: driverInstanceNoProcess,
params: config,
driver: driver,
keyProviders: make(aecmk.ColumnEncryptionKeyProviderMap),
}
}
@ -199,6 +200,8 @@ type Connector struct {
//
// If Dialer is not set, normal net dialers are used.
Dialer Dialer
keyProviders aecmk.ColumnEncryptionKeyProviderMap
}
type Dialer interface {
@ -219,6 +222,11 @@ func (c *Connector) getDialer(p *msdsn.Config) Dialer {
return createDialer(p)
}
// RegisterCekProvider associates the given provider with the named key store. If an entry of the given name already exists, that entry is overwritten
func (c *Connector) RegisterCekProvider(name string, provider aecmk.ColumnEncryptionKeyProvider) {
c.keyProviders[name] = aecmk.NewCekProvider(provider)
}
type Conn struct {
connector *Connector
sess *tdsSession
@ -403,7 +411,7 @@ func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) {
if err != nil {
return nil, err
}
c := &Connector{params: params}
c := newConnector(params, nil)
return d.connect(ctx, c, params)
}
@ -445,10 +453,11 @@ func (c *Conn) Close() error {
}
type Stmt struct {
c *Conn
query string
paramCount int
notifSub *queryNotifSub
c *Conn
query string
paramCount int
notifSub *queryNotifSub
skipEncryption bool
}
type queryNotifSub struct {
@ -472,7 +481,7 @@ func (c *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error)
if c.processQueryText {
query, paramCount = querytext.ParseParams(query)
}
return &Stmt{c, query, paramCount, nil}, nil
return &Stmt{c, query, paramCount, nil, false}, nil
}
func (s *Stmt) Close() error {
@ -654,16 +663,38 @@ func (s *Stmt) makeRPCParams(args []namedValue, isProc bool) ([]param, []string,
if isOutputValue(val.Value) {
output = outputSuffix
}
decls[i] = fmt.Sprintf("%s %s%s", name, makeDecl(params[i+offset].ti), output)
tiDecl := params[i+offset].ti
if val.encrypt != nil {
// Encrypted parameters have a few requirements:
// 1. Copy original typeinfo to a block after the data
// 2. Set the parameter type to varbinary(max)
// 3. Append the crypto metadata bytes
params[i+offset].tiOriginal = params[i+offset].ti
params[i+offset].Flags |= fEncrypted
encryptedBytes, metadata, err := val.encrypt(params[i+offset].buffer)
if err != nil {
return nil, nil, err
}
params[i+offset].cipherInfo = metadata
params[i+offset].ti.TypeId = typeBigVarBin
params[i+offset].buffer = encryptedBytes
params[i+offset].ti.Size = 0
}
decls[i] = fmt.Sprintf("%s %s%s", name, makeDecl(tiDecl), output)
}
return params, decls, nil
}
// Encrypts the input bytes. Returns the encrypted bytes followed by the encryption metadata to append to the packet.
type valueEncryptor func(bytes []byte) ([]byte, []byte, error)
type namedValue struct {
Name string
Ordinal int
Value driver.Value
encrypt valueEncryptor
}
func convertOldArgs(args []driver.Value) []namedValue {
@ -677,6 +708,10 @@ func convertOldArgs(args []driver.Value) []namedValue {
return list
}
func (s *Stmt) doEncryption() bool {
return !s.skipEncryption && s.c.sess.alwaysEncrypted
}
func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
defer s.c.clearOuts()
@ -687,6 +722,12 @@ func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver
if !s.c.connectionGood {
return nil, driver.ErrBadConn
}
if s.doEncryption() && len(args) > 0 {
args, err = s.encryptArgs(ctx, args)
}
if err != nil {
return nil, err
}
if err = s.sendQuery(ctx, args); err != nil {
return nil, s.c.checkBadConn(ctx, err, true)
}
@ -754,6 +795,12 @@ func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result,
if !s.c.connectionGood {
return nil, driver.ErrBadConn
}
if s.doEncryption() && len(args) > 0 {
args, err = s.encryptArgs(ctx, args)
}
if err != nil {
return nil, err
}
if err = s.sendQuery(ctx, args); err != nil {
return nil, s.c.checkBadConn(ctx, err, true)
}
@ -872,7 +919,7 @@ func (rc *Rows) NextResultSet() error {
// the value type that can be used to scan types into. For example, the database
// column type "bigint" this should return "reflect.TypeOf(int64(0))".
func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
return makeGoLangScanType(r.cols[index].ti)
return makeGoLangScanType(r.cols[index].originalTypeInfo())
}
// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
@ -881,7 +928,7 @@ func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
// "TIMESTAMP".
func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
return makeGoLangTypeName(r.cols[index].ti)
return makeGoLangTypeName(r.cols[index].originalTypeInfo())
}
// RowsColumnTypeLength may be implemented by Rows. It should return the length
@ -897,7 +944,7 @@ func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
// int (0, false)
// bytea(30) (30, true)
func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
return makeGoLangTypeLength(r.cols[index].ti)
return makeGoLangTypeLength(r.cols[index].originalTypeInfo())
}
// It should return
@ -908,7 +955,7 @@ func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
// int (0, 0, false)
// decimal (math.MaxInt64, math.MaxInt64, true)
func (r *Rows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
return makeGoLangTypePrecisionScale(r.cols[index].ti)
return makeGoLangTypePrecisionScale(r.cols[index].originalTypeInfo())
}
// The nullable value should
@ -974,12 +1021,20 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
res.ti.TypeId = typeIntN
res.ti.Size = 8
res.buffer = []byte{}
case byte:
res.ti.TypeId = typeIntN
res.buffer = []byte{val}
res.ti.Size = 1
case float64:
res.ti.TypeId = typeFltN
res.ti.Size = 8
res.buffer = make([]byte, 8)
binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(val))
case float32:
res.ti.TypeId = typeFltN
res.ti.Size = 4
res.buffer = make([]byte, 4)
binary.LittleEndian.PutUint32(res.buffer, math.Float32bits(val))
case sql.NullFloat64:
// only null values should be getting here
res.ti.TypeId = typeFltN
@ -1043,7 +1098,7 @@ func (c *Conn) Ping(ctx context.Context) error {
if !c.connectionGood {
return driver.ErrBadConn
}
stmt := &Stmt{c, `select 1;`, 0, nil}
stmt := &Stmt{c, `select 1;`, 0, nil, true}
_, err := stmt.ExecContext(ctx, nil)
return err
}
@ -1108,7 +1163,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
}
list := make([]namedValue, len(args))
for i, nv := range args {
list[i] = namedValue(nv)
list[i] = namedValueFromDriverNamedValue(nv)
}
return s.queryContext(ctx, list)
}
@ -1121,11 +1176,15 @@ func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
}
list := make([]namedValue, len(args))
for i, nv := range args {
list[i] = namedValue(nv)
list[i] = namedValueFromDriverNamedValue(nv)
}
return s.exec(ctx, list)
}
func namedValueFromDriverNamedValue(v driver.NamedValue) namedValue {
return namedValue{Name: v.Name, Ordinal: v.Ordinal, Value: v.Value, encrypt: nil}
}
// Rowsq implements the sqlexp messages model for Query and QueryContext
// Theory: We could also implement the non-experimental model this way
type Rowsq struct {
@ -1316,7 +1375,7 @@ scan:
// the value type that can be used to scan types into. For example, the database
// column type "bigint" this should return "reflect.TypeOf(int64(0))".
func (r *Rowsq) ColumnTypeScanType(index int) reflect.Type {
return makeGoLangScanType(r.cols[index].ti)
return makeGoLangScanType(r.cols[index].originalTypeInfo())
}
// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
@ -1325,7 +1384,7 @@ func (r *Rowsq) ColumnTypeScanType(index int) reflect.Type {
// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
// "TIMESTAMP".
func (r *Rowsq) ColumnTypeDatabaseTypeName(index int) string {
return makeGoLangTypeName(r.cols[index].ti)
return makeGoLangTypeName(r.cols[index].originalTypeInfo())
}
// RowsColumnTypeLength may be implemented by Rows. It should return the length
@ -1341,7 +1400,7 @@ func (r *Rowsq) ColumnTypeDatabaseTypeName(index int) string {
// int (0, false)
// bytea(30) (30, true)
func (r *Rowsq) ColumnTypeLength(index int) (int64, bool) {
return makeGoLangTypeLength(r.cols[index].ti)
return makeGoLangTypeLength(r.cols[index].originalTypeInfo())
}
// It should return
@ -1352,7 +1411,7 @@ func (r *Rowsq) ColumnTypeLength(index int) (int64, bool) {
// int (0, 0, false)
// decimal (math.MaxInt64, math.MaxInt64, true)
func (r *Rowsq) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
return makeGoLangTypePrecisionScale(r.cols[index].ti)
return makeGoLangTypePrecisionScale(r.cols[index].originalTypeInfo())
}
// The nullable value should

View file

@ -29,12 +29,18 @@ type MssqlStmt = Stmt // Deprecated: users should transition to th
var _ driver.NamedValueChecker = &Conn{}
// VarChar parameter types.
// VarChar is used to encode a string parameter as VarChar instead of a sized NVarChar
type VarChar string
// NVarCharMax is used to encode a string parameter as NVarChar(max) instead of a sized NVarChar
type NVarCharMax string
// VarCharMax is used to encode a string parameter as VarChar(max) instead of a sized NVarChar
type VarCharMax string
// NChar is used to encode a string parameter as NChar instead of a sized NVarChar
type NChar string
// DateTime1 encodes parameters to original DateTime SQL types.
type DateTime1 time.Time
@ -45,12 +51,16 @@ func convertInputParameter(val interface{}) (interface{}, error) {
switch v := val.(type) {
case int, int16, int32, int64, int8:
return val, nil
case byte:
return val, nil
case VarChar:
return val, nil
case NVarCharMax:
return val, nil
case VarCharMax:
return val, nil
case NChar:
return val, nil
case DateTime1:
return val, nil
case DateTimeOffset:
@ -61,8 +71,10 @@ func convertInputParameter(val interface{}) (interface{}, error) {
return val, nil
case civil.Time:
return val, nil
// case *apd.Decimal:
// return nil
// case *apd.Decimal:
// return nil
case float32:
return val, nil
default:
return driver.DefaultParameterConverter.ConvertValue(v)
}
@ -144,6 +156,10 @@ func (s *Stmt) makeParamExtra(val driver.Value) (res param, err error) {
res.ti.TypeId = typeNVarChar
res.buffer = str2ucs2(string(val))
res.ti.Size = 0 // currently zero forces nvarchar(max)
case NChar:
res.ti.TypeId = typeNChar
res.buffer = str2ucs2(string(val))
res.ti.Size = len(res.buffer)
case DateTime1:
t := time.Time(val)
res.ti.TypeId = typeDateTimeN

40
vendor/github.com/microsoft/go-mssqldb/quoter.go generated vendored Normal file
View file

@ -0,0 +1,40 @@
package mssql
import (
"strings"
)
// TSQLQuoter implements sqlexp.Quoter
type TSQLQuoter struct {
}
// ID quotes identifiers such as schema, table, or column names.
// This implementation handles multi-part names.
func (TSQLQuoter) ID(name string) string {
return "[" + strings.Replace(name, "]", "]]", -1) + "]"
}
// Value quotes database values such as string or []byte types as strings
// that are suitable and safe to embed in SQL text. The returned value
// of a string will include all surrounding quotes.
//
// If a value type is not supported it must panic.
func (TSQLQuoter) Value(v interface{}) string {
switch v := v.(type) {
default:
panic("unsupported value")
case string:
return sqlString(v)
case VarChar:
return sqlString(string(v))
case VarCharMax:
return sqlString(string(v))
case NVarCharMax:
return sqlString(string(v))
}
}
func sqlString(v string) string {
return "'" + strings.Replace(string(v), "'", "''", -1) + "'"
}

View file

@ -13,13 +13,16 @@ type procId struct {
const (
fByRevValue = 1
fDefaultValue = 2
fEncrypted = 8
)
type param struct {
Name string
Flags uint8
ti typeInfo
buffer []byte
Name string
Flags uint8
ti typeInfo
buffer []byte
tiOriginal typeInfo
cipherInfo []byte
}
var (
@ -78,6 +81,15 @@ func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16,
if err != nil {
return
}
if (param.Flags & fEncrypted) == fEncrypted {
err = writeTypeInfo(buf, &param.tiOriginal)
if err != nil {
return
}
if _, err = buf.Write(param.cipherInfo); err != nil {
return
}
}
}
return buf.FinishPacket()
}

View file

@ -15,6 +15,7 @@ import (
"unicode/utf16"
"unicode/utf8"
"github.com/microsoft/go-mssqldb/aecmk"
"github.com/microsoft/go-mssqldb/integratedauth"
"github.com/microsoft/go-mssqldb/msdsn"
)
@ -102,6 +103,7 @@ const (
verTDS73 = verTDS73A
verTDS73B = 0x730B0003
verTDS74 = 0x74000004
verTDS80 = 0x08000000
)
// packet types
@ -143,6 +145,7 @@ const (
encryptOn = 1 // Encryption is available and on.
encryptNotSup = 2 // Encryption is not available.
encryptReq = 3 // Encryption is required.
encryptStrict = 4
)
const (
@ -157,16 +160,23 @@ const (
)
type tdsSession struct {
buf *tdsBuffer
loginAck loginAckStruct
database string
partner string
columns []columnStruct
tranid uint64
logFlags uint64
logger ContextLogger
routedServer string
routedPort uint16
buf *tdsBuffer
loginAck loginAckStruct
database string
partner string
columns []columnStruct
tranid uint64
logFlags uint64
logger ContextLogger
routedServer string
routedPort uint16
alwaysEncrypted bool
aeSettings *alwaysEncryptedSettings
}
type alwaysEncryptedSettings struct {
enclaveType string
keyProviders aecmk.ColumnEncryptionKeyProviderMap
}
const (
@ -178,10 +188,26 @@ const (
)
type columnStruct struct {
UserType uint32
Flags uint16
ColName string
ti typeInfo
UserType uint32
Flags uint16
ColName string
ti typeInfo
cryptoMeta *cryptoMetadata
}
func (c columnStruct) isEncrypted() bool {
return isEncryptedFlag(c.Flags)
}
func isEncryptedFlag(flags uint16) bool {
return colFlagEncrypted == (flags & colFlagEncrypted)
}
func (c columnStruct) originalTypeInfo() typeInfo {
if c.isEncrypted() {
return c.cryptoMeta.typeInfo
}
return c.ti
}
type keySlice []uint8
@ -577,7 +603,7 @@ func sendLogin(w *tdsBuffer, login *login) error {
language := str2ucs2(login.Language)
database := str2ucs2(login.Database)
atchdbfile := str2ucs2(login.AtchDBFile)
changepassword := str2ucs2(login.ChangePassword)
changepassword := manglePassword(login.ChangePassword)
featureExt := login.FeatureExt.toBytes()
hdr := loginHeader{
@ -638,6 +664,9 @@ func sendLogin(w *tdsBuffer, login *login) error {
offset += hdr.ExtensionLength // DWORD
featureExtOffset = uint32(offset)
}
if len(changepassword) > 0 {
hdr.OptionFlags3 |= fChangePassword
}
hdr.Length = uint32(offset) + uint32(featureExtLen)
var err error
@ -977,6 +1006,8 @@ func preparePreloginFields(p msdsn.Config, fe *featureExtFedAuth) map[uint8][]by
encrypt = encryptOn
case msdsn.EncryptionOff:
encrypt = encryptOff
case msdsn.EncryptionStrict:
encrypt = encryptStrict
}
v := getDriverVersion(driverVersion)
fields := map[uint8][]byte{
@ -1023,6 +1054,12 @@ func interpretPreloginResponse(p msdsn.Config, fe *featureExtFedAuth, fields map
}
func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger ContextLogger, auth integratedauth.IntegratedAuthenticator, fe *featureExtFedAuth, packetSize uint32) (l *login, err error) {
var TDSVersion uint32
if p.Encryption == msdsn.EncryptionStrict {
TDSVersion = verTDS80
} else {
TDSVersion = verTDS74
}
var typeFlags uint8
if p.ReadOnlyIntent {
typeFlags |= fReadOnlyIntent
@ -1035,17 +1072,21 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont
serverName = p.Host
}
l = &login{
TDSVersion: verTDS74,
PacketSize: packetSize,
Database: p.Database,
OptionFlags2: fODBC, // to get unlimited TEXTSIZE
OptionFlags1: fUseDB | fSetLang,
HostName: p.Workstation,
ServerName: serverName,
AppName: p.AppName,
TypeFlags: typeFlags,
CtlIntName: "go-mssqldb",
ClientProgVer: getDriverVersion(driverVersion),
TDSVersion: TDSVersion,
PacketSize: packetSize,
Database: p.Database,
OptionFlags2: fODBC, // to get unlimited TEXTSIZE
OptionFlags1: fUseDB | fSetLang,
HostName: p.Workstation,
ServerName: serverName,
AppName: p.AppName,
TypeFlags: typeFlags,
CtlIntName: "go-mssqldb",
ClientProgVer: getDriverVersion(driverVersion),
ChangePassword: p.ChangePassword,
}
if p.ColumnEncryption {
_ = l.FeatureExt.Add(&featureExtColumnEncryption{})
}
switch {
case fe.FedAuthLibrary == FedAuthLibrarySecurityToken:
@ -1061,14 +1102,14 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont
return nil, err
}
l.FeatureExt.Add(fe)
_ = l.FeatureExt.Add(fe)
case fe.FedAuthLibrary == FedAuthLibraryADAL:
if uint64(p.LogFlags)&logDebug != 0 {
logger.Log(ctx, msdsn.LogDebug, "Starting federated authentication using ADAL")
}
l.FeatureExt.Add(fe)
_ = l.FeatureExt.Add(fe)
case auth != nil:
if uint64(p.LogFlags)&logDebug != 0 {
@ -1092,8 +1133,29 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont
return l, nil
}
func connect(ctx context.Context, c *Connector, logger ContextLogger, p msdsn.Config) (res *tdsSession, err error) {
func getTLSConn(conn *timeoutConn, p msdsn.Config, alpnSeq string) (tlsConn *tls.Conn, err error) {
var config *tls.Config
if pc := p.TLSConfig; pc != nil {
config = pc
}
if config == nil {
config, err = msdsn.SetupTLS("", false, p.Host, "")
if err != nil {
return nil, err
}
}
//Set ALPN Sequence
config.NextProtos = []string{alpnSeq}
tlsConn = tls.Client(conn.c, config)
err = tlsConn.Handshake()
if err != nil {
return nil, fmt.Errorf("TLS Handshake failed: %w", err)
}
return tlsConn, nil
}
func connect(ctx context.Context, c *Connector, logger ContextLogger, p msdsn.Config) (res *tdsSession, err error) {
isTransportEncrypted := false
// if instance is specified use instance resolution service
if len(p.Instance) > 0 && p.Port != 0 && uint64(p.LogFlags)&logDebug != 0 {
// both instance name and port specified
@ -1133,14 +1195,25 @@ initiate_connection:
}
toconn := newTimeoutConn(conn, p.ConnTimeout)
outbuf := newTdsBuffer(packetSize, toconn)
if p.Encryption == msdsn.EncryptionStrict {
outbuf.transport, err = getTLSConn(toconn, p, "tds/8.0")
if err != nil {
return nil, err
}
isTransportEncrypted = true
}
sess := tdsSession{
buf: outbuf,
logger: logger,
logFlags: uint64(p.LogFlags),
buf: outbuf,
logger: logger,
logFlags: uint64(p.LogFlags),
aeSettings: &alwaysEncryptedSettings{keyProviders: aecmk.GetGlobalCekProviders()},
}
for i, p := range c.keyProviders {
sess.aeSettings.keyProviders[i] = p
}
fedAuth := &featureExtFedAuth{
FedAuthLibrary: FedAuthLibraryReserved,
}
@ -1166,42 +1239,47 @@ initiate_connection:
return nil, err
}
if encrypt != encryptNotSup {
var config *tls.Config
if pc := p.TLSConfig; pc != nil {
config = pc
if config.DynamicRecordSizingDisabled == false {
config = config.Clone()
//We need not perform TLS handshake if the communication channel is already encrypted (encrypt=strict)
if !isTransportEncrypted {
if encrypt != encryptNotSup {
var config *tls.Config
if pc := p.TLSConfig; pc != nil {
config = pc
if !config.DynamicRecordSizingDisabled {
config = config.Clone()
// fix for https://github.com/microsoft/go-mssqldb/issues/166
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
config.DynamicRecordSizingDisabled = true
// fix for https://github.com/microsoft/go-mssqldb/issues/166
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
config.DynamicRecordSizingDisabled = true
}
}
}
if config == nil {
config, err = msdsn.SetupTLS("", false, p.Host, "")
if config == nil {
config, err = msdsn.SetupTLS("", false, p.Host, "")
if err != nil {
return nil, err
}
}
// setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream
handshakeConn := tlsHandshakeConn{buf: outbuf}
passthrough := passthroughConn{c: &handshakeConn}
tlsConn := tls.Client(&passthrough, config)
err = tlsConn.Handshake()
passthrough.c = toconn
outbuf.transport = tlsConn
if err != nil {
return nil, err
return nil, fmt.Errorf("TLS Handshake failed: %v", err)
}
if encrypt == encryptOff {
outbuf.afterFirst = func() {
outbuf.transport = toconn
}
}
}
// setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream
handshakeConn := tlsHandshakeConn{buf: outbuf}
passthrough := passthroughConn{c: &handshakeConn}
tlsConn := tls.Client(&passthrough, config)
err = tlsConn.Handshake()
passthrough.c = toconn
outbuf.transport = tlsConn
if err != nil {
return nil, fmt.Errorf("TLS Handshake failed: %v", err)
}
if encrypt == encryptOff {
outbuf.afterFirst = func() {
outbuf.transport = toconn
}
}
}
auth, err := integratedauth.GetIntegratedAuthenticator(p)
@ -1288,6 +1366,18 @@ initiate_connection:
case loginAckStruct:
sess.loginAck = token
loginAck = true
case featureExtAck:
for _, v := range token {
switch v := v.(type) {
case colAckStruct:
if v.Version <= 2 && v.Version > 0 {
sess.alwaysEncrypted = true
if len(v.EnclaveType) > 0 {
sess.aeSettings.enclaveType = string(v.EnclaveType)
}
}
}
}
case doneStruct:
if token.isError() {
tokenErr := token.getError()
@ -1317,3 +1407,21 @@ initiate_connection:
}
return &sess, nil
}
type featureExtColumnEncryption struct {
}
func (f *featureExtColumnEncryption) featureID() byte {
return featExtCOLUMNENCRYPTION
}
func (f *featureExtColumnEncryption) toBytes() []byte {
/*
1 = The client supports column encryption without enclave computations.
2 = The client SHOULD<25> support column encryption when encrypted data require enclave computations.
3 = The client SHOULD<26> support column encryption when encrypted data require enclave computations
with the additional ability to cache column encryption keys that are to be sent to the enclave
and the ability to retry queries when the keys sent by the client do not match what is needed for the query to run.
*/
return []byte{0x01}
}

View file

@ -1,6 +1,7 @@
package mssql
import (
"bytes"
"context"
"encoding/binary"
"fmt"
@ -10,7 +11,11 @@ import (
"strconv"
"github.com/golang-sql/sqlexp"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys"
"github.com/microsoft/go-mssqldb/msdsn"
"golang.org/x/text/encoding/unicode"
)
//go:generate go run golang.org/x/tools/cmd/stringer -type token
@ -92,10 +97,15 @@ const (
fedAuthInfoSPN = 0x02
)
const (
cipherAlgCustom = 0x00
)
// COLMETADATA flags
// https://msdn.microsoft.com/en-us/library/dd357363.aspx
const (
colFlagNullable = 1
colFlagNullable = 1
colFlagEncrypted = 0x0800
// TODO implement more flags
)
@ -533,7 +543,14 @@ type fedAuthAckStruct struct {
Signature []byte
}
func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} {
type colAckStruct struct {
Version int
EnclaveType string
}
type featureExtAck map[byte]interface{}
func parseFeatureExtAck(r *tdsBuffer) featureExtAck {
ack := map[byte]interface{}{}
for feature := r.byte(); feature != featExtTERMINATOR; feature = r.byte() {
@ -555,7 +572,21 @@ func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} {
length -= 32
}
ack[feature] = fedAuthAck
case featExtCOLUMNENCRYPTION:
colAck := colAckStruct{Version: int(r.byte())}
length--
if length > 0 {
// enclave type is sent as utf16 le
enclaveLength := r.byte() * 2
length--
enclaveBytes := make([]byte, enclaveLength)
r.ReadFull(enclaveBytes)
// if the enclave type is malformed we'll just ignore it
colAck.EnclaveType, _ = ucs22str(enclaveBytes)
length -= uint32(enclaveLength)
}
ack[feature] = colAck
}
// Skip unprocessed bytes
@ -568,34 +599,265 @@ func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} {
}
// http://msdn.microsoft.com/en-us/library/dd357363.aspx
func parseColMetadata72(r *tdsBuffer) (columns []columnStruct) {
func parseColMetadata72(r *tdsBuffer, s *tdsSession) (columns []columnStruct) {
count := r.uint16()
if count == 0xffff {
// no metadata is sent
return nil
}
columns = make([]columnStruct, count)
var cekTable *cekTable
if s.alwaysEncrypted {
// column encryption key list
cekTable = readCekTable(r)
}
for i := range columns {
column := &columns[i]
column.UserType = r.uint32()
column.Flags = r.uint16()
baseTi := getBaseTypeInfo(r, true)
typeInfo := readTypeInfo(r, baseTi.TypeId, column.cryptoMeta)
typeInfo.UserType = baseTi.UserType
typeInfo.Flags = baseTi.Flags
typeInfo.TypeId = baseTi.TypeId
column.Flags = baseTi.Flags
column.UserType = baseTi.UserType
column.ti = typeInfo
if column.isEncrypted() && s.alwaysEncrypted {
// Read Crypto Metadata
cryptoMeta := parseCryptoMetadata(r, cekTable)
cryptoMeta.typeInfo.Flags = baseTi.Flags
column.cryptoMeta = &cryptoMeta
} else {
column.cryptoMeta = nil
}
// parsing TYPE_INFO structure
column.ti = readTypeInfo(r)
column.ColName = r.BVarChar()
}
return columns
}
// http://msdn.microsoft.com/en-us/library/dd357254.aspx
func parseRow(r *tdsBuffer, columns []columnStruct, row []interface{}) {
for i, column := range columns {
row[i] = column.ti.Reader(&column.ti, r)
func getBaseTypeInfo(r *tdsBuffer, parseFlags bool) typeInfo {
userType := r.uint32()
flags := uint16(0)
if parseFlags {
flags = r.uint16()
}
tId := r.byte()
return typeInfo{
UserType: userType,
Flags: flags,
TypeId: tId}
}
type cryptoMetadata struct {
entry *cekTableEntry
ordinal uint16
algorithmId byte
algorithmName *string
encType byte
normRuleVer byte
typeInfo typeInfo
}
func parseCryptoMetadata(r *tdsBuffer, cekTable *cekTable) cryptoMetadata {
ordinal := uint16(0)
if cekTable != nil {
ordinal = r.uint16()
}
typeInfo := getBaseTypeInfo(r, false)
ti := readTypeInfo(r, typeInfo.TypeId, nil)
ti.UserType = typeInfo.UserType
ti.Flags = typeInfo.Flags
ti.TypeId = typeInfo.TypeId
algorithmId := r.byte()
var algName *string = nil
if algorithmId == cipherAlgCustom {
// Read the name when a custom algorithm is used
nameLen := int(r.byte())
var algNameUtf16 = make([]byte, nameLen*2)
r.ReadFull(algNameUtf16)
algNameBytes, _ := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewDecoder().Bytes(algNameUtf16)
mAlgName := string(algNameBytes)
algName = &mAlgName
}
encType := r.byte()
normRuleVer := r.byte()
var entry *cekTableEntry = nil
if cekTable != nil {
if int(ordinal) > len(cekTable.entries)-1 {
panic(fmt.Errorf("invalid ordinal, cekTable only has %d entries", len(cekTable.entries)))
}
entry = &cekTable.entries[ordinal]
}
return cryptoMetadata{
entry: entry,
ordinal: ordinal,
algorithmId: algorithmId,
algorithmName: algName,
encType: encType,
normRuleVer: normRuleVer,
typeInfo: ti,
}
}
func readCekTable(r *tdsBuffer) *cekTable {
tableSize := r.uint16()
var cekTable *cekTable = nil
if tableSize != 0 {
mCekTable := newCekTable(tableSize)
for i := uint16(0); i < tableSize; i++ {
mCekTable.entries[i] = readCekTableEntry(r)
}
cekTable = &mCekTable
}
return cekTable
}
func readCekTableEntry(r *tdsBuffer) cekTableEntry {
databaseId := r.int32()
cekID := r.int32()
cekVersion := r.int32()
var cekMdVersion = make([]byte, 8)
_, err := r.Read(cekMdVersion)
if err != nil {
panic("unable to read cekMdVersion")
}
cekValueCount := uint(r.byte())
// not using ucs22str because we already know the data is utf16
enc := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM)
utf16dec := enc.NewDecoder()
cekValues := make([]encryptionKeyInfo, cekValueCount)
for i := uint(0); i < cekValueCount; i++ {
encryptedCekLength := r.uint16()
encryptedCek := make([]byte, encryptedCekLength)
r.ReadFull(encryptedCek)
keyStoreLength := r.byte()
keyStoreNameUtf16 := make([]byte, keyStoreLength*2)
r.ReadFull(keyStoreNameUtf16)
keyStoreName, _ := utf16dec.Bytes(keyStoreNameUtf16)
keyPathLength := r.uint16()
keyPathUtf16 := make([]byte, keyPathLength*2)
r.ReadFull(keyPathUtf16)
keyPath, _ := utf16dec.Bytes(keyPathUtf16)
algLength := r.byte()
algNameUtf16 := make([]byte, algLength*2)
r.ReadFull(algNameUtf16)
algName, _ := utf16dec.Bytes(algNameUtf16)
cekValues[i] = encryptionKeyInfo{
encryptedKey: encryptedCek,
databaseID: int(databaseId),
cekID: int(cekID),
cekVersion: int(cekVersion),
cekMdVersion: cekMdVersion,
keyPath: string(keyPath),
keyStoreName: string(keyStoreName),
algorithmName: string(algName),
}
}
return cekTableEntry{
databaseID: int(databaseId),
keyId: int(cekID),
keyVersion: int(cekVersion),
mdVersion: cekMdVersion,
valueCount: int(cekValueCount),
cekValues: cekValues,
}
}
// http://msdn.microsoft.com/en-us/library/dd357254.aspx
func parseRow(r *tdsBuffer, s *tdsSession, columns []columnStruct, row []interface{}) {
for i, column := range columns {
columnContent := column.ti.Reader(&column.ti, r, nil)
if columnContent == nil {
row[i] = columnContent
continue
}
if column.isEncrypted() {
buffer := decryptColumn(column, s, columnContent)
// Decrypt
row[i] = column.cryptoMeta.typeInfo.Reader(&column.cryptoMeta.typeInfo, &buffer, column.cryptoMeta)
} else {
row[i] = columnContent
}
}
}
type RWCBuffer struct {
buffer *bytes.Reader
}
func (R RWCBuffer) Read(p []byte) (n int, err error) {
return R.buffer.Read(p)
}
func (R RWCBuffer) Write(p []byte) (n int, err error) {
return 0, nil
}
func (R RWCBuffer) Close() error {
return nil
}
func decryptColumn(column columnStruct, s *tdsSession, columnContent interface{}) tdsBuffer {
encType := encryption.From(column.cryptoMeta.encType)
cekValue := column.cryptoMeta.entry.cekValues[column.cryptoMeta.ordinal]
if (s.logFlags & uint64(msdsn.LogDebug)) == uint64(msdsn.LogDebug) {
s.logger.Log(context.Background(), msdsn.LogDebug, fmt.Sprintf("Decrypting column %s. Key path: %s, Key store:%s, Algo: %s", column.ColName, cekValue.keyPath, cekValue.keyStoreName, cekValue.algorithmName))
}
cekProvider, ok := s.aeSettings.keyProviders[cekValue.keyStoreName]
if !ok {
panic(fmt.Errorf("Unable to find provider %s to decrypt CEK", cekValue.keyStoreName))
}
cek, err := cekProvider.GetDecryptedKey(cekValue.keyPath, column.cryptoMeta.entry.cekValues[0].encryptedKey)
if err != nil {
panic(err)
}
k := keys.NewAeadAes256CbcHmac256(cek)
alg := algorithms.NewAeadAes256CbcHmac256Algorithm(k, encType, byte(cekValue.cekVersion))
d, err := alg.Decrypt(columnContent.([]byte))
if err != nil {
panic(err)
}
// Decrypt returns a minimum of 8 bytes so truncate to the actual data size
if column.cryptoMeta.typeInfo.Size > 0 && column.cryptoMeta.typeInfo.Size < len(d) {
d = d[:column.cryptoMeta.typeInfo.Size]
}
var newBuff []byte
newBuff = append(newBuff, d...)
rwc := RWCBuffer{
buffer: bytes.NewReader(newBuff),
}
column.cryptoMeta.typeInfo.Buffer = d
buffer := tdsBuffer{rpos: 0, rsize: len(newBuff), rbuf: newBuff, transport: rwc}
return buffer
}
// http://msdn.microsoft.com/en-us/library/dd304783.aspx
func parseNbcRow(r *tdsBuffer, columns []columnStruct, row []interface{}) {
func parseNbcRow(r *tdsBuffer, s *tdsSession, columns []columnStruct, row []interface{}) {
bitlen := (len(columns) + 7) / 8
pres := make([]byte, bitlen)
r.ReadFull(pres)
@ -604,7 +866,15 @@ func parseNbcRow(r *tdsBuffer, columns []columnStruct, row []interface{}) {
row[i] = nil
continue
}
row[i] = col.ti.Reader(&col.ti, r)
columnContent := col.ti.Reader(&col.ti, r, nil)
if col.isEncrypted() {
buffer := decryptColumn(col, s, columnContent)
// Decrypt
row[i] = col.cryptoMeta.typeInfo.Reader(&col.cryptoMeta.typeInfo, &buffer, col.cryptoMeta)
} else {
row[i] = columnContent
}
}
}
@ -637,7 +907,7 @@ func parseInfo(r *tdsBuffer) (res Error) {
}
// https://msdn.microsoft.com/en-us/library/dd303881.aspx
func parseReturnValue(r *tdsBuffer) (nv namedValue) {
func parseReturnValue(r *tdsBuffer, s *tdsSession) (nv namedValue) {
/*
ParamOrdinal
ParamName
@ -648,13 +918,21 @@ func parseReturnValue(r *tdsBuffer) (nv namedValue) {
CryptoMetadata
Value
*/
r.uint16()
nv.Name = r.BVarChar()
r.byte()
r.uint32() // UserType (uint16 prior to 7.2)
r.uint16()
ti := readTypeInfo(r)
nv.Value = ti.Reader(&ti, r)
_ = r.uint16() // ParamOrdinal
nv.Name = r.BVarChar() // ParamName
_ = r.byte() // Status
ti := getBaseTypeInfo(r, true) // UserType + Flags + TypeInfo
var cryptoMetadata *cryptoMetadata = nil
if s.alwaysEncrypted && (ti.Flags&fEncrypted) == fEncrypted {
cm := parseCryptoMetadata(r, nil) // CryptoMetadata
cryptoMetadata = &cm
}
ti2 := readTypeInfo(r, ti.TypeId, cryptoMetadata)
nv.Value = ti2.Reader(&ti2, r, cryptoMetadata)
return
}
@ -664,6 +942,17 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS
if sess.logFlags&logErrors != 0 {
sess.logger.Log(ctx, msdsn.LogErrors, fmt.Sprintf("Intercepted panic %v", err))
}
if outs.msgq != nil {
var derr error
switch e := err.(type) {
case error:
derr = e
default:
derr = fmt.Errorf("Unhandled session error %v", e)
}
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgError{Error: derr})
}
ch <- err
}
close(ch)
@ -760,7 +1049,7 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS
ch <- done
if done.Status&doneCount != 0 {
if sess.logFlags&logRows != 0 {
sess.logger.Log(ctx, msdsn.LogRows, fmt.Sprintf("(%d row(s) affected)", done.RowCount))
sess.logger.Log(ctx, msdsn.LogRows, fmt.Sprintf("(Rows affected: %d)", done.RowCount))
}
if (colsReceived || done.CurCmd != cmdSelect) && outs.msgq != nil {
@ -781,7 +1070,7 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS
return
}
case tokenColMetadata:
columns = parseColMetadata72(sess.buf)
columns = parseColMetadata72(sess.buf, sess)
ch <- columns
colsReceived = true
if outs.msgq != nil {
@ -790,11 +1079,11 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS
case tokenRow:
row := make([]interface{}, len(columns))
parseRow(sess.buf, columns, row)
parseRow(sess.buf, sess, columns, row)
ch <- row
case tokenNbcRow:
row := make([]interface{}, len(columns))
parseNbcRow(sess.buf, columns, row)
parseNbcRow(sess.buf, sess, columns, row)
ch <- row
case tokenEnvChange:
processEnvChg(ctx, sess)
@ -822,7 +1111,7 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNotice{Message: info})
}
case tokenReturnValue:
nv := parseReturnValue(sess.buf)
nv := parseReturnValue(sess.buf, sess)
if len(nv.Name) > 0 {
name := nv.Name[1:] // Remove the leading "@".
if ov, has := outs.params[name]; has {

View file

@ -89,6 +89,8 @@ const (
// http://msdn.microsoft.com/en-us/library/dd358284.aspx
type typeInfo struct {
TypeId uint8
UserType uint32
Flags uint16
Size int
Scale uint8
Prec uint8
@ -96,7 +98,7 @@ type typeInfo struct {
Collation cp.Collation
UdtInfo udtInfo
XmlInfo xmlInfo
Reader func(ti *typeInfo, r *tdsBuffer) (res interface{})
Reader func(ti *typeInfo, r *tdsBuffer, cryptoMeta *cryptoMetadata) (res interface{})
Writer func(w io.Writer, ti typeInfo, buf []byte) (err error)
}
@ -119,9 +121,9 @@ type xmlInfo struct {
XmlSchemaCollection string
}
func readTypeInfo(r *tdsBuffer) (res typeInfo) {
res.TypeId = r.byte()
switch res.TypeId {
func readTypeInfo(r *tdsBuffer, typeId byte, c *cryptoMetadata) (res typeInfo) {
res.TypeId = typeId
switch typeId {
case typeNull, typeInt1, typeBit, typeInt2, typeInt4, typeDateTim4,
typeFlt4, typeMoney, typeDateTime, typeFlt8, typeMoney4, typeInt8:
// those are fixed length types
@ -140,7 +142,7 @@ func readTypeInfo(r *tdsBuffer) (res typeInfo) {
res.Reader = readFixedType
res.Buffer = make([]byte, res.Size)
default: // all others are VARLENTYPE
readVarLen(&res, r)
readVarLen(&res, r, c)
}
return
}
@ -315,7 +317,7 @@ func decodeDateTime(buf []byte) time.Time {
0, 0, secs, ns, time.UTC)
}
func readFixedType(ti *typeInfo, r *tdsBuffer) interface{} {
func readFixedType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} {
r.ReadFull(ti.Buffer)
buf := ti.Buffer
switch ti.TypeId {
@ -349,8 +351,13 @@ func readFixedType(ti *typeInfo, r *tdsBuffer) interface{} {
panic("shoulnd't get here")
}
func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} {
size := r.byte()
func readByteLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} {
var size byte
if c != nil {
size = byte(r.rsize)
} else {
size = r.byte()
}
if size == 0 {
return nil
}
@ -433,7 +440,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} {
default:
badStreamPanicf("Invalid typeid")
}
panic("shoulnd't get here")
panic("shouldn't get here")
}
func writeByteLenType(w io.Writer, ti typeInfo, buf []byte) (err error) {
@ -448,8 +455,13 @@ func writeByteLenType(w io.Writer, ti typeInfo, buf []byte) (err error) {
return
}
func readShortLenType(ti *typeInfo, r *tdsBuffer) interface{} {
size := r.uint16()
func readShortLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} {
var size uint16
if c != nil {
size = uint16(r.rsize)
} else {
size = r.uint16()
}
if size == 0xffff {
return nil
}
@ -491,7 +503,7 @@ func writeShortLenType(w io.Writer, ti typeInfo, buf []byte) (err error) {
return
}
func readLongLenType(ti *typeInfo, r *tdsBuffer) interface{} {
func readLongLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} {
// information about this format can be found here:
// http://msdn.microsoft.com/en-us/library/dd304783.aspx
// and here:
@ -566,7 +578,7 @@ func writeCollation(w io.Writer, col cp.Collation) (err error) {
// reads variant value
// http://msdn.microsoft.com/en-us/library/dd303302.aspx
func readVariantType(ti *typeInfo, r *tdsBuffer) interface{} {
func readVariantType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} {
size := r.int32()
if size == 0 {
return nil
@ -658,41 +670,47 @@ func readVariantType(ti *typeInfo, r *tdsBuffer) interface{} {
// partially length prefixed stream
// http://msdn.microsoft.com/en-us/library/dd340469.aspx
func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} {
size := r.uint64()
var buf *bytes.Buffer
switch size {
case _PLP_NULL:
// null
return nil
case _UNKNOWN_PLP_LEN:
// size unknown
buf = bytes.NewBuffer(make([]byte, 0, 1000))
default:
buf = bytes.NewBuffer(make([]byte, 0, size))
}
for {
chunksize := r.uint32()
if chunksize == 0 {
break
func readPLPType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} {
var bytesToDecode []byte
if c == nil {
size := r.uint64()
var buf *bytes.Buffer
switch size {
case _PLP_NULL:
// null
return nil
case _UNKNOWN_PLP_LEN:
// size unknown
buf = bytes.NewBuffer(make([]byte, 0, 1000))
default:
buf = bytes.NewBuffer(make([]byte, 0, size))
}
if _, err := io.CopyN(buf, r, int64(chunksize)); err != nil {
badStreamPanicf("Reading PLP type failed: %s", err.Error())
for {
chunksize := r.uint32()
if chunksize == 0 {
break
}
if _, err := io.CopyN(buf, r, int64(chunksize)); err != nil {
badStreamPanicf("Reading PLP type failed: %s", err.Error())
}
}
bytesToDecode = buf.Bytes()
} else {
bytesToDecode = r.rbuf
}
switch ti.TypeId {
case typeXml:
return decodeXml(*ti, buf.Bytes())
return decodeXml(*ti, bytesToDecode)
case typeBigVarChar, typeBigChar, typeText:
return decodeChar(ti.Collation, buf.Bytes())
return decodeChar(ti.Collation, bytesToDecode)
case typeBigVarBin, typeBigBinary, typeImage:
return buf.Bytes()
return bytesToDecode
case typeNVarChar, typeNChar, typeNText:
return decodeNChar(buf.Bytes())
return decodeNChar(bytesToDecode)
case typeUdt:
return decodeUdt(*ti, buf.Bytes())
return decodeUdt(*ti, bytesToDecode)
}
panic("shoulnd't get here")
panic("shouldn't get here")
}
func writePLPType(w io.Writer, ti typeInfo, buf []byte) (err error) {
@ -719,7 +737,7 @@ func writePLPType(w io.Writer, ti typeInfo, buf []byte) (err error) {
}
}
func readVarLen(ti *typeInfo, r *tdsBuffer) {
func readVarLen(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) {
switch ti.TypeId {
case typeDateN:
ti.Size = 3

View file

@ -4,7 +4,7 @@ import "fmt"
// Update this variable with the release tag before pushing the tag
// This value is written to the prelogin and login7 packets during a new connection
const driverVersion = "v1.5.0"
const driverVersion = "v1.6.0"
func getDriverVersion(ver string) uint32 {
var majorVersion uint32