1
0
Fork 0
mirror of https://github.com/documize/community.git synced 2025-07-19 05:09:42 +02:00

Prep 5.11.3

This commit is contained in:
Harvey Kandola 2024-02-16 10:53:02 -05:00
parent 69940cb7f1
commit f2ba294be8
47 changed files with 169 additions and 3090 deletions

View file

@ -2,10 +2,10 @@ package ldap
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"net/url"
"sync"
@ -61,21 +61,13 @@ type messageContext struct {
// sendResponse should only be called within the processMessages() loop which
// is also responsible for closing the responses channel.
func (msgCtx *messageContext) sendResponse(packet *PacketResponse, timeout time.Duration) {
timeoutCtx := context.Background()
if timeout > 0 {
var cancelFunc context.CancelFunc
timeoutCtx, cancelFunc = context.WithTimeout(context.Background(), timeout)
defer cancelFunc()
}
func (msgCtx *messageContext) sendResponse(packet *PacketResponse) {
select {
case msgCtx.responses <- packet:
// Successfully sent packet to message handler.
case <-msgCtx.done:
// The request handler is done and will not receive more
// packets.
case <-timeoutCtx.Done():
// The timeout was reached before the packet was sent.
}
}
@ -96,7 +88,6 @@ const (
type Conn struct {
// requestTimeout is loaded atomically
// so we need to ensure 64-bit alignment on 32-bit platforms.
// https://github.com/go-ldap/ldap/pull/199
requestTimeout int64
conn net.Conn
isTLS bool
@ -111,8 +102,6 @@ type Conn struct {
wgClose sync.WaitGroup
outstandingRequests uint
messageMutex sync.Mutex
err error
}
var _ Client = &Conn{}
@ -130,31 +119,30 @@ type DialOpt func(*DialContext)
// DialWithDialer updates net.Dialer in DialContext.
func DialWithDialer(d *net.Dialer) DialOpt {
return func(dc *DialContext) {
dc.dialer = d
dc.d = d
}
}
// DialWithTLSConfig updates tls.Config in DialContext.
func DialWithTLSConfig(tc *tls.Config) DialOpt {
return func(dc *DialContext) {
dc.tlsConfig = tc
dc.tc = tc
}
}
// DialWithTLSDialer is a wrapper for DialWithTLSConfig with the option to
// specify a net.Dialer to for example define a timeout or a custom resolver.
// @deprecated Use DialWithDialer and DialWithTLSConfig instead
func DialWithTLSDialer(tlsConfig *tls.Config, dialer *net.Dialer) DialOpt {
return func(dc *DialContext) {
dc.tlsConfig = tlsConfig
dc.dialer = dialer
dc.tc = tlsConfig
dc.d = dialer
}
}
// DialContext contains necessary parameters to dial the given ldap URL.
type DialContext struct {
dialer *net.Dialer
tlsConfig *tls.Config
d *net.Dialer
tc *tls.Config
}
func (dc *DialContext) dial(u *url.URL) (net.Conn, error) {
@ -162,7 +150,7 @@ func (dc *DialContext) dial(u *url.URL) (net.Conn, error) {
if u.Path == "" || u.Path == "/" {
u.Path = "/var/run/slapd/ldapi"
}
return dc.dialer.Dial("unix", u.Path)
return dc.d.Dial("unix", u.Path)
}
host, port, err := net.SplitHostPort(u.Host)
@ -173,21 +161,16 @@ func (dc *DialContext) dial(u *url.URL) (net.Conn, error) {
}
switch u.Scheme {
case "cldap":
if port == "" {
port = DefaultLdapPort
}
return dc.dialer.Dial("udp", net.JoinHostPort(host, port))
case "ldap":
if port == "" {
port = DefaultLdapPort
}
return dc.dialer.Dial("tcp", net.JoinHostPort(host, port))
return dc.d.Dial("tcp", net.JoinHostPort(host, port))
case "ldaps":
if port == "" {
port = DefaultLdapsPort
}
return tls.DialWithDialer(dc.dialer, "tcp", net.JoinHostPort(host, port), dc.tlsConfig)
return tls.DialWithDialer(dc.d, "tcp", net.JoinHostPort(host, port), dc.tc)
}
return nil, fmt.Errorf("Unknown scheme '%s'", u.Scheme)
@ -220,8 +203,7 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
}
// DialURL connects to the given ldap URL.
// The following schemas are supported: ldap://, ldaps://, ldapi://,
// and cldap:// (RFC1798, deprecated but used by Active Directory).
// The following schemas are supported: ldap://, ldaps://, ldapi://.
// On success a new Conn for the connection is returned.
func DialURL(addr string, opts ...DialOpt) (*Conn, error) {
u, err := url.Parse(addr)
@ -233,8 +215,8 @@ func DialURL(addr string, opts ...DialOpt) (*Conn, error) {
for _, opt := range opts {
opt(&dc)
}
if dc.dialer == nil {
dc.dialer = &net.Dialer{Timeout: DefaultTimeout}
if dc.d == nil {
dc.d = &net.Dialer{Timeout: DefaultTimeout}
}
c, err := dc.dial(u)
@ -249,7 +231,7 @@ func DialURL(addr string, opts ...DialOpt) (*Conn, error) {
// NewConn returns a new Conn using conn for network I/O.
func NewConn(conn net.Conn, isTLS bool) *Conn {
l := &Conn{
return &Conn{
conn: conn,
chanConfirm: make(chan struct{}),
chanMessageID: make(chan int64),
@ -258,12 +240,11 @@ func NewConn(conn net.Conn, isTLS bool) *Conn {
requestTimeout: 0,
isTLS: isTLS,
}
l.wgClose.Add(1)
return l
}
// Start initializes goroutines to read responses and process messages
func (l *Conn) Start() {
l.wgClose.Add(1)
go l.reader()
go l.processMessages()
}
@ -279,45 +260,31 @@ func (l *Conn) setClosing() bool {
}
// Close closes the connection.
func (l *Conn) Close() (err error) {
func (l *Conn) Close() {
l.messageMutex.Lock()
defer l.messageMutex.Unlock()
if l.setClosing() {
l.Debug.Printf("Sending quit message and waiting for confirmation")
l.chanMessage <- &messagePacket{Op: MessageQuit}
timeoutCtx := context.Background()
if l.getTimeout() > 0 {
var cancelFunc context.CancelFunc
timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(l.getTimeout()))
defer cancelFunc()
}
select {
case <-l.chanConfirm:
// Confirmation was received.
case <-timeoutCtx.Done():
// The timeout was reached before confirmation was received.
}
<-l.chanConfirm
close(l.chanMessage)
l.Debug.Printf("Closing network connection")
err = l.conn.Close()
if err := l.conn.Close(); err != nil {
log.Println(err)
}
l.wgClose.Done()
}
l.wgClose.Wait()
return err
}
// SetTimeout sets the time after a request is sent that a MessageTimeout triggers
func (l *Conn) SetTimeout(timeout time.Duration) {
atomic.StoreInt64(&l.requestTimeout, int64(timeout))
}
func (l *Conn) getTimeout() int64 {
return atomic.LoadInt64(&l.requestTimeout)
if timeout > 0 {
atomic.StoreInt64(&l.requestTimeout, int64(timeout))
}
}
// Returns the next available messageID
@ -328,14 +295,6 @@ func (l *Conn) nextMessageID() int64 {
return 0
}
// GetLastError returns the last recorded error from goroutines like processMessages and reader.
// Only the last recorded error will be returned.
func (l *Conn) GetLastError() error {
l.messageMutex.Lock()
defer l.messageMutex.Unlock()
return l.err
}
// StartTLS sends the command to start a TLS session and then creates a new TLS Client
func (l *Conn) StartTLS(config *tls.Config) error {
if l.isTLS {
@ -484,13 +443,13 @@ func (l *Conn) sendProcessMessage(message *messagePacket) bool {
func (l *Conn) processMessages() {
defer func() {
if err := recover(); err != nil {
l.err = fmt.Errorf("ldap: recovered panic in processMessages: %v", err)
log.Printf("ldap: recovered panic in processMessages: %v", err)
}
for messageID, msgCtx := range l.messageContexts {
// If we are closing due to an error, inform anyone who
// is waiting about the error.
if l.IsClosing() && l.closeErr.Load() != nil {
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.getTimeout()))
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)})
}
l.Debug.Printf("Closing channel for MessageID %d", messageID)
close(msgCtx.responses)
@ -518,7 +477,7 @@ func (l *Conn) processMessages() {
_, err := l.conn.Write(buf)
if err != nil {
l.Debug.Printf("Error Sending Message: %s", err.Error())
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.getTimeout()))
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)})
close(message.Context.responses)
break
}
@ -528,35 +487,28 @@ func (l *Conn) processMessages() {
l.messageContexts[message.MessageID] = message.Context
// Add timeout if defined
requestTimeout := l.getTimeout()
requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout))
if requestTimeout > 0 {
go func() {
timer := time.NewTimer(time.Duration(requestTimeout))
defer func() {
if err := recover(); err != nil {
l.err = fmt.Errorf("ldap: recovered panic in RequestTimeout: %v", err)
log.Printf("ldap: recovered panic in RequestTimeout: %v", err)
}
timer.Stop()
}()
select {
case <-timer.C:
timeoutMessage := &messagePacket{
Op: MessageTimeout,
MessageID: message.MessageID,
}
l.sendProcessMessage(timeoutMessage)
case <-message.Context.done:
time.Sleep(requestTimeout)
timeoutMessage := &messagePacket{
Op: MessageTimeout,
MessageID: message.MessageID,
}
l.sendProcessMessage(timeoutMessage)
}()
}
case MessageResponse:
l.Debug.Printf("Receiving message %d", message.MessageID)
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.getTimeout()))
msgCtx.sendResponse(&PacketResponse{message.Packet, nil})
} else {
l.err = fmt.Errorf("ldap: received unexpected message %d, %v", message.MessageID, l.IsClosing())
log.Printf("Received unexpected message %d, %v", message.MessageID, l.IsClosing())
l.Debug.PrintPacket(message.Packet)
}
case MessageTimeout:
@ -564,7 +516,7 @@ func (l *Conn) processMessages() {
// All reads will return immediately
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.getTimeout()))
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))})
delete(l.messageContexts, message.MessageID)
close(msgCtx.responses)
}
@ -583,7 +535,7 @@ func (l *Conn) reader() {
cleanstop := false
defer func() {
if err := recover(); err != nil {
l.err = fmt.Errorf("ldap: recovered panic in reader: %v", err)
log.Printf("ldap: recovered panic in reader: %v", err)
}
if !cleanstop {
l.Close()