1
0
Fork 0
mirror of https://github.com/documize/community.git synced 2025-08-07 22:45:24 +02:00
documize/vendor/gopkg.in/cas.v2/client.go
2019-08-09 13:44:03 +08:00

503 lines
11 KiB
Go

package cas
import (
"crypto/rand"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"path"
"sync"
"github.com/golang/glog"
)
// Client configuration options
type Options struct {
URL *url.URL // URL to the CAS service
Store TicketStore // Custom TicketStore, if nil a MemoryStore will be used
Client *http.Client // Custom http client to allow options for http connections
SendService bool // Custom sendService to determine whether you need to send service param
}
// Client implements the main protocol
type Client struct {
url *url.URL
tickets TicketStore
client *http.Client
mu sync.Mutex
sessions map[string]string
sendService bool
}
// NewClient creates a Client with the provided Options.
func NewClient(options *Options) *Client {
if glog.V(2) {
glog.Infof("cas: new client with options %v", options)
}
var tickets TicketStore
if options.Store != nil {
tickets = options.Store
} else {
tickets = &MemoryStore{}
}
var client *http.Client
if options.Client != nil {
client = options.Client
} else {
client = &http.Client{}
}
return &Client{
url: options.URL,
tickets: tickets,
client: client,
sessions: make(map[string]string),
sendService: options.SendService,
}
}
// Handle wraps a http.Handler to provide CAS authentication for the handler.
func (c *Client) Handle(h http.Handler) http.Handler {
return &clientHandler{
c: c,
h: h,
}
}
// HandleFunc wraps a function to provide CAS authentication for the handler function.
func (c *Client) HandleFunc(h func(http.ResponseWriter, *http.Request)) http.Handler {
return c.Handle(http.HandlerFunc(h))
}
// requestURL determines an absolute URL from the http.Request.
func requestURL(r *http.Request) (*url.URL, error) {
u, err := url.Parse(r.URL.String())
if err != nil {
return nil, err
}
u.Host = r.Host
u.Scheme = "http"
if scheme := r.Header.Get("X-Forwarded-Proto"); scheme != "" {
u.Scheme = scheme
} else if r.TLS != nil {
u.Scheme = "https"
}
return u, nil
}
// LoginUrlForRequest determines the CAS login URL for the http.Request.
func (c *Client) LoginUrlForRequest(r *http.Request) (string, error) {
u, err := c.url.Parse(path.Join(c.url.Path, "login"))
if err != nil {
return "", err
}
service, err := requestURL(r)
if err != nil {
return "", err
}
q := u.Query()
q.Add("service", sanitisedURLString(service))
u.RawQuery = q.Encode()
return u.String(), nil
}
// LogoutUrlForRequest determines the CAS logout URL for the http.Request.
func (c *Client) LogoutUrlForRequest(r *http.Request) (string, error) {
u, err := c.url.Parse(path.Join(c.url.Path, "logout"))
if err != nil {
return "", err
}
if c.sendService {
service, err := requestURL(r)
if err != nil {
return "", err
}
q := u.Query()
q.Add("service", sanitisedURLString(service))
u.RawQuery = q.Encode()
}
return u.String(), nil
}
// ServiceValidateUrlForRequest determines the CAS serviceValidate URL for the ticket and http.Request.
func (c *Client) ServiceValidateUrlForRequest(ticket string, r *http.Request) (string, error) {
u, err := c.url.Parse(path.Join(c.url.Path, "serviceValidate"))
if err != nil {
return "", err
}
service, err := requestURL(r)
if err != nil {
return "", err
}
q := u.Query()
q.Add("service", sanitisedURLString(service))
q.Add("ticket", ticket)
u.RawQuery = q.Encode()
return u.String(), nil
}
// ValidateUrlForRequest determines the CAS validate URL for the ticket and http.Request.
func (c *Client) ValidateUrlForRequest(ticket string, r *http.Request) (string, error) {
u, err := c.url.Parse(path.Join(c.url.Path, "validate"))
if err != nil {
return "", err
}
service, err := requestURL(r)
if err != nil {
return "", err
}
q := u.Query()
q.Add("service", sanitisedURLString(service))
q.Add("ticket", ticket)
u.RawQuery = q.Encode()
return u.String(), nil
}
// RedirectToLogout replies to the request with a redirect URL to log out of CAS.
func (c *Client) RedirectToLogout(w http.ResponseWriter, r *http.Request) {
u, err := c.LogoutUrlForRequest(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if glog.V(2) {
glog.Info("Logging out, redirecting client to %v with status %v",
u, http.StatusFound)
}
c.clearSession(w, r)
http.Redirect(w, r, u, http.StatusFound)
}
// RedirectToLogout replies to the request with a redirect URL to authenticate with CAS.
func (c *Client) RedirectToLogin(w http.ResponseWriter, r *http.Request) {
u, err := c.LoginUrlForRequest(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if glog.V(2) {
glog.Infof("Redirecting client to %v with status %v", u, http.StatusFound)
}
http.Redirect(w, r, u, http.StatusFound)
}
// validateTicket performs CAS ticket validation with the given ticket and service.
//
// If the request returns a 404 then validateTicketCas1 will be returned.
func (c *Client) validateTicket(ticket string, service *http.Request) error {
if glog.V(2) {
serviceUrl, _ := requestURL(service)
glog.Infof("Validating ticket %v for service %v", ticket, serviceUrl)
}
u, err := c.ServiceValidateUrlForRequest(ticket, service)
if err != nil {
return err
}
r, err := http.NewRequest("GET", u, nil)
if err != nil {
return err
}
r.Header.Add("User-Agent", "Golang CAS client gopkg.in/cas")
if glog.V(2) {
glog.Infof("Attempting ticket validation with %v", r.URL)
}
resp, err := c.client.Do(r)
if err != nil {
return err
}
if glog.V(2) {
glog.Infof("Request %v %v returned %v",
r.Method, r.URL,
resp.Status)
}
if resp.StatusCode == http.StatusNotFound {
return c.validateTicketCas1(ticket, service)
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("cas: validate ticket: %v", string(body))
}
if glog.V(2) {
glog.Infof("Received authentication response\n%v", string(body))
}
success, err := ParseServiceResponse(body)
if err != nil {
return err
}
if glog.V(2) {
glog.Infof("Parsed ServiceResponse: %#v", success)
}
if err := c.tickets.Write(ticket, success); err != nil {
return err
}
return nil
}
// validateTicketCas1 performs CAS protocol 1 ticket validation.
func (c *Client) validateTicketCas1(ticket string, service *http.Request) error {
u, err := c.ValidateUrlForRequest(ticket, service)
if err != nil {
return err
}
r, err := http.NewRequest("GET", u, nil)
if err != nil {
return err
}
r.Header.Add("User-Agent", "Golang CAS client gopkg.in/cas")
if glog.V(2) {
glog.Info("Attempting ticket validation with %v", r.URL)
}
resp, err := c.client.Do(r)
if err != nil {
return err
}
if glog.V(2) {
glog.Info("Request %v %v returned %v",
r.Method, r.URL,
resp.Status)
}
data, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return err
}
body := string(data)
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("cas: validate ticket: %v", body)
}
if glog.V(2) {
glog.Infof("Received authentication response\n%v", body)
}
if body == "no\n\n" {
return nil // not logged in
}
success := &AuthenticationResponse{
User: body[4 : len(body)-1],
}
if glog.V(2) {
glog.Infof("Parsed ServiceResponse: %#v", success)
}
if err := c.tickets.Write(ticket, success); err != nil {
return err
}
return nil
}
// getSession finds or creates a session for the request.
//
// A cookie is set on the response if one is not provided with the request.
// Validates the ticket if the URL parameter is provided.
func (c *Client) getSession(w http.ResponseWriter, r *http.Request) {
cookie := getCookie(w, r)
if s, ok := c.sessions[cookie.Value]; ok {
if t, err := c.tickets.Read(s); err == nil {
if glog.V(1) {
glog.Infof("Re-used ticket %s for %s", s, t.User)
}
setAuthenticationResponse(r, t)
return
} else {
if glog.V(2) {
glog.Infof("Ticket %v not in %T: %v", s, c.tickets, err)
}
if glog.V(1) {
glog.Infof("Clearing ticket %s, no longer exists in ticket store", s)
}
clearCookie(w, cookie)
}
}
if ticket := r.URL.Query().Get("ticket"); ticket != "" {
if err := c.validateTicket(ticket, r); err != nil {
if glog.V(2) {
glog.Infof("Error validating ticket: %v", err)
}
return // allow ServeHTTP()
}
c.setSession(cookie.Value, ticket)
if t, err := c.tickets.Read(ticket); err == nil {
if glog.V(1) {
glog.Infof("Validated ticket %s for %s", ticket, t.User)
}
setAuthenticationResponse(r, t)
return
} else {
if glog.V(2) {
glog.Infof("Ticket %v not in %T: %v", ticket, c.tickets, err)
}
if glog.V(1) {
glog.Infof("Clearing ticket %s, no longer exists in ticket store", ticket)
}
clearCookie(w, cookie)
}
}
}
// getCookie finds or creates the session cookie on the response.
func getCookie(w http.ResponseWriter, r *http.Request) *http.Cookie {
c, err := r.Cookie(sessionCookieName)
if err != nil {
// NOTE: Intentionally not enabling HttpOnly so the cookie can
// still be used by Ajax requests.
c = &http.Cookie{
Name: sessionCookieName,
Value: newSessionId(),
MaxAge: 86400,
HttpOnly: false,
}
if glog.V(2) {
glog.Infof("Setting %v cookie with value: %v", c.Name, c.Value)
}
r.AddCookie(c) // so we can find it later if required
http.SetCookie(w, c)
}
return c
}
// newSessionId generates a new opaque session identifier for use in the cookie.
func newSessionId() string {
const alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
// generate 64 character string
bytes := make([]byte, 64)
rand.Read(bytes)
for k, v := range bytes {
bytes[k] = alphabet[v%byte(len(alphabet))]
}
return string(bytes)
}
// clearCookie invalidates and removes the cookie from the client.
func clearCookie(w http.ResponseWriter, c *http.Cookie) {
c.MaxAge = -1
http.SetCookie(w, c)
}
// setSession stores the session id to ticket mapping in the Client.
func (c *Client) setSession(id string, ticket string) {
if glog.V(2) {
glog.Infof("Recording session, %v -> %v", id, ticket)
}
c.mu.Lock()
c.sessions[id] = ticket
c.mu.Unlock()
}
// clearSession removes the session from the client and clears the cookie.
func (c *Client) clearSession(w http.ResponseWriter, r *http.Request) {
cookie := getCookie(w, r)
if s, ok := c.sessions[cookie.Value]; ok {
if err := c.tickets.Delete(s); err != nil {
fmt.Printf("Failed to remove %v from %T: %v\n", cookie.Value, c.tickets, err)
if glog.V(2) {
glog.Errorf("Failed to remove %v from %T: %v", cookie.Value, c.tickets, err)
}
}
c.deleteSession(s)
}
clearCookie(w, cookie)
}
// deleteSession removes the session from the client
func (c *Client) deleteSession(id string) {
c.mu.Lock()
delete(c.sessions, id)
c.mu.Unlock()
}
// findAndDeleteSessionWithTicket removes the session from the client via Single Log Out
//
// When a Single Log Out request is received we receive the service ticket identidier. This
// function loops through the sessions to find the matching session id. Once retrieved the
// session is removed from the client. When the session is next requested the getSession
// function will notice the session is invalid and revalidate the user.
func (c *Client) findAndDeleteSessionWithTicket(ticket string) {
var id string
for s, t := range c.sessions {
if t == ticket {
id = s
break
}
}
if id == "" {
return
}
c.deleteSession(id)
}